diff --git a/internal/middlewares/access_middleware.go b/internal/middlewares/access_middleware.go index ca26295..5b18181 100644 --- a/internal/middlewares/access_middleware.go +++ b/internal/middlewares/access_middleware.go @@ -53,6 +53,12 @@ func isPathBlocked(requestPath, blockedPath string) bool { } return false } +func isProtectedPath(urlPath string, paths []string) bool { + for _, path := range paths { + return isPathBlocked(urlPath, util.ParseURLPath(path)) + } + return false +} // NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter { diff --git a/internal/middlewares/middleware.go b/internal/middlewares/middleware.go index f8fc1b8..da52614 100644 --- a/internal/middlewares/middleware.go +++ b/internal/middlewares/middleware.go @@ -29,73 +29,75 @@ import ( // authorization based on the result of backend's response and continue the request when the client is authorized func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for _, header := range jwtAuth.RequiredHeaders { - if r.Header.Get(header) == "" { - logger.Error("Proxy error, missing %s header", header) - w.Header().Set("Content-Type", "application/json") - // check allowed origin - if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + if isProtectedPath(r.URL.Path, jwtAuth.Paths) { + for _, header := range jwtAuth.RequiredHeaders { + if r.Header.Get(header) == "" { + logger.Error("Proxy error, missing %s header", header) + w.Header().Set("Content-Type", "application/json") + // check allowed origin + if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } + } + authURL, err := url.Parse(jwtAuth.AuthURL) + if err != nil { + logger.Error("Error parsing auth URL: %v", err) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + return + } + // Create a new request for /authentication + authReq, err := http.NewRequest("GET", authURL.String(), nil) + if err != nil { + logger.Error("Proxy error creating authentication request: %v", err) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + return + } + logger.Trace("JWT Auth response headers: %v", authReq.Header) + // Copy headers from the original request to the new request + for name, values := range r.Header { + for _, value := range values { + authReq.Header.Set(name, value) + } + } + // Copy cookies from the original request to the new request + for _, cookie := range r.Cookies() { + authReq.AddCookie(cookie) + } + // Perform the request to the auth service + client := &http.Client{} + authResp, err := client.Do(authReq) + if err != nil || authResp.StatusCode != http.StatusOK { + logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) + logger.Debug("Proxy authentication error") RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return - } - } - authURL, err := url.Parse(jwtAuth.AuthURL) - if err != nil { - logger.Error("Error parsing auth URL: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) - return - } - // Create a new request for /authentication - authReq, err := http.NewRequest("GET", authURL.String(), nil) - if err != nil { - logger.Error("Proxy error creating authentication request: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) - return - } - logger.Trace("JWT Auth response headers: %v", authReq.Header) - // Copy headers from the original request to the new request - for name, values := range r.Header { - for _, value := range values { - authReq.Header.Set(name, value) + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + logger.Error("Error closing body: %v", err) + } + }(authResp.Body) + // Inject specific header tp the current request's header + // Add header to the next request from AuthRequest header, depending on your requirements + if jwtAuth.Headers != nil { + for k, v := range jwtAuth.Headers { + r.Header.Set(v, authResp.Header.Get(k)) + } } - } - // Copy cookies from the original request to the new request - for _, cookie := range r.Cookies() { - authReq.AddCookie(cookie) - } - // Perform the request to the auth service - client := &http.Client{} - authResp, err := client.Do(authReq) - if err != nil || authResp.StatusCode != http.StatusOK { - logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) - logger.Debug("Proxy authentication error") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - logger.Error("Error closing body: %v", err) - } - }(authResp.Body) - // Inject specific header tp the current request's header - // Add header to the next request from AuthRequest header, depending on your requirements - if jwtAuth.Headers != nil { - for k, v := range jwtAuth.Headers { - r.Header.Set(v, authResp.Header.Get(k)) + query := r.URL.Query() + // Add query parameters to the next request from AuthRequest header, depending on your requirements + if jwtAuth.Params != nil { + for k, v := range jwtAuth.Params { + query.Set(v, authResp.Header.Get(k)) + } } + r.URL.RawQuery = query.Encode() } - query := r.URL.Query() - // Add query parameters to the next request from AuthRequest header, depending on your requirements - if jwtAuth.Params != nil { - for k, v := range jwtAuth.Params { - query.Set(v, authResp.Header.Get(k)) - } - } - r.URL.RawQuery = query.Encode() next.ServeHTTP(w, r) }) @@ -105,36 +107,37 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Trace("Basic-Auth request headers: %v", r.Header) - // Get the Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - logger.Debug("Proxy error, missing Authorization header") - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } - // Check if the Authorization header contains "Basic" scheme - if !strings.HasPrefix(authHeader, "Basic ") { - logger.Error("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + if isProtectedPath(r.URL.Path, basicAuth.Paths) { + // Get the Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + logger.Debug("Proxy error, missing Authorization header") + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } + // Check if the Authorization header contains "Basic" scheme + if !strings.HasPrefix(authHeader, "Basic ") { + logger.Error("Proxy error, missing Basic Authorization header") + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } + return + } + // Decode the base64 encoded username:password string + payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) + if err != nil { + logger.Debug("Proxy error, missing Basic Authorization header") + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } + // Split the payload into username and password + pair := strings.SplitN(string(payload), ":", 2) + if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } - // Decode the base64 encoded username:password string - payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) - if err != nil { - logger.Debug("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } - - // Split the payload into username and password - pair := strings.SplitN(string(payload), ":", 2) - if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return } // Continue to the next handler if the authentication is successful diff --git a/internal/middlewares/oauth_middleware.go b/internal/middlewares/oauth_middleware.go index f2d7407..74b4089 100644 --- a/internal/middlewares/oauth_middleware.go +++ b/internal/middlewares/oauth_middleware.go @@ -26,27 +26,29 @@ import ( func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - oauthConf := oauth2Config(oauth) - // Check if the user is authenticated - token, err := r.Cookie("goma.oauth") - if err != nil { - // If no token, redirect to OAuth provider - url := oauthConf.AuthCodeURL(oauth.State) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return - } - ok, err := validateJWT(token.Value, oauth) - if err != nil { - // If no token, redirect to OAuth provider - url := oauthConf.AuthCodeURL(oauth.State) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return - } - if !ok { - // If no token, redirect to OAuth provider - url := oauthConf.AuthCodeURL(oauth.State) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return + if isProtectedPath(r.URL.Path, oauth.Paths) { + oauthConf := oauth2Config(oauth) + // Check if the user is authenticated + token, err := r.Cookie("goma.oauth") + if err != nil { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + ok, err := validateJWT(token.Value, oauth) + if err != nil { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + if !ok { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } } // Token exists, proceed with request next.ServeHTTP(w, r) diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index bedf131..59d78d8 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -79,6 +79,8 @@ type ProxyResponseError struct { // JwtAuth stores JWT configuration type JwtAuth struct { + RoutePath string + Paths []string AuthURL string RequiredHeaders []string Headers map[string]string @@ -101,6 +103,7 @@ type AccessListMiddleware struct { // AuthBasic contains Basic auth configuration type AuthBasic struct { + Paths []string Username string Password string Headers map[string]string @@ -120,6 +123,8 @@ type responseRecorder struct { body *bytes.Buffer } type Oauth struct { + // Route protected path + Paths []string // ClientID is the application's ID. ClientID string // ClientSecret is the application's secret. diff --git a/internal/routes.go b/internal/routes.go index 6dd0ee9..8f7defd 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -106,12 +106,25 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { r.Use(limiter.RateLimitMiddleware()) } for rIndex, route := range dynamicRoutes { + + // create route + router := r.PathPrefix(route.Path).Subrouter() if len(route.Path) != 0 { // Checks if route destination and backend are empty if len(route.Destination) == 0 && len(route.Backends) == 0 { logger.Fatal("Route %s : destination or backends should not be empty", route.Name) } + proxyRoute := ProxyRoute{ + path: route.Path, + rewrite: route.Rewrite, + destination: route.Destination, + backends: route.Backends, + methods: route.Methods, + disableHostFording: route.DisableHostFording, + cors: route.Cors, + insecureSkipVerify: route.InsecureSkipVerify, + } // Apply middlewares to the route for _, middleware := range route.Middlewares { if len(middleware) != 0 { @@ -144,18 +157,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { logger.Error("Middleware ignored") } } - proxyRoute := ProxyRoute{ - path: route.Path, - rewrite: route.Rewrite, - destination: route.Destination, - backends: route.Backends, - methods: route.Methods, - disableHostFording: route.DisableHostFording, - cors: route.Cors, - insecureSkipVerify: route.InsecureSkipVerify, - } - // create route - router := r.PathPrefix(route.Path).Subrouter() + // Apply common exploits to the route // Enable common exploits if route.BlockCommonExploits { @@ -206,6 +208,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } router.Use(interceptErrors.ErrorInterceptor) } + //r.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler + //r.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) @@ -221,6 +225,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } r.Use(interceptErrors.ErrorInterceptor) } + } return r @@ -228,105 +233,88 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) { - for _, middlewarePath := range routeMiddleware.Paths { - proxyRoute := ProxyRoute{ - path: route.Path, - rewrite: route.Rewrite, - destination: route.Destination, - backends: route.Backends, - disableHostFording: route.DisableHostFording, - methods: route.Methods, - cors: route.Cors, - insecureSkipVerify: route.InsecureSkipVerify, + // Check Authentication middleware types + switch routeMiddleware.Type { + case BasicAuth: + basicAuth, err := getBasicAuthMiddleware(routeMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + authBasic := middlewares.AuthBasic{ + Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), + Username: basicAuth.Username, + Password: basicAuth.Password, + Headers: nil, + Params: nil, + } + // Apply JWT authentication middlewares + r.Use(authBasic.AuthMiddleware) + r.Use(CORSHandler(route.Cors)) } - secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, middlewarePath)).Subrouter() - // Check Authentication middleware types - switch routeMiddleware.Type { - case BasicAuth: - basicAuth, err := getBasicAuthMiddleware(routeMiddleware.Rule) - if err != nil { - logger.Error("Error: %s", err.Error()) - } else { - authBasic := middlewares.AuthBasic{ - Username: basicAuth.Username, - Password: basicAuth.Password, - Headers: nil, - Params: nil, - } - // Apply JWT authentication middlewares - secureRouter.Use(authBasic.AuthMiddleware) - secureRouter.Use(CORSHandler(route.Cors)) - secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler + case JWTAuth: + jwt, err := getJWTMiddleware(routeMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + jwtAuth := middlewares.JwtAuth{ + Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), + AuthURL: jwt.URL, + RequiredHeaders: jwt.RequiredHeaders, + Headers: jwt.Headers, + Params: jwt.Params, + Origins: gateway.Cors.Origins, } - case JWTAuth: - jwt, err := getJWTMiddleware(routeMiddleware.Rule) - if err != nil { - logger.Error("Error: %s", err.Error()) - } else { - jwtAuth := middlewares.JwtAuth{ - AuthURL: jwt.URL, - RequiredHeaders: jwt.RequiredHeaders, - Headers: jwt.Headers, - Params: jwt.Params, - Origins: gateway.Cors.Origins, - } - // Apply JWT authentication middlewares - secureRouter.Use(jwtAuth.AuthMiddleware) - secureRouter.Use(CORSHandler(route.Cors)) - secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler + // Apply JWT authentication middlewares + r.Use(jwtAuth.AuthMiddleware) + r.Use(CORSHandler(route.Cors)) + } + case OAuth: + oauth, err := oAuthMiddleware(routeMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + redirectURL := "/callback" + route.Path + if oauth.RedirectURL != "" { + redirectURL = oauth.RedirectURL } - case OAuth: - oauth, err := oAuthMiddleware(routeMiddleware.Rule) - if err != nil { - logger.Error("Error: %s", err.Error()) - } else { - redirectURL := "/callback" + route.Path - if oauth.RedirectURL != "" { - redirectURL = oauth.RedirectURL - } - amw := middlewares.Oauth{ - ClientID: oauth.ClientID, - ClientSecret: oauth.ClientSecret, - RedirectURL: redirectURL, - Scopes: oauth.Scopes, - Endpoint: middlewares.OauthEndpoint{ - AuthURL: oauth.Endpoint.AuthURL, - TokenURL: oauth.Endpoint.TokenURL, - UserInfoURL: oauth.Endpoint.UserInfoURL, - }, - State: oauth.State, - Origins: gateway.Cors.Origins, - JWTSecret: oauth.JWTSecret, - Provider: oauth.Provider, - } - oauthRuler := oauthRulerMiddleware(amw) - // Check if a cookie path is defined - if oauthRuler.CookiePath == "" { - oauthRuler.CookiePath = route.Path - } - // Check if a RedirectPath is defined - if oauthRuler.RedirectPath == "" { - oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, middlewarePath) - } - if oauthRuler.Provider == "" { - oauthRuler.Provider = "custom" - } - secureRouter.Use(amw.AuthMiddleware) - secureRouter.Use(CORSHandler(route.Cors)) - secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler - // Callback route - r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET") + amw := middlewares.Oauth{ + Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), + ClientID: oauth.ClientID, + ClientSecret: oauth.ClientSecret, + RedirectURL: redirectURL, + Scopes: oauth.Scopes, + Endpoint: middlewares.OauthEndpoint{ + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + UserInfoURL: oauth.Endpoint.UserInfoURL, + }, + State: oauth.State, + Origins: gateway.Cors.Origins, + JWTSecret: oauth.JWTSecret, + Provider: oauth.Provider, } - default: - if !doesExist(routeMiddleware.Type) { - logger.Error("Unknown middlewares type %s", routeMiddleware.Type) + oauthRuler := oauthRulerMiddleware(amw) + // Check if a cookie path is defined + if oauthRuler.CookiePath == "" { + oauthRuler.CookiePath = route.Path } - + // Check if a RedirectPath is defined + if oauthRuler.RedirectPath == "" { + oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, routeMiddleware.Paths[0]) + } + if oauthRuler.Provider == "" { + oauthRuler.Provider = "custom" + } + r.Use(amw.AuthMiddleware) + r.Use(CORSHandler(route.Cors)) + r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET") + } + default: + if !doesExist(routeMiddleware.Type) { + logger.Error("Unknown middlewares type %s", routeMiddleware.Type) } } + } diff --git a/internal/server.go b/internal/server.go index 340b177..a9be5ea 100644 --- a/internal/server.go +++ b/internal/server.go @@ -30,7 +30,7 @@ import ( // Start / Start starts the server func (gatewayServer GatewayServer) Start() error { logger.Info("Initializing routes...") - route := gatewayServer.Initialize() + router := gatewayServer.Initialize() logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) gatewayServer.initRedis() defer gatewayServer.closeRedis() @@ -44,8 +44,8 @@ func (gatewayServer GatewayServer) Start() error { printRoute(dynamicRoutes) } - httpServer := gatewayServer.createServer(":8080", route, nil) - httpsServer := gatewayServer.createServer(":8443", route, tlsConfig) + httpServer := gatewayServer.createServer(":8080", router, nil) + httpsServer := gatewayServer.createServer(":8443", router, tlsConfig) // Start HTTP/HTTPS servers gatewayServer.startServers(httpServer, httpsServer, listenWithTLS) diff --git a/util/helpers.go b/util/helpers.go index 3248312..a3bc174 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -157,3 +157,11 @@ func Slug(text string) string { return text } + +func AddPrefixPath(prefix string, paths []string) []string { + for i := range paths { + paths[i] = ParseURLPath(prefix + paths[i]) + } + return paths + +}