From f4e5bb3be251099a2e9c821ac80bc4c8be5945ee Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Sun, 24 Nov 2024 23:09:13 +0100 Subject: [PATCH] refactor: refactoring of rate limiting --- internal/middlewares/rate_limit.go | 4 +- internal/middlewares/types.go | 6 +++ internal/routes.go | 79 ++++++++++++++++++++---------- 3 files changed, 60 insertions(+), 29 deletions(-) diff --git a/internal/middlewares/rate_limit.go b/internal/middlewares/rate_limit.go index f2f3107..be627a4 100644 --- a/internal/middlewares/rate_limit.go +++ b/internal/middlewares/rate_limit.go @@ -53,13 +53,11 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := getRealIP(r) clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID - logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID) if rl.redisBased { err := redisRateLimiter(clientID, rl.unit, rl.requests) if err != nil { logger.Error("Redis Rate limiter error: %s", err.Error()) logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) - RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) return } } else { @@ -82,8 +80,10 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) + return } } + // Proceed to the next handler if the request limit is not exceeded next.ServeHTTP(w, r) }) diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index 59d78d8..c715d61 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -33,6 +33,8 @@ type RateLimiter struct { mu sync.Mutex origins []string redisBased bool + pathBased bool + paths []string } // Client stores request count and window expiration for each client. @@ -47,6 +49,8 @@ type RateLimit struct { Origins []string Hosts []string RedisBased bool + PathBased bool + Paths []string } // NewRateLimiterWindow creates a new RateLimiter. @@ -58,6 +62,8 @@ func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter { clientMap: make(map[string]*Client), origins: rateLimit.Origins, redisBased: rateLimit.RedisBased, + pathBased: rateLimit.PathBased, + paths: rateLimit.Paths, } } diff --git a/internal/routes.go b/internal/routes.go index 8f7defd..4b736e8 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -23,6 +23,7 @@ import ( "github.com/jkaninda/goma-gateway/util" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + "slices" ) // init initializes prometheus metrics @@ -127,6 +128,31 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } // Apply middlewares to the route for _, middleware := range route.Middlewares { + // Apply common exploits to the route + // Enable common exploits + if route.BlockCommonExploits { + logger.Info("Block common exploits enabled") + router.Use(middlewares.BlockExploitsMiddleware) + } + id := string(rune(rIndex)) + if len(route.Name) != 0 { + // Use route name as ID + id = util.Slug(route.Name) + } + // Apply route rate limit + if route.RateLimit != 0 { + rateLimit := middlewares.RateLimit{ + Unit: "minute", + Id: id, // Use route index as ID + Requests: route.RateLimit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + } if len(middleware) != 0 { // Get Access middlewares if it does exist accessMiddleware, err := getMiddleware([]string{middleware}, m) @@ -143,6 +169,31 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } + // Apply Rate limit middleware + if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) { + rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule) + if err != nil { + logger.Error("Error: %v", err.Error()) + } + if rateLimitMid.RequestsPerUnit != 0 && route.RateLimit == 0 { + rateLimit := middlewares.RateLimit{ + Unit: rateLimitMid.Unit, + Id: id, // Use route index as ID + Requests: rateLimitMid.RequestsPerUnit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + PathBased: true, + Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths), + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + + } + + } + } // Get route authentication middlewares if it does exist routeMiddleware, err := getMiddleware([]string{middleware}, m) @@ -158,31 +209,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } } - // Apply common exploits to the route - // Enable common exploits - if route.BlockCommonExploits { - logger.Info("Block common exploits enabled") - router.Use(middlewares.BlockExploitsMiddleware) - } - id := string(rune(rIndex)) - if len(route.Name) != 0 { - // Use route name as ID - id = util.Slug(route.Name) - } - // Apply route rate limit - if route.RateLimit != 0 { - rateLimit := middlewares.RateLimit{ - Unit: "minute", - Id: id, // Use route index as ID - Requests: route.RateLimit, - Origins: route.Cors.Origins, - Hosts: route.Hosts, - RedisBased: redisBased, - } - limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middlewares - router.Use(limiter.RateLimitMiddleware()) - } // Apply route Cors router.Use(CORSHandler(route.Cors)) if len(route.Hosts) > 0 { @@ -208,8 +234,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } router.Use(interceptErrors.ErrorInterceptor) } - //r.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - //r.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler + } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path)