refactor: refactoring of code
Add graceful shutdown server
This commit is contained in:
89
internal/middlewares/access-middleware.go
Normal file
89
internal/middlewares/access-middleware.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package middlewares
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/jkaninda/goma-gateway/pkg/logger"
|
||||
"github.com/jkaninda/goma-gateway/util"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccessMiddleware checks if the request path is forbidden and returns 403 Forbidden
|
||||
func (blockList AccessListMiddleware) AccessMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, block := range blockList.List {
|
||||
if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) {
|
||||
logger.Error("%s: %s access forbidden", getRealIP(r), r.URL.Path)
|
||||
RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource", http.StatusForbidden))
|
||||
return
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to determine if the request path is blocked
|
||||
func isPathBlocked(requestPath, blockedPath string) bool {
|
||||
// Handle exact match
|
||||
if requestPath == blockedPath {
|
||||
return true
|
||||
}
|
||||
// Handle wildcard match (e.g., /admin/* should block /admin and any subpath)
|
||||
if strings.HasSuffix(blockedPath, "/*") {
|
||||
basePath := strings.TrimSuffix(blockedPath, "/*")
|
||||
if strings.HasPrefix(requestPath, basePath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity
|
||||
func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter {
|
||||
return &TokenRateLimiter{
|
||||
tokens: maxTokens,
|
||||
maxTokens: maxTokens,
|
||||
refillRate: refillRate,
|
||||
lastRefill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request is allowed based on the current token bucket
|
||||
func (rl *TokenRateLimiter) Allow() bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
// Refill tokens based on the time elapsed
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(rl.lastRefill)
|
||||
tokensToAdd := int(elapsed / rl.refillRate)
|
||||
if tokensToAdd > 0 {
|
||||
rl.tokens = min(rl.maxTokens, rl.tokens+tokensToAdd)
|
||||
rl.lastRefill = now
|
||||
}
|
||||
|
||||
// Check if there are enough tokens to allow the request
|
||||
if rl.tokens > 0 {
|
||||
rl.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// Reject request if no tokens are available
|
||||
return false
|
||||
}
|
||||
66
internal/middlewares/block-common-exploits.go
Normal file
66
internal/middlewares/block-common-exploits.go
Normal file
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Copyright 2024 Jonas Kaninda
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/jkaninda/goma-gateway/pkg/logger"
|
||||
"net/http"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// BlockExploitsMiddleware Middleware to block common exploits
|
||||
func BlockExploitsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Patterns to detect SQL injection attempts
|
||||
sqlInjectionPattern := regexp.MustCompile(sqlPatterns)
|
||||
|
||||
// Pattern to detect path traversal attempts
|
||||
pathTraversalPattern := regexp.MustCompile(traversalPatterns)
|
||||
|
||||
// Pattern to detect simple XSS attempts
|
||||
xssPattern := regexp.MustCompile(xssPatterns)
|
||||
|
||||
// Check query strings
|
||||
if sqlInjectionPattern.MatchString(r.URL.RawQuery) ||
|
||||
pathTraversalPattern.MatchString(r.URL.Path) ||
|
||||
xssPattern.MatchString(r.URL.RawQuery) {
|
||||
logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path)
|
||||
RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden", http.StatusForbidden))
|
||||
return
|
||||
|
||||
}
|
||||
// Check form data (for POST requests)
|
||||
if r.Method == http.MethodPost {
|
||||
if err := r.ParseForm(); err == nil {
|
||||
for _, values := range r.Form {
|
||||
for _, value := range values {
|
||||
if sqlInjectionPattern.MatchString(value) || xssPattern.MatchString(value) {
|
||||
logger.Error("%s: %s %s Forbidden - Potential exploit detected", getRealIP(r), r.Method, r.URL.Path)
|
||||
RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden", http.StatusForbidden))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass to the next handler if no exploit patterns were detected
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
59
internal/middlewares/config.go
Normal file
59
internal/middlewares/config.go
Normal file
@@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Copyright 2024 Jonas Kaninda
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"github.com/jkaninda/goma-gateway/pkg/logger"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/amazon"
|
||||
"golang.org/x/oauth2/facebook"
|
||||
"golang.org/x/oauth2/github"
|
||||
"golang.org/x/oauth2/gitlab"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
func oauth2Config(oauth Oauth) *oauth2.Config {
|
||||
config := &oauth2.Config{
|
||||
ClientID: oauth.ClientID,
|
||||
ClientSecret: oauth.ClientSecret,
|
||||
RedirectURL: oauth.RedirectURL,
|
||||
Scopes: oauth.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: oauth.Endpoint.AuthURL,
|
||||
TokenURL: oauth.Endpoint.TokenURL,
|
||||
},
|
||||
}
|
||||
switch oauth.Provider {
|
||||
case "google":
|
||||
config.Endpoint = google.Endpoint
|
||||
case "amazon":
|
||||
config.Endpoint = amazon.Endpoint
|
||||
case "facebook":
|
||||
config.Endpoint = facebook.Endpoint
|
||||
case "github":
|
||||
config.Endpoint = github.Endpoint
|
||||
case "gitlab":
|
||||
config.Endpoint = gitlab.Endpoint
|
||||
default:
|
||||
if oauth.Provider != "custom" {
|
||||
logger.Error("Unknown provider: %s", oauth.Provider)
|
||||
}
|
||||
|
||||
}
|
||||
return config
|
||||
}
|
||||
74
internal/middlewares/error-interceptor.go
Normal file
74
internal/middlewares/error-interceptor.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package middlewares
|
||||
|
||||
/*
|
||||
* Copyright 2024 Jonas Kaninda
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/jkaninda/goma-gateway/pkg/logger"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
)
|
||||
|
||||
func newResponseRecorder(w http.ResponseWriter) *responseRecorder {
|
||||
return &responseRecorder{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
body: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *responseRecorder) WriteHeader(code int) {
|
||||
rec.statusCode = code
|
||||
}
|
||||
|
||||
func (rec *responseRecorder) Write(data []byte) (int, error) {
|
||||
return rec.body.Write(data)
|
||||
}
|
||||
|
||||
// ErrorInterceptor Middleware intercepts backend errors
|
||||
func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if the connection is a WebSocket
|
||||
if isWebSocketRequest(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
rec := newResponseRecorder(w)
|
||||
next.ServeHTTP(rec, r)
|
||||
if canIntercept(rec.statusCode, intercept.Errors) {
|
||||
logger.Error("Request to %s resulted in error with status code %d\n", r.URL.Path, rec.statusCode)
|
||||
RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode))
|
||||
return
|
||||
} else {
|
||||
// No error: write buffered response to client
|
||||
w.WriteHeader(rec.statusCode)
|
||||
_, err := io.Copy(w, rec.body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
func isWebSocketRequest(r *http.Request) bool {
|
||||
return r.Header.Get("Upgrade") == "websocket" && r.Header.Get("Connection") == "Upgrade"
|
||||
}
|
||||
func canIntercept(code int, errors []int) bool {
|
||||
return slices.Contains(errors, code)
|
||||
}
|
||||
56
internal/middlewares/helpers.go
Normal file
56
internal/middlewares/helpers.go
Normal file
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
* Copyright 2024 Jonas Kaninda
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"slices"
|
||||
)
|
||||
|
||||
func getRealIP(r *http.Request) string {
|
||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
func allowedOrigin(origins []string, origin string) bool {
|
||||
return slices.Contains(origins, origin)
|
||||
}
|
||||
|
||||
// RespondWithError is a helper function to handle error responses with flexible content type
|
||||
func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string) {
|
||||
message := http.StatusText(statusCode)
|
||||
if len(logMessage) != 0 {
|
||||
message = logMessage
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Success: false,
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
145
internal/middlewares/middleware.go
Normal file
145
internal/middlewares/middleware.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package middlewares
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"encoding/base64"
|
||||
"github.com/jkaninda/goma-gateway/pkg/logger"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AuthMiddleware authenticate the client using JWT
|
||||
//
|
||||
// authorization based on the result of backend's response and continue the request when the client is authorized
|
||||
func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, header := range jwtAuth.RequiredHeaders {
|
||||
if r.Header.Get(header) == "" {
|
||||
logger.Error("Proxy error, missing %s header", header)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
//check allowed origin
|
||||
if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
}
|
||||
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
|
||||
return
|
||||
|
||||
}
|
||||
}
|
||||
//token := r.Header.Get("Authorization")
|
||||
authURL, err := url.Parse(jwtAuth.AuthURL)
|
||||
if err != nil {
|
||||
logger.Error("Error parsing auth URL: %v", err)
|
||||
RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
|
||||
return
|
||||
}
|
||||
// Create a new request for /authentication
|
||||
authReq, err := http.NewRequest("GET", authURL.String(), nil)
|
||||
if err != nil {
|
||||
logger.Error("Proxy error creating authentication request: %v", err)
|
||||
RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
|
||||
return
|
||||
}
|
||||
logger.Trace("JWT Auth response headers: %v", authReq.Header)
|
||||
// Copy headers from the original request to the new request
|
||||
for name, values := range r.Header {
|
||||
for _, value := range values {
|
||||
authReq.Header.Set(name, value)
|
||||
}
|
||||
}
|
||||
// Copy cookies from the original request to the new request
|
||||
for _, cookie := range r.Cookies() {
|
||||
authReq.AddCookie(cookie)
|
||||
}
|
||||
// Perform the request to the auth service
|
||||
client := &http.Client{}
|
||||
authResp, err := client.Do(authReq)
|
||||
if err != nil || authResp.StatusCode != http.StatusOK {
|
||||
logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent())
|
||||
logger.Debug("Proxy authentication error")
|
||||
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
|
||||
return
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
logger.Error("Error closing body: %v", err)
|
||||
}
|
||||
}(authResp.Body)
|
||||
// Inject specific header tp the current request's header
|
||||
// Add header to the next request from AuthRequest header, depending on your requirements
|
||||
if jwtAuth.Headers != nil {
|
||||
for k, v := range jwtAuth.Headers {
|
||||
r.Header.Set(v, authResp.Header.Get(k))
|
||||
}
|
||||
}
|
||||
query := r.URL.Query()
|
||||
// Add query parameters to the next request from AuthRequest header, depending on your requirements
|
||||
if jwtAuth.Params != nil {
|
||||
for k, v := range jwtAuth.Params {
|
||||
query.Set(v, authResp.Header.Get(k))
|
||||
}
|
||||
}
|
||||
r.URL.RawQuery = query.Encode()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// AuthMiddleware checks for the Authorization header and verifies the credentials
|
||||
func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Trace("Basic-Auth request headers: %v", r.Header)
|
||||
// Get the Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
logger.Debug("Proxy error, missing Authorization header")
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
|
||||
return
|
||||
}
|
||||
// Check if the Authorization header contains "Basic" scheme
|
||||
if !strings.HasPrefix(authHeader, "Basic ") {
|
||||
logger.Error("Proxy error, missing Basic Authorization header")
|
||||
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Decode the base64 encoded username:password string
|
||||
payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):])
|
||||
if err != nil {
|
||||
logger.Debug("Proxy error, missing Basic Authorization header")
|
||||
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
|
||||
return
|
||||
}
|
||||
|
||||
// Split the payload into username and password
|
||||
pair := strings.SplitN(string(payload), ":", 2)
|
||||
if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
|
||||
return
|
||||
}
|
||||
|
||||
// Continue to the next handler if the authentication is successful
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
}
|
||||
86
internal/middlewares/oauth-middleware.go
Normal file
86
internal/middlewares/oauth-middleware.go
Normal file
@@ -0,0 +1,86 @@
|
||||
/*
|
||||
* Copyright 2024 Jonas Kaninda
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
oauthConf := oauth2Config(oauth)
|
||||
// Check if the user is authenticated
|
||||
token, err := r.Cookie("goma.oauth")
|
||||
if err != nil {
|
||||
// If no token, redirect to OAuth provider
|
||||
url := oauthConf.AuthCodeURL(oauth.State)
|
||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
ok, err := validateJWT(token.Value, oauth)
|
||||
if err != nil {
|
||||
// If no token, redirect to OAuth provider
|
||||
url := oauthConf.AuthCodeURL(oauth.State)
|
||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
// If no token, redirect to OAuth provider
|
||||
url := oauthConf.AuthCodeURL(oauth.State)
|
||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
// Token exists, proceed with request
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func validateJWT(signedToken string, oauth Oauth) (bool, error) {
|
||||
// Parse the JWT token and provide the key function
|
||||
token, err := jwt.Parse(signedToken, func(token *jwt.Token) (interface{}, error) {
|
||||
// Ensure the signing method is HMAC and specifically HS256
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
// Return the shared secret key for validation
|
||||
return []byte(oauth.JWTSecret), nil
|
||||
})
|
||||
|
||||
// If there's an error or token is invalid, return false
|
||||
if err != nil || !token.Valid {
|
||||
return false, fmt.Errorf("token is invalid: %v", err)
|
||||
}
|
||||
|
||||
// Check if token claims are valid
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
// Optional: Check token expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Unix(int64(exp), 0).Before(time.Now()) {
|
||||
return false, fmt.Errorf("token has expired")
|
||||
}
|
||||
}
|
||||
|
||||
// Token is valid and not expired
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("token is invalid or missing claims")
|
||||
}
|
||||
89
internal/middlewares/rate-limit.go
Normal file
89
internal/middlewares/rate-limit.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package middlewares
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jkaninda/goma-gateway/pkg/logger"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware limits request based on the number of tokens peer minutes.
|
||||
func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !rl.Allow() {
|
||||
logger.Error("Too many requests from IP: %s %s %s", getRealIP(r), r.URL, r.UserAgent())
|
||||
//RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor)
|
||||
|
||||
// Rate limit exceeded, return a 429 Too Many Requests response
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
// Proceed to the next handler if requests limit is not exceeded
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitMiddleware limits request based on the number of requests peer minutes.
|
||||
func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := getRealIP(r)
|
||||
clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID
|
||||
logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID)
|
||||
if rl.redisBased {
|
||||
err := redisRateLimiter(clientID, rl.requests)
|
||||
if err != nil {
|
||||
logger.Error("Redis Rate limiter error: %s", err.Error())
|
||||
logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent())
|
||||
RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
rl.mu.Lock()
|
||||
client, exists := rl.clientMap[clientID]
|
||||
if !exists || time.Now().After(client.ExpiresAt) {
|
||||
client = &Client{
|
||||
RequestCount: 0,
|
||||
ExpiresAt: time.Now().Add(rl.window),
|
||||
}
|
||||
rl.clientMap[clientID] = client
|
||||
}
|
||||
client.RequestCount++
|
||||
rl.mu.Unlock()
|
||||
|
||||
if client.RequestCount > rl.requests {
|
||||
logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent())
|
||||
//Update Origin Cors Headers
|
||||
if allowedOrigin(rl.origins, r.Header.Get("Origin")) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
}
|
||||
RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))
|
||||
}
|
||||
}
|
||||
// Proceed to the next handler if requests limit is not exceeded
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
46
internal/middlewares/redis.go
Normal file
46
internal/middlewares/redis.go
Normal file
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright 2024 Jonas Kaninda
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/go-redis/redis_rate/v10"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
func redisRateLimiter(clientIP string, rate int) error {
|
||||
ctx := context.Background()
|
||||
|
||||
res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.Remaining == 0 {
|
||||
return errors.New("requests limit exceeded")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func InitRedis(addr, password string) {
|
||||
Rdb = redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
Password: password,
|
||||
})
|
||||
limiter = redis_rate.NewLimiter(Rdb)
|
||||
}
|
||||
145
internal/middlewares/types.go
Normal file
145
internal/middlewares/types.go
Normal file
@@ -0,0 +1,145 @@
|
||||
/*
|
||||
* Copyright 2024 Jonas Kaninda
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter defines requests limit properties.
|
||||
type RateLimiter struct {
|
||||
requests int
|
||||
id string
|
||||
window time.Duration
|
||||
clientMap map[string]*Client
|
||||
mu sync.Mutex
|
||||
origins []string
|
||||
//hosts []string
|
||||
redisBased bool
|
||||
}
|
||||
|
||||
// Client stores request count and window expiration for each client.
|
||||
type Client struct {
|
||||
RequestCount int
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
type RateLimit struct {
|
||||
Id string
|
||||
Requests int
|
||||
Window time.Duration
|
||||
Origins []string
|
||||
Hosts []string
|
||||
RedisBased bool
|
||||
}
|
||||
|
||||
// NewRateLimiterWindow creates a new RateLimiter.
|
||||
func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter {
|
||||
return &RateLimiter{
|
||||
id: rateLimit.Id,
|
||||
requests: rateLimit.Requests,
|
||||
window: rateLimit.Window,
|
||||
clientMap: make(map[string]*Client),
|
||||
origins: rateLimit.Origins,
|
||||
redisBased: rateLimit.RedisBased,
|
||||
}
|
||||
}
|
||||
|
||||
// TokenRateLimiter stores tokenRate limit
|
||||
type TokenRateLimiter struct {
|
||||
tokens int
|
||||
maxTokens int
|
||||
refillRate time.Duration
|
||||
lastRefill time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ProxyResponseError represents the structure of the JSON error response
|
||||
type ProxyResponseError struct {
|
||||
Success bool `json:"success"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// JwtAuth stores JWT configuration
|
||||
type JwtAuth struct {
|
||||
AuthURL string
|
||||
RequiredHeaders []string
|
||||
Headers map[string]string
|
||||
Params map[string]string
|
||||
Origins []string
|
||||
}
|
||||
|
||||
// AuthenticationMiddleware Define struct
|
||||
type AuthenticationMiddleware struct {
|
||||
AuthURL string
|
||||
RequiredHeaders []string
|
||||
Headers map[string]string
|
||||
Params map[string]string
|
||||
}
|
||||
type AccessListMiddleware struct {
|
||||
Path string
|
||||
Destination string
|
||||
List []string
|
||||
}
|
||||
|
||||
// AuthBasic contains Basic auth configuration
|
||||
type AuthBasic struct {
|
||||
Username string
|
||||
Password string
|
||||
Headers map[string]string
|
||||
Params map[string]string
|
||||
}
|
||||
|
||||
// InterceptErrors contains backend status code errors to intercept
|
||||
type InterceptErrors struct {
|
||||
Errors []int
|
||||
Origins []string
|
||||
}
|
||||
|
||||
// responseRecorder intercepts the response body and status code
|
||||
type responseRecorder struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
body *bytes.Buffer
|
||||
}
|
||||
type Oauth struct {
|
||||
// ClientID is the application's ID.
|
||||
ClientID string
|
||||
// ClientSecret is the application's secret.
|
||||
ClientSecret string
|
||||
// Endpoint contains the resource server's token endpoint
|
||||
Endpoint OauthEndpoint
|
||||
// RedirectURL is the URL to redirect users going through
|
||||
// the OAuth flow, after the resource owner's URLs.
|
||||
RedirectURL string
|
||||
// Scope specifies optional requested permissions.
|
||||
Scopes []string
|
||||
// contains filtered or unexported fields
|
||||
State string
|
||||
Origins []string
|
||||
JWTSecret string
|
||||
Provider string
|
||||
}
|
||||
type OauthEndpoint struct {
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
}
|
||||
33
internal/middlewares/var.go
Normal file
33
internal/middlewares/var.go
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright 2024 Jonas Kaninda
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"github.com/go-redis/redis_rate/v10"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// sqlPatterns contains SQL injections patters
|
||||
const sqlPatterns = `(?i)(union|select|drop|insert|delete|update|create|alter|exec|;|--)`
|
||||
const traversalPatterns = `\.\./`
|
||||
const xssPatterns = `(?i)<script|onerror|onload`
|
||||
|
||||
var (
|
||||
Rdb *redis.Client
|
||||
limiter *redis_rate.Limiter
|
||||
)
|
||||
Reference in New Issue
Block a user