diff --git a/internal/routes.go b/internal/routes.go index 8536ffb..aa3e4f0 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -61,7 +61,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { if err != nil { logger.Fatal("Error: %v", err) } - m := dynamicMiddlewares if len(gateway.Redis.Addr) != 0 { redisBased = true } @@ -126,89 +125,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { cors: route.Cors, insecureSkipVerify: route.InsecureSkipVerify, } - // Apply middlewares to the route - for _, middleware := range route.Middlewares { - // Apply common exploits to the route - // Enable common exploits - if route.BlockCommonExploits { - logger.Info("Block common exploits enabled") - router.Use(middlewares.BlockExploitsMiddleware) - } - id := string(rune(rIndex)) - if len(route.Name) != 0 { - // Use route name as ID - id = util.Slug(route.Name) - } - // Apply route rate limit - if route.RateLimit != 0 { - rateLimit := middlewares.RateLimit{ - Unit: "minute", - Id: id, // Use route index as ID - Requests: route.RateLimit, - Origins: route.Cors.Origins, - Hosts: route.Hosts, - RedisBased: redisBased, - } - limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middlewares - router.Use(limiter.RateLimitMiddleware()) - } - if len(middleware) != 0 { - // Get Access middlewares if it does exist - accessMiddleware, err := getMiddleware([]string{middleware}, m) - if err != nil { - logger.Error("Error: %v", err.Error()) - } else { - // Apply access middlewares - if accessMiddleware.Type == AccessMiddleware { - blM := middlewares.AccessListMiddleware{ - Path: route.Path, - List: accessMiddleware.Paths, - } - r.Use(blM.AccessMiddleware) - - } - - // Apply Rate limit middleware - if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) { - rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule) - if err != nil { - logger.Error("Error: %v", err.Error()) - } - if rateLimitMid.RequestsPerUnit != 0 && route.RateLimit == 0 { - rateLimit := middlewares.RateLimit{ - Unit: rateLimitMid.Unit, - Id: id, // Use route index as ID - Requests: rateLimitMid.RequestsPerUnit, - Origins: route.Cors.Origins, - Hosts: route.Hosts, - RedisBased: redisBased, - PathBased: true, - Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths), - } - limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middlewares - router.Use(limiter.RateLimitMiddleware()) - - } - - } - - } - // Get route authentication middlewares if it does exist - routeMiddleware, err := getMiddleware([]string{middleware}, m) - if err != nil { - // Error: middlewares not found - logger.Error("Error: %v", err.Error()) - } else { - attachAuthMiddlewares(route, routeMiddleware, gateway, router) - } - } else { - logger.Error("Error, middlewares path is empty") - logger.Error("Middleware ignored") - } - } - + attachMiddlewares(rIndex, route, gateway, router) // Apply route Cors router.Use(CORSHandler(route.Cors)) if gateway.EnableMetrics { @@ -257,6 +174,92 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } +// attachMiddlewares attach middlewares to the route +func attachMiddlewares(rIndex int, route Route, gateway Gateway, router *mux.Router) { + // Apply middlewares to the route + for _, middleware := range route.Middlewares { + // Apply common exploits to the route + // Enable common exploits + if route.BlockCommonExploits { + logger.Info("Block common exploits enabled") + router.Use(middlewares.BlockExploitsMiddleware) + } + id := string(rune(rIndex)) + if len(route.Name) != 0 { + // Use route name as ID + id = util.Slug(route.Name) + } + // Apply route rate limit + if route.RateLimit != 0 { + rateLimit := middlewares.RateLimit{ + Unit: "minute", + Id: id, // Use route index as ID + Requests: route.RateLimit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + } + if len(middleware) != 0 { + // Get Access middlewares if it does exist + accessMiddleware, err := getMiddleware([]string{middleware}, dynamicMiddlewares) + if err != nil { + logger.Error("Error: %v", err.Error()) + } else { + // Apply access middlewares + if accessMiddleware.Type == AccessMiddleware { + blM := middlewares.AccessListMiddleware{ + Path: route.Path, + List: accessMiddleware.Paths, + } + router.Use(blM.AccessMiddleware) + } + + // Apply Rate limit middleware + if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) { + rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule) + if err != nil { + logger.Error("Error: %v", err.Error()) + } + if rateLimitMid.RequestsPerUnit != 0 && route.RateLimit == 0 { + rateLimit := middlewares.RateLimit{ + Unit: rateLimitMid.Unit, + Id: id, // Use route index as ID + Requests: rateLimitMid.RequestsPerUnit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + PathBased: true, + Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths), + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + + } + + } + + } + // Get route authentication middlewares if it does exist + routeMiddleware, err := getMiddleware([]string{middleware}, dynamicMiddlewares) + if err != nil { + // Error: middlewares not found + logger.Error("Error: %v", err.Error()) + } else { + attachAuthMiddlewares(route, routeMiddleware, gateway, router) + } + } else { + logger.Error("Error, middlewares path is empty") + logger.Error("Middleware ignored") + } + } + +} + func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) { // Check Authentication middleware types switch routeMiddleware.Type {