refactor: refactoring of rate limiting
This commit is contained in:
@@ -53,13 +53,11 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
clientIP := getRealIP(r)
|
clientIP := getRealIP(r)
|
||||||
clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID
|
clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID
|
||||||
logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID)
|
|
||||||
if rl.redisBased {
|
if rl.redisBased {
|
||||||
err := redisRateLimiter(clientID, rl.unit, rl.requests)
|
err := redisRateLimiter(clientID, rl.unit, rl.requests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Redis Rate limiter error: %s", err.Error())
|
logger.Error("Redis Rate limiter error: %s", err.Error())
|
||||||
logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent())
|
logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent())
|
||||||
RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -82,8 +80,10 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
|
|||||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||||
}
|
}
|
||||||
RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))
|
RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proceed to the next handler if the request limit is not exceeded
|
// Proceed to the next handler if the request limit is not exceeded
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ type RateLimiter struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
origins []string
|
origins []string
|
||||||
redisBased bool
|
redisBased bool
|
||||||
|
pathBased bool
|
||||||
|
paths []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client stores request count and window expiration for each client.
|
// Client stores request count and window expiration for each client.
|
||||||
@@ -47,6 +49,8 @@ type RateLimit struct {
|
|||||||
Origins []string
|
Origins []string
|
||||||
Hosts []string
|
Hosts []string
|
||||||
RedisBased bool
|
RedisBased bool
|
||||||
|
PathBased bool
|
||||||
|
Paths []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRateLimiterWindow creates a new RateLimiter.
|
// NewRateLimiterWindow creates a new RateLimiter.
|
||||||
@@ -58,6 +62,8 @@ func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter {
|
|||||||
clientMap: make(map[string]*Client),
|
clientMap: make(map[string]*Client),
|
||||||
origins: rateLimit.Origins,
|
origins: rateLimit.Origins,
|
||||||
redisBased: rateLimit.RedisBased,
|
redisBased: rateLimit.RedisBased,
|
||||||
|
pathBased: rateLimit.PathBased,
|
||||||
|
paths: rateLimit.Paths,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/jkaninda/goma-gateway/util"
|
"github.com/jkaninda/goma-gateway/util"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
"slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
// init initializes prometheus metrics
|
// init initializes prometheus metrics
|
||||||
@@ -127,37 +128,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
|
|||||||
}
|
}
|
||||||
// Apply middlewares to the route
|
// Apply middlewares to the route
|
||||||
for _, middleware := range route.Middlewares {
|
for _, middleware := range route.Middlewares {
|
||||||
if len(middleware) != 0 {
|
|
||||||
// Get Access middlewares if it does exist
|
|
||||||
accessMiddleware, err := getMiddleware([]string{middleware}, m)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error: %v", err.Error())
|
|
||||||
} else {
|
|
||||||
// Apply access middlewares
|
|
||||||
if accessMiddleware.Type == AccessMiddleware {
|
|
||||||
blM := middlewares.AccessListMiddleware{
|
|
||||||
Path: route.Path,
|
|
||||||
List: accessMiddleware.Paths,
|
|
||||||
}
|
|
||||||
r.Use(blM.AccessMiddleware)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
// Get route authentication middlewares if it does exist
|
|
||||||
routeMiddleware, err := getMiddleware([]string{middleware}, m)
|
|
||||||
if err != nil {
|
|
||||||
// Error: middlewares not found
|
|
||||||
logger.Error("Error: %v", err.Error())
|
|
||||||
} else {
|
|
||||||
attachAuthMiddlewares(route, routeMiddleware, gateway, r)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Error("Error, middlewares path is empty")
|
|
||||||
logger.Error("Middleware ignored")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply common exploits to the route
|
// Apply common exploits to the route
|
||||||
// Enable common exploits
|
// Enable common exploits
|
||||||
if route.BlockCommonExploits {
|
if route.BlockCommonExploits {
|
||||||
@@ -183,6 +153,62 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
|
|||||||
// Add rate limit middlewares
|
// Add rate limit middlewares
|
||||||
router.Use(limiter.RateLimitMiddleware())
|
router.Use(limiter.RateLimitMiddleware())
|
||||||
}
|
}
|
||||||
|
if len(middleware) != 0 {
|
||||||
|
// Get Access middlewares if it does exist
|
||||||
|
accessMiddleware, err := getMiddleware([]string{middleware}, m)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error: %v", err.Error())
|
||||||
|
} else {
|
||||||
|
// Apply access middlewares
|
||||||
|
if accessMiddleware.Type == AccessMiddleware {
|
||||||
|
blM := middlewares.AccessListMiddleware{
|
||||||
|
Path: route.Path,
|
||||||
|
List: accessMiddleware.Paths,
|
||||||
|
}
|
||||||
|
r.Use(blM.AccessMiddleware)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply Rate limit middleware
|
||||||
|
if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) {
|
||||||
|
rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error: %v", err.Error())
|
||||||
|
}
|
||||||
|
if rateLimitMid.RequestsPerUnit != 0 && route.RateLimit == 0 {
|
||||||
|
rateLimit := middlewares.RateLimit{
|
||||||
|
Unit: rateLimitMid.Unit,
|
||||||
|
Id: id, // Use route index as ID
|
||||||
|
Requests: rateLimitMid.RequestsPerUnit,
|
||||||
|
Origins: route.Cors.Origins,
|
||||||
|
Hosts: route.Hosts,
|
||||||
|
RedisBased: redisBased,
|
||||||
|
PathBased: true,
|
||||||
|
Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths),
|
||||||
|
}
|
||||||
|
limiter := rateLimit.NewRateLimiterWindow()
|
||||||
|
// Add rate limit middlewares
|
||||||
|
router.Use(limiter.RateLimitMiddleware())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// Get route authentication middlewares if it does exist
|
||||||
|
routeMiddleware, err := getMiddleware([]string{middleware}, m)
|
||||||
|
if err != nil {
|
||||||
|
// Error: middlewares not found
|
||||||
|
logger.Error("Error: %v", err.Error())
|
||||||
|
} else {
|
||||||
|
attachAuthMiddlewares(route, routeMiddleware, gateway, r)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Error("Error, middlewares path is empty")
|
||||||
|
logger.Error("Middleware ignored")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply route Cors
|
// Apply route Cors
|
||||||
router.Use(CORSHandler(route.Cors))
|
router.Use(CORSHandler(route.Cors))
|
||||||
if len(route.Hosts) > 0 {
|
if len(route.Hosts) > 0 {
|
||||||
@@ -208,8 +234,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
|
|||||||
}
|
}
|
||||||
router.Use(interceptErrors.ErrorInterceptor)
|
router.Use(interceptErrors.ErrorInterceptor)
|
||||||
}
|
}
|
||||||
//r.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
|
|
||||||
//r.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
|
|
||||||
} else {
|
} else {
|
||||||
logger.Error("Error, path is empty in route %s", route.Name)
|
logger.Error("Error, path is empty in route %s", route.Name)
|
||||||
logger.Error("Route path ignored: %s", route.Path)
|
logger.Error("Route path ignored: %s", route.Path)
|
||||||
|
|||||||
Reference in New Issue
Block a user