Merge pull request #81 from jkaninda/feature/loadbalancer

Feature/loadbalancer
This commit is contained in:
2024-11-10 17:16:21 +01:00
committed by GitHub
8 changed files with 49 additions and 11 deletions

View File

@@ -118,7 +118,7 @@ func initConfig(configFile string) {
configFile = GetConfigPaths() configFile = GetConfigPaths()
} }
conf := &GatewayConfig{ conf := &GatewayConfig{
Version: util.Version, Version: util.ConfigVersion,
GatewayConfig: Gateway{ GatewayConfig: Gateway{
WriteTimeout: 15, WriteTimeout: 15,
ReadTimeout: 15, ReadTimeout: 15,
@@ -165,7 +165,7 @@ func initConfig(configFile string) {
}, },
{ {
Name: "Hostname example", Name: "Hostname example",
Host: "http://example.localhost", Hosts: []string{"example.com", "example.localhost"},
Path: "/", Path: "/",
Destination: "https://example.com", Destination: "https://example.com",
Rewrite: "/", Rewrite: "/",

View File

@@ -73,7 +73,7 @@ func (heathRoute HealthCheckRoute) HealthCheckHandler(w http.ResponseWriter, r *
go func() { go func() {
defer wg.Done() defer wg.Done()
if route.HealthCheck != "" { if route.HealthCheck != "" {
err := HealthCheck(route.Destination + route.HealthCheck) err := healthCheck(route.Destination + route.HealthCheck)
if err != nil { if err != nil {
if heathRoute.DisableRouteHealthCheckError { if heathRoute.DisableRouteHealthCheckError {
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "unhealthy", Error: "Route healthcheck errors disabled"}) routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "unhealthy", Error: "Route healthcheck errors disabled"})

View File

@@ -23,7 +23,7 @@ import (
"net/url" "net/url"
) )
func HealthCheck(healthURL string) error { func healthCheck(healthURL string) error {
healthCheckURL, err := url.Parse(healthURL) healthCheckURL, err := url.Parse(healthURL)
if err != nil { if err != nil {
return fmt.Errorf("error parsing HealthCheck URL: %v ", err) return fmt.Errorf("error parsing HealthCheck URL: %v ", err)

View File

@@ -23,6 +23,7 @@ import (
"net/url" "net/url"
"slices" "slices"
"strings" "strings"
"sync/atomic"
) )
// ProxyHandler proxies requests to the backend // ProxyHandler proxies requests to the backend
@@ -76,8 +77,13 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
r.Header.Set("X-Real-IP", getRealIP(r)) r.Header.Set("X-Real-IP", getRealIP(r))
r.Host = targetURL.Host r.Host = targetURL.Host
} }
backendURL, _ := url.Parse(proxyRoute.destination)
if len(proxyRoute.backends) > 0 {
// Select the next backend URL
backendURL = getNextBackend(proxyRoute.backends)
}
// Create proxy // Create proxy
proxy := httputil.NewSingleHostReverseProxy(targetURL) proxy := httputil.NewSingleHostReverseProxy(backendURL)
// Rewrite // Rewrite
if proxyRoute.path != "" && proxyRoute.rewrite != "" { if proxyRoute.path != "" && proxyRoute.rewrite != "" {
// Rewrite the path // Rewrite the path
@@ -92,3 +98,10 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
proxy.ServeHTTP(w, r) proxy.ServeHTTP(w, r)
} }
} }
// getNextBackend selects the next backend in a round-robin fashion
func getNextBackend(backendURLs []string) *url.URL {
idx := atomic.AddUint32(&counter, 1) % uint32(len(backendURLs))
backendURL, _ := url.Parse(backendURLs[idx])
return backendURL
}

View File

@@ -53,7 +53,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
for _, route := range gateway.Routes { for _, route := range gateway.Routes {
if route.Path != "" { if route.Path != "" {
if route.Destination == "" && len(route.Backends) == 0 {
logger.Fatal("Route %s : destination or backends should not be empty", route.Name)
}
// Apply middlewares to route // Apply middlewares to route
for _, mid := range route.Middlewares { for _, mid := range route.Middlewares {
if mid != "" { if mid != "" {
@@ -84,6 +87,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
path: route.Path, path: route.Path,
rewrite: route.Rewrite, rewrite: route.Rewrite,
destination: route.Destination, destination: route.Destination,
backends: route.Backends,
disableXForward: route.DisableHeaderXForward, disableXForward: route.DisableHeaderXForward,
methods: route.Methods, methods: route.Methods,
cors: route.Cors, cors: route.Cors,
@@ -190,6 +194,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
path: route.Path, path: route.Path,
rewrite: route.Rewrite, rewrite: route.Rewrite,
destination: route.Destination, destination: route.Destination,
backends: route.Backends,
methods: route.Methods, methods: route.Methods,
disableXForward: route.DisableHeaderXForward, disableXForward: route.DisableHeaderXForward,
cors: route.Cors, cors: route.Cors,
@@ -197,14 +202,16 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
router := r.PathPrefix(route.Path).Subrouter() router := r.PathPrefix(route.Path).Subrouter()
// Apply route Cors // Apply route Cors
router.Use(CORSHandler(route.Cors)) router.Use(CORSHandler(route.Cors))
if route.Host != "" { if len(route.Hosts) > 0 {
router.Host(route.Host).PathPrefix("").Handler(proxyRoute.ProxyHandler()) for _, host := range route.Hosts {
router.Host(host).PathPrefix("").Handler(proxyRoute.ProxyHandler())
}
} else { } else {
router.PathPrefix("").Handler(proxyRoute.ProxyHandler()) router.PathPrefix("").Handler(proxyRoute.ProxyHandler())
} }
} else { } else {
logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Error, path is empty in route %s", route.Name)
logger.Debug("Route path ignored: %s", route.Path) logger.Error("Route path ignored: %s", route.Path)
} }
} }
// Apply global Cors middlewares // Apply global Cors middlewares

View File

@@ -131,18 +131,22 @@ type MiddlewareName struct {
// Route defines gateway route // Route defines gateway route
type Route struct { type Route struct {
// Path defines route path
Path string `yaml:"path"`
// Name defines route name // Name defines route name
Name string `yaml:"name"` Name string `yaml:"name"`
//Host Domain/host based request routing //Host Domain/host based request routing
Host string `yaml:"host"` //Host string `yaml:"host"`
// Path defines route path //Hosts Domains/hosts based request routing
Path string `yaml:"path"` Hosts []string `yaml:"hosts"`
// Rewrite rewrites route path to desired path // Rewrite rewrites route path to desired path
// //
// E.g. /cart to / => It will rewrite /cart path to / // E.g. /cart to / => It will rewrite /cart path to /
Rewrite string `yaml:"rewrite"` Rewrite string `yaml:"rewrite"`
// Destination Defines backend URL // Destination Defines backend URL
Destination string `yaml:"destination"` Destination string `yaml:"destination"`
//
Backends []string `yaml:"backends"`
// Cors contains the route cors headers // Cors contains the route cors headers
Cors Cors `yaml:"cors"` Cors Cors `yaml:"cors"`
//RateLimit int `yaml:"rateLimit"` //RateLimit int `yaml:"rateLimit"`
@@ -197,6 +201,14 @@ type Gateway struct {
// Routes holds proxy routes // Routes holds proxy routes
Routes []Route `yaml:"routes"` Routes []Route `yaml:"routes"`
} }
type RouteHealthCheck struct {
Path string `yaml:"path"`
Interval int `yaml:"interval"`
Timeout int `yaml:"timeout"`
HealthyStatuses []int `yaml:"healthyStatuses"`
UnhealthyStatuses []int `yaml:"unhealthyStatuses"`
}
type GatewayConfig struct { type GatewayConfig struct {
Version string `yaml:"version"` Version string `yaml:"version"`
// GatewayConfig holds Gateway config // GatewayConfig holds Gateway config
@@ -220,6 +232,8 @@ type ProxyRoute struct {
path string path string
rewrite string rewrite string
destination string destination string
backends []string
healthCheck RouteHealthCheck
methods []string methods []string
cors Cors cors Cors
disableXForward bool disableXForward bool

View File

@@ -9,3 +9,5 @@ const AccessMiddleware = "access" // access middleware
const BasicAuth = "basic" // basic authentication middleware const BasicAuth = "basic" // basic authentication middleware
const JWTAuth = "jwt" // JWT authentication middleware const JWTAuth = "jwt" // JWT authentication middleware
const OAuth = "oauth" // OAuth authentication middleware const OAuth = "oauth" // OAuth authentication middleware
// Round-robin counter
var counter uint32

View File

@@ -15,6 +15,8 @@ import (
var Version string var Version string
const ConfigVersion = "1.0"
func VERSION(def string) string { func VERSION(def string) string {
build := os.Getenv("VERSION") build := os.Getenv("VERSION")
if build == "" { if build == "" {