Merge pull request #111 from jkaninda/refactor

refactor: refactoring of code to meet all golangci-lint requirements
This commit is contained in:
2024-11-17 05:30:27 +01:00
committed by GitHub
25 changed files with 290 additions and 298 deletions

44
.golangci.yml Normal file
View File

@@ -0,0 +1,44 @@
run:
timeout: 5m
allow-parallel-runners: true
issues:
# don't skip warning about doc comments
# don't exclude the default set of lint
exclude-use-default: false
# restore some of the defaults
# (fill in the rest as needed)
exclude-rules:
- path: "internal/*"
linters:
- dupl
- lll
linters:
disable-all: true
enable:
- dupl
- errcheck
- copyloopvar
- ginkgolinter
- goconst
- gocyclo
- gofmt
#- goimports
- gosimple
- govet
- ineffassign
- lll
- misspell
- nakedret
- prealloc
- revive
- staticcheck
- typecheck
- unconvert
- unparam
- unused
linters-settings:
revive:
rules:
- name: comment-spacings

View File

@@ -19,9 +19,10 @@ package config
import ( import (
"fmt" "fmt"
"os"
pkg "github.com/jkaninda/goma-gateway/internal" pkg "github.com/jkaninda/goma-gateway/internal"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"os"
) )
var CheckConfigCmd = &cobra.Command{ var CheckConfigCmd = &cobra.Command{

View File

@@ -18,8 +18,9 @@ package config
import ( import (
"fmt" "fmt"
"github.com/spf13/cobra"
"os" "os"
"github.com/spf13/cobra"
) )
var Cmd = &cobra.Command{ var Cmd = &cobra.Command{

View File

@@ -17,9 +17,10 @@ limitations under the License.
*/ */
import ( import (
"fmt" "fmt"
"github.com/jkaninda/goma-gateway/internal"
"github.com/spf13/cobra"
"os" "os"
pkg "github.com/jkaninda/goma-gateway/internal"
"github.com/spf13/cobra"
) )
var InitConfigCmd = &cobra.Command{ var InitConfigCmd = &cobra.Command{

View File

@@ -52,7 +52,7 @@ func CheckConfig(fileName string) error {
} }
} }
//Check middlewares // Check middlewares
for index, mid := range c.Middlewares { for index, mid := range c.Middlewares {
if util.HasWhitespace(mid.Name) { if util.HasWhitespace(mid.Name) {
fmt.Printf("Warning: Middleware contains whitespace: %s | index: [%d], please remove whitespace characters\n", mid.Name, index) fmt.Printf("Warning: Middleware contains whitespace: %s | index: [%d], please remove whitespace characters\n", mid.Name, index)

View File

@@ -76,7 +76,7 @@ func (GatewayServer) Config(configFile string, ctx context.Context) (*GatewaySer
} }
logger.Info("Generating new configuration file...") logger.Info("Generating new configuration file...")
//check if config directory does exist // check if config directory does exist
if !util.FolderExists(ConfigDir) { if !util.FolderExists(ConfigDir) {
err := os.MkdirAll(ConfigDir, os.ModePerm) err := os.MkdirAll(ConfigDir, os.ModePerm)
if err != nil { if err != nil {
@@ -145,14 +145,14 @@ func initConfig(configFile string) error {
Cors: Cors{ Cors: Cors{
Origins: []string{"http://localhost:8080", "https://example.com"}, Origins: []string{"http://localhost:8080", "https://example.com"},
Headers: map[string]string{ Headers: map[string]string{
"Access-Control-Allow-Headers": "Origin, Authorization, Accept, Content-Type, Access-Control-Allow-Headers, X-Client-Id, X-Session-Id", "Access-Control-Allow-Headers": "Origin, Authorization, Accept, Content-Type, Access-Control-Allow-Headers",
"Access-Control-Allow-Credentials": "true", "Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "1728000", "Access-Control-Max-Age": "1728000",
}, },
}, },
Routes: []Route{ Routes: []Route{
{ {
Name: "Public", Name: "Example",
Path: "/", Path: "/",
Methods: []string{"GET"}, Methods: []string{"GET"},
Destination: "https://example.com", Destination: "https://example.com",
@@ -163,12 +163,17 @@ func initConfig(configFile string) error {
Timeout: "10s", Timeout: "10s",
HealthyStatuses: []int{200, 404}, HealthyStatuses: []int{200, 404},
}, },
Middlewares: []string{"api-forbidden-paths"}, DisableHostFording: true,
Middlewares: []string{"block-access"},
}, },
{ {
Name: "Basic auth", Name: "Load balancer",
Path: "/protected", Path: "/protected",
Destination: "https://example.com", Backends: []string{
"https://example.com",
"https://example2.com",
"https://example3.com",
},
Rewrite: "/", Rewrite: "/",
HealthCheck: RouteHealthCheck{}, HealthCheck: RouteHealthCheck{},
Cors: Cors{ Cors: Cors{
@@ -179,38 +184,7 @@ func initConfig(configFile string) error {
"Access-Control-Max-Age": "1728000", "Access-Control-Max-Age": "1728000",
}, },
}, },
Middlewares: []string{"basic-auth", "api-forbidden-paths"}, Middlewares: []string{"basic-auth", "block-access"},
},
{
Path: "/",
Name: "Hostname and load balancing example",
Hosts: []string{"example.com", "example.localhost"},
InterceptErrors: []int{404, 405, 500},
RateLimit: 60,
Backends: []string{
"https://example.com",
"https://example2.com",
"https://example4.com",
},
Rewrite: "/",
HealthCheck: RouteHealthCheck{},
},
{
Path: "/",
Name: "loadBalancing example",
Hosts: []string{"example.com", "example.localhost"},
Backends: []string{
"https://example.com",
"https://example2.com",
"https://example4.com",
},
Rewrite: "/",
HealthCheck: RouteHealthCheck{
Path: "/health/live",
HealthyStatuses: []int{200, 404},
Interval: "30s",
Timeout: "10s",
},
}, },
}, },
}, },
@@ -225,24 +199,9 @@ func initConfig(configFile string) error {
Username: "admin", Username: "admin",
Password: "admin", Password: "admin",
}, },
}, {
Name: "jwt",
Type: JWTAuth,
Paths: []string{
"/protected-access",
"/example-of-jwt",
},
Rule: JWTRuleMiddleware{
URL: "https://example.com/auth/userinfo",
RequiredHeaders: []string{
"Authorization",
},
Headers: map[string]string{},
Params: map[string]string{},
},
}, },
{ {
Name: "api-forbidden-paths", Name: "block-access",
Type: AccessMiddleware, Type: AccessMiddleware,
Paths: []string{ Paths: []string{
"/swagger-ui/*", "/swagger-ui/*",
@@ -251,46 +210,6 @@ func initConfig(configFile string) error {
"/actuator/*", "/actuator/*",
}, },
}, },
{
Name: "oauth-google",
Type: OAuth,
Paths: []string{
"/protected",
"/example-of-oauth",
},
Rule: OauthRulerMiddleware{
ClientID: "xxx",
ClientSecret: "xxx",
Provider: "google",
JWTSecret: "your-strong-jwt-secret | It's optional",
RedirectURL: "http://localhost:8080/callback",
Scopes: []string{"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile"},
Endpoint: OauthEndpoint{},
State: "randomStateString",
},
},
{
Name: "oauth-authentik",
Type: OAuth,
Paths: []string{
"/*",
},
Rule: OauthRulerMiddleware{
ClientID: "xxxx",
ClientSecret: "xxxx",
RedirectURL: "http://localhost:8080/callback",
Provider: "custom",
Scopes: []string{"email", "openid"},
JWTSecret: "your-strong-jwt-secret | It's optional",
Endpoint: OauthEndpoint{
AuthURL: "https://authentik.example.com/application/o/authorize/",
TokenURL: "https://authentik.example.com/application/o/token/",
UserInfoURL: "https://authentik.example.com/application/o/userinfo/",
},
State: "randomStateString",
},
},
}, },
} }
yamlData, err := yaml.Marshal(&conf) yamlData, err := yaml.Marshal(&conf)

View File

@@ -19,16 +19,16 @@ package pkg
type Cors struct { type Cors struct {
// Cors Allowed origins, // Cors Allowed origins,
//e.g: // e.g:
// //
// - http://localhost:80 // - http://localhost:80
// //
// - https://example.com // - https://example.com
Origins []string `yaml:"origins"` Origins []string `yaml:"origins"`
// //
//e.g: // e.g:
// //
//Access-Control-Allow-Origin: '*' // Access-Control-Allow-Origin: '*'
// //
// Access-Control-Allow-Methods: 'GET, POST, PUT, DELETE, OPTIONS' // Access-Control-Allow-Methods: 'GET, POST, PUT, DELETE, OPTIONS'
// //

View File

@@ -42,7 +42,7 @@ type Gateway struct {
DisableHealthCheckStatus bool `yaml:"disableHealthCheckStatus"` DisableHealthCheckStatus bool `yaml:"disableHealthCheckStatus"`
// DisableRouteHealthCheckError allows enabling and disabling backend healthcheck errors // DisableRouteHealthCheckError allows enabling and disabling backend healthcheck errors
DisableRouteHealthCheckError bool `yaml:"disableRouteHealthCheckError"` DisableRouteHealthCheckError bool `yaml:"disableRouteHealthCheckError"`
//Disable allows enabling and disabling displaying routes on start // Disable allows enabling and disabling displaying routes on start
DisableDisplayRouteOnStart bool `yaml:"disableDisplayRouteOnStart"` DisableDisplayRouteOnStart bool `yaml:"disableDisplayRouteOnStart"`
// DisableKeepAlive allows enabling and disabling KeepALive server // DisableKeepAlive allows enabling and disabling KeepALive server
DisableKeepAlive bool `yaml:"disableKeepAlive"` DisableKeepAlive bool `yaml:"disableKeepAlive"`

View File

@@ -32,11 +32,11 @@ func CORSHandler(cors Cors) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers from the cors config // Set CORS headers from the cors config
//Update Cors Headers // Update Cors Headers
for k, v := range cors.Headers { for k, v := range cors.Headers {
w.Header().Set(k, v) w.Header().Set(k, v)
} }
//Update Origin Cors Headers // Update Origin Cors Headers
if allowedOrigin(cors.Origins, r.Header.Get("Origin")) { if allowedOrigin(cors.Origins, r.Header.Get("Origin")) {
// Handle preflight requests (OPTIONS) // Handle preflight requests (OPTIONS)
if r.Method == "OPTIONS" { if r.Method == "OPTIONS" {
@@ -90,7 +90,7 @@ func (heathRoute HealthCheckRoute) HealthCheckHandler(w http.ResponseWriter, r *
} }
wg.Wait() // Wait for all requests to complete wg.Wait() // Wait for all requests to complete
response := HealthCheckResponse{ response := HealthCheckResponse{
Status: "healthy", //Goma proxy Status: "healthy", // Goma proxy
Routes: routes, // Routes health check Routes: routes, // Routes health check
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")

View File

@@ -98,7 +98,7 @@ func (health Health) createHealthCheckJob() error {
_, err := c.AddFunc(expression, func() { _, err := c.AddFunc(expression, func() {
err := health.Check() err := health.Check()
if err != nil { if err != nil {
logger.Error("Route %s is unhealthy: error %v", health.Name, err.Error()) logger.Error("Route %s is unhealthy: %v", health.Name, err.Error())
return return
} }
logger.Info("Route %s is healthy", health.Name) logger.Info("Route %s is healthy", health.Name)

View File

@@ -19,6 +19,7 @@ import (
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util" "github.com/jkaninda/goma-gateway/util"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"io"
"net/http" "net/http"
"time" "time"
) )
@@ -72,7 +73,12 @@ func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, e
if err != nil { if err != nil {
return UserInfo{}, err return UserInfo{}, err
} }
defer resp.Body.Close() defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
return
}
}(resp.Body)
// Parse the user info // Parse the user info
var userInfo UserInfo var userInfo UserInfo

View File

@@ -15,22 +15,23 @@
* *
*/ */
package pkg package metrics
import ( import (
"net/http"
"strconv"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
"net/http"
"strconv"
) )
type PrometheusRoute struct { type PrometheusRoute struct {
name string Name string
path string Path string
} }
var totalRequests = prometheus.NewCounterVec( var TotalRequests = prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "http_requests_total", Name: "http_requests_total",
Help: "Number of get requests.", Help: "Number of get requests.",
@@ -38,7 +39,7 @@ var totalRequests = prometheus.NewCounterVec(
[]string{"name", "path"}, []string{"name", "path"},
) )
var responseStatus = prometheus.NewCounterVec( var ResponseStatus = prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "response_status", Name: "response_status",
Help: "Status of HTTP response", Help: "Status of HTTP response",
@@ -46,22 +47,22 @@ var responseStatus = prometheus.NewCounterVec(
[]string{"status"}, []string{"status"},
) )
var httpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ var HttpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Name: "http_response_time_seconds", Name: "http_response_time_seconds",
Help: "Duration of HTTP requests.", Help: "Duration of HTTP requests.",
}, []string{"name", "path"}) }, []string{"name", "path"})
func (pr PrometheusRoute) prometheusMiddleware(next http.Handler) http.Handler { func (pr PrometheusRoute) PrometheusMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := pr.path path := pr.Path
if len(path) == 0 { if len(path) == 0 {
route := mux.CurrentRoute(r) route := mux.CurrentRoute(r)
path, _ = route.GetPathTemplate() path, _ = route.GetPathTemplate()
} }
timer := prometheus.NewTimer(httpDuration.WithLabelValues(pr.name, path)) timer := prometheus.NewTimer(HttpDuration.WithLabelValues(pr.Name, path))
responseStatus.WithLabelValues(strconv.Itoa(http.StatusOK)).Inc() ResponseStatus.WithLabelValues(strconv.Itoa(http.StatusOK)).Inc()
totalRequests.WithLabelValues(pr.name, path).Inc() TotalRequests.WithLabelValues(pr.Name, path).Inc()
timer.ObserveDuration() timer.ObserveDuration()
next.ServeHTTP(w, r) next.ServeHTTP(w, r)

View File

@@ -19,7 +19,7 @@ package pkg
// Middleware defined the route middlewares // Middleware defined the route middlewares
type Middleware struct { type Middleware struct {
//Path contains the name of middlewares and must be unique // Path contains the name of middlewares and must be unique
Name string `yaml:"name"` Name string `yaml:"name"`
// Type contains authentication types // Type contains authentication types
// //

View File

@@ -33,7 +33,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
if r.Header.Get(header) == "" { if r.Header.Get(header) == "" {
logger.Error("Proxy error, missing %s header", header) logger.Error("Proxy error, missing %s header", header)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
//check allowed origin // check allowed origin
if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
} }
@@ -42,7 +42,6 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
} }
} }
//token := r.Header.Get("Authorization")
authURL, err := url.Parse(jwtAuth.AuthURL) authURL, err := url.Parse(jwtAuth.AuthURL)
if err != nil { if err != nil {
logger.Error("Error parsing auth URL: %v", err) logger.Error("Error parsing auth URL: %v", err)

View File

@@ -29,8 +29,6 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.Allow() { if !rl.Allow() {
logger.Error("Too many requests from IP: %s %s %s", getRealIP(r), r.URL, r.UserAgent()) logger.Error("Too many requests from IP: %s %s %s", getRealIP(r), r.URL, r.UserAgent())
//RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor)
// Rate limit exceeded, return a 429 Too Many Requests response // Rate limit exceeded, return a 429 Too Many Requests response
w.WriteHeader(http.StatusTooManyRequests) w.WriteHeader(http.StatusTooManyRequests)
_, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))) _, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)))
@@ -75,7 +73,7 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
if client.RequestCount > rl.requests { if client.RequestCount > rl.requests {
logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent())
//Update Origin Cors Headers // Update Origin Cors Headers
if allowedOrigin(rl.origins, r.Header.Get("Origin")) { if allowedOrigin(rl.origins, r.Header.Get("Origin")) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
} }

View File

@@ -26,13 +26,12 @@ import (
// RateLimiter defines requests limit properties. // RateLimiter defines requests limit properties.
type RateLimiter struct { type RateLimiter struct {
requests int requests int
id string id string
window time.Duration window time.Duration
clientMap map[string]*Client clientMap map[string]*Client
mu sync.Mutex mu sync.Mutex
origins []string origins []string
//hosts []string
redisBased bool redisBased bool
} }

View File

@@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"github.com/jkaninda/goma-gateway/internal/middlewares" "github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
@@ -43,7 +44,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
} }
} }
// Set CORS headers from the cors config // Set CORS headers from the cors config
//Update Cors Headers // Update Cors Headers
for k, v := range proxyRoute.cors.Headers { for k, v := range proxyRoute.cors.Headers {
w.Header().Set(k, v) w.Header().Set(k, v)
} }
@@ -80,18 +81,13 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
// Create proxy // Create proxy
proxy := httputil.NewSingleHostReverseProxy(backendURL) proxy := httputil.NewSingleHostReverseProxy(backendURL)
// Rewrite // Rewrite
if proxyRoute.path != "" && proxyRoute.rewrite != "" { rewritePath(r, proxyRoute)
// Rewrite the path
if strings.HasPrefix(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path)) {
r.URL.Path = strings.Replace(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path), proxyRoute.rewrite, 1)
}
}
// Custom transport with InsecureSkipVerify // Custom transport with InsecureSkipVerify
proxy.Transport = &http.Transport{TLSClientConfig: &tls.Config{ proxy.Transport = &http.Transport{TLSClientConfig: &tls.Config{
InsecureSkipVerify: proxyRoute.insecureSkipVerify, InsecureSkipVerify: proxyRoute.insecureSkipVerify,
}, },
} }
w.Header().Set("Proxied-By", gatewayName) //Set Server name w.Header().Set("Proxied-By", gatewayName) // Set Server name
w.Header().Del("Server") // Remove the Server header w.Header().Del("Server") // Remove the Server header
// Custom error handler for proxy errors // Custom error handler for proxy errors
proxy.ErrorHandler = ProxyErrorHandler proxy.ErrorHandler = ProxyErrorHandler
@@ -105,3 +101,13 @@ func getNextBackend(backendURLs []string) *url.URL {
backendURL, _ := url.Parse(backendURLs[idx]) backendURL, _ := url.Parse(backendURLs[idx])
return backendURL return backendURL
} }
// rewritePath rewrites the path if it matches the prefix
func rewritePath(r *http.Request, proxyRoute ProxyRoute) {
if proxyRoute.path != "" && proxyRoute.rewrite != "" {
// Rewrite the path if it matches the prefix
if strings.HasPrefix(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path)) {
r.URL.Path = util.ParseURLPath(strings.Replace(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path), proxyRoute.rewrite, 1))
}
}
}

View File

@@ -22,13 +22,12 @@ import (
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
) )
func (gatewayServer GatewayServer) initRedis() error { func (gatewayServer GatewayServer) initRedis() {
if gatewayServer.gateway.Redis.Addr == "" { if len(gatewayServer.gateway.Redis.Addr) != 0 {
return nil logger.Info("Initializing Redis...")
middlewares.InitRedis(gatewayServer.gateway.Redis.Addr, gatewayServer.gateway.Redis.Password)
} }
logger.Info("Initializing Redis...")
middlewares.InitRedis(gatewayServer.gateway.Redis.Addr, gatewayServer.gateway.Redis.Password)
return nil
} }
func (gatewayServer GatewayServer) closeRedis() { func (gatewayServer GatewayServer) closeRedis() {

View File

@@ -17,6 +17,7 @@ limitations under the License.
*/ */
import ( import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/internal/metrics"
"github.com/jkaninda/goma-gateway/internal/middlewares" "github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util" "github.com/jkaninda/goma-gateway/util"
@@ -26,12 +27,12 @@ import (
) )
func init() { func init() {
_ = prometheus.Register(totalRequests) _ = prometheus.Register(metrics.TotalRequests)
_ = prometheus.Register(responseStatus) _ = prometheus.Register(metrics.ResponseStatus)
_ = prometheus.Register(httpDuration) _ = prometheus.Register(metrics.HttpDuration)
} }
// Initialize the routes // Initialize initializes the routes
func (gatewayServer GatewayServer) Initialize() *mux.Router { func (gatewayServer GatewayServer) Initialize() *mux.Router {
gateway := gatewayServer.gateway gateway := gatewayServer.gateway
m := gatewayServer.middlewares m := gatewayServer.middlewares
@@ -39,8 +40,9 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if len(gateway.Redis.Addr) != 0 { if len(gateway.Redis.Addr) != 0 {
redisBased = true redisBased = true
} }
//Routes background healthcheck // Routes background healthcheck
routesHealthCheck(gateway.Routes) routesHealthCheck(gateway.Routes)
r := mux.NewRouter() r := mux.NewRouter()
heath := HealthCheckRoute{ heath := HealthCheckRoute{
DisableRouteHealthCheckError: gateway.DisableRouteHealthCheckError, DisableRouteHealthCheckError: gateway.DisableRouteHealthCheckError,
@@ -64,10 +66,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
logger.Info("Block common exploits enabled") logger.Info("Block common exploits enabled")
r.Use(middlewares.BlockExploitsMiddleware) r.Use(middlewares.BlockExploitsMiddleware)
} }
if gateway.RateLimit > 0 { if gateway.RateLimit != 0 {
// Add rate limit middlewares to all routes, if defined // Add rate limit middlewares to all routes, if defined
rateLimit := middlewares.RateLimit{ rateLimit := middlewares.RateLimit{
Id: "global_rate", //Generate a unique ID for routes Id: "global_rate", // Generate a unique ID for routes
Requests: gateway.RateLimit, Requests: gateway.RateLimit,
Window: time.Minute, // requests per minute Window: time.Minute, // requests per minute
Origins: gateway.Cors.Origins, Origins: gateway.Cors.Origins,
@@ -79,16 +81,17 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
r.Use(limiter.RateLimitMiddleware()) r.Use(limiter.RateLimitMiddleware())
} }
for rIndex, route := range gateway.Routes { for rIndex, route := range gateway.Routes {
if route.Path != "" { if len(route.Path) != 0 {
if route.Destination == "" && len(route.Backends) == 0 { // Checks if route destination and backend are empty
if len(route.Destination) == 0 && len(route.Backends) == 0 {
logger.Fatal("Route %s : destination or backends should not be empty", route.Name) logger.Fatal("Route %s : destination or backends should not be empty", route.Name)
} }
// Apply middlewares to route // Apply middlewares to the route
for _, mid := range route.Middlewares { for _, middleware := range route.Middlewares {
if mid != "" { if middleware != "" {
// Get Access middlewares if it does exist // Get Access middlewares if it does exist
accessMiddleware, err := getMiddleware([]string{mid}, m) accessMiddleware, err := getMiddleware([]string{middleware}, m)
if err != nil { if err != nil {
logger.Error("Error: %v", err.Error()) logger.Error("Error: %v", err.Error())
} else { } else {
@@ -104,114 +107,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
// Get route authentication middlewares if it does exist // Get route authentication middlewares if it does exist
rMiddleware, err := getMiddleware([]string{mid}, m) routeMiddleware, err := getMiddleware([]string{middleware}, m)
if err != nil { if err != nil {
//Error: middlewares not found // Error: middlewares not found
logger.Error("Error: %v", err.Error()) logger.Error("Error: %v", err.Error())
} else { } else {
for _, midPath := range rMiddleware.Paths { attachAuthMiddlewares(route, routeMiddleware, gateway, r)
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
backends: route.Backends,
disableHostFording: route.DisableHostFording,
methods: route.Methods,
cors: route.Cors,
insecureSkipVerify: route.InsecureSkipVerify,
}
secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter()
//callBackRouter := r.PathPrefix(util.ParseRoutePath(route.Path, "/callback")).Subrouter()
//Check Authentication middlewares
switch rMiddleware.Type {
case BasicAuth:
basicAuth, err := getBasicAuthMiddleware(rMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
amw := middlewares.AuthBasic{
Username: basicAuth.Username,
Password: basicAuth.Password,
Headers: nil,
Params: nil,
}
// Apply JWT authentication middlewares
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
}
case JWTAuth:
jwt, err := getJWTMiddleware(rMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
amw := middlewares.JwtAuth{
AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers,
Params: jwt.Params,
Origins: gateway.Cors.Origins,
}
// Apply JWT authentication middlewares
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
}
case OAuth, "openid":
oauth, err := oAuthMiddleware(rMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
redirectURL := "/callback" + route.Path
if oauth.RedirectURL != "" {
redirectURL = oauth.RedirectURL
}
amw := middlewares.Oauth{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: redirectURL,
Scopes: oauth.Scopes,
Endpoint: middlewares.OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
UserInfoURL: oauth.Endpoint.UserInfoURL,
},
State: oauth.State,
Origins: gateway.Cors.Origins,
JWTSecret: oauth.JWTSecret,
Provider: oauth.Provider,
}
oauthRuler := oauthRulerMiddleware(amw)
// Check if a cookie path is defined
if oauthRuler.CookiePath == "" {
oauthRuler.CookiePath = route.Path
}
// Check if a RedirectPath is defined
if oauthRuler.RedirectPath == "" {
oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, midPath)
}
if oauthRuler.Provider == "" {
oauthRuler.Provider = "custom"
}
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
// Callback route
r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET")
}
default:
if !doesExist(rMiddleware.Type) {
logger.Error("Unknown middlewares type %s", rMiddleware.Type)
}
}
}
} }
} else { } else {
logger.Error("Error, middlewares path is empty") logger.Error("Error, middlewares path is empty")
@@ -242,7 +143,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
id = util.Slug(route.Name) id = util.Slug(route.Name)
} }
// Apply route rate limit // Apply route rate limit
if route.RateLimit > 0 { if route.RateLimit != 0 {
rateLimit := middlewares.RateLimit{ rateLimit := middlewares.RateLimit{
Id: id, // Use route index as ID Id: id, // Use route index as ID
Requests: route.RateLimit, Requests: route.RateLimit,
@@ -265,12 +166,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
router.PathPrefix("").Handler(proxyRoute.ProxyHandler()) router.PathPrefix("").Handler(proxyRoute.ProxyHandler())
} }
if gateway.EnableMetrics { if gateway.EnableMetrics {
pr := PrometheusRoute{ pr := metrics.PrometheusRoute{
name: route.Name, Name: route.Name,
path: route.Path, Path: route.Path,
} }
// Prometheus endpoint // Prometheus endpoint
router.Use(pr.prometheusMiddleware) router.Use(pr.PrometheusMiddleware)
} }
// Apply route Error interceptor middlewares // Apply route Error interceptor middlewares
if len(route.InterceptErrors) != 0 { if len(route.InterceptErrors) != 0 {
@@ -299,3 +200,107 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
return r return r
} }
func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) {
for _, middlewarePath := range routeMiddleware.Paths {
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
backends: route.Backends,
disableHostFording: route.DisableHostFording,
methods: route.Methods,
cors: route.Cors,
insecureSkipVerify: route.InsecureSkipVerify,
}
secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, middlewarePath)).Subrouter()
// Check Authentication middleware types
switch routeMiddleware.Type {
case BasicAuth:
basicAuth, err := getBasicAuthMiddleware(routeMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
authBasic := middlewares.AuthBasic{
Username: basicAuth.Username,
Password: basicAuth.Password,
Headers: nil,
Params: nil,
}
// Apply JWT authentication middlewares
secureRouter.Use(authBasic.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
}
case JWTAuth:
jwt, err := getJWTMiddleware(routeMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
jwtAuth := middlewares.JwtAuth{
AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers,
Params: jwt.Params,
Origins: gateway.Cors.Origins,
}
// Apply JWT authentication middlewares
secureRouter.Use(jwtAuth.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
}
case OAuth:
oauth, err := oAuthMiddleware(routeMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
redirectURL := "/callback" + route.Path
if oauth.RedirectURL != "" {
redirectURL = oauth.RedirectURL
}
amw := middlewares.Oauth{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: redirectURL,
Scopes: oauth.Scopes,
Endpoint: middlewares.OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
UserInfoURL: oauth.Endpoint.UserInfoURL,
},
State: oauth.State,
Origins: gateway.Cors.Origins,
JWTSecret: oauth.JWTSecret,
Provider: oauth.Provider,
}
oauthRuler := oauthRulerMiddleware(amw)
// Check if a cookie path is defined
if oauthRuler.CookiePath == "" {
oauthRuler.CookiePath = route.Path
}
// Check if a RedirectPath is defined
if oauthRuler.RedirectPath == "" {
oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, middlewarePath)
}
if oauthRuler.Provider == "" {
oauthRuler.Provider = "custom"
}
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
// Callback route
r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET")
}
default:
if !doesExist(routeMiddleware.Type) {
logger.Error("Unknown middlewares type %s", routeMiddleware.Type)
}
}
}
}

View File

@@ -23,9 +23,9 @@ type Route struct {
Path string `yaml:"path"` Path string `yaml:"path"`
// Name defines route name // Name defines route name
Name string `yaml:"name"` Name string `yaml:"name"`
//Host Domain/host based request routing // Host Domain/host based request routing
//Host string `yaml:"host"` // Host string `yaml:"host"`
//Hosts Domains/hosts based request routing // Hosts Domains/hosts based request routing
Hosts []string `yaml:"hosts"` Hosts []string `yaml:"hosts"`
// Rewrite rewrites route path to desired path // Rewrite rewrites route path to desired path
// //

View File

@@ -19,7 +19,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"net/http" "net/http"
"os" "os"
@@ -33,9 +32,7 @@ func (gatewayServer GatewayServer) Start() error {
logger.Info("Initializing routes...") logger.Info("Initializing routes...")
route := gatewayServer.Initialize() route := gatewayServer.Initialize()
logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares))
if err := gatewayServer.initRedis(); err != nil { gatewayServer.initRedis()
return fmt.Errorf("failed to initialize Redis: %w", err)
}
defer gatewayServer.closeRedis() defer gatewayServer.closeRedis()
tlsConfig, listenWithTLS, err := gatewayServer.initTLS() tlsConfig, listenWithTLS, err := gatewayServer.initTLS()
@@ -51,9 +48,7 @@ func (gatewayServer GatewayServer) Start() error {
httpsServer := gatewayServer.createServer(":8443", route, tlsConfig) httpsServer := gatewayServer.createServer(":8443", route, tlsConfig)
// Start HTTP/HTTPS servers // Start HTTP/HTTPS servers
if err := gatewayServer.startServers(httpServer, httpsServer, listenWithTLS); err != nil { gatewayServer.startServers(httpServer, httpsServer, listenWithTLS)
return err
}
// Handle graceful shutdown // Handle graceful shutdown
return gatewayServer.shutdown(httpServer, httpsServer, listenWithTLS) return gatewayServer.shutdown(httpServer, httpsServer, listenWithTLS)
@@ -70,7 +65,7 @@ func (gatewayServer GatewayServer) createServer(addr string, handler http.Handle
} }
} }
func (gatewayServer GatewayServer) startServers(httpServer, httpsServer *http.Server, listenWithTLS bool) error { func (gatewayServer GatewayServer) startServers(httpServer, httpsServer *http.Server, listenWithTLS bool) {
go func() { go func() {
logger.Info("Starting HTTP server on 0.0.0.0:8080") logger.Info("Starting HTTP server on 0.0.0.0:8080")
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
@@ -86,8 +81,6 @@ func (gatewayServer GatewayServer) startServers(httpServer, httpsServer *http.Se
} }
}() }()
} }
return nil
} }
func (gatewayServer GatewayServer) shutdown(httpServer, httpsServer *http.Server, listenWithTLS bool) error { func (gatewayServer GatewayServer) shutdown(httpServer, httpsServer *http.Server, listenWithTLS bool) error {

View File

@@ -47,7 +47,7 @@ type JWTRuleMiddleware struct {
// //
// In case you want to get headers from Authentication service and inject them to next request's params. // In case you want to get headers from Authentication service and inject them to next request's params.
// //
//e.g: Header X-Auth-UserId to query userId // e.g: Header X-Auth-UserId to query userId
Params map[string]string `yaml:"params"` Params map[string]string `yaml:"params"`
} }
type OauthRulerMiddleware struct { type OauthRulerMiddleware struct {
@@ -66,7 +66,7 @@ type OauthRulerMiddleware struct {
RedirectURL string `yaml:"redirectUrl"` RedirectURL string `yaml:"redirectUrl"`
// RedirectPath is the PATH to redirect users after authentication, e.g: /my-protected-path/dashboard // RedirectPath is the PATH to redirect users after authentication, e.g: /my-protected-path/dashboard
RedirectPath string `yaml:"redirectPath"` RedirectPath string `yaml:"redirectPath"`
//CookiePath e.g: /my-protected-path or / || by default is applied on a route path // CookiePath e.g: /my-protected-path or / || by default is applied on a route path
CookiePath string `yaml:"cookiePath"` CookiePath string `yaml:"cookiePath"`
// Scope specifies optional requested permissions. // Scope specifies optional requested permissions.
@@ -119,11 +119,10 @@ type GatewayServer struct {
middlewares []Middleware middlewares []Middleware
} }
type ProxyRoute struct { type ProxyRoute struct {
path string path string
rewrite string rewrite string
destination string destination string
backends []string backends []string
//healthCheck RouteHealthCheck
methods []string methods []string
cors Cors cors Cors
disableHostFording bool disableHostFording bool

View File

@@ -55,7 +55,7 @@ func Fatal(msg string, args ...interface{}) {
func Debug(msg string, args ...interface{}) { func Debug(msg string, args ...interface{}) {
log.SetOutput(getStd(util.GetStringEnv("GOMA_ACCESS_LOG", "/dev/stdout"))) log.SetOutput(getStd(util.GetStringEnv("GOMA_ACCESS_LOG", "/dev/stdout")))
logLevel := util.GetStringEnv("GOMA_LOG_LEVEL", "") logLevel := util.GetStringEnv("GOMA_LOG_LEVEL", "")
if strings.ToLower(logLevel) == "trace" || strings.ToLower(logLevel) == "debug" { if strings.ToLower(logLevel) == traceLog || strings.ToLower(logLevel) == "debug" {
logWithCaller("DEBUG", msg, args...) logWithCaller("DEBUG", msg, args...)
} }
@@ -63,7 +63,7 @@ func Debug(msg string, args ...interface{}) {
func Trace(msg string, args ...interface{}) { func Trace(msg string, args ...interface{}) {
log.SetOutput(getStd(util.GetStringEnv("GOMA_ACCESS_LOG", "/dev/stdout"))) log.SetOutput(getStd(util.GetStringEnv("GOMA_ACCESS_LOG", "/dev/stdout")))
logLevel := util.GetStringEnv("GOMA_LOG_LEVEL", "") logLevel := util.GetStringEnv("GOMA_LOG_LEVEL", "")
if strings.ToLower(logLevel) == "trace" { if strings.ToLower(logLevel) == traceLog {
logWithCaller("DEBUG", msg, args...) logWithCaller("DEBUG", msg, args...)
} }
@@ -86,7 +86,7 @@ func logWithCaller(level, msg string, args ...interface{}) {
// Log message with caller information if GOMA_LOG_LEVEL is trace // Log message with caller information if GOMA_LOG_LEVEL is trace
logLevel := util.GetStringEnv("GOMA_LOG_LEVEL", "") logLevel := util.GetStringEnv("GOMA_LOG_LEVEL", "")
if strings.ToLower(logLevel) != "off" { if strings.ToLower(logLevel) != "off" {
if strings.ToLower(logLevel) == "trace" { if strings.ToLower(logLevel) == traceLog {
log.Printf("%s: %s (File: %s, Line: %d)\n", level, formattedMessage, file, line) log.Printf("%s: %s (File: %s, Line: %d)\n", level, formattedMessage, file, line)
} else { } else {
log.Printf("%s: %s\n", level, formattedMessage) log.Printf("%s: %s\n", level, formattedMessage)

20
pkg/logger/var.go Normal file
View File

@@ -0,0 +1,20 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package logger
const traceLog = "trace"

View File

@@ -10,13 +10,14 @@ You may get a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
*/ */
import ( import (
"github.com/robfig/cron/v3"
"net/url" "net/url"
"os" "os"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/robfig/cron/v3"
) )
// FileExists checks if the file does exist // FileExists checks if the file does exist
@@ -87,7 +88,7 @@ func MergeSlices(slice1, slice2 []string) []string {
return append(slice1, slice2...) return append(slice1, slice2...)
} }
// ParseURLPath returns a URL path // ParseURLPath removes duplicated [//]
func ParseURLPath(urlPath string) string { func ParseURLPath(urlPath string) string {
// Replace any double slashes with a single slash // Replace any double slashes with a single slash
urlPath = strings.ReplaceAll(urlPath, "//", "/") urlPath = strings.ReplaceAll(urlPath, "//", "/")
@@ -148,7 +149,7 @@ func Slug(text string) string {
text = strings.ToLower(text) text = strings.ToLower(text)
// Replace spaces and special characters with hyphens // Replace spaces and special characters with hyphens
re := regexp.MustCompile(`[^\w]+`) re := regexp.MustCompile(`\W+`)
text = re.ReplaceAllString(text, "-") text = re.ReplaceAllString(text, "-")
// Remove leading and trailing hyphens // Remove leading and trailing hyphens