diff --git a/internal/healthcheck.go b/internal/healthcheck.go index b50c237..cc38baa 100644 --- a/internal/healthcheck.go +++ b/internal/healthcheck.go @@ -71,6 +71,8 @@ func (health Health) Check() error { } return nil } + +// routesHealthCheck creates healthcheck job func routesHealthCheck(routes []Route) { for _, health := range healthCheckRoutes(routes) { go func() { @@ -84,11 +86,14 @@ func routesHealthCheck(routes []Route) { } } + +// createHealthCheckJob create healthcheck job func (health Health) createHealthCheckJob() error { interval := "30s" if len(health.Interval) > 0 { interval = health.Interval } + // create cron expression expression := fmt.Sprintf("@every %s", interval) if !util.IsValidCronExpression(expression) { logger.Error("Health check interval is invalid: %s", interval) @@ -113,3 +118,45 @@ func (health Health) createHealthCheckJob() error { defer c.Stop() select {} } + +// healthCheckRoutes creates and returns []Health +func healthCheckRoutes(routes []Route) []Health { + var healthRoutes []Health + for _, route := range routes { + if len(route.HealthCheck.Path) != 0 { + timeout, _ := util.ParseDuration("") + if len(route.HealthCheck.Timeout) > 0 { + d1, err1 := util.ParseDuration(route.HealthCheck.Timeout) + if err1 != nil { + logger.Error("Health check timeout is invalid: %s", route.HealthCheck.Timeout) + } + timeout = d1 + } + if len(route.Backends) != 0 { + for index, backend := range route.Backends { + health := Health{ + Name: fmt.Sprintf("%s - [%d]", route.Name, index), + URL: backend + route.HealthCheck.Path, + TimeOut: timeout, + HealthyStatuses: route.HealthCheck.HealthyStatuses, + InsecureSkipVerify: route.InsecureSkipVerify, + } + healthRoutes = append(healthRoutes, health) + } + + } else { + health := Health{ + Name: route.Name, + URL: route.Destination + route.HealthCheck.Path, + TimeOut: timeout, + HealthyStatuses: route.HealthCheck.HealthyStatuses, + InsecureSkipVerify: route.InsecureSkipVerify, + } + healthRoutes = append(healthRoutes, health) + } + } else { + logger.Debug("Route %s's healthCheck is undefined", route.Name) + } + } + return healthRoutes +} diff --git a/internal/helpers.go b/internal/helpers.go index 992b0f7..60e53b4 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -11,25 +11,20 @@ You may get a copy of the License at */ import ( "context" - "crypto/tls" "encoding/json" "fmt" - "github.com/golang-jwt/jwt" "github.com/jedib0t/go-pretty/v6/table" - "github.com/jkaninda/goma-gateway/pkg/logger" - "github.com/jkaninda/goma-gateway/util" "golang.org/x/oauth2" "io" "net/http" - "time" ) // printRoute prints routes func printRoute(routes []Route) { t := table.NewWriter() - t.AppendHeader(table.Row{"Name", "Route", "Rewrite", "Destination"}) + t.AppendHeader(table.Row{"Name", "Path", "Rewrite", "Destination"}) for _, route := range routes { - if len(route.Backends) > 0 { + if len(route.Backends) != 0 { t.AppendRow(table.Row{route.Name, route.Path, route.Rewrite, fmt.Sprintf("backends: [%d]", len(route.Backends))}) } else { @@ -50,21 +45,6 @@ func getRealIP(r *http.Request) string { return r.RemoteAddr } -// loadTLS loads TLS Certificate -func loadTLS(cert, key string) (*tls.Config, error) { - if cert == "" && key == "" { - return nil, fmt.Errorf("no certificate or key file provided") - } - serverCert, err := tls.LoadX509KeyPair(cert, key) - if err != nil { - logger.Error("Error loading server certificate: %v", err) - return nil, err - } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - } - return tlsConfig, nil -} func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, error) { oauthConfig := oauth2Config(oauth) // Call the user info endpoint with the token @@ -88,64 +68,3 @@ func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, e return userInfo, nil } -func createJWT(email, jwtSecret string) (string, error) { - // Define JWT claims - claims := jwt.MapClaims{ - "email": email, - "exp": jwt.TimeFunc().Add(time.Hour * 24).Unix(), // Token expiration - "iss": "Goma-Gateway", // Issuer claim - } - - // Create a new token with HS256 signing method - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - - // Sign the token with a secret - signedToken, err := token.SignedString([]byte(jwtSecret)) - if err != nil { - return "", err - } - - return signedToken, nil -} - -// healthCheckRoutes creates []Health -func healthCheckRoutes(routes []Route) []Health { - var healthRoutes []Health - for _, route := range routes { - if len(route.HealthCheck.Path) > 0 { - timeout, _ := util.ParseDuration("") - if len(route.HealthCheck.Timeout) > 0 { - d1, err1 := util.ParseDuration(route.HealthCheck.Timeout) - if err1 != nil { - logger.Error("Health check timeout is invalid: %s", route.HealthCheck.Timeout) - } - timeout = d1 - } - if len(route.Backends) > 0 { - for index, backend := range route.Backends { - health := Health{ - Name: fmt.Sprintf("%s - [%d]", route.Name, index), - URL: backend + route.HealthCheck.Path, - TimeOut: timeout, - HealthyStatuses: route.HealthCheck.HealthyStatuses, - InsecureSkipVerify: route.InsecureSkipVerify, - } - healthRoutes = append(healthRoutes, health) - } - - } else { - health := Health{ - Name: route.Name, - URL: route.Destination + route.HealthCheck.Path, - TimeOut: timeout, - HealthyStatuses: route.HealthCheck.HealthyStatuses, - InsecureSkipVerify: route.InsecureSkipVerify, - } - healthRoutes = append(healthRoutes, health) - } - } else { - logger.Debug("Route %s's healthCheck is undefined", route.Name) - } - } - return healthRoutes -} diff --git a/internal/jwt.go b/internal/jwt.go new file mode 100644 index 0000000..a6d477d --- /dev/null +++ b/internal/jwt.go @@ -0,0 +1,44 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package pkg + +import ( + "github.com/golang-jwt/jwt" + "time" +) + +// createJWT create JWT token +func createJWT(email, jwtSecret string) (string, error) { + // Define JWT claims + claims := jwt.MapClaims{ + "email": email, + "exp": jwt.TimeFunc().Add(time.Hour * 24).Unix(), // Token expiration + "iss": "Goma-Gateway", // Issuer claim + } + + // Create a new token with HS256 signing method + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + // Sign the token with a secret + signedToken, err := token.SignedString([]byte(jwtSecret)) + if err != nil { + return "", err + } + + return signedToken, nil +} diff --git a/internal/metrics/prometheus.go b/internal/metrics/prometheus.go index 65ec31e..5c32e27 100644 --- a/internal/metrics/prometheus.go +++ b/internal/metrics/prometheus.go @@ -52,6 +52,7 @@ var HttpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ Help: "Duration of HTTP requests.", }, []string{"name", "path"}) +// PrometheusMiddleware Prometheus http handler middleware, returns http.Handler func (pr PrometheusRoute) PrometheusMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { path := pr.Path diff --git a/internal/middlewares/access-middleware.go b/internal/middlewares/access_middleware.go similarity index 100% rename from internal/middlewares/access-middleware.go rename to internal/middlewares/access_middleware.go diff --git a/internal/middlewares/block-common-exploits.go b/internal/middlewares/block_common_exploits.go similarity index 100% rename from internal/middlewares/block-common-exploits.go rename to internal/middlewares/block_common_exploits.go diff --git a/internal/middlewares/error-interceptor.go b/internal/middlewares/error_interceptor.go similarity index 100% rename from internal/middlewares/error-interceptor.go rename to internal/middlewares/error_interceptor.go diff --git a/internal/middlewares/helpers.go b/internal/middlewares/helpers.go index 351f182..a308b84 100644 --- a/internal/middlewares/helpers.go +++ b/internal/middlewares/helpers.go @@ -23,6 +23,7 @@ import ( "slices" ) +// getRealIP returns user real IP func getRealIP(r *http.Request) string { if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip diff --git a/internal/middlewares/oauth-middleware.go b/internal/middlewares/oauth_middleware.go similarity index 100% rename from internal/middlewares/oauth-middleware.go rename to internal/middlewares/oauth_middleware.go diff --git a/internal/middlewares/rate-limit.go b/internal/middlewares/rate_limit.go similarity index 100% rename from internal/middlewares/rate-limit.go rename to internal/middlewares/rate_limit.go diff --git a/internal/middlewares/redis.go b/internal/middlewares/redis.go index 4f1dd0e..1f199d2 100644 --- a/internal/middlewares/redis.go +++ b/internal/middlewares/redis.go @@ -24,6 +24,7 @@ import ( "github.com/redis/go-redis/v9" ) +// redisRateLimiter, handle rateLimit func redisRateLimiter(clientIP string, rate int) error { ctx := context.Background() diff --git a/internal/proxy.go b/internal/proxy.go index dac8b20..1f6a472 100644 --- a/internal/proxy.go +++ b/internal/proxy.go @@ -74,7 +74,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { r.Host = targetURL.Host } backendURL, _ := url.Parse(proxyRoute.destination) - if len(proxyRoute.backends) > 0 { + if len(proxyRoute.backends) != 0 { // Select the next backend URL backendURL = getNextBackend(proxyRoute.backends) } @@ -87,8 +87,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { InsecureSkipVerify: proxyRoute.insecureSkipVerify, }, } - w.Header().Set("Proxied-By", gatewayName) // Set Server name - w.Header().Del("Server") // Remove the Server header + w.Header().Set("Proxied-By", gatewayName) // Custom error handler for proxy errors proxy.ErrorHandler = ProxyErrorHandler proxy.ServeHTTP(w, r) diff --git a/internal/route.go b/internal/routes.go similarity index 99% rename from internal/route.go rename to internal/routes.go index 867bd81..084b36e 100644 --- a/internal/route.go +++ b/internal/routes.go @@ -26,6 +26,7 @@ import ( "time" ) +// init initializes prometheus metrics func init() { _ = prometheus.Register(metrics.TotalRequests) _ = prometheus.Register(metrics.ResponseStatus) @@ -88,6 +89,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { logger.Info("Block common exploits enabled") r.Use(middlewares.BlockExploitsMiddleware) } + // check if RateLimit is set if gateway.RateLimit != 0 { // Add rate limit middlewares to all routes, if defined rateLimit := middlewares.RateLimit{ diff --git a/internal/route_test.go b/internal/routes_test.go similarity index 100% rename from internal/route_test.go rename to internal/routes_test.go diff --git a/internal/tls.go b/internal/tls.go index 6ca0000..fda25f1 100644 --- a/internal/tls.go +++ b/internal/tls.go @@ -20,6 +20,7 @@ package pkg import ( "crypto/tls" "fmt" + "github.com/jkaninda/goma-gateway/pkg/logger" ) func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) { @@ -34,3 +35,19 @@ func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) { } return tlsConfig, true, nil } + +// loadTLS loads TLS Certificate +func loadTLS(cert, key string) (*tls.Config, error) { + if cert == "" && key == "" { + return nil, fmt.Errorf("no certificate or key file provided") + } + serverCert, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + logger.Error("Error loading server certificate: %v", err) + return nil, err + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } + return tlsConfig, nil +}