2024-10-27 06:10:27 +01:00
package middleware
/ *
Copyright 2024 Jonas Kaninda
Licensed under the Apache License , Version 2.0 ( the "License" ) ;
you may not use this file except in compliance with the License .
You may obtain a copy of the License at
http : //www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an "AS IS" BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
* /
import (
2024-11-14 11:38:36 +01:00
"errors"
2024-11-14 00:26:21 +01:00
"fmt"
2024-11-14 11:38:36 +01:00
"github.com/go-redis/redis_rate/v10"
2024-10-27 06:10:27 +01:00
"github.com/gorilla/mux"
2024-11-04 08:48:38 +01:00
"github.com/jkaninda/goma-gateway/pkg/logger"
2024-11-14 11:38:36 +01:00
"github.com/redis/go-redis/v9"
"golang.org/x/net/context"
2024-10-27 06:10:27 +01:00
"net/http"
"time"
)
// RateLimitMiddleware limits request based on the number of tokens peer minutes.
func ( rl * TokenRateLimiter ) RateLimitMiddleware ( ) mux . MiddlewareFunc {
return func ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
if ! rl . Allow ( ) {
2024-11-14 00:26:21 +01:00
logger . Error ( "Too many requests from IP: %s %s %s" , getRealIP ( r ) , r . URL , r . UserAgent ( ) )
//RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor)
2024-10-27 06:10:27 +01:00
// Rate limit exceeded, return a 429 Too Many Requests response
w . WriteHeader ( http . StatusTooManyRequests )
2024-11-14 13:17:28 +01:00
_ , err := w . Write ( [ ] byte ( fmt . Sprintf ( "%d Too many requests, API requests limit exceeded. Please try again later" , http . StatusTooManyRequests ) ) )
2024-10-27 06:10:27 +01:00
if err != nil {
return
}
return
}
2024-11-14 13:17:28 +01:00
// Proceed to the next handler if requests limit is not exceeded
2024-10-27 06:10:27 +01:00
next . ServeHTTP ( w , r )
} )
}
}
// RateLimitMiddleware limits request based on the number of requests peer minutes.
func ( rl * RateLimiter ) RateLimitMiddleware ( ) mux . MiddlewareFunc {
return func ( next http . Handler ) http . Handler {
return http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2024-11-14 11:38:36 +01:00
clientIP := getRealIP ( r )
2024-11-14 14:46:18 +01:00
clientID := fmt . Sprintf ( "%s-%s" , rl . id , clientIP ) // Generate client Id, ID+ route ID
logger . Debug ( "requests limiter: clientIP: %s, clientID: %s" , clientIP , clientID )
2024-11-14 13:17:28 +01:00
if rl . redisBased {
err := redisRateLimiter ( clientID , rl . requests )
2024-11-14 11:38:36 +01:00
if err != nil {
logger . Error ( "Redis Rate limiter error: %s" , err . Error ( ) )
2024-11-14 13:17:28 +01:00
logger . Error ( "Too many requests from IP: %s %s %s" , clientIP , r . URL , r . UserAgent ( ) )
RespondWithError ( w , http . StatusTooManyRequests , fmt . Sprintf ( "%d Too many requests, API requests limit exceeded. Please try again later" , http . StatusTooManyRequests ) )
2024-11-14 11:38:36 +01:00
return
2024-10-27 06:10:27 +01:00
}
2024-11-14 11:38:36 +01:00
} else {
rl . mu . Lock ( )
2024-11-14 13:17:28 +01:00
client , exists := rl . clientMap [ clientID ]
2024-11-14 11:38:36 +01:00
if ! exists || time . Now ( ) . After ( client . ExpiresAt ) {
client = & Client {
RequestCount : 0 ,
2024-11-14 13:17:28 +01:00
ExpiresAt : time . Now ( ) . Add ( rl . window ) ,
2024-11-14 11:38:36 +01:00
}
2024-11-14 13:17:28 +01:00
rl . clientMap [ clientID ] = client
2024-11-14 11:38:36 +01:00
}
client . RequestCount ++
rl . mu . Unlock ( )
2024-10-27 06:10:27 +01:00
2024-11-14 13:17:28 +01:00
if client . RequestCount > rl . requests {
2024-11-14 11:38:36 +01:00
logger . Error ( "Too many requests from IP: %s %s %s" , clientIP , r . URL , r . UserAgent ( ) )
//Update Origin Cors Headers
2024-11-14 13:17:28 +01:00
if allowedOrigin ( rl . origins , r . Header . Get ( "Origin" ) ) {
2024-11-14 11:38:36 +01:00
w . Header ( ) . Set ( "Access-Control-Allow-Origin" , r . Header . Get ( "Origin" ) )
}
2024-11-14 13:17:28 +01:00
RespondWithError ( w , http . StatusTooManyRequests , fmt . Sprintf ( "%d Too many requests, API requests limit exceeded. Please try again later" , http . StatusTooManyRequests ) )
2024-11-05 20:44:06 +01:00
}
2024-10-27 06:10:27 +01:00
}
2024-11-14 13:17:28 +01:00
// Proceed to the next handler if requests limit is not exceeded
2024-10-27 06:10:27 +01:00
next . ServeHTTP ( w , r )
} )
}
}
2024-11-14 11:38:36 +01:00
func redisRateLimiter ( clientIP string , rate int ) error {
ctx := context . Background ( )
res , err := limiter . Allow ( ctx , clientIP , redis_rate . PerMinute ( rate ) )
if err != nil {
return err
}
if res . Remaining == 0 {
2024-11-14 13:17:28 +01:00
return errors . New ( "requests limit exceeded" )
2024-11-14 11:38:36 +01:00
}
return nil
}
func InitRedis ( addr , password string ) {
Rdb = redis . NewClient ( & redis . Options {
Addr : addr ,
Password : password ,
} )
limiter = redis_rate . NewLimiter ( Rdb )
}