refactor: refactoring of rate limiting

This commit is contained in:
Jonas Kaninda
2024-11-24 23:09:13 +01:00
parent 3df8dce59b
commit f4e5bb3be2
3 changed files with 60 additions and 29 deletions

View File

@@ -53,13 +53,11 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := getRealIP(r)
clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID
logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID)
if rl.redisBased {
err := redisRateLimiter(clientID, rl.unit, rl.requests)
if err != nil {
logger.Error("Redis Rate limiter error: %s", err.Error())
logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent())
RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))
return
}
} else {
@@ -82,8 +80,10 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
}
RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))
return
}
}
// Proceed to the next handler if the request limit is not exceeded
next.ServeHTTP(w, r)
})

View File

@@ -33,6 +33,8 @@ type RateLimiter struct {
mu sync.Mutex
origins []string
redisBased bool
pathBased bool
paths []string
}
// Client stores request count and window expiration for each client.
@@ -47,6 +49,8 @@ type RateLimit struct {
Origins []string
Hosts []string
RedisBased bool
PathBased bool
Paths []string
}
// NewRateLimiterWindow creates a new RateLimiter.
@@ -58,6 +62,8 @@ func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter {
clientMap: make(map[string]*Client),
origins: rateLimit.Origins,
redisBased: rateLimit.RedisBased,
pathBased: rateLimit.PathBased,
paths: rateLimit.Paths,
}
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/jkaninda/goma-gateway/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"slices"
)
// init initializes prometheus metrics
@@ -127,6 +128,31 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
// Apply middlewares to the route
for _, middleware := range route.Middlewares {
// Apply common exploits to the route
// Enable common exploits
if route.BlockCommonExploits {
logger.Info("Block common exploits enabled")
router.Use(middlewares.BlockExploitsMiddleware)
}
id := string(rune(rIndex))
if len(route.Name) != 0 {
// Use route name as ID
id = util.Slug(route.Name)
}
// Apply route rate limit
if route.RateLimit != 0 {
rateLimit := middlewares.RateLimit{
Unit: "minute",
Id: id, // Use route index as ID
Requests: route.RateLimit,
Origins: route.Cors.Origins,
Hosts: route.Hosts,
RedisBased: redisBased,
}
limiter := rateLimit.NewRateLimiterWindow()
// Add rate limit middlewares
router.Use(limiter.RateLimitMiddleware())
}
if len(middleware) != 0 {
// Get Access middlewares if it does exist
accessMiddleware, err := getMiddleware([]string{middleware}, m)
@@ -143,6 +169,31 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
// Apply Rate limit middleware
if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) {
rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule)
if err != nil {
logger.Error("Error: %v", err.Error())
}
if rateLimitMid.RequestsPerUnit != 0 && route.RateLimit == 0 {
rateLimit := middlewares.RateLimit{
Unit: rateLimitMid.Unit,
Id: id, // Use route index as ID
Requests: rateLimitMid.RequestsPerUnit,
Origins: route.Cors.Origins,
Hosts: route.Hosts,
RedisBased: redisBased,
PathBased: true,
Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths),
}
limiter := rateLimit.NewRateLimiterWindow()
// Add rate limit middlewares
router.Use(limiter.RateLimitMiddleware())
}
}
}
// Get route authentication middlewares if it does exist
routeMiddleware, err := getMiddleware([]string{middleware}, m)
@@ -158,31 +209,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
}
// Apply common exploits to the route
// Enable common exploits
if route.BlockCommonExploits {
logger.Info("Block common exploits enabled")
router.Use(middlewares.BlockExploitsMiddleware)
}
id := string(rune(rIndex))
if len(route.Name) != 0 {
// Use route name as ID
id = util.Slug(route.Name)
}
// Apply route rate limit
if route.RateLimit != 0 {
rateLimit := middlewares.RateLimit{
Unit: "minute",
Id: id, // Use route index as ID
Requests: route.RateLimit,
Origins: route.Cors.Origins,
Hosts: route.Hosts,
RedisBased: redisBased,
}
limiter := rateLimit.NewRateLimiterWindow()
// Add rate limit middlewares
router.Use(limiter.RateLimitMiddleware())
}
// Apply route Cors
router.Use(CORSHandler(route.Cors))
if len(route.Hosts) > 0 {
@@ -208,8 +234,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
router.Use(interceptErrors.ErrorInterceptor)
}
//r.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
//r.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
} else {
logger.Error("Error, path is empty in route %s", route.Name)
logger.Error("Route path ignored: %s", route.Path)