From f1af5c3ce6d1269f3abb9642b5a58b3fe511e6e3 Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Fri, 15 Nov 2024 14:24:35 +0100 Subject: [PATCH] refactor: refactoring of code Add graceful shutdown server --- examples/goma.yml | 4 +- internal/checkConfig.go | 2 +- internal/config.go | 10 +- internal/{healthCheck.go => healthcheck.go} | 0 internal/middleware.go | 4 +- internal/middleware_test.go | 10 +- internal/middleware_type.go | 4 +- .../access-middleware.go | 2 +- .../block-common-exploits.go | 2 +- .../{middleware => middlewares}/config.go | 2 +- .../error-interceptor.go | 2 +- .../{middleware => middlewares}/helpers.go | 2 +- .../{middleware => middlewares}/middleware.go | 2 +- .../oauth-middleware.go | 2 +- .../{middleware => middlewares}/rate-limit.go | 26 +-- internal/middlewares/redis.go | 46 +++++ internal/{middleware => middlewares}/types.go | 2 +- internal/{middleware => middlewares}/var.go | 2 +- internal/proxy.go | 6 +- internal/redis.go | 40 +++++ internal/route.go | 60 +++---- internal/route_type.go | 2 +- internal/server.go | 168 ++++++++---------- internal/tls.go | 36 ++++ internal/var.go | 8 +- main.go | 4 +- 26 files changed, 267 insertions(+), 181 deletions(-) rename internal/{healthCheck.go => healthcheck.go} (100%) rename internal/{middleware => middlewares}/access-middleware.go (99%) rename internal/{middleware => middlewares}/block-common-exploits.go (99%) rename internal/{middleware => middlewares}/config.go (98%) rename internal/{middleware => middlewares}/error-interceptor.go (99%) rename internal/{middleware => middlewares}/helpers.go (98%) rename internal/{middleware => middlewares}/middleware.go (99%) rename internal/{middleware => middlewares}/oauth-middleware.go (99%) rename internal/{middleware => middlewares}/rate-limit.go (85%) create mode 100644 internal/middlewares/redis.go rename internal/{middleware => middlewares}/types.go (99%) rename internal/{middleware => middlewares}/var.go (97%) create mode 100644 internal/redis.go create mode 100644 internal/tls.go diff --git a/examples/goma.yml b/examples/goma.yml index ab41911..7397595 100644 --- a/examples/goma.yml +++ b/examples/goma.yml @@ -76,7 +76,7 @@ gateway: Access-Control-Max-Age: 1728000 ##### Apply middlewares to the route ## The name must be unique - ## List of middleware name + ## List of middlewares name middlewares: - api-forbidden-paths # Example of a route | 2 @@ -103,7 +103,7 @@ gateway: - api-forbidden-paths - basic-auth #Defines proxy middlewares -# middleware name must be unique +# middlewares name must be unique middlewares: # Enable Basic auth authorization based - name: basic-auth diff --git a/internal/checkConfig.go b/internal/checkConfig.go index 4c99e28..2b54aab 100644 --- a/internal/checkConfig.go +++ b/internal/checkConfig.go @@ -52,7 +52,7 @@ func CheckConfig(fileName string) error { } } - //Check middleware + //Check middlewares for index, mid := range c.Middlewares { if util.HasWhitespace(mid.Name) { fmt.Printf("Warning: Middleware contains whitespace: %s | index: [%d], please remove whitespace characters\n", mid.Name, index) diff --git a/internal/config.go b/internal/config.go index 72ad4ff..7becabd 100644 --- a/internal/config.go +++ b/internal/config.go @@ -17,7 +17,7 @@ limitations under the License. */ import ( "fmt" - "github.com/jkaninda/goma-gateway/internal/middleware" + "github.com/jkaninda/goma-gateway/internal/middlewares" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" "golang.org/x/oauth2" @@ -330,7 +330,7 @@ func getJWTMiddleware(input interface{}) (JWTRuleMiddleware, error) { return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) } if jWTRuler.URL == "" { - return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty url in jwt auth middleware") + return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty url in jwt auth middlewares") } return *jWTRuler, nil @@ -349,7 +349,7 @@ func getBasicAuthMiddleware(input interface{}) (BasicRuleMiddleware, error) { return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) } if basicAuth.Username == "" || basicAuth.Password == "" { - return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty username/password in %s middleware", basicAuth) + return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty username/password in %s middlewares", basicAuth) } return *basicAuth, nil @@ -368,12 +368,12 @@ func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) { return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) } if oauthRuler.ClientID == "" || oauthRuler.ClientSecret == "" || oauthRuler.RedirectURL == "" { - return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middleware", oauthRuler) + return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middlewares", oauthRuler) } return *oauthRuler, nil } -func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware { +func oauthRulerMiddleware(oauth middlewares.Oauth) *OauthRulerMiddleware { return &OauthRulerMiddleware{ ClientID: oauth.ClientID, ClientSecret: oauth.ClientSecret, diff --git a/internal/healthCheck.go b/internal/healthcheck.go similarity index 100% rename from internal/healthCheck.go rename to internal/healthcheck.go diff --git a/internal/middleware.go b/internal/middleware.go index aa7b146..b72fec4 100644 --- a/internal/middleware.go +++ b/internal/middleware.go @@ -14,7 +14,7 @@ func getMiddleware(rules []string, middlewares []Middleware) (Middleware, error) continue } - return Middleware{}, errors.New("middleware not found with name: [" + strings.Join(rules, ";") + "]") + return Middleware{}, errors.New("middlewares not found with name: [" + strings.Join(rules, ";") + "]") } func doesExist(tyName string) bool { @@ -30,5 +30,5 @@ func GetMiddleware(rule string, middlewares []Middleware) (Middleware, error) { continue } - return Middleware{}, errors.New("no middleware found with name " + rule) + return Middleware{}, errors.New("no middlewares found with name " + rule) } diff --git a/internal/middleware_test.go b/internal/middleware_test.go index 316d370..e7e8e73 100644 --- a/internal/middleware_test.go +++ b/internal/middleware_test.go @@ -105,7 +105,7 @@ func TestReadMiddleware(t *testing.T) { middlewares := getMiddlewares(t) m, err := getMiddleware(rules, middlewares) if err != nil { - t.Fatalf("Error searching middleware %s", err.Error()) + t.Fatalf("Error searching middlewares %s", err.Error()) } log.Printf("Middleware: %v\n", m) @@ -134,10 +134,10 @@ func TestReadMiddleware(t *testing.T) { } log.Printf("OAuth authentification: provider %s\n", oauth.Provider) case AccessMiddleware: - log.Println("Access middleware") - log.Printf("Access middleware: paths: [%s]\n", middleware.Paths) + log.Println("Access middlewares") + log.Printf("Access middlewares: paths: [%s]\n", middleware.Paths) default: - t.Errorf("Unknown middleware type %s", middleware.Type) + t.Errorf("Unknown middlewares type %s", middleware.Type) } } @@ -148,7 +148,7 @@ func TestFoundMiddleware(t *testing.T) { middlewares := getMiddlewares(t) middleware, err := GetMiddleware("jwt", middlewares) if err != nil { - t.Errorf("Error getting middleware %v", err) + t.Errorf("Error getting middlewares %v", err) } fmt.Println(middleware.Type) } diff --git a/internal/middleware_type.go b/internal/middleware_type.go index acd33a3..1bd9426 100644 --- a/internal/middleware_type.go +++ b/internal/middleware_type.go @@ -17,9 +17,9 @@ package pkg -// Middleware defined the route middleware +// Middleware defined the route middlewares type Middleware struct { - //Path contains the name of middleware and must be unique + //Path contains the name of middlewares and must be unique Name string `yaml:"name"` // Type contains authentication types // diff --git a/internal/middleware/access-middleware.go b/internal/middlewares/access-middleware.go similarity index 99% rename from internal/middleware/access-middleware.go rename to internal/middlewares/access-middleware.go index 11ddd5f..ca26295 100644 --- a/internal/middleware/access-middleware.go +++ b/internal/middlewares/access-middleware.go @@ -1,4 +1,4 @@ -package middleware +package middlewares /* Copyright 2024 Jonas Kaninda diff --git a/internal/middleware/block-common-exploits.go b/internal/middlewares/block-common-exploits.go similarity index 99% rename from internal/middleware/block-common-exploits.go rename to internal/middlewares/block-common-exploits.go index 7c09ffb..ff8b4b9 100644 --- a/internal/middleware/block-common-exploits.go +++ b/internal/middlewares/block-common-exploits.go @@ -15,7 +15,7 @@ * */ -package middleware +package middlewares import ( "fmt" diff --git a/internal/middleware/config.go b/internal/middlewares/config.go similarity index 98% rename from internal/middleware/config.go rename to internal/middlewares/config.go index c26f218..26ec0a0 100644 --- a/internal/middleware/config.go +++ b/internal/middlewares/config.go @@ -15,7 +15,7 @@ * */ -package middleware +package middlewares import ( "github.com/jkaninda/goma-gateway/pkg/logger" diff --git a/internal/middleware/error-interceptor.go b/internal/middlewares/error-interceptor.go similarity index 99% rename from internal/middleware/error-interceptor.go rename to internal/middlewares/error-interceptor.go index 36e2a5b..000010e 100644 --- a/internal/middleware/error-interceptor.go +++ b/internal/middlewares/error-interceptor.go @@ -1,4 +1,4 @@ -package middleware +package middlewares /* * Copyright 2024 Jonas Kaninda diff --git a/internal/middleware/helpers.go b/internal/middlewares/helpers.go similarity index 98% rename from internal/middleware/helpers.go rename to internal/middlewares/helpers.go index 5f78b7a..351f182 100644 --- a/internal/middleware/helpers.go +++ b/internal/middlewares/helpers.go @@ -15,7 +15,7 @@ * */ -package middleware +package middlewares import ( "encoding/json" diff --git a/internal/middleware/middleware.go b/internal/middlewares/middleware.go similarity index 99% rename from internal/middleware/middleware.go rename to internal/middlewares/middleware.go index 985bb95..ebc3ddd 100644 --- a/internal/middleware/middleware.go +++ b/internal/middlewares/middleware.go @@ -1,4 +1,4 @@ -package middleware +package middlewares /* Copyright 2024 Jonas Kaninda diff --git a/internal/middleware/oauth-middleware.go b/internal/middlewares/oauth-middleware.go similarity index 99% rename from internal/middleware/oauth-middleware.go rename to internal/middlewares/oauth-middleware.go index 1697ae1..f2d7407 100644 --- a/internal/middleware/oauth-middleware.go +++ b/internal/middlewares/oauth-middleware.go @@ -15,7 +15,7 @@ * */ -package middleware +package middlewares import ( "fmt" diff --git a/internal/middleware/rate-limit.go b/internal/middlewares/rate-limit.go similarity index 85% rename from internal/middleware/rate-limit.go rename to internal/middlewares/rate-limit.go index 2d2942a..3a8bfe3 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middlewares/rate-limit.go @@ -1,4 +1,4 @@ -package middleware +package middlewares /* Copyright 2024 Jonas Kaninda @@ -16,13 +16,9 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( - "errors" "fmt" - "github.com/go-redis/redis_rate/v10" "github.com/gorilla/mux" "github.com/jkaninda/goma-gateway/pkg/logger" - "github.com/redis/go-redis/v9" - "golang.org/x/net/context" "net/http" "time" ) @@ -91,23 +87,3 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { }) } } -func redisRateLimiter(clientIP string, rate int) error { - ctx := context.Background() - - res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate)) - if err != nil { - return err - } - if res.Remaining == 0 { - return errors.New("requests limit exceeded") - } - - return nil -} -func InitRedis(addr, password string) { - Rdb = redis.NewClient(&redis.Options{ - Addr: addr, - Password: password, - }) - limiter = redis_rate.NewLimiter(Rdb) -} diff --git a/internal/middlewares/redis.go b/internal/middlewares/redis.go new file mode 100644 index 0000000..4f1dd0e --- /dev/null +++ b/internal/middlewares/redis.go @@ -0,0 +1,46 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package middlewares + +import ( + "context" + "errors" + "github.com/go-redis/redis_rate/v10" + "github.com/redis/go-redis/v9" +) + +func redisRateLimiter(clientIP string, rate int) error { + ctx := context.Background() + + res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate)) + if err != nil { + return err + } + if res.Remaining == 0 { + return errors.New("requests limit exceeded") + } + + return nil +} +func InitRedis(addr, password string) { + Rdb = redis.NewClient(&redis.Options{ + Addr: addr, + Password: password, + }) + limiter = redis_rate.NewLimiter(Rdb) +} diff --git a/internal/middleware/types.go b/internal/middlewares/types.go similarity index 99% rename from internal/middleware/types.go rename to internal/middlewares/types.go index 27d1cda..5bcfb7c 100644 --- a/internal/middleware/types.go +++ b/internal/middlewares/types.go @@ -15,7 +15,7 @@ * */ -package middleware +package middlewares import ( "bytes" diff --git a/internal/middleware/var.go b/internal/middlewares/var.go similarity index 97% rename from internal/middleware/var.go rename to internal/middlewares/var.go index 5eb3266..f3021aa 100644 --- a/internal/middleware/var.go +++ b/internal/middlewares/var.go @@ -15,7 +15,7 @@ * */ -package middleware +package middlewares import ( "github.com/go-redis/redis_rate/v10" diff --git a/internal/proxy.go b/internal/proxy.go index e1dbf3d..f21f68b 100644 --- a/internal/proxy.go +++ b/internal/proxy.go @@ -18,7 +18,7 @@ limitations under the License. import ( "crypto/tls" "fmt" - "github.com/jkaninda/goma-gateway/internal/middleware" + "github.com/jkaninda/goma-gateway/internal/middlewares" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" "net/http/httputil" @@ -38,7 +38,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { if len(proxyRoute.methods) > 0 { if !slices.Contains(proxyRoute.methods, r.Method) { logger.Error("%s Method is not allowed", r.Method) - middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method)) + middlewares.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method)) return } } @@ -61,7 +61,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { targetURL, err := url.Parse(proxyRoute.destination) if err != nil { logger.Error("Error parsing backend URL: %s", err) - middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + middlewares.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) diff --git a/internal/redis.go b/internal/redis.go new file mode 100644 index 0000000..8802e28 --- /dev/null +++ b/internal/redis.go @@ -0,0 +1,40 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package pkg + +import ( + "github.com/jkaninda/goma-gateway/internal/middlewares" + "github.com/jkaninda/goma-gateway/pkg/logger" +) + +func (gatewayServer GatewayServer) initRedis() error { + if gatewayServer.gateway.Redis.Addr == "" { + return nil + } + logger.Info("Initializing Redis...") + middlewares.InitRedis(gatewayServer.gateway.Redis.Addr, gatewayServer.gateway.Redis.Password) + return nil +} + +func (gatewayServer GatewayServer) closeRedis() { + if middlewares.Rdb != nil { + if err := middlewares.Rdb.Close(); err != nil { + logger.Error("Error closing Redis: %v", err) + } + } +} diff --git a/internal/route.go b/internal/route.go index dd9c077..6ef0b3e 100644 --- a/internal/route.go +++ b/internal/route.go @@ -17,7 +17,7 @@ limitations under the License. */ import ( "github.com/gorilla/mux" - "github.com/jkaninda/goma-gateway/internal/middleware" + "github.com/jkaninda/goma-gateway/internal/middlewares" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" "github.com/prometheus/client_golang/prometheus" @@ -34,7 +34,7 @@ func init() { // Initialize the routes func (gatewayServer GatewayServer) Initialize() *mux.Router { gateway := gatewayServer.gateway - middlewares := gatewayServer.middlewares + m := gatewayServer.middlewares redisBased := false if len(gateway.Redis.Addr) != 0 { redisBased = true @@ -62,11 +62,11 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Enable common exploits if gateway.BlockCommonExploits { logger.Info("Block common exploits enabled") - r.Use(middleware.BlockExploitsMiddleware) + r.Use(middlewares.BlockExploitsMiddleware) } if gateway.RateLimit > 0 { - // Add rate limit middleware to all routes, if defined - rateLimit := middleware.RateLimit{ + // Add rate limit middlewares to all routes, if defined + rateLimit := middlewares.RateLimit{ Id: "global_rate", //Generate a unique ID for routes Requests: gateway.RateLimit, Window: time.Minute, // requests per minute @@ -75,7 +75,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { RedisBased: redisBased, } limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middleware + // Add rate limit middlewares r.Use(limiter.RateLimitMiddleware()) } for rIndex, route := range gateway.Routes { @@ -87,14 +87,14 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Apply middlewares to route for _, mid := range route.Middlewares { if mid != "" { - // Get Access middleware if it does exist - accessMiddleware, err := getMiddleware([]string{mid}, middlewares) + // Get Access middlewares if it does exist + accessMiddleware, err := getMiddleware([]string{mid}, m) if err != nil { logger.Error("Error: %v", err.Error()) } else { - // Apply access middleware + // Apply access middlewares if accessMiddleware.Type == AccessMiddleware { - blM := middleware.AccessListMiddleware{ + blM := middlewares.AccessListMiddleware{ Path: route.Path, List: accessMiddleware.Paths, } @@ -103,10 +103,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } } - // Get route authentication middleware if it does exist - rMiddleware, err := getMiddleware([]string{mid}, middlewares) + // Get route authentication middlewares if it does exist + rMiddleware, err := getMiddleware([]string{mid}, m) if err != nil { - //Error: middleware not found + //Error: middlewares not found logger.Error("Error: %v", err.Error()) } else { for _, midPath := range rMiddleware.Paths { @@ -122,20 +122,20 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter() //callBackRouter := r.PathPrefix(util.ParseRoutePath(route.Path, "/callback")).Subrouter() - //Check Authentication middleware + //Check Authentication middlewares switch rMiddleware.Type { case BasicAuth: basicAuth, err := getBasicAuthMiddleware(rMiddleware.Rule) if err != nil { logger.Error("Error: %s", err.Error()) } else { - amw := middleware.AuthBasic{ + amw := middlewares.AuthBasic{ Username: basicAuth.Username, Password: basicAuth.Password, Headers: nil, Params: nil, } - // Apply JWT authentication middleware + // Apply JWT authentication middlewares secureRouter.Use(amw.AuthMiddleware) secureRouter.Use(CORSHandler(route.Cors)) secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler @@ -146,14 +146,14 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { if err != nil { logger.Error("Error: %s", err.Error()) } else { - amw := middleware.JwtAuth{ + amw := middlewares.JwtAuth{ AuthURL: jwt.URL, RequiredHeaders: jwt.RequiredHeaders, Headers: jwt.Headers, Params: jwt.Params, Origins: gateway.Cors.Origins, } - // Apply JWT authentication middleware + // Apply JWT authentication middlewares secureRouter.Use(amw.AuthMiddleware) secureRouter.Use(CORSHandler(route.Cors)) secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler @@ -169,12 +169,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { if oauth.RedirectURL != "" { redirectURL = oauth.RedirectURL } - amw := middleware.Oauth{ + amw := middlewares.Oauth{ ClientID: oauth.ClientID, ClientSecret: oauth.ClientSecret, RedirectURL: redirectURL, Scopes: oauth.Scopes, - Endpoint: middleware.OauthEndpoint{ + Endpoint: middlewares.OauthEndpoint{ AuthURL: oauth.Endpoint.AuthURL, TokenURL: oauth.Endpoint.TokenURL, UserInfoURL: oauth.Endpoint.UserInfoURL, @@ -205,7 +205,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } default: if !doesExist(rMiddleware.Type) { - logger.Error("Unknown middleware type %s", rMiddleware.Type) + logger.Error("Unknown middlewares type %s", rMiddleware.Type) } } @@ -214,7 +214,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } } else { - logger.Error("Error, middleware path is empty") + logger.Error("Error, middlewares path is empty") logger.Error("Middleware ignored") } } @@ -234,7 +234,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Enable common exploits if route.BlockCommonExploits { logger.Info("Block common exploits enabled") - router.Use(middleware.BlockExploitsMiddleware) + router.Use(middlewares.BlockExploitsMiddleware) } id := string(rune(rIndex)) if len(route.Name) != 0 { @@ -243,7 +243,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } // Apply route rate limit if route.RateLimit > 0 { - rateLimit := middleware.RateLimit{ + rateLimit := middlewares.RateLimit{ Id: id, // Use route index as ID Requests: route.RateLimit, Window: time.Minute, // requests per minute @@ -252,7 +252,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { RedisBased: redisBased, } limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middleware + // Add rate limit middlewares router.Use(limiter.RateLimitMiddleware()) } // Apply route Cors @@ -272,9 +272,9 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Prometheus endpoint router.Use(pr.prometheusMiddleware) } - // Apply route Error interceptor middleware + // Apply route Error interceptor middlewares if len(route.InterceptErrors) != 0 { - interceptErrors := middleware.InterceptErrors{ + interceptErrors := middlewares.InterceptErrors{ Origins: route.Cors.Origins, Errors: route.InterceptErrors, } @@ -286,10 +286,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } } // Apply global Cors middlewares - r.Use(CORSHandler(gateway.Cors)) // Apply CORS middleware - // Apply errorInterceptor middleware + r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares + // Apply errorInterceptor middlewares if len(gateway.InterceptErrors) != 0 { - interceptErrors := middleware.InterceptErrors{ + interceptErrors := middlewares.InterceptErrors{ Errors: gateway.InterceptErrors, Origins: gateway.Cors.Origins, } diff --git a/internal/route_type.go b/internal/route_type.go index 11bd21e..acc8210 100644 --- a/internal/route_type.go +++ b/internal/route_type.go @@ -53,6 +53,6 @@ type Route struct { InterceptErrors []int `yaml:"interceptErrors"` // BlockCommonExploits enable, disable block common exploits BlockCommonExploits bool `yaml:"blockCommonExploits"` - // Middlewares Defines route middleware from Middleware names + // Middlewares Defines route middlewares from Middleware names Middlewares []string `yaml:"middlewares"` } diff --git a/internal/server.go b/internal/server.go index e71d2ca..ce8c57c 100644 --- a/internal/server.go +++ b/internal/server.go @@ -18,111 +18,97 @@ limitations under the License. import ( "context" "crypto/tls" + "errors" "fmt" - "github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/pkg/logger" - "github.com/redis/go-redis/v9" "net/http" "os" - "sync" + "os/signal" + "syscall" "time" ) -// Start starts the server +// Start / Start starts the server func (gatewayServer GatewayServer) Start(ctx context.Context) error { logger.Info("Initializing routes...") route := gatewayServer.Initialize() - gateway := gatewayServer.gateway - logger.Debug("Routes count=%d Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) - logger.Info("Initializing routes...done") - if len(gateway.Redis.Addr) != 0 { - middleware.InitRedis(gateway.Redis.Addr, gateway.Redis.Password) - defer func(Rdb *redis.Client) { - err := Rdb.Close() - if err != nil { - logger.Error("Redis connection closed with error: %v", err) - } - }(middleware.Rdb) + logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) + if err := gatewayServer.initRedis(); err != nil { + return fmt.Errorf("failed to initialize Redis: %w", err) + } + defer gatewayServer.closeRedis() + + tlsConfig, listenWithTLS, err := gatewayServer.initTLS() + if err != nil { + return err } - tlsConfig := &tls.Config{} - var listenWithTLS = false - if cert := gatewayServer.gateway.SSLCertFile; cert != "" && gatewayServer.gateway.SSLKeyFile != "" { - tlsConf, err := loadTLS(cert, gatewayServer.gateway.SSLKeyFile) - if err != nil { - return err - } - tlsConfig = tlsConf - listenWithTLS = true - - } - // HTTP Server - httpServer := &http.Server{ - Addr: ":8080", - WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout), - ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout), - IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout), - Handler: route, // Pass our instance of gorilla/mux in. - } - // HTTPS Server - httpsServer := &http.Server{ - Addr: ":8443", - WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout), - ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout), - IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout), - Handler: route, // Pass our instance of gorilla/mux in. - TLSConfig: tlsConfig, - } if !gatewayServer.gateway.DisableDisplayRouteOnStart { printRoute(gatewayServer.gateway.Routes) } - // Set KeepAlive - httpServer.SetKeepAlivesEnabled(!gatewayServer.gateway.DisableKeepAlive) - go func() { - logger.Info("Starting HTTP server listen=0.0.0.0:8080") - if err := httpServer.ListenAndServe(); err != nil { - logger.Fatal("Error starting Goma Gateway HTTP server: %v", err) - } - }() - go func() { - if listenWithTLS { - logger.Info("Starting HTTPS server listen=0.0.0.0:8443") - if err := httpsServer.ListenAndServeTLS("", ""); err != nil { - logger.Fatal("Error starting Goma Gateway HTTPS server: %v", err) - } - } - }() - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - <-ctx.Done() - shutdownCtx := context.Background() - shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second) - defer cancel() - if err := httpServer.Shutdown(shutdownCtx); err != nil { - _, err := fmt.Fprintf(os.Stderr, "error shutting down HTTP server: %s\n", err) - if err != nil { - return - } - } - }() - go func() { - defer wg.Done() - <-ctx.Done() - shutdownCtx := context.Background() - shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second) - defer cancel() - if listenWithTLS { - if err := httpsServer.Shutdown(shutdownCtx); err != nil { - _, err := fmt.Fprintf(os.Stderr, "error shutting HTTPS server: %s\n", err) - if err != nil { - return - } - } - } - }() - wg.Wait() - return nil + httpServer := gatewayServer.createServer(":8080", route, nil) + httpsServer := gatewayServer.createServer(":8443", route, tlsConfig) + + // Start HTTP/HTTPS servers + if err := gatewayServer.startServers(httpServer, httpsServer, listenWithTLS); err != nil { + return err + } + + // Handle graceful shutdown + return gatewayServer.gracefulShutdown(ctx, httpServer, httpsServer, listenWithTLS) +} + +func (gatewayServer GatewayServer) createServer(addr string, handler http.Handler, tlsConfig *tls.Config) *http.Server { + return &http.Server{ + Addr: addr, + WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout), + ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout), + IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout), + Handler: handler, + TLSConfig: tlsConfig, + } +} + +func (gatewayServer GatewayServer) startServers(httpServer, httpsServer *http.Server, listenWithTLS bool) error { + go func() { + logger.Info("Starting HTTP server on 0.0.0.0:8080") + if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Fatal("HTTP server error: %v", err) + } + }() + + if listenWithTLS { + go func() { + logger.Info("Starting HTTPS server on 0.0.0.0:8443") + if err := httpsServer.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Fatal("HTTPS server error: %v", err) + } + }() + } + + return nil +} + +func (gatewayServer GatewayServer) gracefulShutdown(ctx context.Context, httpServer, httpsServer *http.Server, listenWithTLS bool) error { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + logger.Info("Shutting down Goma Gateway...") + + shutdownCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + logger.Error("Error shutting down HTTP server: %v", err) + } + + if listenWithTLS { + if err := httpsServer.Shutdown(shutdownCtx); err != nil { + logger.Error("Error shutting down HTTPS server: %v", err) + } + } + + logger.Info("Goma Gateway shut down successfully") + return nil } diff --git a/internal/tls.go b/internal/tls.go new file mode 100644 index 0000000..6ca0000 --- /dev/null +++ b/internal/tls.go @@ -0,0 +1,36 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package pkg + +import ( + "crypto/tls" + "fmt" +) + +func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) { + cert, key := gatewayServer.gateway.SSLCertFile, gatewayServer.gateway.SSLKeyFile + if cert == "" || key == "" { + return nil, false, nil + } + + tlsConfig, err := loadTLS(cert, key) + if err != nil { + return nil, false, fmt.Errorf("failed to load TLS config: %w", err) + } + return tlsConfig, true, nil +} diff --git a/internal/var.go b/internal/var.go index b4ded35..c83716a 100644 --- a/internal/var.go +++ b/internal/var.go @@ -4,10 +4,10 @@ const ConfigDir = "/etc/goma/" // Default config const ConfigFile = "/etc/goma/goma.yml" // Default configuration file const accessControlAllowOrigin = "Access-Control-Allow-Origin" // Cors const gatewayName = "Goma Gateway" -const AccessMiddleware = "access" // access middleware -const BasicAuth = "basic" // basic authentication middleware -const JWTAuth = "jwt" // JWT authentication middleware -const OAuth = "oauth" // OAuth authentication middleware +const AccessMiddleware = "access" // access middlewares +const BasicAuth = "basic" // basic authentication middlewares +const JWTAuth = "jwt" // JWT authentication middlewares +const OAuth = "oauth" // OAuth authentication middlewares // Round-robin counter var counter uint32 diff --git a/main.go b/main.go index f41279e..b27b37f 100644 --- a/main.go +++ b/main.go @@ -15,7 +15,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -import "github.com/jkaninda/goma-gateway/cmd" +import ( + "github.com/jkaninda/goma-gateway/cmd" +) func main() {