Merge pull request #109 from jkaninda/refactor

Refactor
This commit is contained in:
2024-11-15 15:49:24 +01:00
committed by GitHub
28 changed files with 279 additions and 190 deletions

View File

@@ -37,13 +37,13 @@ var ServerCmd = &cobra.Command{
} }
ctx := context.Background() ctx := context.Background()
g := pkg.GatewayServer{} g := pkg.GatewayServer{}
gs, err := g.Config(configFile) gs, err := g.Config(configFile, ctx)
if err != nil { if err != nil {
fmt.Printf("Could not load configuration: %v\n", err) fmt.Printf("Could not load configuration: %v\n", err)
os.Exit(1) os.Exit(1)
} }
gs.SetEnv() gs.SetEnv()
if err := gs.Start(ctx); err != nil { if err := gs.Start(); err != nil {
fmt.Printf("Could not start server: %v\n", err) fmt.Printf("Could not start server: %v\n", err)
os.Exit(1) os.Exit(1)

View File

@@ -52,7 +52,7 @@ func CheckConfig(fileName string) error {
} }
} }
//Check middleware //Check middlewares
for index, mid := range c.Middlewares { for index, mid := range c.Middlewares {
if util.HasWhitespace(mid.Name) { if util.HasWhitespace(mid.Name) {
fmt.Printf("Warning: Middleware contains whitespace: %s | index: [%d], please remove whitespace characters\n", mid.Name, index) fmt.Printf("Warning: Middleware contains whitespace: %s | index: [%d], please remove whitespace characters\n", mid.Name, index)

View File

@@ -16,8 +16,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import ( import (
"context"
"fmt" "fmt"
"github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util" "github.com/jkaninda/goma-gateway/util"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -31,7 +32,7 @@ import (
) )
// Config reads config file and returns Gateway // Config reads config file and returns Gateway
func (GatewayServer) Config(configFile string) (*GatewayServer, error) { func (GatewayServer) Config(configFile string, ctx context.Context) (*GatewayServer, error) {
if util.FileExists(configFile) { if util.FileExists(configFile) {
buf, err := os.ReadFile(configFile) buf, err := os.ReadFile(configFile)
if err != nil { if err != nil {
@@ -44,7 +45,8 @@ func (GatewayServer) Config(configFile string) (*GatewayServer, error) {
return nil, fmt.Errorf("parsing the configuration file %q: %w", configFile, err) return nil, fmt.Errorf("parsing the configuration file %q: %w", configFile, err)
} }
return &GatewayServer{ return &GatewayServer{
ctx: nil, ctx: ctx,
configFile: configFile,
version: c.Version, version: c.Version,
gateway: c.GatewayConfig, gateway: c.GatewayConfig,
middlewares: c.Middlewares, middlewares: c.Middlewares,
@@ -59,14 +61,15 @@ func (GatewayServer) Config(configFile string) (*GatewayServer, error) {
} }
logger.Info("Using configuration file: %s", ConfigFile) logger.Info("Using configuration file: %s", ConfigFile)
util.SetEnv("GOMA_CONFIG_FILE", configFile) util.SetEnv("GOMA_CONFIG_FILE", ConfigFile)
c := &GatewayConfig{} c := &GatewayConfig{}
err = yaml.Unmarshal(buf, c) err = yaml.Unmarshal(buf, c)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing the configuration file %q: %w", ConfigFile, err) return nil, fmt.Errorf("parsing the configuration file %q: %w", ConfigFile, err)
} }
return &GatewayServer{ return &GatewayServer{
ctx: nil, ctx: ctx,
configFile: ConfigFile,
gateway: c.GatewayConfig, gateway: c.GatewayConfig,
middlewares: c.Middlewares, middlewares: c.Middlewares,
}, nil }, nil
@@ -98,7 +101,8 @@ func (GatewayServer) Config(configFile string) (*GatewayServer, error) {
} }
logger.Info("Generating new configuration file...done") logger.Info("Generating new configuration file...done")
return &GatewayServer{ return &GatewayServer{
ctx: nil, ctx: ctx,
configFile: ConfigFile,
gateway: c.GatewayConfig, gateway: c.GatewayConfig,
middlewares: c.Middlewares, middlewares: c.Middlewares,
}, nil }, nil
@@ -330,7 +334,7 @@ func getJWTMiddleware(input interface{}) (JWTRuleMiddleware, error) {
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
} }
if jWTRuler.URL == "" { if jWTRuler.URL == "" {
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty url in jwt auth middleware") return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty url in jwt auth middlewares")
} }
return *jWTRuler, nil return *jWTRuler, nil
@@ -349,7 +353,7 @@ func getBasicAuthMiddleware(input interface{}) (BasicRuleMiddleware, error) {
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
} }
if basicAuth.Username == "" || basicAuth.Password == "" { if basicAuth.Username == "" || basicAuth.Password == "" {
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty username/password in %s middleware", basicAuth) return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty username/password in %s middlewares", basicAuth)
} }
return *basicAuth, nil return *basicAuth, nil
@@ -368,12 +372,12 @@ func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) {
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
} }
if oauthRuler.ClientID == "" || oauthRuler.ClientSecret == "" || oauthRuler.RedirectURL == "" { if oauthRuler.ClientID == "" || oauthRuler.ClientSecret == "" || oauthRuler.RedirectURL == "" {
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middleware", oauthRuler) return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middlewares", oauthRuler)
} }
return *oauthRuler, nil return *oauthRuler, nil
} }
func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware { func oauthRulerMiddleware(oauth middlewares.Oauth) *OauthRulerMiddleware {
return &OauthRulerMiddleware{ return &OauthRulerMiddleware{
ClientID: oauth.ClientID, ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret, ClientSecret: oauth.ClientSecret,

View File

@@ -14,7 +14,7 @@ func getMiddleware(rules []string, middlewares []Middleware) (Middleware, error)
continue continue
} }
return Middleware{}, errors.New("middleware not found with name: [" + strings.Join(rules, ";") + "]") return Middleware{}, errors.New("middlewares not found with name: [" + strings.Join(rules, ";") + "]")
} }
func doesExist(tyName string) bool { func doesExist(tyName string) bool {
@@ -30,5 +30,5 @@ func GetMiddleware(rule string, middlewares []Middleware) (Middleware, error) {
continue continue
} }
return Middleware{}, errors.New("no middleware found with name " + rule) return Middleware{}, errors.New("no middlewares found with name " + rule)
} }

View File

@@ -105,7 +105,7 @@ func TestReadMiddleware(t *testing.T) {
middlewares := getMiddlewares(t) middlewares := getMiddlewares(t)
m, err := getMiddleware(rules, middlewares) m, err := getMiddleware(rules, middlewares)
if err != nil { if err != nil {
t.Fatalf("Error searching middleware %s", err.Error()) t.Fatalf("Error searching middlewares %s", err.Error())
} }
log.Printf("Middleware: %v\n", m) log.Printf("Middleware: %v\n", m)
@@ -134,10 +134,10 @@ func TestReadMiddleware(t *testing.T) {
} }
log.Printf("OAuth authentification: provider %s\n", oauth.Provider) log.Printf("OAuth authentification: provider %s\n", oauth.Provider)
case AccessMiddleware: case AccessMiddleware:
log.Println("Access middleware") log.Println("Access middlewares")
log.Printf("Access middleware: paths: [%s]\n", middleware.Paths) log.Printf("Access middlewares: paths: [%s]\n", middleware.Paths)
default: default:
t.Errorf("Unknown middleware type %s", middleware.Type) t.Errorf("Unknown middlewares type %s", middleware.Type)
} }
} }
@@ -148,7 +148,7 @@ func TestFoundMiddleware(t *testing.T) {
middlewares := getMiddlewares(t) middlewares := getMiddlewares(t)
middleware, err := GetMiddleware("jwt", middlewares) middleware, err := GetMiddleware("jwt", middlewares)
if err != nil { if err != nil {
t.Errorf("Error getting middleware %v", err) t.Errorf("Error getting middlewares %v", err)
} }
fmt.Println(middleware.Type) fmt.Println(middleware.Type)
} }

View File

@@ -17,9 +17,9 @@
package pkg package pkg
// Middleware defined the route middleware // Middleware defined the route middlewares
type Middleware struct { type Middleware struct {
//Path contains the name of middleware and must be unique //Path contains the name of middlewares and must be unique
Name string `yaml:"name"` Name string `yaml:"name"`
// Type contains authentication types // Type contains authentication types
// //

View File

@@ -1,4 +1,4 @@
package middleware package middlewares
/* /*
Copyright 2024 Jonas Kaninda Copyright 2024 Jonas Kaninda

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"fmt" "fmt"

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"

View File

@@ -1,4 +1,4 @@
package middleware package middlewares
/* /*
* Copyright 2024 Jonas Kaninda * Copyright 2024 Jonas Kaninda

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"encoding/json" "encoding/json"

View File

@@ -1,4 +1,4 @@
package middleware package middlewares
/* /*
Copyright 2024 Jonas Kaninda Copyright 2024 Jonas Kaninda

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"fmt" "fmt"

View File

@@ -1,4 +1,4 @@
package middleware package middlewares
/* /*
Copyright 2024 Jonas Kaninda Copyright 2024 Jonas Kaninda
@@ -16,13 +16,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import ( import (
"errors"
"fmt" "fmt"
"github.com/go-redis/redis_rate/v10"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/redis/go-redis/v9"
"golang.org/x/net/context"
"net/http" "net/http"
"time" "time"
) )
@@ -91,23 +87,3 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
}) })
} }
} }
func redisRateLimiter(clientIP string, rate int) error {
ctx := context.Background()
res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate))
if err != nil {
return err
}
if res.Remaining == 0 {
return errors.New("requests limit exceeded")
}
return nil
}
func InitRedis(addr, password string) {
Rdb = redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
})
limiter = redis_rate.NewLimiter(Rdb)
}

View File

@@ -0,0 +1,46 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package middlewares
import (
"context"
"errors"
"github.com/go-redis/redis_rate/v10"
"github.com/redis/go-redis/v9"
)
func redisRateLimiter(clientIP string, rate int) error {
ctx := context.Background()
res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate))
if err != nil {
return err
}
if res.Remaining == 0 {
return errors.New("requests limit exceeded")
}
return nil
}
func InitRedis(addr, password string) {
Rdb = redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
})
limiter = redis_rate.NewLimiter(Rdb)
}

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"bytes" "bytes"

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"github.com/go-redis/redis_rate/v10" "github.com/go-redis/redis_rate/v10"

View File

@@ -18,7 +18,7 @@ limitations under the License.
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
@@ -38,7 +38,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
if len(proxyRoute.methods) > 0 { if len(proxyRoute.methods) > 0 {
if !slices.Contains(proxyRoute.methods, r.Method) { if !slices.Contains(proxyRoute.methods, r.Method) {
logger.Error("%s Method is not allowed", r.Method) logger.Error("%s Method is not allowed", r.Method)
middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method)) middlewares.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method))
return return
} }
} }
@@ -61,7 +61,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
targetURL, err := url.Parse(proxyRoute.destination) targetURL, err := url.Parse(proxyRoute.destination)
if err != nil { if err != nil {
logger.Error("Error parsing backend URL: %s", err) logger.Error("Error parsing backend URL: %s", err)
middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) middlewares.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
return return
} }
r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) r.Header.Set("X-Forwarded-Host", r.Header.Get("Host"))

40
internal/redis.go Normal file
View File

@@ -0,0 +1,40 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pkg
import (
"github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger"
)
func (gatewayServer GatewayServer) initRedis() error {
if gatewayServer.gateway.Redis.Addr == "" {
return nil
}
logger.Info("Initializing Redis...")
middlewares.InitRedis(gatewayServer.gateway.Redis.Addr, gatewayServer.gateway.Redis.Password)
return nil
}
func (gatewayServer GatewayServer) closeRedis() {
if middlewares.Rdb != nil {
if err := middlewares.Rdb.Close(); err != nil {
logger.Error("Error closing Redis: %v", err)
}
}
}

View File

@@ -17,7 +17,7 @@ limitations under the License.
*/ */
import ( import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util" "github.com/jkaninda/goma-gateway/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@@ -34,7 +34,7 @@ func init() {
// Initialize the routes // Initialize the routes
func (gatewayServer GatewayServer) Initialize() *mux.Router { func (gatewayServer GatewayServer) Initialize() *mux.Router {
gateway := gatewayServer.gateway gateway := gatewayServer.gateway
middlewares := gatewayServer.middlewares m := gatewayServer.middlewares
redisBased := false redisBased := false
if len(gateway.Redis.Addr) != 0 { if len(gateway.Redis.Addr) != 0 {
redisBased = true redisBased = true
@@ -62,11 +62,11 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Enable common exploits // Enable common exploits
if gateway.BlockCommonExploits { if gateway.BlockCommonExploits {
logger.Info("Block common exploits enabled") logger.Info("Block common exploits enabled")
r.Use(middleware.BlockExploitsMiddleware) r.Use(middlewares.BlockExploitsMiddleware)
} }
if gateway.RateLimit > 0 { if gateway.RateLimit > 0 {
// Add rate limit middleware to all routes, if defined // Add rate limit middlewares to all routes, if defined
rateLimit := middleware.RateLimit{ rateLimit := middlewares.RateLimit{
Id: "global_rate", //Generate a unique ID for routes Id: "global_rate", //Generate a unique ID for routes
Requests: gateway.RateLimit, Requests: gateway.RateLimit,
Window: time.Minute, // requests per minute Window: time.Minute, // requests per minute
@@ -75,7 +75,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
RedisBased: redisBased, RedisBased: redisBased,
} }
limiter := rateLimit.NewRateLimiterWindow() limiter := rateLimit.NewRateLimiterWindow()
// Add rate limit middleware // Add rate limit middlewares
r.Use(limiter.RateLimitMiddleware()) r.Use(limiter.RateLimitMiddleware())
} }
for rIndex, route := range gateway.Routes { for rIndex, route := range gateway.Routes {
@@ -87,14 +87,14 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Apply middlewares to route // Apply middlewares to route
for _, mid := range route.Middlewares { for _, mid := range route.Middlewares {
if mid != "" { if mid != "" {
// Get Access middleware if it does exist // Get Access middlewares if it does exist
accessMiddleware, err := getMiddleware([]string{mid}, middlewares) accessMiddleware, err := getMiddleware([]string{mid}, m)
if err != nil { if err != nil {
logger.Error("Error: %v", err.Error()) logger.Error("Error: %v", err.Error())
} else { } else {
// Apply access middleware // Apply access middlewares
if accessMiddleware.Type == AccessMiddleware { if accessMiddleware.Type == AccessMiddleware {
blM := middleware.AccessListMiddleware{ blM := middlewares.AccessListMiddleware{
Path: route.Path, Path: route.Path,
List: accessMiddleware.Paths, List: accessMiddleware.Paths,
} }
@@ -103,10 +103,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
} }
// Get route authentication middleware if it does exist // Get route authentication middlewares if it does exist
rMiddleware, err := getMiddleware([]string{mid}, middlewares) rMiddleware, err := getMiddleware([]string{mid}, m)
if err != nil { if err != nil {
//Error: middleware not found //Error: middlewares not found
logger.Error("Error: %v", err.Error()) logger.Error("Error: %v", err.Error())
} else { } else {
for _, midPath := range rMiddleware.Paths { for _, midPath := range rMiddleware.Paths {
@@ -122,20 +122,20 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter() secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter()
//callBackRouter := r.PathPrefix(util.ParseRoutePath(route.Path, "/callback")).Subrouter() //callBackRouter := r.PathPrefix(util.ParseRoutePath(route.Path, "/callback")).Subrouter()
//Check Authentication middleware //Check Authentication middlewares
switch rMiddleware.Type { switch rMiddleware.Type {
case BasicAuth: case BasicAuth:
basicAuth, err := getBasicAuthMiddleware(rMiddleware.Rule) basicAuth, err := getBasicAuthMiddleware(rMiddleware.Rule)
if err != nil { if err != nil {
logger.Error("Error: %s", err.Error()) logger.Error("Error: %s", err.Error())
} else { } else {
amw := middleware.AuthBasic{ amw := middlewares.AuthBasic{
Username: basicAuth.Username, Username: basicAuth.Username,
Password: basicAuth.Password, Password: basicAuth.Password,
Headers: nil, Headers: nil,
Params: nil, Params: nil,
} }
// Apply JWT authentication middleware // Apply JWT authentication middlewares
secureRouter.Use(amw.AuthMiddleware) secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors)) secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
@@ -146,14 +146,14 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if err != nil { if err != nil {
logger.Error("Error: %s", err.Error()) logger.Error("Error: %s", err.Error())
} else { } else {
amw := middleware.JwtAuth{ amw := middlewares.JwtAuth{
AuthURL: jwt.URL, AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders, RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers, Headers: jwt.Headers,
Params: jwt.Params, Params: jwt.Params,
Origins: gateway.Cors.Origins, Origins: gateway.Cors.Origins,
} }
// Apply JWT authentication middleware // Apply JWT authentication middlewares
secureRouter.Use(amw.AuthMiddleware) secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors)) secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
@@ -169,12 +169,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if oauth.RedirectURL != "" { if oauth.RedirectURL != "" {
redirectURL = oauth.RedirectURL redirectURL = oauth.RedirectURL
} }
amw := middleware.Oauth{ amw := middlewares.Oauth{
ClientID: oauth.ClientID, ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret, ClientSecret: oauth.ClientSecret,
RedirectURL: redirectURL, RedirectURL: redirectURL,
Scopes: oauth.Scopes, Scopes: oauth.Scopes,
Endpoint: middleware.OauthEndpoint{ Endpoint: middlewares.OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL, AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL, TokenURL: oauth.Endpoint.TokenURL,
UserInfoURL: oauth.Endpoint.UserInfoURL, UserInfoURL: oauth.Endpoint.UserInfoURL,
@@ -205,7 +205,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
default: default:
if !doesExist(rMiddleware.Type) { if !doesExist(rMiddleware.Type) {
logger.Error("Unknown middleware type %s", rMiddleware.Type) logger.Error("Unknown middlewares type %s", rMiddleware.Type)
} }
} }
@@ -214,7 +214,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
} else { } else {
logger.Error("Error, middleware path is empty") logger.Error("Error, middlewares path is empty")
logger.Error("Middleware ignored") logger.Error("Middleware ignored")
} }
} }
@@ -234,7 +234,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Enable common exploits // Enable common exploits
if route.BlockCommonExploits { if route.BlockCommonExploits {
logger.Info("Block common exploits enabled") logger.Info("Block common exploits enabled")
router.Use(middleware.BlockExploitsMiddleware) router.Use(middlewares.BlockExploitsMiddleware)
} }
id := string(rune(rIndex)) id := string(rune(rIndex))
if len(route.Name) != 0 { if len(route.Name) != 0 {
@@ -243,7 +243,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
// Apply route rate limit // Apply route rate limit
if route.RateLimit > 0 { if route.RateLimit > 0 {
rateLimit := middleware.RateLimit{ rateLimit := middlewares.RateLimit{
Id: id, // Use route index as ID Id: id, // Use route index as ID
Requests: route.RateLimit, Requests: route.RateLimit,
Window: time.Minute, // requests per minute Window: time.Minute, // requests per minute
@@ -252,7 +252,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
RedisBased: redisBased, RedisBased: redisBased,
} }
limiter := rateLimit.NewRateLimiterWindow() limiter := rateLimit.NewRateLimiterWindow()
// Add rate limit middleware // Add rate limit middlewares
router.Use(limiter.RateLimitMiddleware()) router.Use(limiter.RateLimitMiddleware())
} }
// Apply route Cors // Apply route Cors
@@ -272,9 +272,9 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Prometheus endpoint // Prometheus endpoint
router.Use(pr.prometheusMiddleware) router.Use(pr.prometheusMiddleware)
} }
// Apply route Error interceptor middleware // Apply route Error interceptor middlewares
if len(route.InterceptErrors) != 0 { if len(route.InterceptErrors) != 0 {
interceptErrors := middleware.InterceptErrors{ interceptErrors := middlewares.InterceptErrors{
Origins: route.Cors.Origins, Origins: route.Cors.Origins,
Errors: route.InterceptErrors, Errors: route.InterceptErrors,
} }
@@ -286,10 +286,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
} }
// Apply global Cors middlewares // Apply global Cors middlewares
r.Use(CORSHandler(gateway.Cors)) // Apply CORS middleware r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares
// Apply errorInterceptor middleware // Apply errorInterceptor middlewares
if len(gateway.InterceptErrors) != 0 { if len(gateway.InterceptErrors) != 0 {
interceptErrors := middleware.InterceptErrors{ interceptErrors := middlewares.InterceptErrors{
Errors: gateway.InterceptErrors, Errors: gateway.InterceptErrors,
Origins: gateway.Cors.Origins, Origins: gateway.Cors.Origins,
} }

View File

@@ -53,6 +53,6 @@ type Route struct {
InterceptErrors []int `yaml:"interceptErrors"` InterceptErrors []int `yaml:"interceptErrors"`
// BlockCommonExploits enable, disable block common exploits // BlockCommonExploits enable, disable block common exploits
BlockCommonExploits bool `yaml:"blockCommonExploits"` BlockCommonExploits bool `yaml:"blockCommonExploits"`
// Middlewares Defines route middleware from Middleware names // Middlewares Defines route middlewares from Middleware names
Middlewares []string `yaml:"middlewares"` Middlewares []string `yaml:"middlewares"`
} }

View File

@@ -18,111 +18,95 @@ limitations under the License.
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"github.com/jkaninda/goma-gateway/internal/middleware"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/redis/go-redis/v9"
"net/http" "net/http"
"os" "os"
"sync" "os/signal"
"syscall"
"time" "time"
) )
// Start starts the server // Start / Start starts the server
func (gatewayServer GatewayServer) Start(ctx context.Context) error { func (gatewayServer GatewayServer) Start() error {
logger.Info("Initializing routes...") logger.Info("Initializing routes...")
route := gatewayServer.Initialize() route := gatewayServer.Initialize()
gateway := gatewayServer.gateway logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares))
logger.Debug("Routes count=%d Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) if err := gatewayServer.initRedis(); err != nil {
logger.Info("Initializing routes...done") return fmt.Errorf("failed to initialize Redis: %w", err)
if len(gateway.Redis.Addr) != 0 { }
middleware.InitRedis(gateway.Redis.Addr, gateway.Redis.Password) defer gatewayServer.closeRedis()
defer func(Rdb *redis.Client) {
err := Rdb.Close() tlsConfig, listenWithTLS, err := gatewayServer.initTLS()
if err != nil { if err != nil {
logger.Error("Redis connection closed with error: %v", err) return err
}
}(middleware.Rdb)
} }
tlsConfig := &tls.Config{}
var listenWithTLS = false
if cert := gatewayServer.gateway.SSLCertFile; cert != "" && gatewayServer.gateway.SSLKeyFile != "" {
tlsConf, err := loadTLS(cert, gatewayServer.gateway.SSLKeyFile)
if err != nil {
return err
}
tlsConfig = tlsConf
listenWithTLS = true
}
// HTTP Server
httpServer := &http.Server{
Addr: ":8080",
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: route, // Pass our instance of gorilla/mux in.
}
// HTTPS Server
httpsServer := &http.Server{
Addr: ":8443",
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: route, // Pass our instance of gorilla/mux in.
TLSConfig: tlsConfig,
}
if !gatewayServer.gateway.DisableDisplayRouteOnStart { if !gatewayServer.gateway.DisableDisplayRouteOnStart {
printRoute(gatewayServer.gateway.Routes) printRoute(gatewayServer.gateway.Routes)
} }
// Set KeepAlive
httpServer.SetKeepAlivesEnabled(!gatewayServer.gateway.DisableKeepAlive)
go func() {
logger.Info("Starting HTTP server listen=0.0.0.0:8080")
if err := httpServer.ListenAndServe(); err != nil {
logger.Fatal("Error starting Goma Gateway HTTP server: %v", err)
}
}()
go func() {
if listenWithTLS {
logger.Info("Starting HTTPS server listen=0.0.0.0:8443")
if err := httpsServer.ListenAndServeTLS("", ""); err != nil {
logger.Fatal("Error starting Goma Gateway HTTPS server: %v", err)
}
}
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
_, err := fmt.Fprintf(os.Stderr, "error shutting down HTTP server: %s\n", err)
if err != nil {
return
}
}
}()
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if listenWithTLS {
if err := httpsServer.Shutdown(shutdownCtx); err != nil {
_, err := fmt.Fprintf(os.Stderr, "error shutting HTTPS server: %s\n", err)
if err != nil {
return
}
}
}
}()
wg.Wait()
return nil
httpServer := gatewayServer.createServer(":8080", route, nil)
httpsServer := gatewayServer.createServer(":8443", route, tlsConfig)
// Start HTTP/HTTPS servers
if err := gatewayServer.startServers(httpServer, httpsServer, listenWithTLS); err != nil {
return err
}
// Handle graceful shutdown
return gatewayServer.shutdown(httpServer, httpsServer, listenWithTLS)
}
func (gatewayServer GatewayServer) createServer(addr string, handler http.Handler, tlsConfig *tls.Config) *http.Server {
return &http.Server{
Addr: addr,
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: handler,
TLSConfig: tlsConfig,
}
}
func (gatewayServer GatewayServer) startServers(httpServer, httpsServer *http.Server, listenWithTLS bool) error {
go func() {
logger.Info("Starting HTTP server on 0.0.0.0:8080")
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Fatal("HTTP server error: %v", err)
}
}()
if listenWithTLS {
go func() {
logger.Info("Starting HTTPS server on 0.0.0.0:8443")
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Fatal("HTTPS server error: %v", err)
}
}()
}
return nil
}
func (gatewayServer GatewayServer) shutdown(httpServer, httpsServer *http.Server, listenWithTLS bool) error {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Info("Shutting down Goma Gateway...")
shutdownCtx, cancel := context.WithTimeout(gatewayServer.ctx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down HTTP server: %v", err)
}
if listenWithTLS {
if err := httpsServer.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down HTTPS server: %v", err)
}
}
return nil
} }

View File

@@ -39,8 +39,9 @@ func TestStart(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Error initializing config: %s", err.Error()) t.Fatalf("Error initializing config: %s", err.Error())
} }
ctx := context.Background()
g := GatewayServer{} g := GatewayServer{}
gatewayServer, err := g.Config(configFile) gatewayServer, err := g.Config(configFile, ctx)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -54,9 +55,8 @@ func TestStart(t *testing.T) {
t.Fatalf("expected a status code of 200, got %v", resp.StatusCode) t.Fatalf("expected a status code of 200, got %v", resp.StatusCode)
} }
} }
ctx := context.Background()
go func() { go func() {
err = gatewayServer.Start(ctx) err = gatewayServer.Start()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

36
internal/tls.go Normal file
View File

@@ -0,0 +1,36 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pkg
import (
"crypto/tls"
"fmt"
)
func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) {
cert, key := gatewayServer.gateway.SSLCertFile, gatewayServer.gateway.SSLKeyFile
if cert == "" || key == "" {
return nil, false, nil
}
tlsConfig, err := loadTLS(cert, key)
if err != nil {
return nil, false, fmt.Errorf("failed to load TLS config: %w", err)
}
return tlsConfig, true, nil
}

View File

@@ -113,6 +113,7 @@ type ErrorResponse struct {
} }
type GatewayServer struct { type GatewayServer struct {
ctx context.Context ctx context.Context
configFile string
version string version string
gateway Gateway gateway Gateway
middlewares []Middleware middlewares []Middleware

View File

@@ -4,10 +4,10 @@ const ConfigDir = "/etc/goma/" // Default config
const ConfigFile = "/etc/goma/goma.yml" // Default configuration file const ConfigFile = "/etc/goma/goma.yml" // Default configuration file
const accessControlAllowOrigin = "Access-Control-Allow-Origin" // Cors const accessControlAllowOrigin = "Access-Control-Allow-Origin" // Cors
const gatewayName = "Goma Gateway" const gatewayName = "Goma Gateway"
const AccessMiddleware = "access" // access middleware const AccessMiddleware = "access" // access middlewares
const BasicAuth = "basic" // basic authentication middleware const BasicAuth = "basic" // basic authentication middlewares
const JWTAuth = "jwt" // JWT authentication middleware const JWTAuth = "jwt" // JWT authentication middlewares
const OAuth = "oauth" // OAuth authentication middleware const OAuth = "oauth" // OAuth authentication middlewares
// Round-robin counter // Round-robin counter
var counter uint32 var counter uint32

View File

@@ -15,7 +15,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import "github.com/jkaninda/goma-gateway/cmd" import (
"github.com/jkaninda/goma-gateway/cmd"
)
func main() { func main() {