fix: fix authentification middlewares

This commit is contained in:
Jonas Kaninda
2024-11-24 22:13:26 +01:00
parent 6258b07c82
commit 3df8dce59b
7 changed files with 228 additions and 216 deletions

View File

@@ -53,6 +53,12 @@ func isPathBlocked(requestPath, blockedPath string) bool {
} }
return false return false
} }
func isProtectedPath(urlPath string, paths []string) bool {
for _, path := range paths {
return isPathBlocked(urlPath, util.ParseURLPath(path))
}
return false
}
// NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity // NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity
func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter { func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter {

View File

@@ -29,73 +29,75 @@ import (
// authorization based on the result of backend's response and continue the request when the client is authorized // authorization based on the result of backend's response and continue the request when the client is authorized
func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, header := range jwtAuth.RequiredHeaders { if isProtectedPath(r.URL.Path, jwtAuth.Paths) {
if r.Header.Get(header) == "" { for _, header := range jwtAuth.RequiredHeaders {
logger.Error("Proxy error, missing %s header", header) if r.Header.Get(header) == "" {
w.Header().Set("Content-Type", "application/json") logger.Error("Proxy error, missing %s header", header)
// check allowed origin w.Header().Set("Content-Type", "application/json")
if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { // check allowed origin
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
}
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
return
} }
}
authURL, err := url.Parse(jwtAuth.AuthURL)
if err != nil {
logger.Error("Error parsing auth URL: %v", err)
RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
return
}
// Create a new request for /authentication
authReq, err := http.NewRequest("GET", authURL.String(), nil)
if err != nil {
logger.Error("Proxy error creating authentication request: %v", err)
RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
return
}
logger.Trace("JWT Auth response headers: %v", authReq.Header)
// Copy headers from the original request to the new request
for name, values := range r.Header {
for _, value := range values {
authReq.Header.Set(name, value)
}
}
// Copy cookies from the original request to the new request
for _, cookie := range r.Cookies() {
authReq.AddCookie(cookie)
}
// Perform the request to the auth service
client := &http.Client{}
authResp, err := client.Do(authReq)
if err != nil || authResp.StatusCode != http.StatusOK {
logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent())
logger.Debug("Proxy authentication error")
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
return return
} }
} defer func(Body io.ReadCloser) {
authURL, err := url.Parse(jwtAuth.AuthURL) err := Body.Close()
if err != nil { if err != nil {
logger.Error("Error parsing auth URL: %v", err) logger.Error("Error closing body: %v", err)
RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) }
return }(authResp.Body)
} // Inject specific header tp the current request's header
// Create a new request for /authentication // Add header to the next request from AuthRequest header, depending on your requirements
authReq, err := http.NewRequest("GET", authURL.String(), nil) if jwtAuth.Headers != nil {
if err != nil { for k, v := range jwtAuth.Headers {
logger.Error("Proxy error creating authentication request: %v", err) r.Header.Set(v, authResp.Header.Get(k))
RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) }
return
}
logger.Trace("JWT Auth response headers: %v", authReq.Header)
// Copy headers from the original request to the new request
for name, values := range r.Header {
for _, value := range values {
authReq.Header.Set(name, value)
} }
} query := r.URL.Query()
// Copy cookies from the original request to the new request // Add query parameters to the next request from AuthRequest header, depending on your requirements
for _, cookie := range r.Cookies() { if jwtAuth.Params != nil {
authReq.AddCookie(cookie) for k, v := range jwtAuth.Params {
} query.Set(v, authResp.Header.Get(k))
// Perform the request to the auth service }
client := &http.Client{}
authResp, err := client.Do(authReq)
if err != nil || authResp.StatusCode != http.StatusOK {
logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent())
logger.Debug("Proxy authentication error")
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
return
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
logger.Error("Error closing body: %v", err)
}
}(authResp.Body)
// Inject specific header tp the current request's header
// Add header to the next request from AuthRequest header, depending on your requirements
if jwtAuth.Headers != nil {
for k, v := range jwtAuth.Headers {
r.Header.Set(v, authResp.Header.Get(k))
} }
r.URL.RawQuery = query.Encode()
} }
query := r.URL.Query()
// Add query parameters to the next request from AuthRequest header, depending on your requirements
if jwtAuth.Params != nil {
for k, v := range jwtAuth.Params {
query.Set(v, authResp.Header.Get(k))
}
}
r.URL.RawQuery = query.Encode()
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
@@ -105,36 +107,37 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Trace("Basic-Auth request headers: %v", r.Header) logger.Trace("Basic-Auth request headers: %v", r.Header)
// Get the Authorization header if isProtectedPath(r.URL.Path, basicAuth.Paths) {
authHeader := r.Header.Get("Authorization") // Get the Authorization header
if authHeader == "" { authHeader := r.Header.Get("Authorization")
logger.Debug("Proxy error, missing Authorization header") if authHeader == "" {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) logger.Debug("Proxy error, missing Authorization header")
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
return RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
} return
// Check if the Authorization header contains "Basic" scheme }
if !strings.HasPrefix(authHeader, "Basic ") { // Check if the Authorization header contains "Basic" scheme
logger.Error("Proxy error, missing Basic Authorization header") if !strings.HasPrefix(authHeader, "Basic ") {
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) logger.Error("Proxy error, missing Basic Authorization header")
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
return return
} }
// Decode the base64 encoded username:password string
payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):])
if err != nil {
logger.Debug("Proxy error, missing Basic Authorization header")
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
return
}
// Split the payload into username and password
pair := strings.SplitN(string(payload), ":", 2)
if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
return
}
// Decode the base64 encoded username:password string
payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):])
if err != nil {
logger.Debug("Proxy error, missing Basic Authorization header")
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
return
}
// Split the payload into username and password
pair := strings.SplitN(string(payload), ":", 2)
if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
return
} }
// Continue to the next handler if the authentication is successful // Continue to the next handler if the authentication is successful

View File

@@ -26,27 +26,29 @@ import (
func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
oauthConf := oauth2Config(oauth) if isProtectedPath(r.URL.Path, oauth.Paths) {
// Check if the user is authenticated oauthConf := oauth2Config(oauth)
token, err := r.Cookie("goma.oauth") // Check if the user is authenticated
if err != nil { token, err := r.Cookie("goma.oauth")
// If no token, redirect to OAuth provider if err != nil {
url := oauthConf.AuthCodeURL(oauth.State) // If no token, redirect to OAuth provider
http.Redirect(w, r, url, http.StatusTemporaryRedirect) url := oauthConf.AuthCodeURL(oauth.State)
return http.Redirect(w, r, url, http.StatusTemporaryRedirect)
} return
ok, err := validateJWT(token.Value, oauth) }
if err != nil { ok, err := validateJWT(token.Value, oauth)
// If no token, redirect to OAuth provider if err != nil {
url := oauthConf.AuthCodeURL(oauth.State) // If no token, redirect to OAuth provider
http.Redirect(w, r, url, http.StatusTemporaryRedirect) url := oauthConf.AuthCodeURL(oauth.State)
return http.Redirect(w, r, url, http.StatusTemporaryRedirect)
} return
if !ok { }
// If no token, redirect to OAuth provider if !ok {
url := oauthConf.AuthCodeURL(oauth.State) // If no token, redirect to OAuth provider
http.Redirect(w, r, url, http.StatusTemporaryRedirect) url := oauthConf.AuthCodeURL(oauth.State)
return http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return
}
} }
// Token exists, proceed with request // Token exists, proceed with request
next.ServeHTTP(w, r) next.ServeHTTP(w, r)

View File

@@ -79,6 +79,8 @@ type ProxyResponseError struct {
// JwtAuth stores JWT configuration // JwtAuth stores JWT configuration
type JwtAuth struct { type JwtAuth struct {
RoutePath string
Paths []string
AuthURL string AuthURL string
RequiredHeaders []string RequiredHeaders []string
Headers map[string]string Headers map[string]string
@@ -101,6 +103,7 @@ type AccessListMiddleware struct {
// AuthBasic contains Basic auth configuration // AuthBasic contains Basic auth configuration
type AuthBasic struct { type AuthBasic struct {
Paths []string
Username string Username string
Password string Password string
Headers map[string]string Headers map[string]string
@@ -120,6 +123,8 @@ type responseRecorder struct {
body *bytes.Buffer body *bytes.Buffer
} }
type Oauth struct { type Oauth struct {
// Route protected path
Paths []string
// ClientID is the application's ID. // ClientID is the application's ID.
ClientID string ClientID string
// ClientSecret is the application's secret. // ClientSecret is the application's secret.

View File

@@ -106,12 +106,25 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
r.Use(limiter.RateLimitMiddleware()) r.Use(limiter.RateLimitMiddleware())
} }
for rIndex, route := range dynamicRoutes { for rIndex, route := range dynamicRoutes {
// create route
router := r.PathPrefix(route.Path).Subrouter()
if len(route.Path) != 0 { if len(route.Path) != 0 {
// Checks if route destination and backend are empty // Checks if route destination and backend are empty
if len(route.Destination) == 0 && len(route.Backends) == 0 { if len(route.Destination) == 0 && len(route.Backends) == 0 {
logger.Fatal("Route %s : destination or backends should not be empty", route.Name) logger.Fatal("Route %s : destination or backends should not be empty", route.Name)
} }
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
backends: route.Backends,
methods: route.Methods,
disableHostFording: route.DisableHostFording,
cors: route.Cors,
insecureSkipVerify: route.InsecureSkipVerify,
}
// Apply middlewares to the route // Apply middlewares to the route
for _, middleware := range route.Middlewares { for _, middleware := range route.Middlewares {
if len(middleware) != 0 { if len(middleware) != 0 {
@@ -144,18 +157,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
logger.Error("Middleware ignored") logger.Error("Middleware ignored")
} }
} }
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
backends: route.Backends,
methods: route.Methods,
disableHostFording: route.DisableHostFording,
cors: route.Cors,
insecureSkipVerify: route.InsecureSkipVerify,
}
// create route
router := r.PathPrefix(route.Path).Subrouter()
// Apply common exploits to the route // Apply common exploits to the route
// Enable common exploits // Enable common exploits
if route.BlockCommonExploits { if route.BlockCommonExploits {
@@ -206,6 +208,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
router.Use(interceptErrors.ErrorInterceptor) router.Use(interceptErrors.ErrorInterceptor)
} }
//r.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
//r.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
} else { } else {
logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Error, path is empty in route %s", route.Name)
logger.Error("Route path ignored: %s", route.Path) logger.Error("Route path ignored: %s", route.Path)
@@ -221,6 +225,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
r.Use(interceptErrors.ErrorInterceptor) r.Use(interceptErrors.ErrorInterceptor)
} }
} }
return r return r
@@ -228,105 +233,88 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) { func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) {
for _, middlewarePath := range routeMiddleware.Paths { // Check Authentication middleware types
proxyRoute := ProxyRoute{ switch routeMiddleware.Type {
path: route.Path, case BasicAuth:
rewrite: route.Rewrite, basicAuth, err := getBasicAuthMiddleware(routeMiddleware.Rule)
destination: route.Destination, if err != nil {
backends: route.Backends, logger.Error("Error: %s", err.Error())
disableHostFording: route.DisableHostFording, } else {
methods: route.Methods, authBasic := middlewares.AuthBasic{
cors: route.Cors, Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths),
insecureSkipVerify: route.InsecureSkipVerify, Username: basicAuth.Username,
Password: basicAuth.Password,
Headers: nil,
Params: nil,
}
// Apply JWT authentication middlewares
r.Use(authBasic.AuthMiddleware)
r.Use(CORSHandler(route.Cors))
} }
secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, middlewarePath)).Subrouter() case JWTAuth:
// Check Authentication middleware types jwt, err := getJWTMiddleware(routeMiddleware.Rule)
switch routeMiddleware.Type { if err != nil {
case BasicAuth: logger.Error("Error: %s", err.Error())
basicAuth, err := getBasicAuthMiddleware(routeMiddleware.Rule) } else {
if err != nil { jwtAuth := middlewares.JwtAuth{
logger.Error("Error: %s", err.Error()) Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths),
} else { AuthURL: jwt.URL,
authBasic := middlewares.AuthBasic{ RequiredHeaders: jwt.RequiredHeaders,
Username: basicAuth.Username, Headers: jwt.Headers,
Password: basicAuth.Password, Params: jwt.Params,
Headers: nil, Origins: gateway.Cors.Origins,
Params: nil,
}
// Apply JWT authentication middlewares
secureRouter.Use(authBasic.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
} }
case JWTAuth: // Apply JWT authentication middlewares
jwt, err := getJWTMiddleware(routeMiddleware.Rule) r.Use(jwtAuth.AuthMiddleware)
if err != nil { r.Use(CORSHandler(route.Cors))
logger.Error("Error: %s", err.Error())
} else {
jwtAuth := middlewares.JwtAuth{
AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers,
Params: jwt.Params,
Origins: gateway.Cors.Origins,
}
// Apply JWT authentication middlewares
secureRouter.Use(jwtAuth.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
}
case OAuth:
oauth, err := oAuthMiddleware(routeMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
redirectURL := "/callback" + route.Path
if oauth.RedirectURL != "" {
redirectURL = oauth.RedirectURL
} }
case OAuth: amw := middlewares.Oauth{
oauth, err := oAuthMiddleware(routeMiddleware.Rule) Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths),
if err != nil { ClientID: oauth.ClientID,
logger.Error("Error: %s", err.Error()) ClientSecret: oauth.ClientSecret,
} else { RedirectURL: redirectURL,
redirectURL := "/callback" + route.Path Scopes: oauth.Scopes,
if oauth.RedirectURL != "" { Endpoint: middlewares.OauthEndpoint{
redirectURL = oauth.RedirectURL AuthURL: oauth.Endpoint.AuthURL,
} TokenURL: oauth.Endpoint.TokenURL,
amw := middlewares.Oauth{ UserInfoURL: oauth.Endpoint.UserInfoURL,
ClientID: oauth.ClientID, },
ClientSecret: oauth.ClientSecret, State: oauth.State,
RedirectURL: redirectURL, Origins: gateway.Cors.Origins,
Scopes: oauth.Scopes, JWTSecret: oauth.JWTSecret,
Endpoint: middlewares.OauthEndpoint{ Provider: oauth.Provider,
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
UserInfoURL: oauth.Endpoint.UserInfoURL,
},
State: oauth.State,
Origins: gateway.Cors.Origins,
JWTSecret: oauth.JWTSecret,
Provider: oauth.Provider,
}
oauthRuler := oauthRulerMiddleware(amw)
// Check if a cookie path is defined
if oauthRuler.CookiePath == "" {
oauthRuler.CookiePath = route.Path
}
// Check if a RedirectPath is defined
if oauthRuler.RedirectPath == "" {
oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, middlewarePath)
}
if oauthRuler.Provider == "" {
oauthRuler.Provider = "custom"
}
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
// Callback route
r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET")
} }
default: oauthRuler := oauthRulerMiddleware(amw)
if !doesExist(routeMiddleware.Type) { // Check if a cookie path is defined
logger.Error("Unknown middlewares type %s", routeMiddleware.Type) if oauthRuler.CookiePath == "" {
oauthRuler.CookiePath = route.Path
} }
// Check if a RedirectPath is defined
if oauthRuler.RedirectPath == "" {
oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, routeMiddleware.Paths[0])
}
if oauthRuler.Provider == "" {
oauthRuler.Provider = "custom"
}
r.Use(amw.AuthMiddleware)
r.Use(CORSHandler(route.Cors))
r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET")
}
default:
if !doesExist(routeMiddleware.Type) {
logger.Error("Unknown middlewares type %s", routeMiddleware.Type)
} }
} }
} }

View File

@@ -30,7 +30,7 @@ import (
// Start / Start starts the server // Start / Start starts the server
func (gatewayServer GatewayServer) Start() error { func (gatewayServer GatewayServer) Start() error {
logger.Info("Initializing routes...") logger.Info("Initializing routes...")
route := gatewayServer.Initialize() router := gatewayServer.Initialize()
logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares))
gatewayServer.initRedis() gatewayServer.initRedis()
defer gatewayServer.closeRedis() defer gatewayServer.closeRedis()
@@ -44,8 +44,8 @@ func (gatewayServer GatewayServer) Start() error {
printRoute(dynamicRoutes) printRoute(dynamicRoutes)
} }
httpServer := gatewayServer.createServer(":8080", route, nil) httpServer := gatewayServer.createServer(":8080", router, nil)
httpsServer := gatewayServer.createServer(":8443", route, tlsConfig) httpsServer := gatewayServer.createServer(":8443", router, tlsConfig)
// Start HTTP/HTTPS servers // Start HTTP/HTTPS servers
gatewayServer.startServers(httpServer, httpsServer, listenWithTLS) gatewayServer.startServers(httpServer, httpsServer, listenWithTLS)

View File

@@ -157,3 +157,11 @@ func Slug(text string) string {
return text return text
} }
func AddPrefixPath(prefix string, paths []string) []string {
for i := range paths {
paths[i] = ParseURLPath(prefix + paths[i])
}
return paths
}