From a874d141941c7bfe06583cacf008b51e0a470e8b Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Thu, 14 Nov 2024 11:38:36 +0100 Subject: [PATCH] feat: add Redis based rate limiting for multiple instances --- go.mod | 6 ++- go.sum | 8 ++++ internal/middleware/rate-limit.go | 72 +++++++++++++++++++++++-------- internal/middleware/types.go | 12 +++--- internal/middleware/var.go | 10 +++++ internal/route.go | 10 +++-- internal/server.go | 13 ++++++ internal/types.go | 8 ++++ 8 files changed, 112 insertions(+), 27 deletions(-) diff --git a/go.mod b/go.mod index 13ce334..8f22857 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,14 @@ go 1.23.2 require ( github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be + github.com/go-redis/redis_rate/v10 v10.0.1 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/gorilla/mux v1.8.1 github.com/prometheus/client_golang v1.20.5 + github.com/redis/go-redis/v9 v9.7.0 + github.com/robfig/cron/v3 v3.0.1 github.com/spf13/cobra v1.8.1 + golang.org/x/net v0.26.0 golang.org/x/oauth2 v0.24.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -23,13 +27,13 @@ require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/robfig/cron/v3 v3.0.1 // indirect github.com/spf13/pflag v1.0.5 // indirect google.golang.org/protobuf v1.34.2 // indirect diff --git a/go.sum b/go.sum index 5c8caf9..4126899 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,10 @@ github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be/go.mod github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo= +github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -37,6 +41,8 @@ github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= +github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= @@ -50,6 +56,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= diff --git a/internal/middleware/rate-limit.go b/internal/middleware/rate-limit.go index 050ae76..ef701f4 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middleware/rate-limit.go @@ -16,9 +16,13 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( + "errors" "fmt" + "github.com/go-redis/redis_rate/v10" "github.com/gorilla/mux" "github.com/jkaninda/goma-gateway/pkg/logger" + "github.com/redis/go-redis/v9" + "golang.org/x/net/context" "net/http" "time" ) @@ -49,30 +53,62 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - clientID := getRealIP(r) - rl.mu.Lock() - client, exists := rl.ClientMap[clientID] - if !exists || time.Now().After(client.ExpiresAt) { - client = &Client{ - RequestCount: 0, - ExpiresAt: time.Now().Add(rl.Window), + clientIP := getRealIP(r) + logger.Debug("rate limiter: clientID: %s, redisBased: %s", clientIP, rl.RedisBased) + if rl.RedisBased { + err := redisRateLimiter(clientIP, rl.Requests) + if err != nil { + logger.Error("Redis Rate limiter error: %s", err.Error()) + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) + return } - rl.ClientMap[clientID] = client - } - client.RequestCount++ - rl.mu.Unlock() + // Proceed to the next handler if rate limit is not exceeded + next.ServeHTTP(w, r) + } else { + rl.mu.Lock() + client, exists := rl.ClientMap[clientIP] + if !exists || time.Now().After(client.ExpiresAt) { + client = &Client{ + RequestCount: 0, + ExpiresAt: time.Now().Add(rl.Window), + } + rl.ClientMap[clientIP] = client + } + client.RequestCount++ + rl.mu.Unlock() - if client.RequestCount > rl.Requests { - logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent()) - //Update Origin Cors Headers - if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + if client.RequestCount > rl.Requests { + logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) + //Update Origin Cors Headers + if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) + return } - RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) - return } // Proceed to the next handler if rate limit is not exceeded next.ServeHTTP(w, r) }) } } +func redisRateLimiter(clientIP string, rate int) error { + ctx := context.Background() + + res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate)) + if err != nil { + return err + } + if res.Remaining == 0 { + return errors.New("rate limit exceeded") + } + + return nil +} +func InitRedis(addr, password string) { + Rdb = redis.NewClient(&redis.Options{ + Addr: addr, + Password: password, + }) + limiter = redis_rate.NewLimiter(Rdb) +} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index d4f1a18..f775aff 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -33,6 +33,7 @@ type RateLimiter struct { mu sync.Mutex Origins []string ErrorInterceptor errorinterceptor.ErrorInterceptor + RedisBased bool } // Client stores request count and window expiration for each client. @@ -42,12 +43,13 @@ type Client struct { } // NewRateLimiterWindow creates a new RateLimiter. -func NewRateLimiterWindow(requests int, window time.Duration, origin []string) *RateLimiter { +func NewRateLimiterWindow(requests int, window time.Duration, redisBased bool, origin []string) *RateLimiter { return &RateLimiter{ - Requests: requests, - Window: window, - ClientMap: make(map[string]*Client), - Origins: origin, + Requests: requests, + Window: window, + ClientMap: make(map[string]*Client), + Origins: origin, + RedisBased: redisBased, } } diff --git a/internal/middleware/var.go b/internal/middleware/var.go index 4e8502c..5eb3266 100644 --- a/internal/middleware/var.go +++ b/internal/middleware/var.go @@ -17,7 +17,17 @@ package middleware +import ( + "github.com/go-redis/redis_rate/v10" + "github.com/redis/go-redis/v9" +) + // sqlPatterns contains SQL injections patters const sqlPatterns = `(?i)(union|select|drop|insert|delete|update|create|alter|exec|;|--)` const traversalPatterns = `\.\./` const xssPatterns = `(?i) 0 { //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(gateway.RateLimit, time.Minute, gateway.Cors.Origins) // requests per minute + limiter := middleware.NewRateLimiterWindow(gateway.RateLimit, time.Minute, redisBased, gateway.Cors.Origins) // requests per minute // Add rate limit middleware to all routes, if defined r.Use(limiter.RateLimitMiddleware()) } @@ -229,7 +233,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Apply route rate limit if route.RateLimit > 0 { //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(route.RateLimit, time.Minute, route.Cors.Origins) // requests per minute + limiter := middleware.NewRateLimiterWindow(route.RateLimit, time.Minute, redisBased, route.Cors.Origins) // requests per minute // Add rate limit middleware to all routes, if defined router.Use(limiter.RateLimitMiddleware()) } diff --git a/internal/server.go b/internal/server.go index 6b43c7b..e71d2ca 100644 --- a/internal/server.go +++ b/internal/server.go @@ -19,7 +19,9 @@ import ( "context" "crypto/tls" "fmt" + "github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/pkg/logger" + "github.com/redis/go-redis/v9" "net/http" "os" "sync" @@ -30,8 +32,19 @@ import ( func (gatewayServer GatewayServer) Start(ctx context.Context) error { logger.Info("Initializing routes...") route := gatewayServer.Initialize() + gateway := gatewayServer.gateway logger.Debug("Routes count=%d Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) logger.Info("Initializing routes...done") + if len(gateway.Redis.Addr) != 0 { + middleware.InitRedis(gateway.Redis.Addr, gateway.Redis.Password) + defer func(Rdb *redis.Client) { + err := Rdb.Close() + if err != nil { + logger.Error("Redis connection closed with error: %v", err) + } + }(middleware.Rdb) + } + tlsConfig := &tls.Config{} var listenWithTLS = false if cert := gatewayServer.gateway.SSLCertFile; cert != "" && gatewayServer.gateway.SSLKeyFile != "" { diff --git a/internal/types.go b/internal/types.go index 80163de..5c75e85 100644 --- a/internal/types.go +++ b/internal/types.go @@ -178,6 +178,8 @@ type Gateway struct { SSLCertFile string `yaml:"sslCertFile" env:"GOMA_SSL_CERT_FILE, overwrite"` // SSLKeyFile SSL Private key file SSLKeyFile string `yaml:"sslKeyFile" env:"GOMA_SSL_KEY_FILE, overwrite"` + // Redis contains redis database details + Redis Redis `yaml:"redis"` // WriteTimeout defines proxy write timeout WriteTimeout int `yaml:"writeTimeout" env:"GOMA_WRITE_TIMEOUT, overwrite"` // ReadTimeout defines proxy read timeout @@ -204,6 +206,7 @@ type Gateway struct { InterceptErrors []int `yaml:"interceptErrors"` // Cors holds proxy global cors Cors Cors `yaml:"cors"` + // Routes holds proxy routes Routes []Route `yaml:"routes"` } @@ -287,3 +290,8 @@ type Health struct { Interval string HealthyStatuses []int } +type Redis struct { + // Addr redis hostname and post number : + Addr string `yaml:"addr"` + Password string `yaml:"password"` +}