From 42abf56473a8b833be1bb69fdbf287c037a1350a Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Thu, 14 Nov 2024 00:26:21 +0100 Subject: [PATCH 1/6] refactor: improve error interceptor --- internal/config.go | 25 ++++-- internal/middleware/access-middleware.go | 12 +-- internal/middleware/block-common-exploits.go | 34 +++---- internal/middleware/error-interceptor.go | 12 +-- internal/middleware/helpers.go | 57 +++++++++++- internal/middleware/middleware.go | 90 +++---------------- internal/middleware/rate-limit.go | 26 ++---- .../middleware/route_error_interceptor.go | 58 ++++++++++++ internal/middleware/types.go | 48 +++++----- internal/proxy.go | 13 +-- internal/route.go | 14 ++- internal/types.go | 23 ++++- internal/var.go | 4 + pkg/error-interceptor/types.go | 31 +++++++ pkg/error-interceptor/var.go | 22 +++++ 15 files changed, 284 insertions(+), 185 deletions(-) create mode 100644 internal/middleware/route_error_interceptor.go create mode 100644 pkg/error-interceptor/types.go create mode 100644 pkg/error-interceptor/var.go diff --git a/internal/config.go b/internal/config.go index faf15e5..72b3896 100644 --- a/internal/config.go +++ b/internal/config.go @@ -18,6 +18,7 @@ limitations under the License. import ( "fmt" "github.com/jkaninda/goma-gateway/internal/middleware" + error_interceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" "golang.org/x/oauth2" @@ -27,6 +28,7 @@ import ( "golang.org/x/oauth2/gitlab" "golang.org/x/oauth2/google" "gopkg.in/yaml.v3" + "net/http" "os" ) @@ -180,11 +182,24 @@ func initConfig(configFile string) error { Middlewares: []string{"basic-auth", "api-forbidden-paths"}, }, { - Path: "/", - Name: "Hostname and load balancing example", - Hosts: []string{"example.com", "example.localhost"}, - InterceptErrors: []int{404, 405, 500}, - RateLimit: 60, + Path: "/", + Name: "Hostname and load balancing example", + Hosts: []string{"example.com", "example.localhost"}, + //InterceptErrors: []int{404, 405, 500}, + ErrorInterceptor: error_interceptor.ErrorInterceptor{ + ContentType: applicationJson, + Errors: []error_interceptor.Error{ + { + Code: http.StatusUnauthorized, + Message: http.StatusText(http.StatusUnauthorized), + }, + { + Code: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), + }, + }, + }, + RateLimit: 60, Backends: []string{ "https://example.com", "https://example2.com", diff --git a/internal/middleware/access-middleware.go b/internal/middleware/access-middleware.go index b581c73..4451c4b 100644 --- a/internal/middleware/access-middleware.go +++ b/internal/middleware/access-middleware.go @@ -16,7 +16,6 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( - "encoding/json" "fmt" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" @@ -31,16 +30,7 @@ func (blockList AccessListMiddleware) AccessMiddleware(next http.Handler) http.H for _, block := range blockList.List { if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) { logger.Error("%s: %s access forbidden", getRealIP(r), r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusForbidden, - Message: fmt.Sprintf("You do not have permission to access this resource"), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource", http.StatusForbidden), blockList.ErrorInterceptor) return } } diff --git a/internal/middleware/block-common-exploits.go b/internal/middleware/block-common-exploits.go index 8a82534..9d87221 100644 --- a/internal/middleware/block-common-exploits.go +++ b/internal/middleware/block-common-exploits.go @@ -18,15 +18,19 @@ package middleware import ( - "encoding/json" "fmt" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" "regexp" ) +type BlockCommon struct { + ErrorInterceptor errorinterceptor.ErrorInterceptor +} + // BlockExploitsMiddleware Middleware to block common exploits -func BlockExploitsMiddleware(next http.Handler) http.Handler { +func (blockCommon BlockCommon) BlockExploitsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Patterns to detect SQL injection attempts sqlInjectionPattern := regexp.MustCompile(sqlPatterns) @@ -42,36 +46,18 @@ func BlockExploitsMiddleware(next http.Handler) http.Handler { pathTraversalPattern.MatchString(r.URL.Path) || xssPattern.MatchString(r.URL.RawQuery) { logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusForbidden, - Message: fmt.Sprintf("Forbidden - Potential exploit detected"), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden), blockCommon.ErrorInterceptor) return - } + } // Check form data (for POST requests) if r.Method == http.MethodPost { if err := r.ParseForm(); err == nil { for _, values := range r.Form { for _, value := range values { if sqlInjectionPattern.MatchString(value) || xssPattern.MatchString(value) { - logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusForbidden, - Message: fmt.Sprintf("Forbidden - Potential exploit detected"), - }) - if err != nil { - return - } + logger.Error("%s: %s %s Forbidden - Potential exploit detected", getRealIP(r), r.Method, r.URL.Path) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden), blockCommon.ErrorInterceptor) return } } diff --git a/internal/middleware/error-interceptor.go b/internal/middleware/error-interceptor.go index 76235c4..c71d20e 100644 --- a/internal/middleware/error-interceptor.go +++ b/internal/middleware/error-interceptor.go @@ -22,6 +22,7 @@ import ( "github.com/jkaninda/goma-gateway/pkg/logger" "io" "net/http" + "slices" ) func newResponseRecorder(w http.ResponseWriter) *responseRecorder { @@ -62,6 +63,7 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle if err != nil { return } + return } else { // No error: write buffered response to client w.WriteHeader(rec.statusCode) @@ -69,18 +71,12 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle if err != nil { return } + return } }) } func canIntercept(code int, errors []int) bool { - for _, er := range errors { - if er == code { - return true - } - continue - - } - return false + return slices.Contains(errors, code) } diff --git a/internal/middleware/helpers.go b/internal/middleware/helpers.go index 65b95c3..fcd0f35 100644 --- a/internal/middleware/helpers.go +++ b/internal/middleware/helpers.go @@ -17,7 +17,12 @@ package middleware -import "net/http" +import ( + "encoding/json" + "fmt" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + "net/http" +) func getRealIP(r *http.Request) string { if ip := r.Header.Get("X-Real-IP"); ip != "" { @@ -38,3 +43,53 @@ func allowedOrigin(origins []string, origin string) bool { return false } +func canInterceptError(code int, errors []errorinterceptor.Error) bool { + for _, er := range errors { + if er.Code == code { + return true + } + continue + + } + return false +} +func errMessage(code int, errors []errorinterceptor.Error) (string, error) { + for _, er := range errors { + if er.Code == code { + if len(er.Message) != 0 { + return er.Message, nil + } + continue + } + } + return "", fmt.Errorf("%d errors occurred", code) +} + +// RespondWithError is a helper function to handle error responses with flexible content type +func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string, errorIntercept errorinterceptor.ErrorInterceptor) { + message, err := errMessage(statusCode, errorIntercept.Errors) + if err != nil { + message = logMessage + } + if errorIntercept.ContentType == errorinterceptor.ApplicationJson { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + err := json.NewEncoder(w).Encode(ProxyResponseError{ + Success: false, + Code: statusCode, + Message: message, + }) + if err != nil { + return + } + return + } else { + w.Header().Set("Content-Type", "plain/text;charset=utf-8") + w.WriteHeader(statusCode) + _, err2 := w.Write([]byte(message)) + if err2 != nil { + return + } + return + } +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index c8fa848..35c2c67 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -17,7 +17,6 @@ limitations under the License. */ import ( "encoding/base64" - "encoding/json" "github.com/jkaninda/goma-gateway/pkg/logger" "io" "net/http" @@ -38,48 +37,23 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Message: http.StatusText(http.StatusUnauthorized), - Code: http.StatusUnauthorized, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), jwtAuth.ErrorInterceptor) return + } } //token := r.Header.Get("Authorization") authURL, err := url.Parse(jwtAuth.AuthURL) if err != nil { logger.Error("Error parsing auth URL: %v", err) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - err = json.NewEncoder(w).Encode(ProxyResponseError{ - Message: "Internal Server Error", - Code: http.StatusInternalServerError, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), jwtAuth.ErrorInterceptor) return } // Create a new request for /authentication authReq, err := http.NewRequest("GET", authURL.String(), nil) if err != nil { logger.Error("Proxy error creating authentication request: %v", err) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - err = json.NewEncoder(w).Encode(ProxyResponseError{ - Message: "Internal Server Error", - Code: http.StatusInternalServerError, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), jwtAuth.ErrorInterceptor) return } logger.Trace("JWT Auth response headers: %v", authReq.Header) @@ -99,16 +73,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if err != nil || authResp.StatusCode != http.StatusOK { logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) logger.Debug("Proxy authentication error") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err = json.NewEncoder(w).Encode(ProxyResponseError{ - Message: "Unauthorized", - Code: http.StatusUnauthorized, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), jwtAuth.ErrorInterceptor) return } defer func(Body io.ReadCloser) { @@ -146,31 +111,14 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { if authHeader == "" { logger.Debug("Proxy error, missing Authorization header") w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) return } // Check if the Authorization header contains "Basic" scheme if !strings.HasPrefix(authHeader, "Basic ") { logger.Error("Proxy error, missing Basic Authorization header") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + return } @@ -178,16 +126,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) if err != nil { logger.Debug("Proxy error, missing Basic Authorization header") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) return } @@ -195,16 +134,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { pair := strings.SplitN(string(payload), ":", 2) if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) return } diff --git a/internal/middleware/rate-limit.go b/internal/middleware/rate-limit.go index 125a10c..050ae76 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middleware/rate-limit.go @@ -16,7 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( - "encoding/json" + "fmt" "github.com/gorilla/mux" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" @@ -28,20 +28,17 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !rl.Allow() { + logger.Error("Too many requests from IP: %s %s %s", getRealIP(r), r.URL, r.UserAgent()) + //RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + // Rate limit exceeded, return a 429 Too Many Requests response - w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusTooManyRequests, - Message: "Too many requests, API rate limit exceeded. Please try again later.", - }) + _, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests))) if err != nil { return } return } - // Proceed to the next handler if rate limit is not exceeded next.ServeHTTP(w, r) }) @@ -66,21 +63,12 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { rl.mu.Unlock() if client.RequestCount > rl.Requests { - logger.Debug("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent()) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) + logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent()) //Update Origin Cors Headers if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusTooManyRequests, - Message: "Too many requests, API rate limit exceeded. Please try again later.", - }) - if err != nil { - return - } + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) return } // Proceed to the next handler if rate limit is not exceeded diff --git a/internal/middleware/route_error_interceptor.go b/internal/middleware/route_error_interceptor.go new file mode 100644 index 0000000..ec8c1e2 --- /dev/null +++ b/internal/middleware/route_error_interceptor.go @@ -0,0 +1,58 @@ +package middleware + +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +import ( + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + "github.com/jkaninda/goma-gateway/pkg/logger" + "io" + "net/http" +) + +// RouteErrorInterceptor contains backend status code errors to intercept +type RouteErrorInterceptor struct { + Origins []string + ErrorInterceptor errorinterceptor.ErrorInterceptor +} + +// RouteErrorInterceptor Middleware intercepts backend route errors +func (intercept RouteErrorInterceptor) RouteErrorInterceptor(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rec := newResponseRecorder(w) + next.ServeHTTP(rec, r) + if canInterceptError(rec.statusCode, intercept.ErrorInterceptor.Errors) { + logger.Debug("Backend error") + logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) + //Update Origin Cors Headers + if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } + RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode), intercept.ErrorInterceptor) + return + } else { + // No error: write buffered response to client + w.WriteHeader(rec.statusCode) + _, err := io.Copy(w, rec.body) + if err != nil { + return + } + return + + } + + }) +} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index 54bebee..15f695b 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -19,6 +19,7 @@ package middleware import ( "bytes" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" "net/http" "sync" "time" @@ -26,11 +27,12 @@ import ( // RateLimiter defines rate limit properties. type RateLimiter struct { - Requests int - Window time.Duration - ClientMap map[string]*Client - mu sync.Mutex - Origins []string + Requests int + Window time.Duration + ClientMap map[string]*Client + mu sync.Mutex + Origins []string + ErrorInterceptor errorinterceptor.ErrorInterceptor } // Client stores request count and window expiration for each client. @@ -67,11 +69,12 @@ type ProxyResponseError struct { // JwtAuth stores JWT configuration type JwtAuth struct { - AuthURL string - RequiredHeaders []string - Headers map[string]string - Params map[string]string - Origins []string + AuthURL string + RequiredHeaders []string + Headers map[string]string + Params map[string]string + Origins []string + ErrorInterceptor errorinterceptor.ErrorInterceptor } // AuthenticationMiddleware Define struct @@ -82,17 +85,19 @@ type AuthenticationMiddleware struct { Params map[string]string } type AccessListMiddleware struct { - Path string - Destination string - List []string + Path string + Destination string + List []string + ErrorInterceptor errorinterceptor.ErrorInterceptor } // AuthBasic contains Basic auth configuration type AuthBasic struct { - Username string - Password string - Headers map[string]string - Params map[string]string + Username string + Password string + Headers map[string]string + Params map[string]string + ErrorInterceptor errorinterceptor.ErrorInterceptor } // InterceptErrors contains backend status code errors to intercept @@ -120,10 +125,11 @@ type Oauth struct { // Scope specifies optional requested permissions. Scopes []string // contains filtered or unexported fields - State string - Origins []string - JWTSecret string - Provider string + State string + Origins []string + JWTSecret string + Provider string + ErrorInterceptor errorinterceptor.ErrorInterceptor } type OauthEndpoint struct { AuthURL string diff --git a/internal/proxy.go b/internal/proxy.go index 6478078..bc6b90e 100644 --- a/internal/proxy.go +++ b/internal/proxy.go @@ -17,6 +17,7 @@ limitations under the License. */ import ( "fmt" + "github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" "net/http/httputil" @@ -36,11 +37,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { if len(proxyRoute.methods) > 0 { if !slices.Contains(proxyRoute.methods, r.Method) { logger.Error("%s Method is not allowed", r.Method) - w.WriteHeader(http.StatusMethodNotAllowed) - _, err := w.Write([]byte(fmt.Sprintf("%s method is not allowed", r.Method))) - if err != nil { - return - } + middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method), proxyRoute.ErrorInterceptor) return } } @@ -63,11 +60,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { targetURL, err := url.Parse(proxyRoute.destination) if err != nil { logger.Error("Error parsing backend URL: %s", err) - w.WriteHeader(http.StatusInternalServerError) - _, err := w.Write([]byte("Internal Server Error")) - if err != nil { - return - } + middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), proxyRoute.ErrorInterceptor) return } r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) diff --git a/internal/route.go b/internal/route.go index c735aca..a0299b9 100644 --- a/internal/route.go +++ b/internal/route.go @@ -58,7 +58,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Enable common exploits if gateway.BlockCommonExploits { logger.Info("Block common exploits enabled") - r.Use(middleware.BlockExploitsMiddleware) + blockCommon := middleware.BlockCommon{} + r.Use(blockCommon.BlockExploitsMiddleware) } if gateway.RateLimit != 0 { //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) @@ -219,8 +220,11 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Apply common exploits to the route // Enable common exploits if route.BlockCommonExploits { + blockCommon := middleware.BlockCommon{ + ErrorInterceptor: route.ErrorInterceptor, + } logger.Info("Block common exploits enabled") - router.Use(middleware.BlockExploitsMiddleware) + router.Use(blockCommon.BlockExploitsMiddleware) } // Apply route rate limit if route.RateLimit > 0 { @@ -246,6 +250,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Prometheus endpoint router.Use(pr.prometheusMiddleware) } + // Apply route Error interceptor middleware + interceptErrors := middleware.RouteErrorInterceptor{ + Origins: gateway.Cors.Origins, + ErrorInterceptor: route.ErrorInterceptor, + } + r.Use(interceptErrors.RouteErrorInterceptor) } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) diff --git a/internal/types.go b/internal/types.go index 143b1ad..b1b600f 100644 --- a/internal/types.go +++ b/internal/types.go @@ -20,6 +20,7 @@ package pkg import ( "context" "github.com/gorilla/mux" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" "time" ) @@ -161,12 +162,12 @@ type Route struct { // // It will not match the backend route DisableHostFording bool `yaml:"disableHostFording"` - // InterceptErrors intercepts backend errors based on the status codes - // - // Eg: [ 403, 405, 500 ] - InterceptErrors []int `yaml:"interceptErrors"` + // BlockCommonExploits enable, disable block common exploits BlockCommonExploits bool `yaml:"blockCommonExploits"` + // ErrorInterceptor intercepts backend errors based on the status codes and custom message + // + ErrorInterceptor errorinterceptor.ErrorInterceptor `yaml:"errorInterceptor"` // Middlewares Defines route middleware from Middleware names Middlewares []string `yaml:"middlewares"` } @@ -242,6 +243,7 @@ type ProxyRoute struct { methods []string cors Cors disableHostFording bool + ErrorInterceptor errorinterceptor.ErrorInterceptor } type RoutePath struct { route Route @@ -285,3 +287,16 @@ type Health struct { Interval string HealthyStatuses []int } + +//type ErrorInterceptor struct { +// // ContentType error response content type, application/json, plain/text +// ContentType string `yaml:"contentType"` +// //Errors contains error status code and custom message +// Errors []ErrorInterceptor `yaml:"errors"` +//} +//type ErrorInterceptor struct { +// // Code HTTP status code +// Code int `yaml:"code"` +// // Message custom message +// Message string `yaml:"message"` +//} diff --git a/internal/var.go b/internal/var.go index baf9516..b57bb0d 100644 --- a/internal/var.go +++ b/internal/var.go @@ -9,6 +9,10 @@ const AccessMiddleware = "access" // access middleware const BasicAuth = "basic" // basic authentication middleware const JWTAuth = "jwt" // JWT authentication middleware const OAuth = "oauth" // OAuth authentication middleware +const applicationJson = "application/json" +const textPlain = "text/plain" +const applicationXml = "application/xml" + // Round-robin counter var counter uint32 diff --git a/pkg/error-interceptor/types.go b/pkg/error-interceptor/types.go new file mode 100644 index 0000000..95d3bf4 --- /dev/null +++ b/pkg/error-interceptor/types.go @@ -0,0 +1,31 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package error_interceptor + +type ErrorInterceptor struct { + // ContentType error response content type, application/json, plain/text + ContentType string `yaml:"contentType"` + //Errors contains error status code and custom message + Errors []Error `yaml:"errors"` +} +type Error struct { + // Code HTTP status code + Code int `yaml:"code"` + // Message custom message + Message string `yaml:"message"` +} diff --git a/pkg/error-interceptor/var.go b/pkg/error-interceptor/var.go new file mode 100644 index 0000000..267abe7 --- /dev/null +++ b/pkg/error-interceptor/var.go @@ -0,0 +1,22 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package error_interceptor + +const TextPlain = "text/plain" +const ApplicationXml = "application/xml" +const ApplicationJson = "application/json" From 3c4920ec9a0b8c6c00a5a1cff0f03d0c374da0ae Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Thu, 14 Nov 2024 09:49:18 +0100 Subject: [PATCH 2/6] refatcor: improve error route interceptor --- internal/config.go | 9 ++-- internal/middleware/block-common-exploits.go | 2 +- internal/middleware/helpers.go | 41 ++++++------------- .../middleware/route_error_interceptor.go | 2 +- internal/middleware/types.go | 2 +- internal/types.go | 15 +------ .../types.go | 6 +-- .../var.go | 2 +- 8 files changed, 25 insertions(+), 54 deletions(-) rename pkg/{error-interceptor => errorinterceptor}/types.go (89%) rename pkg/{error-interceptor => errorinterceptor}/var.go (96%) diff --git a/internal/config.go b/internal/config.go index 72b3896..6e9fa2c 100644 --- a/internal/config.go +++ b/internal/config.go @@ -18,7 +18,7 @@ limitations under the License. import ( "fmt" "github.com/jkaninda/goma-gateway/internal/middleware" - error_interceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" "golang.org/x/oauth2" @@ -185,10 +185,9 @@ func initConfig(configFile string) error { Path: "/", Name: "Hostname and load balancing example", Hosts: []string{"example.com", "example.localhost"}, - //InterceptErrors: []int{404, 405, 500}, - ErrorInterceptor: error_interceptor.ErrorInterceptor{ - ContentType: applicationJson, - Errors: []error_interceptor.Error{ + //ErrorIntercept: []int{404, 405, 500}, + ErrorInterceptor: errorinterceptor.ErrorInterceptor{ + Errors: []errorinterceptor.Error{ { Code: http.StatusUnauthorized, Message: http.StatusText(http.StatusUnauthorized), diff --git a/internal/middleware/block-common-exploits.go b/internal/middleware/block-common-exploits.go index 9d87221..18154de 100644 --- a/internal/middleware/block-common-exploits.go +++ b/internal/middleware/block-common-exploits.go @@ -19,7 +19,7 @@ package middleware import ( "fmt" - errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" "regexp" diff --git a/internal/middleware/helpers.go b/internal/middleware/helpers.go index fcd0f35..219c13c 100644 --- a/internal/middleware/helpers.go +++ b/internal/middleware/helpers.go @@ -20,8 +20,9 @@ package middleware import ( "encoding/json" "fmt" - errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "net/http" + "slices" ) func getRealIP(r *http.Request) string { @@ -34,14 +35,7 @@ func getRealIP(r *http.Request) string { return r.RemoteAddr } func allowedOrigin(origins []string, origin string) bool { - for _, o := range origins { - if o == origin { - return true - } - continue - } - return false - + return slices.Contains(origins, origin) } func canInterceptError(code int, errors []errorinterceptor.Error) bool { for _, er := range errors { @@ -71,25 +65,16 @@ func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string, if err != nil { message = logMessage } - if errorIntercept.ContentType == errorinterceptor.ApplicationJson { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(statusCode) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: statusCode, - Message: message, - }) - if err != nil { - return - } - return - } else { - w.Header().Set("Content-Type", "plain/text;charset=utf-8") - w.WriteHeader(statusCode) - _, err2 := w.Write([]byte(message)) - if err2 != nil { - return - } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + err = json.NewEncoder(w).Encode(ProxyResponseError{ + Success: false, + Code: statusCode, + Message: message, + }) + if err != nil { return } + return + } diff --git a/internal/middleware/route_error_interceptor.go b/internal/middleware/route_error_interceptor.go index ec8c1e2..9dea9ee 100644 --- a/internal/middleware/route_error_interceptor.go +++ b/internal/middleware/route_error_interceptor.go @@ -17,7 +17,7 @@ package middleware * */ import ( - errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "github.com/jkaninda/goma-gateway/pkg/logger" "io" "net/http" diff --git a/internal/middleware/types.go b/internal/middleware/types.go index 15f695b..d4f1a18 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -19,7 +19,7 @@ package middleware import ( "bytes" - errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "net/http" "sync" "time" diff --git a/internal/types.go b/internal/types.go index b1b600f..80163de 100644 --- a/internal/types.go +++ b/internal/types.go @@ -20,7 +20,7 @@ package pkg import ( "context" "github.com/gorilla/mux" - errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "time" ) @@ -287,16 +287,3 @@ type Health struct { Interval string HealthyStatuses []int } - -//type ErrorInterceptor struct { -// // ContentType error response content type, application/json, plain/text -// ContentType string `yaml:"contentType"` -// //Errors contains error status code and custom message -// Errors []ErrorInterceptor `yaml:"errors"` -//} -//type ErrorInterceptor struct { -// // Code HTTP status code -// Code int `yaml:"code"` -// // Message custom message -// Message string `yaml:"message"` -//} diff --git a/pkg/error-interceptor/types.go b/pkg/errorinterceptor/types.go similarity index 89% rename from pkg/error-interceptor/types.go rename to pkg/errorinterceptor/types.go index 95d3bf4..23b0ffc 100644 --- a/pkg/error-interceptor/types.go +++ b/pkg/errorinterceptor/types.go @@ -15,17 +15,17 @@ * */ -package error_interceptor +package errorinterceptor type ErrorInterceptor struct { // ContentType error response content type, application/json, plain/text - ContentType string `yaml:"contentType"` + //ContentType string `yaml:"contentType"` //Errors contains error status code and custom message Errors []Error `yaml:"errors"` } type Error struct { // Code HTTP status code Code int `yaml:"code"` - // Message custom message + // Message Error custom response message Message string `yaml:"message"` } diff --git a/pkg/error-interceptor/var.go b/pkg/errorinterceptor/var.go similarity index 96% rename from pkg/error-interceptor/var.go rename to pkg/errorinterceptor/var.go index 267abe7..d924846 100644 --- a/pkg/error-interceptor/var.go +++ b/pkg/errorinterceptor/var.go @@ -15,7 +15,7 @@ * */ -package error_interceptor +package errorinterceptor const TextPlain = "text/plain" const ApplicationXml = "application/xml" From a874d141941c7bfe06583cacf008b51e0a470e8b Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Thu, 14 Nov 2024 11:38:36 +0100 Subject: [PATCH 3/6] feat: add Redis based rate limiting for multiple instances --- go.mod | 6 ++- go.sum | 8 ++++ internal/middleware/rate-limit.go | 72 +++++++++++++++++++++++-------- internal/middleware/types.go | 12 +++--- internal/middleware/var.go | 10 +++++ internal/route.go | 10 +++-- internal/server.go | 13 ++++++ internal/types.go | 8 ++++ 8 files changed, 112 insertions(+), 27 deletions(-) diff --git a/go.mod b/go.mod index 13ce334..8f22857 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,14 @@ go 1.23.2 require ( github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be + github.com/go-redis/redis_rate/v10 v10.0.1 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/gorilla/mux v1.8.1 github.com/prometheus/client_golang v1.20.5 + github.com/redis/go-redis/v9 v9.7.0 + github.com/robfig/cron/v3 v3.0.1 github.com/spf13/cobra v1.8.1 + golang.org/x/net v0.26.0 golang.org/x/oauth2 v0.24.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -23,13 +27,13 @@ require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/robfig/cron/v3 v3.0.1 // indirect github.com/spf13/pflag v1.0.5 // indirect google.golang.org/protobuf v1.34.2 // indirect diff --git a/go.sum b/go.sum index 5c8caf9..4126899 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,10 @@ github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be/go.mod github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo= +github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -37,6 +41,8 @@ github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= +github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= @@ -50,6 +56,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= diff --git a/internal/middleware/rate-limit.go b/internal/middleware/rate-limit.go index 050ae76..ef701f4 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middleware/rate-limit.go @@ -16,9 +16,13 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( + "errors" "fmt" + "github.com/go-redis/redis_rate/v10" "github.com/gorilla/mux" "github.com/jkaninda/goma-gateway/pkg/logger" + "github.com/redis/go-redis/v9" + "golang.org/x/net/context" "net/http" "time" ) @@ -49,30 +53,62 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - clientID := getRealIP(r) - rl.mu.Lock() - client, exists := rl.ClientMap[clientID] - if !exists || time.Now().After(client.ExpiresAt) { - client = &Client{ - RequestCount: 0, - ExpiresAt: time.Now().Add(rl.Window), + clientIP := getRealIP(r) + logger.Debug("rate limiter: clientID: %s, redisBased: %s", clientIP, rl.RedisBased) + if rl.RedisBased { + err := redisRateLimiter(clientIP, rl.Requests) + if err != nil { + logger.Error("Redis Rate limiter error: %s", err.Error()) + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) + return } - rl.ClientMap[clientID] = client - } - client.RequestCount++ - rl.mu.Unlock() + // Proceed to the next handler if rate limit is not exceeded + next.ServeHTTP(w, r) + } else { + rl.mu.Lock() + client, exists := rl.ClientMap[clientIP] + if !exists || time.Now().After(client.ExpiresAt) { + client = &Client{ + RequestCount: 0, + ExpiresAt: time.Now().Add(rl.Window), + } + rl.ClientMap[clientIP] = client + } + client.RequestCount++ + rl.mu.Unlock() - if client.RequestCount > rl.Requests { - logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent()) - //Update Origin Cors Headers - if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + if client.RequestCount > rl.Requests { + logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) + //Update Origin Cors Headers + if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) + return } - RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) - return } // Proceed to the next handler if rate limit is not exceeded next.ServeHTTP(w, r) }) } } +func redisRateLimiter(clientIP string, rate int) error { + ctx := context.Background() + + res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate)) + if err != nil { + return err + } + if res.Remaining == 0 { + return errors.New("rate limit exceeded") + } + + return nil +} +func InitRedis(addr, password string) { + Rdb = redis.NewClient(&redis.Options{ + Addr: addr, + Password: password, + }) + limiter = redis_rate.NewLimiter(Rdb) +} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index d4f1a18..f775aff 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -33,6 +33,7 @@ type RateLimiter struct { mu sync.Mutex Origins []string ErrorInterceptor errorinterceptor.ErrorInterceptor + RedisBased bool } // Client stores request count and window expiration for each client. @@ -42,12 +43,13 @@ type Client struct { } // NewRateLimiterWindow creates a new RateLimiter. -func NewRateLimiterWindow(requests int, window time.Duration, origin []string) *RateLimiter { +func NewRateLimiterWindow(requests int, window time.Duration, redisBased bool, origin []string) *RateLimiter { return &RateLimiter{ - Requests: requests, - Window: window, - ClientMap: make(map[string]*Client), - Origins: origin, + Requests: requests, + Window: window, + ClientMap: make(map[string]*Client), + Origins: origin, + RedisBased: redisBased, } } diff --git a/internal/middleware/var.go b/internal/middleware/var.go index 4e8502c..5eb3266 100644 --- a/internal/middleware/var.go +++ b/internal/middleware/var.go @@ -17,7 +17,17 @@ package middleware +import ( + "github.com/go-redis/redis_rate/v10" + "github.com/redis/go-redis/v9" +) + // sqlPatterns contains SQL injections patters const sqlPatterns = `(?i)(union|select|drop|insert|delete|update|create|alter|exec|;|--)` const traversalPatterns = `\.\./` const xssPatterns = `(?i) 0 { //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(gateway.RateLimit, time.Minute, gateway.Cors.Origins) // requests per minute + limiter := middleware.NewRateLimiterWindow(gateway.RateLimit, time.Minute, redisBased, gateway.Cors.Origins) // requests per minute // Add rate limit middleware to all routes, if defined r.Use(limiter.RateLimitMiddleware()) } @@ -229,7 +233,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Apply route rate limit if route.RateLimit > 0 { //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(route.RateLimit, time.Minute, route.Cors.Origins) // requests per minute + limiter := middleware.NewRateLimiterWindow(route.RateLimit, time.Minute, redisBased, route.Cors.Origins) // requests per minute // Add rate limit middleware to all routes, if defined router.Use(limiter.RateLimitMiddleware()) } diff --git a/internal/server.go b/internal/server.go index 6b43c7b..e71d2ca 100644 --- a/internal/server.go +++ b/internal/server.go @@ -19,7 +19,9 @@ import ( "context" "crypto/tls" "fmt" + "github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/pkg/logger" + "github.com/redis/go-redis/v9" "net/http" "os" "sync" @@ -30,8 +32,19 @@ import ( func (gatewayServer GatewayServer) Start(ctx context.Context) error { logger.Info("Initializing routes...") route := gatewayServer.Initialize() + gateway := gatewayServer.gateway logger.Debug("Routes count=%d Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) logger.Info("Initializing routes...done") + if len(gateway.Redis.Addr) != 0 { + middleware.InitRedis(gateway.Redis.Addr, gateway.Redis.Password) + defer func(Rdb *redis.Client) { + err := Rdb.Close() + if err != nil { + logger.Error("Redis connection closed with error: %v", err) + } + }(middleware.Rdb) + } + tlsConfig := &tls.Config{} var listenWithTLS = false if cert := gatewayServer.gateway.SSLCertFile; cert != "" && gatewayServer.gateway.SSLKeyFile != "" { diff --git a/internal/types.go b/internal/types.go index 80163de..5c75e85 100644 --- a/internal/types.go +++ b/internal/types.go @@ -178,6 +178,8 @@ type Gateway struct { SSLCertFile string `yaml:"sslCertFile" env:"GOMA_SSL_CERT_FILE, overwrite"` // SSLKeyFile SSL Private key file SSLKeyFile string `yaml:"sslKeyFile" env:"GOMA_SSL_KEY_FILE, overwrite"` + // Redis contains redis database details + Redis Redis `yaml:"redis"` // WriteTimeout defines proxy write timeout WriteTimeout int `yaml:"writeTimeout" env:"GOMA_WRITE_TIMEOUT, overwrite"` // ReadTimeout defines proxy read timeout @@ -204,6 +206,7 @@ type Gateway struct { InterceptErrors []int `yaml:"interceptErrors"` // Cors holds proxy global cors Cors Cors `yaml:"cors"` + // Routes holds proxy routes Routes []Route `yaml:"routes"` } @@ -287,3 +290,8 @@ type Health struct { Interval string HealthyStatuses []int } +type Redis struct { + // Addr redis hostname and post number : + Addr string `yaml:"addr"` + Password string `yaml:"password"` +} From 59516161535e5247349f99861d34323058003a18 Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Thu, 14 Nov 2024 13:17:28 +0100 Subject: [PATCH 4/6] feat: add Redis based rate limiting for multiple instances --- internal/middleware/access-middleware.go | 4 +- internal/middleware/block-common-exploits.go | 4 +- internal/middleware/error-interceptor.go | 12 +-- internal/middleware/helpers.go | 8 +- internal/middleware/middleware.go | 16 ++-- internal/middleware/rate-limit.go | 33 ++++---- .../middleware/route_error_interceptor.go | 58 -------------- internal/middleware/types.go | 75 ++++++++++--------- internal/proxy.go | 4 +- internal/route.go | 34 ++++++--- internal/types.go | 1 - 11 files changed, 99 insertions(+), 150 deletions(-) delete mode 100644 internal/middleware/route_error_interceptor.go diff --git a/internal/middleware/access-middleware.go b/internal/middleware/access-middleware.go index 4451c4b..3f86702 100644 --- a/internal/middleware/access-middleware.go +++ b/internal/middleware/access-middleware.go @@ -30,7 +30,7 @@ func (blockList AccessListMiddleware) AccessMiddleware(next http.Handler) http.H for _, block := range blockList.List { if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) { logger.Error("%s: %s access forbidden", getRealIP(r), r.URL.Path) - RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource", http.StatusForbidden), blockList.ErrorInterceptor) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource")) return } } @@ -54,7 +54,7 @@ func isPathBlocked(requestPath, blockedPath string) bool { return false } -// NewRateLimiter creates a new rate limiter with the specified refill rate and token capacity +// NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter { return &TokenRateLimiter{ tokens: maxTokens, diff --git a/internal/middleware/block-common-exploits.go b/internal/middleware/block-common-exploits.go index 18154de..62fe39c 100644 --- a/internal/middleware/block-common-exploits.go +++ b/internal/middleware/block-common-exploits.go @@ -46,7 +46,7 @@ func (blockCommon BlockCommon) BlockExploitsMiddleware(next http.Handler) http.H pathTraversalPattern.MatchString(r.URL.Path) || xssPattern.MatchString(r.URL.RawQuery) { logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path) - RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden), blockCommon.ErrorInterceptor) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden)) return } @@ -57,7 +57,7 @@ func (blockCommon BlockCommon) BlockExploitsMiddleware(next http.Handler) http.H for _, value := range values { if sqlInjectionPattern.MatchString(value) || xssPattern.MatchString(value) { logger.Error("%s: %s %s Forbidden - Potential exploit detected", getRealIP(r), r.Method, r.URL.Path) - RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden), blockCommon.ErrorInterceptor) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden)) return } } diff --git a/internal/middleware/error-interceptor.go b/internal/middleware/error-interceptor.go index c71d20e..b154d0b 100644 --- a/internal/middleware/error-interceptor.go +++ b/internal/middleware/error-interceptor.go @@ -18,7 +18,6 @@ package middleware */ import ( "bytes" - "encoding/json" "github.com/jkaninda/goma-gateway/pkg/logger" "io" "net/http" @@ -49,20 +48,11 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle if canIntercept(rec.statusCode, intercept.Errors) { logger.Debug("Backend error") logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) - w.Header().Set("Content-Type", "application/json") //Update Origin Cors Headers if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - w.WriteHeader(rec.statusCode) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: rec.statusCode, - Message: http.StatusText(rec.statusCode), - }) - if err != nil { - return - } + RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode)) return } else { // No error: write buffered response to client diff --git a/internal/middleware/helpers.go b/internal/middleware/helpers.go index 219c13c..1cfc12b 100644 --- a/internal/middleware/helpers.go +++ b/internal/middleware/helpers.go @@ -60,14 +60,14 @@ func errMessage(code int, errors []errorinterceptor.Error) (string, error) { } // RespondWithError is a helper function to handle error responses with flexible content type -func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string, errorIntercept errorinterceptor.ErrorInterceptor) { - message, err := errMessage(statusCode, errorIntercept.Errors) - if err != nil { +func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string) { + message := http.StatusText(statusCode) + if len(logMessage) != 0 { message = logMessage } w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) - err = json.NewEncoder(w).Encode(ProxyResponseError{ + err := json.NewEncoder(w).Encode(ProxyResponseError{ Success: false, Code: statusCode, Message: message, diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 35c2c67..c2e6461 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -37,7 +37,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), jwtAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } @@ -46,14 +46,14 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { authURL, err := url.Parse(jwtAuth.AuthURL) if err != nil { logger.Error("Error parsing auth URL: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), jwtAuth.ErrorInterceptor) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } // Create a new request for /authentication authReq, err := http.NewRequest("GET", authURL.String(), nil) if err != nil { logger.Error("Proxy error creating authentication request: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), jwtAuth.ErrorInterceptor) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } logger.Trace("JWT Auth response headers: %v", authReq.Header) @@ -73,7 +73,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if err != nil || authResp.StatusCode != http.StatusOK { logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) logger.Debug("Proxy authentication error") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), jwtAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } defer func(Body io.ReadCloser) { @@ -111,13 +111,13 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { if authHeader == "" { logger.Debug("Proxy error, missing Authorization header") w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } // Check if the Authorization header contains "Basic" scheme if !strings.HasPrefix(authHeader, "Basic ") { logger.Error("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } @@ -126,7 +126,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) if err != nil { logger.Debug("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } @@ -134,7 +134,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { pair := strings.SplitN(string(payload), ":", 2) if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } diff --git a/internal/middleware/rate-limit.go b/internal/middleware/rate-limit.go index ef701f4..992a558 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middleware/rate-limit.go @@ -37,13 +37,13 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { // Rate limit exceeded, return a 429 Too Many Requests response w.WriteHeader(http.StatusTooManyRequests) - _, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests))) + _, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))) if err != nil { return } return } - // Proceed to the next handler if rate limit is not exceeded + // Proceed to the next handler if requests limit is not exceeded next.ServeHTTP(w, r) }) } @@ -54,40 +54,39 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := getRealIP(r) - logger.Debug("rate limiter: clientID: %s, redisBased: %s", clientIP, rl.RedisBased) - if rl.RedisBased { - err := redisRateLimiter(clientIP, rl.Requests) + clientID := fmt.Sprintf("%d-%s", rl.id, clientIP) // Generate client Id, ID+ route ID + logger.Debug("requests limiter: clientIP: %s, redisBased: %s", clientIP, rl.redisBased) + if rl.redisBased { + err := redisRateLimiter(clientID, rl.requests) if err != nil { logger.Error("Redis Rate limiter error: %s", err.Error()) - RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) + logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) return } - // Proceed to the next handler if rate limit is not exceeded - next.ServeHTTP(w, r) } else { rl.mu.Lock() - client, exists := rl.ClientMap[clientIP] + client, exists := rl.clientMap[clientID] if !exists || time.Now().After(client.ExpiresAt) { client = &Client{ RequestCount: 0, - ExpiresAt: time.Now().Add(rl.Window), + ExpiresAt: time.Now().Add(rl.window), } - rl.ClientMap[clientIP] = client + rl.clientMap[clientID] = client } client.RequestCount++ rl.mu.Unlock() - if client.RequestCount > rl.Requests { + if client.RequestCount > rl.requests { logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) //Update Origin Cors Headers - if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { + if allowedOrigin(rl.origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) - return + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) } } - // Proceed to the next handler if rate limit is not exceeded + // Proceed to the next handler if requests limit is not exceeded next.ServeHTTP(w, r) }) } @@ -100,7 +99,7 @@ func redisRateLimiter(clientIP string, rate int) error { return err } if res.Remaining == 0 { - return errors.New("rate limit exceeded") + return errors.New("requests limit exceeded") } return nil diff --git a/internal/middleware/route_error_interceptor.go b/internal/middleware/route_error_interceptor.go deleted file mode 100644 index 9dea9ee..0000000 --- a/internal/middleware/route_error_interceptor.go +++ /dev/null @@ -1,58 +0,0 @@ -package middleware - -/* - * Copyright 2024 Jonas Kaninda - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -import ( - "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" - "github.com/jkaninda/goma-gateway/pkg/logger" - "io" - "net/http" -) - -// RouteErrorInterceptor contains backend status code errors to intercept -type RouteErrorInterceptor struct { - Origins []string - ErrorInterceptor errorinterceptor.ErrorInterceptor -} - -// RouteErrorInterceptor Middleware intercepts backend route errors -func (intercept RouteErrorInterceptor) RouteErrorInterceptor(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - rec := newResponseRecorder(w) - next.ServeHTTP(rec, r) - if canInterceptError(rec.statusCode, intercept.ErrorInterceptor.Errors) { - logger.Debug("Backend error") - logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) - //Update Origin Cors Headers - if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - } - RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode), intercept.ErrorInterceptor) - return - } else { - // No error: write buffered response to client - w.WriteHeader(rec.statusCode) - _, err := io.Copy(w, rec.body) - if err != nil { - return - } - return - - } - - }) -} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index f775aff..7d813cc 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -19,21 +19,21 @@ package middleware import ( "bytes" - errorinterceptor "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "net/http" "sync" "time" ) -// RateLimiter defines rate limit properties. +// RateLimiter defines requests limit properties. type RateLimiter struct { - Requests int - Window time.Duration - ClientMap map[string]*Client - mu sync.Mutex - Origins []string - ErrorInterceptor errorinterceptor.ErrorInterceptor - RedisBased bool + requests int + id int + window time.Duration + clientMap map[string]*Client + mu sync.Mutex + origins []string + hosts []string + redisBased bool } // Client stores request count and window expiration for each client. @@ -41,15 +41,24 @@ type Client struct { RequestCount int ExpiresAt time.Time } +type RateLimit struct { + Id int + Requests int + Window time.Duration + Origins []string + Hosts []string + RedisBased bool +} // NewRateLimiterWindow creates a new RateLimiter. -func NewRateLimiterWindow(requests int, window time.Duration, redisBased bool, origin []string) *RateLimiter { +func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter { return &RateLimiter{ - Requests: requests, - Window: window, - ClientMap: make(map[string]*Client), - Origins: origin, - RedisBased: redisBased, + id: rateLimit.Id, + requests: rateLimit.Requests, + window: rateLimit.Window, + clientMap: make(map[string]*Client), + origins: rateLimit.Origins, + redisBased: rateLimit.RedisBased, } } @@ -71,12 +80,11 @@ type ProxyResponseError struct { // JwtAuth stores JWT configuration type JwtAuth struct { - AuthURL string - RequiredHeaders []string - Headers map[string]string - Params map[string]string - Origins []string - ErrorInterceptor errorinterceptor.ErrorInterceptor + AuthURL string + RequiredHeaders []string + Headers map[string]string + Params map[string]string + Origins []string } // AuthenticationMiddleware Define struct @@ -87,19 +95,17 @@ type AuthenticationMiddleware struct { Params map[string]string } type AccessListMiddleware struct { - Path string - Destination string - List []string - ErrorInterceptor errorinterceptor.ErrorInterceptor + Path string + Destination string + List []string } // AuthBasic contains Basic auth configuration type AuthBasic struct { - Username string - Password string - Headers map[string]string - Params map[string]string - ErrorInterceptor errorinterceptor.ErrorInterceptor + Username string + Password string + Headers map[string]string + Params map[string]string } // InterceptErrors contains backend status code errors to intercept @@ -127,11 +133,10 @@ type Oauth struct { // Scope specifies optional requested permissions. Scopes []string // contains filtered or unexported fields - State string - Origins []string - JWTSecret string - Provider string - ErrorInterceptor errorinterceptor.ErrorInterceptor + State string + Origins []string + JWTSecret string + Provider string } type OauthEndpoint struct { AuthURL string diff --git a/internal/proxy.go b/internal/proxy.go index bc6b90e..6d6d352 100644 --- a/internal/proxy.go +++ b/internal/proxy.go @@ -37,7 +37,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { if len(proxyRoute.methods) > 0 { if !slices.Contains(proxyRoute.methods, r.Method) { logger.Error("%s Method is not allowed", r.Method) - middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method), proxyRoute.ErrorInterceptor) + middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method)) return } } @@ -60,7 +60,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { targetURL, err := url.Parse(proxyRoute.destination) if err != nil { logger.Error("Error parsing backend URL: %s", err) - middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), proxyRoute.ErrorInterceptor) + middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) diff --git a/internal/route.go b/internal/route.go index cd5f050..c64c1b5 100644 --- a/internal/route.go +++ b/internal/route.go @@ -66,12 +66,20 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { r.Use(blockCommon.BlockExploitsMiddleware) } if gateway.RateLimit > 0 { - //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(gateway.RateLimit, time.Minute, redisBased, gateway.Cors.Origins) // requests per minute // Add rate limit middleware to all routes, if defined + rateLimit := middleware.RateLimit{ + Id: 1, + Requests: gateway.RateLimit, + Window: time.Minute, // requests per minute + Origins: gateway.Cors.Origins, + Hosts: []string{}, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middleware r.Use(limiter.RateLimitMiddleware()) } - for _, route := range gateway.Routes { + for rIndex, route := range gateway.Routes { if route.Path != "" { if route.Destination == "" && len(route.Backends) == 0 { logger.Fatal("Route %s : destination or backends should not be empty", route.Name) @@ -232,9 +240,16 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } // Apply route rate limit if route.RateLimit > 0 { - //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(route.RateLimit, time.Minute, redisBased, route.Cors.Origins) // requests per minute - // Add rate limit middleware to all routes, if defined + rateLimit := middleware.RateLimit{ + Id: rIndex, + Requests: route.RateLimit, + Window: time.Minute, // requests per minute + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middleware router.Use(limiter.RateLimitMiddleware()) } // Apply route Cors @@ -255,11 +270,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { router.Use(pr.prometheusMiddleware) } // Apply route Error interceptor middleware - interceptErrors := middleware.RouteErrorInterceptor{ - Origins: gateway.Cors.Origins, - ErrorInterceptor: route.ErrorInterceptor, + interceptErrors := middleware.InterceptErrors{ + Origins: gateway.Cors.Origins, } - r.Use(interceptErrors.RouteErrorInterceptor) + router.Use(interceptErrors.ErrorInterceptor) } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) diff --git a/internal/types.go b/internal/types.go index 5c75e85..c3c85e2 100644 --- a/internal/types.go +++ b/internal/types.go @@ -246,7 +246,6 @@ type ProxyRoute struct { methods []string cors Cors disableHostFording bool - ErrorInterceptor errorinterceptor.ErrorInterceptor } type RoutePath struct { route Route From 949667cc6057e30fe9a71469ff20d98f6ff34281 Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Thu, 14 Nov 2024 14:41:10 +0100 Subject: [PATCH 5/6] fix: backend error interceptor --- internal/middleware/access-middleware.go | 2 +- internal/middleware/block-common-exploits.go | 7 +------ internal/middleware/error-interceptor.go | 12 ++++-------- internal/middleware/helpers.go | 1 - internal/middleware/types.go | 4 ++-- internal/route.go | 12 ++++-------- 6 files changed, 12 insertions(+), 26 deletions(-) diff --git a/internal/middleware/access-middleware.go b/internal/middleware/access-middleware.go index 3f86702..11ddd5f 100644 --- a/internal/middleware/access-middleware.go +++ b/internal/middleware/access-middleware.go @@ -30,7 +30,7 @@ func (blockList AccessListMiddleware) AccessMiddleware(next http.Handler) http.H for _, block := range blockList.List { if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) { logger.Error("%s: %s access forbidden", getRealIP(r), r.URL.Path) - RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource")) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource", http.StatusForbidden)) return } } diff --git a/internal/middleware/block-common-exploits.go b/internal/middleware/block-common-exploits.go index 62fe39c..b3f8f2b 100644 --- a/internal/middleware/block-common-exploits.go +++ b/internal/middleware/block-common-exploits.go @@ -19,18 +19,13 @@ package middleware import ( "fmt" - errorinterceptor "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" "regexp" ) -type BlockCommon struct { - ErrorInterceptor errorinterceptor.ErrorInterceptor -} - // BlockExploitsMiddleware Middleware to block common exploits -func (blockCommon BlockCommon) BlockExploitsMiddleware(next http.Handler) http.Handler { +func BlockExploitsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Patterns to detect SQL injection attempts sqlInjectionPattern := regexp.MustCompile(sqlPatterns) diff --git a/internal/middleware/error-interceptor.go b/internal/middleware/error-interceptor.go index b154d0b..9c25127 100644 --- a/internal/middleware/error-interceptor.go +++ b/internal/middleware/error-interceptor.go @@ -45,15 +45,12 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rec := newResponseRecorder(w) next.ServeHTTP(rec, r) + w.Header().Set("Proxied-By", "Goma Gateway") + w.Header().Del("Server") //Delete server name if canIntercept(rec.statusCode, intercept.Errors) { - logger.Debug("Backend error") - logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) - //Update Origin Cors Headers - if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - } + logger.Debug("An error occurred in the backend, %d", rec.statusCode) + logger.Error("Backend error: %d", rec.statusCode) RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode)) - return } else { // No error: write buffered response to client w.WriteHeader(rec.statusCode) @@ -61,7 +58,6 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle if err != nil { return } - return } diff --git a/internal/middleware/helpers.go b/internal/middleware/helpers.go index 1cfc12b..1594165 100644 --- a/internal/middleware/helpers.go +++ b/internal/middleware/helpers.go @@ -75,6 +75,5 @@ func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string) if err != nil { return } - return } diff --git a/internal/middleware/types.go b/internal/middleware/types.go index 7d813cc..ad3d613 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -27,7 +27,7 @@ import ( // RateLimiter defines requests limit properties. type RateLimiter struct { requests int - id int + id string window time.Duration clientMap map[string]*Client mu sync.Mutex @@ -42,7 +42,7 @@ type Client struct { ExpiresAt time.Time } type RateLimit struct { - Id int + Id string Requests int Window time.Duration Origins []string diff --git a/internal/route.go b/internal/route.go index c64c1b5..a6fcbad 100644 --- a/internal/route.go +++ b/internal/route.go @@ -62,13 +62,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Enable common exploits if gateway.BlockCommonExploits { logger.Info("Block common exploits enabled") - blockCommon := middleware.BlockCommon{} - r.Use(blockCommon.BlockExploitsMiddleware) + r.Use(middleware.BlockExploitsMiddleware) } if gateway.RateLimit > 0 { // Add rate limit middleware to all routes, if defined rateLimit := middleware.RateLimit{ - Id: 1, + Id: "global_rate", //Generate a unique ID for routes Requests: gateway.RateLimit, Window: time.Minute, // requests per minute Origins: gateway.Cors.Origins, @@ -232,16 +231,13 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Apply common exploits to the route // Enable common exploits if route.BlockCommonExploits { - blockCommon := middleware.BlockCommon{ - ErrorInterceptor: route.ErrorInterceptor, - } logger.Info("Block common exploits enabled") - router.Use(blockCommon.BlockExploitsMiddleware) + router.Use(middleware.BlockExploitsMiddleware) } // Apply route rate limit if route.RateLimit > 0 { rateLimit := middleware.RateLimit{ - Id: rIndex, + Id: string(rune(rIndex)), // Use route index as ID Requests: route.RateLimit, Window: time.Minute, // requests per minute Origins: route.Cors.Origins, From 2fd0159eb4148e54d75ded4c2f53b390a86fe66e Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Thu, 14 Nov 2024 14:46:18 +0100 Subject: [PATCH 6/6] fix go test --- internal/middleware/rate-limit.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/middleware/rate-limit.go b/internal/middleware/rate-limit.go index 992a558..2d2942a 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middleware/rate-limit.go @@ -54,8 +54,8 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := getRealIP(r) - clientID := fmt.Sprintf("%d-%s", rl.id, clientIP) // Generate client Id, ID+ route ID - logger.Debug("requests limiter: clientIP: %s, redisBased: %s", clientIP, rl.redisBased) + clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID + logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID) if rl.redisBased { err := redisRateLimiter(clientID, rl.requests) if err != nil {