From 59516161535e5247349f99861d34323058003a18 Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Thu, 14 Nov 2024 13:17:28 +0100 Subject: [PATCH] feat: add Redis based rate limiting for multiple instances --- internal/middleware/access-middleware.go | 4 +- internal/middleware/block-common-exploits.go | 4 +- internal/middleware/error-interceptor.go | 12 +-- internal/middleware/helpers.go | 8 +- internal/middleware/middleware.go | 16 ++-- internal/middleware/rate-limit.go | 33 ++++---- .../middleware/route_error_interceptor.go | 58 -------------- internal/middleware/types.go | 75 ++++++++++--------- internal/proxy.go | 4 +- internal/route.go | 34 ++++++--- internal/types.go | 1 - 11 files changed, 99 insertions(+), 150 deletions(-) delete mode 100644 internal/middleware/route_error_interceptor.go diff --git a/internal/middleware/access-middleware.go b/internal/middleware/access-middleware.go index 4451c4b..3f86702 100644 --- a/internal/middleware/access-middleware.go +++ b/internal/middleware/access-middleware.go @@ -30,7 +30,7 @@ func (blockList AccessListMiddleware) AccessMiddleware(next http.Handler) http.H for _, block := range blockList.List { if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) { logger.Error("%s: %s access forbidden", getRealIP(r), r.URL.Path) - RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource", http.StatusForbidden), blockList.ErrorInterceptor) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource")) return } } @@ -54,7 +54,7 @@ func isPathBlocked(requestPath, blockedPath string) bool { return false } -// NewRateLimiter creates a new rate limiter with the specified refill rate and token capacity +// NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter { return &TokenRateLimiter{ tokens: maxTokens, diff --git a/internal/middleware/block-common-exploits.go b/internal/middleware/block-common-exploits.go index 18154de..62fe39c 100644 --- a/internal/middleware/block-common-exploits.go +++ b/internal/middleware/block-common-exploits.go @@ -46,7 +46,7 @@ func (blockCommon BlockCommon) BlockExploitsMiddleware(next http.Handler) http.H pathTraversalPattern.MatchString(r.URL.Path) || xssPattern.MatchString(r.URL.RawQuery) { logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path) - RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden), blockCommon.ErrorInterceptor) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden)) return } @@ -57,7 +57,7 @@ func (blockCommon BlockCommon) BlockExploitsMiddleware(next http.Handler) http.H for _, value := range values { if sqlInjectionPattern.MatchString(value) || xssPattern.MatchString(value) { logger.Error("%s: %s %s Forbidden - Potential exploit detected", getRealIP(r), r.Method, r.URL.Path) - RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden), blockCommon.ErrorInterceptor) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden)) return } } diff --git a/internal/middleware/error-interceptor.go b/internal/middleware/error-interceptor.go index c71d20e..b154d0b 100644 --- a/internal/middleware/error-interceptor.go +++ b/internal/middleware/error-interceptor.go @@ -18,7 +18,6 @@ package middleware */ import ( "bytes" - "encoding/json" "github.com/jkaninda/goma-gateway/pkg/logger" "io" "net/http" @@ -49,20 +48,11 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle if canIntercept(rec.statusCode, intercept.Errors) { logger.Debug("Backend error") logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) - w.Header().Set("Content-Type", "application/json") //Update Origin Cors Headers if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - w.WriteHeader(rec.statusCode) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: rec.statusCode, - Message: http.StatusText(rec.statusCode), - }) - if err != nil { - return - } + RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode)) return } else { // No error: write buffered response to client diff --git a/internal/middleware/helpers.go b/internal/middleware/helpers.go index 219c13c..1cfc12b 100644 --- a/internal/middleware/helpers.go +++ b/internal/middleware/helpers.go @@ -60,14 +60,14 @@ func errMessage(code int, errors []errorinterceptor.Error) (string, error) { } // RespondWithError is a helper function to handle error responses with flexible content type -func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string, errorIntercept errorinterceptor.ErrorInterceptor) { - message, err := errMessage(statusCode, errorIntercept.Errors) - if err != nil { +func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string) { + message := http.StatusText(statusCode) + if len(logMessage) != 0 { message = logMessage } w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) - err = json.NewEncoder(w).Encode(ProxyResponseError{ + err := json.NewEncoder(w).Encode(ProxyResponseError{ Success: false, Code: statusCode, Message: message, diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 35c2c67..c2e6461 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -37,7 +37,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), jwtAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } @@ -46,14 +46,14 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { authURL, err := url.Parse(jwtAuth.AuthURL) if err != nil { logger.Error("Error parsing auth URL: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), jwtAuth.ErrorInterceptor) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } // Create a new request for /authentication authReq, err := http.NewRequest("GET", authURL.String(), nil) if err != nil { logger.Error("Proxy error creating authentication request: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), jwtAuth.ErrorInterceptor) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } logger.Trace("JWT Auth response headers: %v", authReq.Header) @@ -73,7 +73,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if err != nil || authResp.StatusCode != http.StatusOK { logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) logger.Debug("Proxy authentication error") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), jwtAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } defer func(Body io.ReadCloser) { @@ -111,13 +111,13 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { if authHeader == "" { logger.Debug("Proxy error, missing Authorization header") w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } // Check if the Authorization header contains "Basic" scheme if !strings.HasPrefix(authHeader, "Basic ") { logger.Error("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } @@ -126,7 +126,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) if err != nil { logger.Debug("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } @@ -134,7 +134,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { pair := strings.SplitN(string(payload), ":", 2) if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } diff --git a/internal/middleware/rate-limit.go b/internal/middleware/rate-limit.go index ef701f4..992a558 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middleware/rate-limit.go @@ -37,13 +37,13 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { // Rate limit exceeded, return a 429 Too Many Requests response w.WriteHeader(http.StatusTooManyRequests) - _, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests))) + _, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))) if err != nil { return } return } - // Proceed to the next handler if rate limit is not exceeded + // Proceed to the next handler if requests limit is not exceeded next.ServeHTTP(w, r) }) } @@ -54,40 +54,39 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := getRealIP(r) - logger.Debug("rate limiter: clientID: %s, redisBased: %s", clientIP, rl.RedisBased) - if rl.RedisBased { - err := redisRateLimiter(clientIP, rl.Requests) + clientID := fmt.Sprintf("%d-%s", rl.id, clientIP) // Generate client Id, ID+ route ID + logger.Debug("requests limiter: clientIP: %s, redisBased: %s", clientIP, rl.redisBased) + if rl.redisBased { + err := redisRateLimiter(clientID, rl.requests) if err != nil { logger.Error("Redis Rate limiter error: %s", err.Error()) - RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) + logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) return } - // Proceed to the next handler if rate limit is not exceeded - next.ServeHTTP(w, r) } else { rl.mu.Lock() - client, exists := rl.ClientMap[clientIP] + client, exists := rl.clientMap[clientID] if !exists || time.Now().After(client.ExpiresAt) { client = &Client{ RequestCount: 0, - ExpiresAt: time.Now().Add(rl.Window), + ExpiresAt: time.Now().Add(rl.window), } - rl.ClientMap[clientIP] = client + rl.clientMap[clientID] = client } client.RequestCount++ rl.mu.Unlock() - if client.RequestCount > rl.Requests { + if client.RequestCount > rl.requests { logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) //Update Origin Cors Headers - if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { + if allowedOrigin(rl.origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) - return + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) } } - // Proceed to the next handler if rate limit is not exceeded + // Proceed to the next handler if requests limit is not exceeded next.ServeHTTP(w, r) }) } @@ -100,7 +99,7 @@ func redisRateLimiter(clientIP string, rate int) error { return err } if res.Remaining == 0 { - return errors.New("rate limit exceeded") + return errors.New("requests limit exceeded") } return nil diff --git a/internal/middleware/route_error_interceptor.go b/internal/middleware/route_error_interceptor.go deleted file mode 100644 index 9dea9ee..0000000 --- a/internal/middleware/route_error_interceptor.go +++ /dev/null @@ -1,58 +0,0 @@ -package middleware - -/* - * Copyright 2024 Jonas Kaninda - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -import ( - "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" - "github.com/jkaninda/goma-gateway/pkg/logger" - "io" - "net/http" -) - -// RouteErrorInterceptor contains backend status code errors to intercept -type RouteErrorInterceptor struct { - Origins []string - ErrorInterceptor errorinterceptor.ErrorInterceptor -} - -// RouteErrorInterceptor Middleware intercepts backend route errors -func (intercept RouteErrorInterceptor) RouteErrorInterceptor(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - rec := newResponseRecorder(w) - next.ServeHTTP(rec, r) - if canInterceptError(rec.statusCode, intercept.ErrorInterceptor.Errors) { - logger.Debug("Backend error") - logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) - //Update Origin Cors Headers - if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - } - RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode), intercept.ErrorInterceptor) - return - } else { - // No error: write buffered response to client - w.WriteHeader(rec.statusCode) - _, err := io.Copy(w, rec.body) - if err != nil { - return - } - return - - } - - }) -} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index f775aff..7d813cc 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -19,21 +19,21 @@ package middleware import ( "bytes" - errorinterceptor "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "net/http" "sync" "time" ) -// RateLimiter defines rate limit properties. +// RateLimiter defines requests limit properties. type RateLimiter struct { - Requests int - Window time.Duration - ClientMap map[string]*Client - mu sync.Mutex - Origins []string - ErrorInterceptor errorinterceptor.ErrorInterceptor - RedisBased bool + requests int + id int + window time.Duration + clientMap map[string]*Client + mu sync.Mutex + origins []string + hosts []string + redisBased bool } // Client stores request count and window expiration for each client. @@ -41,15 +41,24 @@ type Client struct { RequestCount int ExpiresAt time.Time } +type RateLimit struct { + Id int + Requests int + Window time.Duration + Origins []string + Hosts []string + RedisBased bool +} // NewRateLimiterWindow creates a new RateLimiter. -func NewRateLimiterWindow(requests int, window time.Duration, redisBased bool, origin []string) *RateLimiter { +func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter { return &RateLimiter{ - Requests: requests, - Window: window, - ClientMap: make(map[string]*Client), - Origins: origin, - RedisBased: redisBased, + id: rateLimit.Id, + requests: rateLimit.Requests, + window: rateLimit.Window, + clientMap: make(map[string]*Client), + origins: rateLimit.Origins, + redisBased: rateLimit.RedisBased, } } @@ -71,12 +80,11 @@ type ProxyResponseError struct { // JwtAuth stores JWT configuration type JwtAuth struct { - AuthURL string - RequiredHeaders []string - Headers map[string]string - Params map[string]string - Origins []string - ErrorInterceptor errorinterceptor.ErrorInterceptor + AuthURL string + RequiredHeaders []string + Headers map[string]string + Params map[string]string + Origins []string } // AuthenticationMiddleware Define struct @@ -87,19 +95,17 @@ type AuthenticationMiddleware struct { Params map[string]string } type AccessListMiddleware struct { - Path string - Destination string - List []string - ErrorInterceptor errorinterceptor.ErrorInterceptor + Path string + Destination string + List []string } // AuthBasic contains Basic auth configuration type AuthBasic struct { - Username string - Password string - Headers map[string]string - Params map[string]string - ErrorInterceptor errorinterceptor.ErrorInterceptor + Username string + Password string + Headers map[string]string + Params map[string]string } // InterceptErrors contains backend status code errors to intercept @@ -127,11 +133,10 @@ type Oauth struct { // Scope specifies optional requested permissions. Scopes []string // contains filtered or unexported fields - State string - Origins []string - JWTSecret string - Provider string - ErrorInterceptor errorinterceptor.ErrorInterceptor + State string + Origins []string + JWTSecret string + Provider string } type OauthEndpoint struct { AuthURL string diff --git a/internal/proxy.go b/internal/proxy.go index bc6b90e..6d6d352 100644 --- a/internal/proxy.go +++ b/internal/proxy.go @@ -37,7 +37,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { if len(proxyRoute.methods) > 0 { if !slices.Contains(proxyRoute.methods, r.Method) { logger.Error("%s Method is not allowed", r.Method) - middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method), proxyRoute.ErrorInterceptor) + middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method)) return } } @@ -60,7 +60,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { targetURL, err := url.Parse(proxyRoute.destination) if err != nil { logger.Error("Error parsing backend URL: %s", err) - middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), proxyRoute.ErrorInterceptor) + middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) diff --git a/internal/route.go b/internal/route.go index cd5f050..c64c1b5 100644 --- a/internal/route.go +++ b/internal/route.go @@ -66,12 +66,20 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { r.Use(blockCommon.BlockExploitsMiddleware) } if gateway.RateLimit > 0 { - //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(gateway.RateLimit, time.Minute, redisBased, gateway.Cors.Origins) // requests per minute // Add rate limit middleware to all routes, if defined + rateLimit := middleware.RateLimit{ + Id: 1, + Requests: gateway.RateLimit, + Window: time.Minute, // requests per minute + Origins: gateway.Cors.Origins, + Hosts: []string{}, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middleware r.Use(limiter.RateLimitMiddleware()) } - for _, route := range gateway.Routes { + for rIndex, route := range gateway.Routes { if route.Path != "" { if route.Destination == "" && len(route.Backends) == 0 { logger.Fatal("Route %s : destination or backends should not be empty", route.Name) @@ -232,9 +240,16 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } // Apply route rate limit if route.RateLimit > 0 { - //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(route.RateLimit, time.Minute, redisBased, route.Cors.Origins) // requests per minute - // Add rate limit middleware to all routes, if defined + rateLimit := middleware.RateLimit{ + Id: rIndex, + Requests: route.RateLimit, + Window: time.Minute, // requests per minute + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middleware router.Use(limiter.RateLimitMiddleware()) } // Apply route Cors @@ -255,11 +270,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { router.Use(pr.prometheusMiddleware) } // Apply route Error interceptor middleware - interceptErrors := middleware.RouteErrorInterceptor{ - Origins: gateway.Cors.Origins, - ErrorInterceptor: route.ErrorInterceptor, + interceptErrors := middleware.InterceptErrors{ + Origins: gateway.Cors.Origins, } - r.Use(interceptErrors.RouteErrorInterceptor) + router.Use(interceptErrors.ErrorInterceptor) } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) diff --git a/internal/types.go b/internal/types.go index 5c75e85..c3c85e2 100644 --- a/internal/types.go +++ b/internal/types.go @@ -246,7 +246,6 @@ type ProxyRoute struct { methods []string cors Cors disableHostFording bool - ErrorInterceptor errorinterceptor.ErrorInterceptor } type RoutePath struct { route Route