chore: refactroing of code
Commenting code for enhancing readability
This commit is contained in:
@@ -71,6 +71,8 @@ func (health Health) Check() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// routesHealthCheck creates healthcheck job
|
||||||
func routesHealthCheck(routes []Route) {
|
func routesHealthCheck(routes []Route) {
|
||||||
for _, health := range healthCheckRoutes(routes) {
|
for _, health := range healthCheckRoutes(routes) {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -84,11 +86,14 @@ func routesHealthCheck(routes []Route) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createHealthCheckJob create healthcheck job
|
||||||
func (health Health) createHealthCheckJob() error {
|
func (health Health) createHealthCheckJob() error {
|
||||||
interval := "30s"
|
interval := "30s"
|
||||||
if len(health.Interval) > 0 {
|
if len(health.Interval) > 0 {
|
||||||
interval = health.Interval
|
interval = health.Interval
|
||||||
}
|
}
|
||||||
|
// create cron expression
|
||||||
expression := fmt.Sprintf("@every %s", interval)
|
expression := fmt.Sprintf("@every %s", interval)
|
||||||
if !util.IsValidCronExpression(expression) {
|
if !util.IsValidCronExpression(expression) {
|
||||||
logger.Error("Health check interval is invalid: %s", interval)
|
logger.Error("Health check interval is invalid: %s", interval)
|
||||||
@@ -113,3 +118,45 @@ func (health Health) createHealthCheckJob() error {
|
|||||||
defer c.Stop()
|
defer c.Stop()
|
||||||
select {}
|
select {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// healthCheckRoutes creates and returns []Health
|
||||||
|
func healthCheckRoutes(routes []Route) []Health {
|
||||||
|
var healthRoutes []Health
|
||||||
|
for _, route := range routes {
|
||||||
|
if len(route.HealthCheck.Path) != 0 {
|
||||||
|
timeout, _ := util.ParseDuration("")
|
||||||
|
if len(route.HealthCheck.Timeout) > 0 {
|
||||||
|
d1, err1 := util.ParseDuration(route.HealthCheck.Timeout)
|
||||||
|
if err1 != nil {
|
||||||
|
logger.Error("Health check timeout is invalid: %s", route.HealthCheck.Timeout)
|
||||||
|
}
|
||||||
|
timeout = d1
|
||||||
|
}
|
||||||
|
if len(route.Backends) != 0 {
|
||||||
|
for index, backend := range route.Backends {
|
||||||
|
health := Health{
|
||||||
|
Name: fmt.Sprintf("%s - [%d]", route.Name, index),
|
||||||
|
URL: backend + route.HealthCheck.Path,
|
||||||
|
TimeOut: timeout,
|
||||||
|
HealthyStatuses: route.HealthCheck.HealthyStatuses,
|
||||||
|
InsecureSkipVerify: route.InsecureSkipVerify,
|
||||||
|
}
|
||||||
|
healthRoutes = append(healthRoutes, health)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
health := Health{
|
||||||
|
Name: route.Name,
|
||||||
|
URL: route.Destination + route.HealthCheck.Path,
|
||||||
|
TimeOut: timeout,
|
||||||
|
HealthyStatuses: route.HealthCheck.HealthyStatuses,
|
||||||
|
InsecureSkipVerify: route.InsecureSkipVerify,
|
||||||
|
}
|
||||||
|
healthRoutes = append(healthRoutes, health)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("Route %s's healthCheck is undefined", route.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return healthRoutes
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,25 +11,20 @@ You may get a copy of the License at
|
|||||||
*/
|
*/
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"github.com/jedib0t/go-pretty/v6/table"
|
"github.com/jedib0t/go-pretty/v6/table"
|
||||||
"github.com/jkaninda/goma-gateway/pkg/logger"
|
|
||||||
"github.com/jkaninda/goma-gateway/util"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// printRoute prints routes
|
// printRoute prints routes
|
||||||
func printRoute(routes []Route) {
|
func printRoute(routes []Route) {
|
||||||
t := table.NewWriter()
|
t := table.NewWriter()
|
||||||
t.AppendHeader(table.Row{"Name", "Route", "Rewrite", "Destination"})
|
t.AppendHeader(table.Row{"Name", "Path", "Rewrite", "Destination"})
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if len(route.Backends) > 0 {
|
if len(route.Backends) != 0 {
|
||||||
t.AppendRow(table.Row{route.Name, route.Path, route.Rewrite, fmt.Sprintf("backends: [%d]", len(route.Backends))})
|
t.AppendRow(table.Row{route.Name, route.Path, route.Rewrite, fmt.Sprintf("backends: [%d]", len(route.Backends))})
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
@@ -50,21 +45,6 @@ func getRealIP(r *http.Request) string {
|
|||||||
return r.RemoteAddr
|
return r.RemoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadTLS loads TLS Certificate
|
|
||||||
func loadTLS(cert, key string) (*tls.Config, error) {
|
|
||||||
if cert == "" && key == "" {
|
|
||||||
return nil, fmt.Errorf("no certificate or key file provided")
|
|
||||||
}
|
|
||||||
serverCert, err := tls.LoadX509KeyPair(cert, key)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error loading server certificate: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{serverCert},
|
|
||||||
}
|
|
||||||
return tlsConfig, nil
|
|
||||||
}
|
|
||||||
func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, error) {
|
func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, error) {
|
||||||
oauthConfig := oauth2Config(oauth)
|
oauthConfig := oauth2Config(oauth)
|
||||||
// Call the user info endpoint with the token
|
// Call the user info endpoint with the token
|
||||||
@@ -88,64 +68,3 @@ func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, e
|
|||||||
|
|
||||||
return userInfo, nil
|
return userInfo, nil
|
||||||
}
|
}
|
||||||
func createJWT(email, jwtSecret string) (string, error) {
|
|
||||||
// Define JWT claims
|
|
||||||
claims := jwt.MapClaims{
|
|
||||||
"email": email,
|
|
||||||
"exp": jwt.TimeFunc().Add(time.Hour * 24).Unix(), // Token expiration
|
|
||||||
"iss": "Goma-Gateway", // Issuer claim
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new token with HS256 signing method
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
||||||
|
|
||||||
// Sign the token with a secret
|
|
||||||
signedToken, err := token.SignedString([]byte(jwtSecret))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return signedToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// healthCheckRoutes creates []Health
|
|
||||||
func healthCheckRoutes(routes []Route) []Health {
|
|
||||||
var healthRoutes []Health
|
|
||||||
for _, route := range routes {
|
|
||||||
if len(route.HealthCheck.Path) > 0 {
|
|
||||||
timeout, _ := util.ParseDuration("")
|
|
||||||
if len(route.HealthCheck.Timeout) > 0 {
|
|
||||||
d1, err1 := util.ParseDuration(route.HealthCheck.Timeout)
|
|
||||||
if err1 != nil {
|
|
||||||
logger.Error("Health check timeout is invalid: %s", route.HealthCheck.Timeout)
|
|
||||||
}
|
|
||||||
timeout = d1
|
|
||||||
}
|
|
||||||
if len(route.Backends) > 0 {
|
|
||||||
for index, backend := range route.Backends {
|
|
||||||
health := Health{
|
|
||||||
Name: fmt.Sprintf("%s - [%d]", route.Name, index),
|
|
||||||
URL: backend + route.HealthCheck.Path,
|
|
||||||
TimeOut: timeout,
|
|
||||||
HealthyStatuses: route.HealthCheck.HealthyStatuses,
|
|
||||||
InsecureSkipVerify: route.InsecureSkipVerify,
|
|
||||||
}
|
|
||||||
healthRoutes = append(healthRoutes, health)
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
health := Health{
|
|
||||||
Name: route.Name,
|
|
||||||
URL: route.Destination + route.HealthCheck.Path,
|
|
||||||
TimeOut: timeout,
|
|
||||||
HealthyStatuses: route.HealthCheck.HealthyStatuses,
|
|
||||||
InsecureSkipVerify: route.InsecureSkipVerify,
|
|
||||||
}
|
|
||||||
healthRoutes = append(healthRoutes, health)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Debug("Route %s's healthCheck is undefined", route.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return healthRoutes
|
|
||||||
}
|
|
||||||
|
|||||||
44
internal/jwt.go
Normal file
44
internal/jwt.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024 Jonas Kaninda
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package pkg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createJWT create JWT token
|
||||||
|
func createJWT(email, jwtSecret string) (string, error) {
|
||||||
|
// Define JWT claims
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"email": email,
|
||||||
|
"exp": jwt.TimeFunc().Add(time.Hour * 24).Unix(), // Token expiration
|
||||||
|
"iss": "Goma-Gateway", // Issuer claim
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new token with HS256 signing method
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
|
||||||
|
// Sign the token with a secret
|
||||||
|
signedToken, err := token.SignedString([]byte(jwtSecret))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return signedToken, nil
|
||||||
|
}
|
||||||
@@ -52,6 +52,7 @@ var HttpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
|||||||
Help: "Duration of HTTP requests.",
|
Help: "Duration of HTTP requests.",
|
||||||
}, []string{"name", "path"})
|
}, []string{"name", "path"})
|
||||||
|
|
||||||
|
// PrometheusMiddleware Prometheus http handler middleware, returns http.Handler
|
||||||
func (pr PrometheusRoute) PrometheusMiddleware(next http.Handler) http.Handler {
|
func (pr PrometheusRoute) PrometheusMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
path := pr.Path
|
path := pr.Path
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// getRealIP returns user real IP
|
||||||
func getRealIP(r *http.Request) string {
|
func getRealIP(r *http.Request) string {
|
||||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||||
return ip
|
return ip
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// redisRateLimiter, handle rateLimit
|
||||||
func redisRateLimiter(clientIP string, rate int) error {
|
func redisRateLimiter(clientIP string, rate int) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
|
|||||||
r.Host = targetURL.Host
|
r.Host = targetURL.Host
|
||||||
}
|
}
|
||||||
backendURL, _ := url.Parse(proxyRoute.destination)
|
backendURL, _ := url.Parse(proxyRoute.destination)
|
||||||
if len(proxyRoute.backends) > 0 {
|
if len(proxyRoute.backends) != 0 {
|
||||||
// Select the next backend URL
|
// Select the next backend URL
|
||||||
backendURL = getNextBackend(proxyRoute.backends)
|
backendURL = getNextBackend(proxyRoute.backends)
|
||||||
}
|
}
|
||||||
@@ -87,8 +87,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
|
|||||||
InsecureSkipVerify: proxyRoute.insecureSkipVerify,
|
InsecureSkipVerify: proxyRoute.insecureSkipVerify,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
w.Header().Set("Proxied-By", gatewayName) // Set Server name
|
w.Header().Set("Proxied-By", gatewayName)
|
||||||
w.Header().Del("Server") // Remove the Server header
|
|
||||||
// Custom error handler for proxy errors
|
// Custom error handler for proxy errors
|
||||||
proxy.ErrorHandler = ProxyErrorHandler
|
proxy.ErrorHandler = ProxyErrorHandler
|
||||||
proxy.ServeHTTP(w, r)
|
proxy.ServeHTTP(w, r)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// init initializes prometheus metrics
|
||||||
func init() {
|
func init() {
|
||||||
_ = prometheus.Register(metrics.TotalRequests)
|
_ = prometheus.Register(metrics.TotalRequests)
|
||||||
_ = prometheus.Register(metrics.ResponseStatus)
|
_ = prometheus.Register(metrics.ResponseStatus)
|
||||||
@@ -88,6 +89,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
|
|||||||
logger.Info("Block common exploits enabled")
|
logger.Info("Block common exploits enabled")
|
||||||
r.Use(middlewares.BlockExploitsMiddleware)
|
r.Use(middlewares.BlockExploitsMiddleware)
|
||||||
}
|
}
|
||||||
|
// check if RateLimit is set
|
||||||
if gateway.RateLimit != 0 {
|
if gateway.RateLimit != 0 {
|
||||||
// Add rate limit middlewares to all routes, if defined
|
// Add rate limit middlewares to all routes, if defined
|
||||||
rateLimit := middlewares.RateLimit{
|
rateLimit := middlewares.RateLimit{
|
||||||
@@ -20,6 +20,7 @@ package pkg
|
|||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/jkaninda/goma-gateway/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) {
|
func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) {
|
||||||
@@ -34,3 +35,19 @@ func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) {
|
|||||||
}
|
}
|
||||||
return tlsConfig, true, nil
|
return tlsConfig, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadTLS loads TLS Certificate
|
||||||
|
func loadTLS(cert, key string) (*tls.Config, error) {
|
||||||
|
if cert == "" && key == "" {
|
||||||
|
return nil, fmt.Errorf("no certificate or key file provided")
|
||||||
|
}
|
||||||
|
serverCert, err := tls.LoadX509KeyPair(cert, key)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error loading server certificate: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
}
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user