From 6258b07c8236842c9f00d9e4caa6f7c7d7504929 Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Sun, 24 Nov 2024 15:59:47 +0100 Subject: [PATCH] refacor: improvement of rate limiting --- internal/config.go | 19 +++++++++++++++++++ internal/middleware.go | 1 + internal/middlewares/rate_limit.go | 8 ++++++-- internal/middlewares/redis.go | 9 ++++++--- internal/middlewares/types.go | 6 +++--- internal/routes.go | 27 +++++++++++++-------------- internal/types.go | 10 ++++------ internal/var.go | 7 +++++-- 8 files changed, 57 insertions(+), 30 deletions(-) diff --git a/internal/config.go b/internal/config.go index 73aaf37..6a88905 100644 --- a/internal/config.go +++ b/internal/config.go @@ -226,6 +226,25 @@ func (Gateway) Setup(conf string) *Gateway { } +// rateLimitMiddleware returns RateLimitRuleMiddleware, error +func rateLimitMiddleware(input interface{}) (RateLimitRuleMiddleware, error) { + rateLimit := new(RateLimitRuleMiddleware) + var bytes []byte + bytes, err := yaml.Marshal(input) + if err != nil { + return RateLimitRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + err = yaml.Unmarshal(bytes, rateLimit) + if err != nil { + return RateLimitRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + if rateLimit.RequestsPerUnit == 0 { + return RateLimitRuleMiddleware{}, fmt.Errorf("requests per unit not defined") + + } + return *rateLimit, nil +} + // getJWTMiddleware returns JWTRuleMiddleware,error func getJWTMiddleware(input interface{}) (JWTRuleMiddleware, error) { jWTRuler := new(JWTRuleMiddleware) diff --git a/internal/middleware.go b/internal/middleware.go index 6f519bc..ddc55a8 100644 --- a/internal/middleware.go +++ b/internal/middleware.go @@ -22,6 +22,7 @@ func getMiddleware(rules []string, middlewares []Middleware) (Middleware, error) func doesExist(tyName string) bool { middlewareList := []string{BasicAuth, JWTAuth, AccessMiddleware} + middlewareList = append(middlewareList, RateLimitMiddleware...) return slices.Contains(middlewareList, tyName) } func GetMiddleware(rule string, middlewares []Middleware) (Middleware, error) { diff --git a/internal/middlewares/rate_limit.go b/internal/middlewares/rate_limit.go index e63cb04..f2f3107 100644 --- a/internal/middlewares/rate_limit.go +++ b/internal/middlewares/rate_limit.go @@ -45,13 +45,17 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { // RateLimitMiddleware limits request based on the number of requests peer minutes. func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { + window := time.Minute // requests per minute + if len(rl.unit) != 0 && rl.unit == "hour" { + window = time.Hour + } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := getRealIP(r) clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID) if rl.redisBased { - err := redisRateLimiter(clientID, rl.requests) + err := redisRateLimiter(clientID, rl.unit, rl.requests) if err != nil { logger.Error("Redis Rate limiter error: %s", err.Error()) logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) @@ -64,7 +68,7 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { if !exists || time.Now().After(client.ExpiresAt) { client = &Client{ RequestCount: 0, - ExpiresAt: time.Now().Add(rl.window), + ExpiresAt: time.Now().Add(window), } rl.clientMap[clientID] = client } diff --git a/internal/middlewares/redis.go b/internal/middlewares/redis.go index 1f199d2..eb57867 100644 --- a/internal/middlewares/redis.go +++ b/internal/middlewares/redis.go @@ -25,10 +25,13 @@ import ( ) // redisRateLimiter, handle rateLimit -func redisRateLimiter(clientIP string, rate int) error { +func redisRateLimiter(clientIP, unit string, rate int) error { + limit := redis_rate.PerMinute(rate) + if len(unit) != 0 && unit == "hour" { + limit = redis_rate.PerHour(rate) + } ctx := context.Background() - - res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate)) + res, err := limiter.Allow(ctx, clientIP, limit) if err != nil { return err } diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index 49c9112..bedf131 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -27,8 +27,8 @@ import ( // RateLimiter defines requests limit properties. type RateLimiter struct { requests int + unit string id string - window time.Duration clientMap map[string]*Client mu sync.Mutex origins []string @@ -42,8 +42,8 @@ type Client struct { } type RateLimit struct { Id string + Unit string Requests int - Window time.Duration Origins []string Hosts []string RedisBased bool @@ -53,8 +53,8 @@ type RateLimit struct { func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter { return &RateLimiter{ id: rateLimit.Id, + unit: rateLimit.Unit, requests: rateLimit.Requests, - window: rateLimit.Window, clientMap: make(map[string]*Client), origins: rateLimit.Origins, redisBased: rateLimit.RedisBased, diff --git a/internal/routes.go b/internal/routes.go index 592acde..6dd0ee9 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -23,7 +23,6 @@ import ( "github.com/jkaninda/goma-gateway/util" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" - "time" ) // init initializes prometheus metrics @@ -62,7 +61,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { logger.Fatal("Error: %v", err) } m := dynamicMiddlewares - redisBased := false if len(gateway.Redis.Addr) != 0 { redisBased = true } @@ -97,8 +95,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Add rate limit middlewares to all routes, if defined rateLimit := middlewares.RateLimit{ Id: "global_rate", // Generate a unique ID for routes + Unit: "minute", Requests: gateway.RateLimit, - Window: time.Minute, // requests per minute Origins: gateway.Cors.Origins, Hosts: []string{}, RedisBased: redisBased, @@ -116,7 +114,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } // Apply middlewares to the route for _, middleware := range route.Middlewares { - if middleware != "" { + if len(middleware) != 0 { // Get Access middlewares if it does exist accessMiddleware, err := getMiddleware([]string{middleware}, m) if err != nil { @@ -172,9 +170,9 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Apply route rate limit if route.RateLimit != 0 { rateLimit := middlewares.RateLimit{ + Unit: "minute", Id: id, // Use route index as ID Requests: route.RateLimit, - Window: time.Minute, // requests per minute Origins: route.Cors.Origins, Hosts: route.Hosts, RedisBased: redisBased, @@ -212,16 +210,17 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) } - } - // Apply global Cors middlewares - r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares - // Apply errorInterceptor middlewares - if len(gateway.InterceptErrors) != 0 { - interceptErrors := middlewares.InterceptErrors{ - Errors: gateway.InterceptErrors, - Origins: gateway.Cors.Origins, + + // Apply global Cors middlewares + r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares + // Apply errorInterceptor middlewares + if len(gateway.InterceptErrors) != 0 { + interceptErrors := middlewares.InterceptErrors{ + Errors: gateway.InterceptErrors, + Origins: gateway.Cors.Origins, + } + r.Use(interceptErrors.ErrorInterceptor) } - r.Use(interceptErrors.ErrorInterceptor) } return r diff --git a/internal/types.go b/internal/types.go index 8f05891..2da2283 100644 --- a/internal/types.go +++ b/internal/types.go @@ -80,13 +80,11 @@ type OauthEndpoint struct { TokenURL string `yaml:"tokenUrl"` UserInfoURL string `yaml:"userInfoUrl"` } -type RateLimiter struct { - // ipBased, tokenBased - Type string `yaml:"type"` - Rate float64 `yaml:"rate"` - Rule int `yaml:"rule"` -} +type RateLimitRuleMiddleware struct { + Unit string `yaml:"unit"` + RequestsPerUnit int `yaml:"requestsPerUnit"` +} type AccessRuleMiddleware struct { ResponseCode int `yaml:"responseCode"` // HTTP Response code } diff --git a/internal/var.go b/internal/var.go index 7698733..273dd6e 100644 --- a/internal/var.go +++ b/internal/var.go @@ -9,10 +9,13 @@ const AccessMiddleware = "access" // access middlewares const BasicAuth = "basic" // basic authentication middlewares const JWTAuth = "jwt" // JWT authentication middlewares const OAuth = "oauth" // OAuth authentication middlewares + var ( // Round-robin counter counter uint32 // dynamicRoutes routes - dynamicRoutes []Route - dynamicMiddlewares []Middleware + dynamicRoutes []Route + dynamicMiddlewares []Middleware + RateLimitMiddleware = []string{"ratelimit", "rateLimit"} // Rate Limit middlewares + redisBased = false )