chore: refactroing of code

Commenting code for enhancing readability
This commit is contained in:
Jonas Kaninda
2024-11-19 18:18:58 +01:00
parent c54ae4bd34
commit 1c0097d8e4
15 changed files with 117 additions and 86 deletions

View File

@@ -71,6 +71,8 @@ func (health Health) Check() error {
}
return nil
}
// routesHealthCheck creates healthcheck job
func routesHealthCheck(routes []Route) {
for _, health := range healthCheckRoutes(routes) {
go func() {
@@ -84,11 +86,14 @@ func routesHealthCheck(routes []Route) {
}
}
// createHealthCheckJob create healthcheck job
func (health Health) createHealthCheckJob() error {
interval := "30s"
if len(health.Interval) > 0 {
interval = health.Interval
}
// create cron expression
expression := fmt.Sprintf("@every %s", interval)
if !util.IsValidCronExpression(expression) {
logger.Error("Health check interval is invalid: %s", interval)
@@ -113,3 +118,45 @@ func (health Health) createHealthCheckJob() error {
defer c.Stop()
select {}
}
// healthCheckRoutes creates and returns []Health
func healthCheckRoutes(routes []Route) []Health {
var healthRoutes []Health
for _, route := range routes {
if len(route.HealthCheck.Path) != 0 {
timeout, _ := util.ParseDuration("")
if len(route.HealthCheck.Timeout) > 0 {
d1, err1 := util.ParseDuration(route.HealthCheck.Timeout)
if err1 != nil {
logger.Error("Health check timeout is invalid: %s", route.HealthCheck.Timeout)
}
timeout = d1
}
if len(route.Backends) != 0 {
for index, backend := range route.Backends {
health := Health{
Name: fmt.Sprintf("%s - [%d]", route.Name, index),
URL: backend + route.HealthCheck.Path,
TimeOut: timeout,
HealthyStatuses: route.HealthCheck.HealthyStatuses,
InsecureSkipVerify: route.InsecureSkipVerify,
}
healthRoutes = append(healthRoutes, health)
}
} else {
health := Health{
Name: route.Name,
URL: route.Destination + route.HealthCheck.Path,
TimeOut: timeout,
HealthyStatuses: route.HealthCheck.HealthyStatuses,
InsecureSkipVerify: route.InsecureSkipVerify,
}
healthRoutes = append(healthRoutes, health)
}
} else {
logger.Debug("Route %s's healthCheck is undefined", route.Name)
}
}
return healthRoutes
}

View File

@@ -11,25 +11,20 @@ You may get a copy of the License at
*/
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"github.com/golang-jwt/jwt"
"github.com/jedib0t/go-pretty/v6/table"
"github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util"
"golang.org/x/oauth2"
"io"
"net/http"
"time"
)
// printRoute prints routes
func printRoute(routes []Route) {
t := table.NewWriter()
t.AppendHeader(table.Row{"Name", "Route", "Rewrite", "Destination"})
t.AppendHeader(table.Row{"Name", "Path", "Rewrite", "Destination"})
for _, route := range routes {
if len(route.Backends) > 0 {
if len(route.Backends) != 0 {
t.AppendRow(table.Row{route.Name, route.Path, route.Rewrite, fmt.Sprintf("backends: [%d]", len(route.Backends))})
} else {
@@ -50,21 +45,6 @@ func getRealIP(r *http.Request) string {
return r.RemoteAddr
}
// loadTLS loads TLS Certificate
func loadTLS(cert, key string) (*tls.Config, error) {
if cert == "" && key == "" {
return nil, fmt.Errorf("no certificate or key file provided")
}
serverCert, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
logger.Error("Error loading server certificate: %v", err)
return nil, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{serverCert},
}
return tlsConfig, nil
}
func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, error) {
oauthConfig := oauth2Config(oauth)
// Call the user info endpoint with the token
@@ -88,64 +68,3 @@ func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, e
return userInfo, nil
}
func createJWT(email, jwtSecret string) (string, error) {
// Define JWT claims
claims := jwt.MapClaims{
"email": email,
"exp": jwt.TimeFunc().Add(time.Hour * 24).Unix(), // Token expiration
"iss": "Goma-Gateway", // Issuer claim
}
// Create a new token with HS256 signing method
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Sign the token with a secret
signedToken, err := token.SignedString([]byte(jwtSecret))
if err != nil {
return "", err
}
return signedToken, nil
}
// healthCheckRoutes creates []Health
func healthCheckRoutes(routes []Route) []Health {
var healthRoutes []Health
for _, route := range routes {
if len(route.HealthCheck.Path) > 0 {
timeout, _ := util.ParseDuration("")
if len(route.HealthCheck.Timeout) > 0 {
d1, err1 := util.ParseDuration(route.HealthCheck.Timeout)
if err1 != nil {
logger.Error("Health check timeout is invalid: %s", route.HealthCheck.Timeout)
}
timeout = d1
}
if len(route.Backends) > 0 {
for index, backend := range route.Backends {
health := Health{
Name: fmt.Sprintf("%s - [%d]", route.Name, index),
URL: backend + route.HealthCheck.Path,
TimeOut: timeout,
HealthyStatuses: route.HealthCheck.HealthyStatuses,
InsecureSkipVerify: route.InsecureSkipVerify,
}
healthRoutes = append(healthRoutes, health)
}
} else {
health := Health{
Name: route.Name,
URL: route.Destination + route.HealthCheck.Path,
TimeOut: timeout,
HealthyStatuses: route.HealthCheck.HealthyStatuses,
InsecureSkipVerify: route.InsecureSkipVerify,
}
healthRoutes = append(healthRoutes, health)
}
} else {
logger.Debug("Route %s's healthCheck is undefined", route.Name)
}
}
return healthRoutes
}

44
internal/jwt.go Normal file
View File

@@ -0,0 +1,44 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pkg
import (
"github.com/golang-jwt/jwt"
"time"
)
// createJWT create JWT token
func createJWT(email, jwtSecret string) (string, error) {
// Define JWT claims
claims := jwt.MapClaims{
"email": email,
"exp": jwt.TimeFunc().Add(time.Hour * 24).Unix(), // Token expiration
"iss": "Goma-Gateway", // Issuer claim
}
// Create a new token with HS256 signing method
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Sign the token with a secret
signedToken, err := token.SignedString([]byte(jwtSecret))
if err != nil {
return "", err
}
return signedToken, nil
}

View File

@@ -52,6 +52,7 @@ var HttpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Help: "Duration of HTTP requests.",
}, []string{"name", "path"})
// PrometheusMiddleware Prometheus http handler middleware, returns http.Handler
func (pr PrometheusRoute) PrometheusMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := pr.Path

View File

@@ -23,6 +23,7 @@ import (
"slices"
)
// getRealIP returns user real IP
func getRealIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip

View File

@@ -24,6 +24,7 @@ import (
"github.com/redis/go-redis/v9"
)
// redisRateLimiter, handle rateLimit
func redisRateLimiter(clientIP string, rate int) error {
ctx := context.Background()

View File

@@ -74,7 +74,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
r.Host = targetURL.Host
}
backendURL, _ := url.Parse(proxyRoute.destination)
if len(proxyRoute.backends) > 0 {
if len(proxyRoute.backends) != 0 {
// Select the next backend URL
backendURL = getNextBackend(proxyRoute.backends)
}
@@ -87,8 +87,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
InsecureSkipVerify: proxyRoute.insecureSkipVerify,
},
}
w.Header().Set("Proxied-By", gatewayName) // Set Server name
w.Header().Del("Server") // Remove the Server header
w.Header().Set("Proxied-By", gatewayName)
// Custom error handler for proxy errors
proxy.ErrorHandler = ProxyErrorHandler
proxy.ServeHTTP(w, r)

View File

@@ -26,6 +26,7 @@ import (
"time"
)
// init initializes prometheus metrics
func init() {
_ = prometheus.Register(metrics.TotalRequests)
_ = prometheus.Register(metrics.ResponseStatus)
@@ -88,6 +89,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
logger.Info("Block common exploits enabled")
r.Use(middlewares.BlockExploitsMiddleware)
}
// check if RateLimit is set
if gateway.RateLimit != 0 {
// Add rate limit middlewares to all routes, if defined
rateLimit := middlewares.RateLimit{

View File

@@ -20,6 +20,7 @@ package pkg
import (
"crypto/tls"
"fmt"
"github.com/jkaninda/goma-gateway/pkg/logger"
)
func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) {
@@ -34,3 +35,19 @@ func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) {
}
return tlsConfig, true, nil
}
// loadTLS loads TLS Certificate
func loadTLS(cert, key string) (*tls.Config, error) {
if cert == "" && key == "" {
return nil, fmt.Errorf("no certificate or key file provided")
}
serverCert, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
logger.Error("Error loading server certificate: %v", err)
return nil, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{serverCert},
}
return tlsConfig, nil
}