diff --git a/docs/middleware/rate-limit.md b/docs/middleware/rate-limit.md index 5bc3f67..4e368f1 100644 --- a/docs/middleware/rate-limit.md +++ b/docs/middleware/rate-limit.md @@ -10,13 +10,35 @@ nav_order: 6 The RateLimit middleware ensures that services will receive a fair number of requests, and allows one to define what fair is. -Example of global rateLimit middleware +Example of rate limiting middleware + +```yaml +middlewares: + - name: rate-limit + type: ratelimit #or rateLimit + paths: + - /* + rule: + unit: minute # or hour + requestsPerUnit: 10 +``` + +Example of route rate limiting middleware + +```yaml +version: 0.1.7 +gateway: + routes: + - name: Example + rateLimit: 60 # peer minute +``` + +Example of global rate limiting middleware ```yaml version: 0.1.7 gateway: - # Proxy rate limit, it's In-Memory IP based rateLimit: 60 # peer minute routes: - name: Example -``` +``` \ No newline at end of file diff --git a/internal/config.go b/internal/config.go index 73aaf37..6a88905 100644 --- a/internal/config.go +++ b/internal/config.go @@ -226,6 +226,25 @@ func (Gateway) Setup(conf string) *Gateway { } +// rateLimitMiddleware returns RateLimitRuleMiddleware, error +func rateLimitMiddleware(input interface{}) (RateLimitRuleMiddleware, error) { + rateLimit := new(RateLimitRuleMiddleware) + var bytes []byte + bytes, err := yaml.Marshal(input) + if err != nil { + return RateLimitRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + err = yaml.Unmarshal(bytes, rateLimit) + if err != nil { + return RateLimitRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + if rateLimit.RequestsPerUnit == 0 { + return RateLimitRuleMiddleware{}, fmt.Errorf("requests per unit not defined") + + } + return *rateLimit, nil +} + // getJWTMiddleware returns JWTRuleMiddleware,error func getJWTMiddleware(input interface{}) (JWTRuleMiddleware, error) { jWTRuler := new(JWTRuleMiddleware) diff --git a/internal/middleware.go b/internal/middleware.go index 6f519bc..ddc55a8 100644 --- a/internal/middleware.go +++ b/internal/middleware.go @@ -22,6 +22,7 @@ func getMiddleware(rules []string, middlewares []Middleware) (Middleware, error) func doesExist(tyName string) bool { middlewareList := []string{BasicAuth, JWTAuth, AccessMiddleware} + middlewareList = append(middlewareList, RateLimitMiddleware...) return slices.Contains(middlewareList, tyName) } func GetMiddleware(rule string, middlewares []Middleware) (Middleware, error) { diff --git a/internal/middlewares/access_middleware.go b/internal/middlewares/access_middleware.go index ca26295..d6132e2 100644 --- a/internal/middlewares/access_middleware.go +++ b/internal/middlewares/access_middleware.go @@ -53,6 +53,12 @@ func isPathBlocked(requestPath, blockedPath string) bool { } return false } +func isProtectedPath(urlPath, prefix string, paths []string) bool { + for _, path := range paths { + return isPathBlocked(urlPath, util.ParseURLPath(prefix+path)) + } + return false +} // NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter { diff --git a/internal/middlewares/middleware.go b/internal/middlewares/middleware.go index f8fc1b8..5cb66e0 100644 --- a/internal/middlewares/middleware.go +++ b/internal/middlewares/middleware.go @@ -29,112 +29,115 @@ import ( // authorization based on the result of backend's response and continue the request when the client is authorized func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for _, header := range jwtAuth.RequiredHeaders { - if r.Header.Get(header) == "" { - logger.Error("Proxy error, missing %s header", header) - w.Header().Set("Content-Type", "application/json") - // check allowed origin - if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + if isProtectedPath(r.URL.Path, jwtAuth.Path, jwtAuth.Paths) { + for _, header := range jwtAuth.RequiredHeaders { + if r.Header.Get(header) == "" { + logger.Error("Proxy error, missing %s header", header) + w.Header().Set("Content-Type", "application/json") + // check allowed origin + if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } + } + authURL, err := url.Parse(jwtAuth.AuthURL) + if err != nil { + logger.Error("Error parsing auth URL: %v", err) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + return + } + // Create a new request for /authentication + authReq, err := http.NewRequest("GET", authURL.String(), nil) + if err != nil { + logger.Error("Proxy error creating authentication request: %v", err) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + return + } + logger.Trace("JWT Auth response headers: %v", authReq.Header) + // Copy headers from the original request to the new request + for name, values := range r.Header { + for _, value := range values { + authReq.Header.Set(name, value) + } + } + // Copy cookies from the original request to the new request + for _, cookie := range r.Cookies() { + authReq.AddCookie(cookie) + } + // Perform the request to the auth service + client := &http.Client{} + authResp, err := client.Do(authReq) + if err != nil || authResp.StatusCode != http.StatusOK { + logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) + logger.Debug("Proxy authentication error") RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return - } - } - authURL, err := url.Parse(jwtAuth.AuthURL) - if err != nil { - logger.Error("Error parsing auth URL: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) - return - } - // Create a new request for /authentication - authReq, err := http.NewRequest("GET", authURL.String(), nil) - if err != nil { - logger.Error("Proxy error creating authentication request: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) - return - } - logger.Trace("JWT Auth response headers: %v", authReq.Header) - // Copy headers from the original request to the new request - for name, values := range r.Header { - for _, value := range values { - authReq.Header.Set(name, value) + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + logger.Error("Error closing body: %v", err) + } + }(authResp.Body) + // Inject specific header tp the current request's header + // Add header to the next request from AuthRequest header, depending on your requirements + if jwtAuth.Headers != nil { + for k, v := range jwtAuth.Headers { + r.Header.Set(v, authResp.Header.Get(k)) + } } - } - // Copy cookies from the original request to the new request - for _, cookie := range r.Cookies() { - authReq.AddCookie(cookie) - } - // Perform the request to the auth service - client := &http.Client{} - authResp, err := client.Do(authReq) - if err != nil || authResp.StatusCode != http.StatusOK { - logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) - logger.Debug("Proxy authentication error") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - logger.Error("Error closing body: %v", err) - } - }(authResp.Body) - // Inject specific header tp the current request's header - // Add header to the next request from AuthRequest header, depending on your requirements - if jwtAuth.Headers != nil { - for k, v := range jwtAuth.Headers { - r.Header.Set(v, authResp.Header.Get(k)) + query := r.URL.Query() + // Add query parameters to the next request from AuthRequest header, depending on your requirements + if jwtAuth.Params != nil { + for k, v := range jwtAuth.Params { + query.Set(v, authResp.Header.Get(k)) + } } + r.URL.RawQuery = query.Encode() } - query := r.URL.Query() - // Add query parameters to the next request from AuthRequest header, depending on your requirements - if jwtAuth.Params != nil { - for k, v := range jwtAuth.Params { - query.Set(v, authResp.Header.Get(k)) - } - } - r.URL.RawQuery = query.Encode() - next.ServeHTTP(w, r) }) + } // AuthMiddleware checks for the Authorization header and verifies the credentials func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Trace("Basic-Auth request headers: %v", r.Header) - // Get the Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - logger.Debug("Proxy error, missing Authorization header") - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } - // Check if the Authorization header contains "Basic" scheme - if !strings.HasPrefix(authHeader, "Basic ") { - logger.Error("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + if isProtectedPath(r.URL.Path, basicAuth.Path, basicAuth.Paths) { + // Get the Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + logger.Debug("Proxy error, missing Authorization header") + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } + // Check if the Authorization header contains "Basic" scheme + if !strings.HasPrefix(authHeader, "Basic ") { + logger.Error("Proxy error, missing Basic Authorization header") + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } + return + } + // Decode the base64 encoded username:password string + payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) + if err != nil { + logger.Debug("Proxy error, missing Basic Authorization header") + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } + // Split the payload into username and password + pair := strings.SplitN(string(payload), ":", 2) + if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } - // Decode the base64 encoded username:password string - payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) - if err != nil { - logger.Debug("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } - - // Split the payload into username and password - pair := strings.SplitN(string(payload), ":", 2) - if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return } // Continue to the next handler if the authentication is successful diff --git a/internal/middlewares/oauth_middleware.go b/internal/middlewares/oauth_middleware.go index f2d7407..3157ea5 100644 --- a/internal/middlewares/oauth_middleware.go +++ b/internal/middlewares/oauth_middleware.go @@ -26,27 +26,29 @@ import ( func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - oauthConf := oauth2Config(oauth) - // Check if the user is authenticated - token, err := r.Cookie("goma.oauth") - if err != nil { - // If no token, redirect to OAuth provider - url := oauthConf.AuthCodeURL(oauth.State) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return - } - ok, err := validateJWT(token.Value, oauth) - if err != nil { - // If no token, redirect to OAuth provider - url := oauthConf.AuthCodeURL(oauth.State) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return - } - if !ok { - // If no token, redirect to OAuth provider - url := oauthConf.AuthCodeURL(oauth.State) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return + if isProtectedPath(r.URL.Path, oauth.Path, oauth.Paths) { + oauthConf := oauth2Config(oauth) + // Check if the user is authenticated + token, err := r.Cookie("goma.oauth") + if err != nil { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + ok, err := validateJWT(token.Value, oauth) + if err != nil { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + if !ok { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } } // Token exists, proceed with request next.ServeHTTP(w, r) diff --git a/internal/middlewares/rate_limit.go b/internal/middlewares/rate_limit.go index e63cb04..be627a4 100644 --- a/internal/middlewares/rate_limit.go +++ b/internal/middlewares/rate_limit.go @@ -45,17 +45,19 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { // RateLimitMiddleware limits request based on the number of requests peer minutes. func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { + window := time.Minute // requests per minute + if len(rl.unit) != 0 && rl.unit == "hour" { + window = time.Hour + } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := getRealIP(r) clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID - logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID) if rl.redisBased { - err := redisRateLimiter(clientID, rl.requests) + err := redisRateLimiter(clientID, rl.unit, rl.requests) if err != nil { logger.Error("Redis Rate limiter error: %s", err.Error()) logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) - RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) return } } else { @@ -64,7 +66,7 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { if !exists || time.Now().After(client.ExpiresAt) { client = &Client{ RequestCount: 0, - ExpiresAt: time.Now().Add(rl.window), + ExpiresAt: time.Now().Add(window), } rl.clientMap[clientID] = client } @@ -78,8 +80,10 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) + return } } + // Proceed to the next handler if the request limit is not exceeded next.ServeHTTP(w, r) }) diff --git a/internal/middlewares/redis.go b/internal/middlewares/redis.go index 1f199d2..eb57867 100644 --- a/internal/middlewares/redis.go +++ b/internal/middlewares/redis.go @@ -25,10 +25,13 @@ import ( ) // redisRateLimiter, handle rateLimit -func redisRateLimiter(clientIP string, rate int) error { +func redisRateLimiter(clientIP, unit string, rate int) error { + limit := redis_rate.PerMinute(rate) + if len(unit) != 0 && unit == "hour" { + limit = redis_rate.PerHour(rate) + } ctx := context.Background() - - res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate)) + res, err := limiter.Allow(ctx, clientIP, limit) if err != nil { return err } diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index 49c9112..826d94a 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -27,12 +27,14 @@ import ( // RateLimiter defines requests limit properties. type RateLimiter struct { requests int + unit string id string - window time.Duration clientMap map[string]*Client mu sync.Mutex origins []string redisBased bool + pathBased bool + paths []string } // Client stores request count and window expiration for each client. @@ -42,22 +44,26 @@ type Client struct { } type RateLimit struct { Id string + Unit string Requests int - Window time.Duration Origins []string Hosts []string RedisBased bool + PathBased bool + Paths []string } // NewRateLimiterWindow creates a new RateLimiter. func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter { return &RateLimiter{ id: rateLimit.Id, + unit: rateLimit.Unit, requests: rateLimit.Requests, - window: rateLimit.Window, clientMap: make(map[string]*Client), origins: rateLimit.Origins, redisBased: rateLimit.RedisBased, + pathBased: rateLimit.PathBased, + paths: rateLimit.Paths, } } @@ -79,6 +85,8 @@ type ProxyResponseError struct { // JwtAuth stores JWT configuration type JwtAuth struct { + Path string + Paths []string AuthURL string RequiredHeaders []string Headers map[string]string @@ -101,6 +109,9 @@ type AccessListMiddleware struct { // AuthBasic contains Basic auth configuration type AuthBasic struct { + // Route path + Path string + Paths []string Username string Password string Headers map[string]string @@ -120,6 +131,10 @@ type responseRecorder struct { body *bytes.Buffer } type Oauth struct { + // Route path + Path string + // Route protected path + Paths []string // ClientID is the application's ID. ClientID string // ClientSecret is the application's secret. diff --git a/internal/route_type.go b/internal/route_type.go index 2781037..d263c10 100644 --- a/internal/route_type.go +++ b/internal/route_type.go @@ -42,7 +42,7 @@ type Route struct { HealthCheck RouteHealthCheck `yaml:"healthCheck"` // Cors contains the route cors headers Cors Cors `yaml:"cors"` - RateLimit int `yaml:"rateLimit"` + RateLimit int `yaml:"rateLimit,omitempty"` // DisableHostFording Disable X-forwarded header. // // [X-Forwarded-Host, X-Forwarded-For, Host, Scheme ] diff --git a/internal/routes.go b/internal/routes.go index 592acde..aa3e4f0 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -23,7 +23,7 @@ import ( "github.com/jkaninda/goma-gateway/util" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" - "time" + "slices" ) // init initializes prometheus metrics @@ -61,8 +61,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { if err != nil { logger.Fatal("Error: %v", err) } - m := dynamicMiddlewares - redisBased := false if len(gateway.Redis.Addr) != 0 { redisBased = true } @@ -97,8 +95,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Add rate limit middlewares to all routes, if defined rateLimit := middlewares.RateLimit{ Id: "global_rate", // Generate a unique ID for routes + Unit: "minute", Requests: gateway.RateLimit, - Window: time.Minute, // requests per minute Origins: gateway.Cors.Origins, Hosts: []string{}, RedisBased: redisBased, @@ -108,44 +106,15 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { r.Use(limiter.RateLimitMiddleware()) } for rIndex, route := range dynamicRoutes { + + // create route + router := r.PathPrefix(route.Path).Subrouter() if len(route.Path) != 0 { // Checks if route destination and backend are empty if len(route.Destination) == 0 && len(route.Backends) == 0 { logger.Fatal("Route %s : destination or backends should not be empty", route.Name) } - // Apply middlewares to the route - for _, middleware := range route.Middlewares { - if middleware != "" { - // Get Access middlewares if it does exist - accessMiddleware, err := getMiddleware([]string{middleware}, m) - if err != nil { - logger.Error("Error: %v", err.Error()) - } else { - // Apply access middlewares - if accessMiddleware.Type == AccessMiddleware { - blM := middlewares.AccessListMiddleware{ - Path: route.Path, - List: accessMiddleware.Paths, - } - r.Use(blM.AccessMiddleware) - - } - - } - // Get route authentication middlewares if it does exist - routeMiddleware, err := getMiddleware([]string{middleware}, m) - if err != nil { - // Error: middlewares not found - logger.Error("Error: %v", err.Error()) - } else { - attachAuthMiddlewares(route, routeMiddleware, gateway, r) - } - } else { - logger.Error("Error, middlewares path is empty") - logger.Error("Middleware ignored") - } - } proxyRoute := ProxyRoute{ path: route.Path, rewrite: route.Rewrite, @@ -156,42 +125,9 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { cors: route.Cors, insecureSkipVerify: route.InsecureSkipVerify, } - // create route - router := r.PathPrefix(route.Path).Subrouter() - // Apply common exploits to the route - // Enable common exploits - if route.BlockCommonExploits { - logger.Info("Block common exploits enabled") - router.Use(middlewares.BlockExploitsMiddleware) - } - id := string(rune(rIndex)) - if len(route.Name) != 0 { - // Use route name as ID - id = util.Slug(route.Name) - } - // Apply route rate limit - if route.RateLimit != 0 { - rateLimit := middlewares.RateLimit{ - Id: id, // Use route index as ID - Requests: route.RateLimit, - Window: time.Minute, // requests per minute - Origins: route.Cors.Origins, - Hosts: route.Hosts, - RedisBased: redisBased, - } - limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middlewares - router.Use(limiter.RateLimitMiddleware()) - } + attachMiddlewares(rIndex, route, gateway, router) // Apply route Cors router.Use(CORSHandler(route.Cors)) - if len(route.Hosts) > 0 { - for _, host := range route.Hosts { - router.Host(host).PathPrefix("").Handler(proxyRoute.ProxyHandler()) - } - } else { - router.PathPrefix("").Handler(proxyRoute.ProxyHandler()) - } if gateway.EnableMetrics { pr := metrics.PrometheusRoute{ Name: route.Name, @@ -208,126 +144,208 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } router.Use(interceptErrors.ErrorInterceptor) } + if len(route.Hosts) != 0 { + for _, host := range route.Hosts { + router.Host(host).PathPrefix("").Handler(proxyRoute.ProxyHandler()) + } + } else { + router.PathPrefix("").Handler(proxyRoute.ProxyHandler()) + } + } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) } - } - // Apply global Cors middlewares - r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares - // Apply errorInterceptor middlewares - if len(gateway.InterceptErrors) != 0 { - interceptErrors := middlewares.InterceptErrors{ - Errors: gateway.InterceptErrors, - Origins: gateway.Cors.Origins, + + // Apply global Cors middlewares + r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares + // Apply errorInterceptor middlewares + if len(gateway.InterceptErrors) != 0 { + interceptErrors := middlewares.InterceptErrors{ + Errors: gateway.InterceptErrors, + Origins: gateway.Cors.Origins, + } + r.Use(interceptErrors.ErrorInterceptor) } - r.Use(interceptErrors.ErrorInterceptor) + } return r } -func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) { - for _, middlewarePath := range routeMiddleware.Paths { - proxyRoute := ProxyRoute{ - path: route.Path, - rewrite: route.Rewrite, - destination: route.Destination, - backends: route.Backends, - disableHostFording: route.DisableHostFording, - methods: route.Methods, - cors: route.Cors, - insecureSkipVerify: route.InsecureSkipVerify, +// attachMiddlewares attach middlewares to the route +func attachMiddlewares(rIndex int, route Route, gateway Gateway, router *mux.Router) { + // Apply middlewares to the route + for _, middleware := range route.Middlewares { + // Apply common exploits to the route + // Enable common exploits + if route.BlockCommonExploits { + logger.Info("Block common exploits enabled") + router.Use(middlewares.BlockExploitsMiddleware) } - secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, middlewarePath)).Subrouter() - // Check Authentication middleware types - switch routeMiddleware.Type { - case BasicAuth: - basicAuth, err := getBasicAuthMiddleware(routeMiddleware.Rule) - if err != nil { - logger.Error("Error: %s", err.Error()) - } else { - authBasic := middlewares.AuthBasic{ - Username: basicAuth.Username, - Password: basicAuth.Password, - Headers: nil, - Params: nil, - } - // Apply JWT authentication middlewares - secureRouter.Use(authBasic.AuthMiddleware) - secureRouter.Use(CORSHandler(route.Cors)) - secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler + id := string(rune(rIndex)) + if len(route.Name) != 0 { + // Use route name as ID + id = util.Slug(route.Name) + } + // Apply route rate limit + if route.RateLimit != 0 { + rateLimit := middlewares.RateLimit{ + Unit: "minute", + Id: id, // Use route index as ID + Requests: route.RateLimit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, } - case JWTAuth: - jwt, err := getJWTMiddleware(routeMiddleware.Rule) + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + } + if len(middleware) != 0 { + // Get Access middlewares if it does exist + accessMiddleware, err := getMiddleware([]string{middleware}, dynamicMiddlewares) if err != nil { - logger.Error("Error: %s", err.Error()) + logger.Error("Error: %v", err.Error()) } else { - jwtAuth := middlewares.JwtAuth{ - AuthURL: jwt.URL, - RequiredHeaders: jwt.RequiredHeaders, - Headers: jwt.Headers, - Params: jwt.Params, - Origins: gateway.Cors.Origins, + // Apply access middlewares + if accessMiddleware.Type == AccessMiddleware { + blM := middlewares.AccessListMiddleware{ + Path: route.Path, + List: accessMiddleware.Paths, + } + router.Use(blM.AccessMiddleware) + } + + // Apply Rate limit middleware + if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) { + rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule) + if err != nil { + logger.Error("Error: %v", err.Error()) + } + if rateLimitMid.RequestsPerUnit != 0 && route.RateLimit == 0 { + rateLimit := middlewares.RateLimit{ + Unit: rateLimitMid.Unit, + Id: id, // Use route index as ID + Requests: rateLimitMid.RequestsPerUnit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + PathBased: true, + Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths), + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + + } + } - // Apply JWT authentication middlewares - secureRouter.Use(jwtAuth.AuthMiddleware) - secureRouter.Use(CORSHandler(route.Cors)) - secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler } - case OAuth: - oauth, err := oAuthMiddleware(routeMiddleware.Rule) + // Get route authentication middlewares if it does exist + routeMiddleware, err := getMiddleware([]string{middleware}, dynamicMiddlewares) if err != nil { - logger.Error("Error: %s", err.Error()) + // Error: middlewares not found + logger.Error("Error: %v", err.Error()) } else { - redirectURL := "/callback" + route.Path - if oauth.RedirectURL != "" { - redirectURL = oauth.RedirectURL - } - amw := middlewares.Oauth{ - ClientID: oauth.ClientID, - ClientSecret: oauth.ClientSecret, - RedirectURL: redirectURL, - Scopes: oauth.Scopes, - Endpoint: middlewares.OauthEndpoint{ - AuthURL: oauth.Endpoint.AuthURL, - TokenURL: oauth.Endpoint.TokenURL, - UserInfoURL: oauth.Endpoint.UserInfoURL, - }, - State: oauth.State, - Origins: gateway.Cors.Origins, - JWTSecret: oauth.JWTSecret, - Provider: oauth.Provider, - } - oauthRuler := oauthRulerMiddleware(amw) - // Check if a cookie path is defined - if oauthRuler.CookiePath == "" { - oauthRuler.CookiePath = route.Path - } - // Check if a RedirectPath is defined - if oauthRuler.RedirectPath == "" { - oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, middlewarePath) - } - if oauthRuler.Provider == "" { - oauthRuler.Provider = "custom" - } - secureRouter.Use(amw.AuthMiddleware) - secureRouter.Use(CORSHandler(route.Cors)) - secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler - // Callback route - r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET") - } - default: - if !doesExist(routeMiddleware.Type) { - logger.Error("Unknown middlewares type %s", routeMiddleware.Type) + attachAuthMiddlewares(route, routeMiddleware, gateway, router) } + } else { + logger.Error("Error, middlewares path is empty") + logger.Error("Middleware ignored") + } + } +} + +func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) { + // Check Authentication middleware types + switch routeMiddleware.Type { + case BasicAuth: + basicAuth, err := getBasicAuthMiddleware(routeMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + authBasic := middlewares.AuthBasic{ + Path: route.Path, + Paths: routeMiddleware.Paths, + Username: basicAuth.Username, + Password: basicAuth.Password, + Headers: nil, + Params: nil, + } + // Apply JWT authentication middlewares + r.Use(authBasic.AuthMiddleware) + r.Use(CORSHandler(route.Cors)) + } + case JWTAuth: + jwt, err := getJWTMiddleware(routeMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + jwtAuth := middlewares.JwtAuth{ + Path: route.Path, + Paths: routeMiddleware.Paths, + AuthURL: jwt.URL, + RequiredHeaders: jwt.RequiredHeaders, + Headers: jwt.Headers, + Params: jwt.Params, + Origins: gateway.Cors.Origins, + } + // Apply JWT authentication middlewares + r.Use(jwtAuth.AuthMiddleware) + r.Use(CORSHandler(route.Cors)) + + } + case OAuth: + oauth, err := oAuthMiddleware(routeMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + redirectURL := "/callback" + route.Path + if oauth.RedirectURL != "" { + redirectURL = oauth.RedirectURL + } + amw := middlewares.Oauth{ + Path: route.Path, + Paths: routeMiddleware.Paths, + ClientID: oauth.ClientID, + ClientSecret: oauth.ClientSecret, + RedirectURL: redirectURL, + Scopes: oauth.Scopes, + Endpoint: middlewares.OauthEndpoint{ + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + UserInfoURL: oauth.Endpoint.UserInfoURL, + }, + State: oauth.State, + Origins: gateway.Cors.Origins, + JWTSecret: oauth.JWTSecret, + Provider: oauth.Provider, + } + oauthRuler := oauthRulerMiddleware(amw) + // Check if a cookie path is defined + if oauthRuler.CookiePath == "" { + oauthRuler.CookiePath = route.Path + } + // Check if a RedirectPath is defined + if oauthRuler.RedirectPath == "" { + oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, routeMiddleware.Paths[0]) + } + if oauthRuler.Provider == "" { + oauthRuler.Provider = "custom" + } + r.Use(amw.AuthMiddleware) + r.Use(CORSHandler(route.Cors)) + r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET") + } + default: + if !doesExist(routeMiddleware.Type) { + logger.Error("Unknown middlewares type %s", routeMiddleware.Type) } } + } diff --git a/internal/server.go b/internal/server.go index 340b177..a9be5ea 100644 --- a/internal/server.go +++ b/internal/server.go @@ -30,7 +30,7 @@ import ( // Start / Start starts the server func (gatewayServer GatewayServer) Start() error { logger.Info("Initializing routes...") - route := gatewayServer.Initialize() + router := gatewayServer.Initialize() logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) gatewayServer.initRedis() defer gatewayServer.closeRedis() @@ -44,8 +44,8 @@ func (gatewayServer GatewayServer) Start() error { printRoute(dynamicRoutes) } - httpServer := gatewayServer.createServer(":8080", route, nil) - httpsServer := gatewayServer.createServer(":8443", route, tlsConfig) + httpServer := gatewayServer.createServer(":8080", router, nil) + httpsServer := gatewayServer.createServer(":8443", router, tlsConfig) // Start HTTP/HTTPS servers gatewayServer.startServers(httpServer, httpsServer, listenWithTLS) diff --git a/internal/types.go b/internal/types.go index 8f05891..2da2283 100644 --- a/internal/types.go +++ b/internal/types.go @@ -80,13 +80,11 @@ type OauthEndpoint struct { TokenURL string `yaml:"tokenUrl"` UserInfoURL string `yaml:"userInfoUrl"` } -type RateLimiter struct { - // ipBased, tokenBased - Type string `yaml:"type"` - Rate float64 `yaml:"rate"` - Rule int `yaml:"rule"` -} +type RateLimitRuleMiddleware struct { + Unit string `yaml:"unit"` + RequestsPerUnit int `yaml:"requestsPerUnit"` +} type AccessRuleMiddleware struct { ResponseCode int `yaml:"responseCode"` // HTTP Response code } diff --git a/internal/var.go b/internal/var.go index 7698733..273dd6e 100644 --- a/internal/var.go +++ b/internal/var.go @@ -9,10 +9,13 @@ const AccessMiddleware = "access" // access middlewares const BasicAuth = "basic" // basic authentication middlewares const JWTAuth = "jwt" // JWT authentication middlewares const OAuth = "oauth" // OAuth authentication middlewares + var ( // Round-robin counter counter uint32 // dynamicRoutes routes - dynamicRoutes []Route - dynamicMiddlewares []Middleware + dynamicRoutes []Route + dynamicMiddlewares []Middleware + RateLimitMiddleware = []string{"ratelimit", "rateLimit"} // Rate Limit middlewares + redisBased = false ) diff --git a/util/helpers.go b/util/helpers.go index 3248312..a3bc174 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -157,3 +157,11 @@ func Slug(text string) string { return text } + +func AddPrefixPath(prefix string, paths []string) []string { + for i := range paths { + paths[i] = ParseURLPath(prefix + paths[i]) + } + return paths + +}