From 6258b07c8236842c9f00d9e4caa6f7c7d7504929 Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Sun, 24 Nov 2024 15:59:47 +0100 Subject: [PATCH 1/6] refacor: improvement of rate limiting --- internal/config.go | 19 +++++++++++++++++++ internal/middleware.go | 1 + internal/middlewares/rate_limit.go | 8 ++++++-- internal/middlewares/redis.go | 9 ++++++--- internal/middlewares/types.go | 6 +++--- internal/routes.go | 27 +++++++++++++-------------- internal/types.go | 10 ++++------ internal/var.go | 7 +++++-- 8 files changed, 57 insertions(+), 30 deletions(-) diff --git a/internal/config.go b/internal/config.go index 73aaf37..6a88905 100644 --- a/internal/config.go +++ b/internal/config.go @@ -226,6 +226,25 @@ func (Gateway) Setup(conf string) *Gateway { } +// rateLimitMiddleware returns RateLimitRuleMiddleware, error +func rateLimitMiddleware(input interface{}) (RateLimitRuleMiddleware, error) { + rateLimit := new(RateLimitRuleMiddleware) + var bytes []byte + bytes, err := yaml.Marshal(input) + if err != nil { + return RateLimitRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + err = yaml.Unmarshal(bytes, rateLimit) + if err != nil { + return RateLimitRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + if rateLimit.RequestsPerUnit == 0 { + return RateLimitRuleMiddleware{}, fmt.Errorf("requests per unit not defined") + + } + return *rateLimit, nil +} + // getJWTMiddleware returns JWTRuleMiddleware,error func getJWTMiddleware(input interface{}) (JWTRuleMiddleware, error) { jWTRuler := new(JWTRuleMiddleware) diff --git a/internal/middleware.go b/internal/middleware.go index 6f519bc..ddc55a8 100644 --- a/internal/middleware.go +++ b/internal/middleware.go @@ -22,6 +22,7 @@ func getMiddleware(rules []string, middlewares []Middleware) (Middleware, error) func doesExist(tyName string) bool { middlewareList := []string{BasicAuth, JWTAuth, AccessMiddleware} + middlewareList = append(middlewareList, RateLimitMiddleware...) return slices.Contains(middlewareList, tyName) } func GetMiddleware(rule string, middlewares []Middleware) (Middleware, error) { diff --git a/internal/middlewares/rate_limit.go b/internal/middlewares/rate_limit.go index e63cb04..f2f3107 100644 --- a/internal/middlewares/rate_limit.go +++ b/internal/middlewares/rate_limit.go @@ -45,13 +45,17 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { // RateLimitMiddleware limits request based on the number of requests peer minutes. func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { + window := time.Minute // requests per minute + if len(rl.unit) != 0 && rl.unit == "hour" { + window = time.Hour + } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := getRealIP(r) clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID) if rl.redisBased { - err := redisRateLimiter(clientID, rl.requests) + err := redisRateLimiter(clientID, rl.unit, rl.requests) if err != nil { logger.Error("Redis Rate limiter error: %s", err.Error()) logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) @@ -64,7 +68,7 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { if !exists || time.Now().After(client.ExpiresAt) { client = &Client{ RequestCount: 0, - ExpiresAt: time.Now().Add(rl.window), + ExpiresAt: time.Now().Add(window), } rl.clientMap[clientID] = client } diff --git a/internal/middlewares/redis.go b/internal/middlewares/redis.go index 1f199d2..eb57867 100644 --- a/internal/middlewares/redis.go +++ b/internal/middlewares/redis.go @@ -25,10 +25,13 @@ import ( ) // redisRateLimiter, handle rateLimit -func redisRateLimiter(clientIP string, rate int) error { +func redisRateLimiter(clientIP, unit string, rate int) error { + limit := redis_rate.PerMinute(rate) + if len(unit) != 0 && unit == "hour" { + limit = redis_rate.PerHour(rate) + } ctx := context.Background() - - res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate)) + res, err := limiter.Allow(ctx, clientIP, limit) if err != nil { return err } diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index 49c9112..bedf131 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -27,8 +27,8 @@ import ( // RateLimiter defines requests limit properties. type RateLimiter struct { requests int + unit string id string - window time.Duration clientMap map[string]*Client mu sync.Mutex origins []string @@ -42,8 +42,8 @@ type Client struct { } type RateLimit struct { Id string + Unit string Requests int - Window time.Duration Origins []string Hosts []string RedisBased bool @@ -53,8 +53,8 @@ type RateLimit struct { func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter { return &RateLimiter{ id: rateLimit.Id, + unit: rateLimit.Unit, requests: rateLimit.Requests, - window: rateLimit.Window, clientMap: make(map[string]*Client), origins: rateLimit.Origins, redisBased: rateLimit.RedisBased, diff --git a/internal/routes.go b/internal/routes.go index 592acde..6dd0ee9 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -23,7 +23,6 @@ import ( "github.com/jkaninda/goma-gateway/util" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" - "time" ) // init initializes prometheus metrics @@ -62,7 +61,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { logger.Fatal("Error: %v", err) } m := dynamicMiddlewares - redisBased := false if len(gateway.Redis.Addr) != 0 { redisBased = true } @@ -97,8 +95,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Add rate limit middlewares to all routes, if defined rateLimit := middlewares.RateLimit{ Id: "global_rate", // Generate a unique ID for routes + Unit: "minute", Requests: gateway.RateLimit, - Window: time.Minute, // requests per minute Origins: gateway.Cors.Origins, Hosts: []string{}, RedisBased: redisBased, @@ -116,7 +114,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } // Apply middlewares to the route for _, middleware := range route.Middlewares { - if middleware != "" { + if len(middleware) != 0 { // Get Access middlewares if it does exist accessMiddleware, err := getMiddleware([]string{middleware}, m) if err != nil { @@ -172,9 +170,9 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Apply route rate limit if route.RateLimit != 0 { rateLimit := middlewares.RateLimit{ + Unit: "minute", Id: id, // Use route index as ID Requests: route.RateLimit, - Window: time.Minute, // requests per minute Origins: route.Cors.Origins, Hosts: route.Hosts, RedisBased: redisBased, @@ -212,16 +210,17 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) } - } - // Apply global Cors middlewares - r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares - // Apply errorInterceptor middlewares - if len(gateway.InterceptErrors) != 0 { - interceptErrors := middlewares.InterceptErrors{ - Errors: gateway.InterceptErrors, - Origins: gateway.Cors.Origins, + + // Apply global Cors middlewares + r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares + // Apply errorInterceptor middlewares + if len(gateway.InterceptErrors) != 0 { + interceptErrors := middlewares.InterceptErrors{ + Errors: gateway.InterceptErrors, + Origins: gateway.Cors.Origins, + } + r.Use(interceptErrors.ErrorInterceptor) } - r.Use(interceptErrors.ErrorInterceptor) } return r diff --git a/internal/types.go b/internal/types.go index 8f05891..2da2283 100644 --- a/internal/types.go +++ b/internal/types.go @@ -80,13 +80,11 @@ type OauthEndpoint struct { TokenURL string `yaml:"tokenUrl"` UserInfoURL string `yaml:"userInfoUrl"` } -type RateLimiter struct { - // ipBased, tokenBased - Type string `yaml:"type"` - Rate float64 `yaml:"rate"` - Rule int `yaml:"rule"` -} +type RateLimitRuleMiddleware struct { + Unit string `yaml:"unit"` + RequestsPerUnit int `yaml:"requestsPerUnit"` +} type AccessRuleMiddleware struct { ResponseCode int `yaml:"responseCode"` // HTTP Response code } diff --git a/internal/var.go b/internal/var.go index 7698733..273dd6e 100644 --- a/internal/var.go +++ b/internal/var.go @@ -9,10 +9,13 @@ const AccessMiddleware = "access" // access middlewares const BasicAuth = "basic" // basic authentication middlewares const JWTAuth = "jwt" // JWT authentication middlewares const OAuth = "oauth" // OAuth authentication middlewares + var ( // Round-robin counter counter uint32 // dynamicRoutes routes - dynamicRoutes []Route - dynamicMiddlewares []Middleware + dynamicRoutes []Route + dynamicMiddlewares []Middleware + RateLimitMiddleware = []string{"ratelimit", "rateLimit"} // Rate Limit middlewares + redisBased = false ) From 3df8dce59b6a12128633e56e21e640fbc5b7fbec Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Sun, 24 Nov 2024 22:13:26 +0100 Subject: [PATCH 2/6] fix: fix authentification middlewares --- internal/middlewares/access_middleware.go | 6 + internal/middlewares/middleware.go | 179 ++++++++++---------- internal/middlewares/oauth_middleware.go | 44 ++--- internal/middlewares/types.go | 5 + internal/routes.go | 196 ++++++++++------------ internal/server.go | 6 +- util/helpers.go | 8 + 7 files changed, 228 insertions(+), 216 deletions(-) diff --git a/internal/middlewares/access_middleware.go b/internal/middlewares/access_middleware.go index ca26295..5b18181 100644 --- a/internal/middlewares/access_middleware.go +++ b/internal/middlewares/access_middleware.go @@ -53,6 +53,12 @@ func isPathBlocked(requestPath, blockedPath string) bool { } return false } +func isProtectedPath(urlPath string, paths []string) bool { + for _, path := range paths { + return isPathBlocked(urlPath, util.ParseURLPath(path)) + } + return false +} // NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter { diff --git a/internal/middlewares/middleware.go b/internal/middlewares/middleware.go index f8fc1b8..da52614 100644 --- a/internal/middlewares/middleware.go +++ b/internal/middlewares/middleware.go @@ -29,73 +29,75 @@ import ( // authorization based on the result of backend's response and continue the request when the client is authorized func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for _, header := range jwtAuth.RequiredHeaders { - if r.Header.Get(header) == "" { - logger.Error("Proxy error, missing %s header", header) - w.Header().Set("Content-Type", "application/json") - // check allowed origin - if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + if isProtectedPath(r.URL.Path, jwtAuth.Paths) { + for _, header := range jwtAuth.RequiredHeaders { + if r.Header.Get(header) == "" { + logger.Error("Proxy error, missing %s header", header) + w.Header().Set("Content-Type", "application/json") + // check allowed origin + if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } + } + authURL, err := url.Parse(jwtAuth.AuthURL) + if err != nil { + logger.Error("Error parsing auth URL: %v", err) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + return + } + // Create a new request for /authentication + authReq, err := http.NewRequest("GET", authURL.String(), nil) + if err != nil { + logger.Error("Proxy error creating authentication request: %v", err) + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + return + } + logger.Trace("JWT Auth response headers: %v", authReq.Header) + // Copy headers from the original request to the new request + for name, values := range r.Header { + for _, value := range values { + authReq.Header.Set(name, value) + } + } + // Copy cookies from the original request to the new request + for _, cookie := range r.Cookies() { + authReq.AddCookie(cookie) + } + // Perform the request to the auth service + client := &http.Client{} + authResp, err := client.Do(authReq) + if err != nil || authResp.StatusCode != http.StatusOK { + logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) + logger.Debug("Proxy authentication error") RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return - } - } - authURL, err := url.Parse(jwtAuth.AuthURL) - if err != nil { - logger.Error("Error parsing auth URL: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) - return - } - // Create a new request for /authentication - authReq, err := http.NewRequest("GET", authURL.String(), nil) - if err != nil { - logger.Error("Proxy error creating authentication request: %v", err) - RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) - return - } - logger.Trace("JWT Auth response headers: %v", authReq.Header) - // Copy headers from the original request to the new request - for name, values := range r.Header { - for _, value := range values { - authReq.Header.Set(name, value) + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + logger.Error("Error closing body: %v", err) + } + }(authResp.Body) + // Inject specific header tp the current request's header + // Add header to the next request from AuthRequest header, depending on your requirements + if jwtAuth.Headers != nil { + for k, v := range jwtAuth.Headers { + r.Header.Set(v, authResp.Header.Get(k)) + } } - } - // Copy cookies from the original request to the new request - for _, cookie := range r.Cookies() { - authReq.AddCookie(cookie) - } - // Perform the request to the auth service - client := &http.Client{} - authResp, err := client.Do(authReq) - if err != nil || authResp.StatusCode != http.StatusOK { - logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) - logger.Debug("Proxy authentication error") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - logger.Error("Error closing body: %v", err) - } - }(authResp.Body) - // Inject specific header tp the current request's header - // Add header to the next request from AuthRequest header, depending on your requirements - if jwtAuth.Headers != nil { - for k, v := range jwtAuth.Headers { - r.Header.Set(v, authResp.Header.Get(k)) + query := r.URL.Query() + // Add query parameters to the next request from AuthRequest header, depending on your requirements + if jwtAuth.Params != nil { + for k, v := range jwtAuth.Params { + query.Set(v, authResp.Header.Get(k)) + } } + r.URL.RawQuery = query.Encode() } - query := r.URL.Query() - // Add query parameters to the next request from AuthRequest header, depending on your requirements - if jwtAuth.Params != nil { - for k, v := range jwtAuth.Params { - query.Set(v, authResp.Header.Get(k)) - } - } - r.URL.RawQuery = query.Encode() next.ServeHTTP(w, r) }) @@ -105,36 +107,37 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Trace("Basic-Auth request headers: %v", r.Header) - // Get the Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - logger.Debug("Proxy error, missing Authorization header") - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } - // Check if the Authorization header contains "Basic" scheme - if !strings.HasPrefix(authHeader, "Basic ") { - logger.Error("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + if isProtectedPath(r.URL.Path, basicAuth.Paths) { + // Get the Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + logger.Debug("Proxy error, missing Authorization header") + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } + // Check if the Authorization header contains "Basic" scheme + if !strings.HasPrefix(authHeader, "Basic ") { + logger.Error("Proxy error, missing Basic Authorization header") + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } + return + } + // Decode the base64 encoded username:password string + payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) + if err != nil { + logger.Debug("Proxy error, missing Basic Authorization header") + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } + // Split the payload into username and password + pair := strings.SplitN(string(payload), ":", 2) + if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return + } - // Decode the base64 encoded username:password string - payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) - if err != nil { - logger.Debug("Proxy error, missing Basic Authorization header") - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return - } - - // Split the payload into username and password - pair := strings.SplitN(string(payload), ":", 2) - if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) - return } // Continue to the next handler if the authentication is successful diff --git a/internal/middlewares/oauth_middleware.go b/internal/middlewares/oauth_middleware.go index f2d7407..74b4089 100644 --- a/internal/middlewares/oauth_middleware.go +++ b/internal/middlewares/oauth_middleware.go @@ -26,27 +26,29 @@ import ( func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - oauthConf := oauth2Config(oauth) - // Check if the user is authenticated - token, err := r.Cookie("goma.oauth") - if err != nil { - // If no token, redirect to OAuth provider - url := oauthConf.AuthCodeURL(oauth.State) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return - } - ok, err := validateJWT(token.Value, oauth) - if err != nil { - // If no token, redirect to OAuth provider - url := oauthConf.AuthCodeURL(oauth.State) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return - } - if !ok { - // If no token, redirect to OAuth provider - url := oauthConf.AuthCodeURL(oauth.State) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return + if isProtectedPath(r.URL.Path, oauth.Paths) { + oauthConf := oauth2Config(oauth) + // Check if the user is authenticated + token, err := r.Cookie("goma.oauth") + if err != nil { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + ok, err := validateJWT(token.Value, oauth) + if err != nil { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + if !ok { + // If no token, redirect to OAuth provider + url := oauthConf.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } } // Token exists, proceed with request next.ServeHTTP(w, r) diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index bedf131..59d78d8 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -79,6 +79,8 @@ type ProxyResponseError struct { // JwtAuth stores JWT configuration type JwtAuth struct { + RoutePath string + Paths []string AuthURL string RequiredHeaders []string Headers map[string]string @@ -101,6 +103,7 @@ type AccessListMiddleware struct { // AuthBasic contains Basic auth configuration type AuthBasic struct { + Paths []string Username string Password string Headers map[string]string @@ -120,6 +123,8 @@ type responseRecorder struct { body *bytes.Buffer } type Oauth struct { + // Route protected path + Paths []string // ClientID is the application's ID. ClientID string // ClientSecret is the application's secret. diff --git a/internal/routes.go b/internal/routes.go index 6dd0ee9..8f7defd 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -106,12 +106,25 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { r.Use(limiter.RateLimitMiddleware()) } for rIndex, route := range dynamicRoutes { + + // create route + router := r.PathPrefix(route.Path).Subrouter() if len(route.Path) != 0 { // Checks if route destination and backend are empty if len(route.Destination) == 0 && len(route.Backends) == 0 { logger.Fatal("Route %s : destination or backends should not be empty", route.Name) } + proxyRoute := ProxyRoute{ + path: route.Path, + rewrite: route.Rewrite, + destination: route.Destination, + backends: route.Backends, + methods: route.Methods, + disableHostFording: route.DisableHostFording, + cors: route.Cors, + insecureSkipVerify: route.InsecureSkipVerify, + } // Apply middlewares to the route for _, middleware := range route.Middlewares { if len(middleware) != 0 { @@ -144,18 +157,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { logger.Error("Middleware ignored") } } - proxyRoute := ProxyRoute{ - path: route.Path, - rewrite: route.Rewrite, - destination: route.Destination, - backends: route.Backends, - methods: route.Methods, - disableHostFording: route.DisableHostFording, - cors: route.Cors, - insecureSkipVerify: route.InsecureSkipVerify, - } - // create route - router := r.PathPrefix(route.Path).Subrouter() + // Apply common exploits to the route // Enable common exploits if route.BlockCommonExploits { @@ -206,6 +208,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } router.Use(interceptErrors.ErrorInterceptor) } + //r.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler + //r.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) @@ -221,6 +225,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } r.Use(interceptErrors.ErrorInterceptor) } + } return r @@ -228,105 +233,88 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) { - for _, middlewarePath := range routeMiddleware.Paths { - proxyRoute := ProxyRoute{ - path: route.Path, - rewrite: route.Rewrite, - destination: route.Destination, - backends: route.Backends, - disableHostFording: route.DisableHostFording, - methods: route.Methods, - cors: route.Cors, - insecureSkipVerify: route.InsecureSkipVerify, + // Check Authentication middleware types + switch routeMiddleware.Type { + case BasicAuth: + basicAuth, err := getBasicAuthMiddleware(routeMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + authBasic := middlewares.AuthBasic{ + Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), + Username: basicAuth.Username, + Password: basicAuth.Password, + Headers: nil, + Params: nil, + } + // Apply JWT authentication middlewares + r.Use(authBasic.AuthMiddleware) + r.Use(CORSHandler(route.Cors)) } - secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, middlewarePath)).Subrouter() - // Check Authentication middleware types - switch routeMiddleware.Type { - case BasicAuth: - basicAuth, err := getBasicAuthMiddleware(routeMiddleware.Rule) - if err != nil { - logger.Error("Error: %s", err.Error()) - } else { - authBasic := middlewares.AuthBasic{ - Username: basicAuth.Username, - Password: basicAuth.Password, - Headers: nil, - Params: nil, - } - // Apply JWT authentication middlewares - secureRouter.Use(authBasic.AuthMiddleware) - secureRouter.Use(CORSHandler(route.Cors)) - secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler + case JWTAuth: + jwt, err := getJWTMiddleware(routeMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + jwtAuth := middlewares.JwtAuth{ + Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), + AuthURL: jwt.URL, + RequiredHeaders: jwt.RequiredHeaders, + Headers: jwt.Headers, + Params: jwt.Params, + Origins: gateway.Cors.Origins, } - case JWTAuth: - jwt, err := getJWTMiddleware(routeMiddleware.Rule) - if err != nil { - logger.Error("Error: %s", err.Error()) - } else { - jwtAuth := middlewares.JwtAuth{ - AuthURL: jwt.URL, - RequiredHeaders: jwt.RequiredHeaders, - Headers: jwt.Headers, - Params: jwt.Params, - Origins: gateway.Cors.Origins, - } - // Apply JWT authentication middlewares - secureRouter.Use(jwtAuth.AuthMiddleware) - secureRouter.Use(CORSHandler(route.Cors)) - secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler + // Apply JWT authentication middlewares + r.Use(jwtAuth.AuthMiddleware) + r.Use(CORSHandler(route.Cors)) + } + case OAuth: + oauth, err := oAuthMiddleware(routeMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + redirectURL := "/callback" + route.Path + if oauth.RedirectURL != "" { + redirectURL = oauth.RedirectURL } - case OAuth: - oauth, err := oAuthMiddleware(routeMiddleware.Rule) - if err != nil { - logger.Error("Error: %s", err.Error()) - } else { - redirectURL := "/callback" + route.Path - if oauth.RedirectURL != "" { - redirectURL = oauth.RedirectURL - } - amw := middlewares.Oauth{ - ClientID: oauth.ClientID, - ClientSecret: oauth.ClientSecret, - RedirectURL: redirectURL, - Scopes: oauth.Scopes, - Endpoint: middlewares.OauthEndpoint{ - AuthURL: oauth.Endpoint.AuthURL, - TokenURL: oauth.Endpoint.TokenURL, - UserInfoURL: oauth.Endpoint.UserInfoURL, - }, - State: oauth.State, - Origins: gateway.Cors.Origins, - JWTSecret: oauth.JWTSecret, - Provider: oauth.Provider, - } - oauthRuler := oauthRulerMiddleware(amw) - // Check if a cookie path is defined - if oauthRuler.CookiePath == "" { - oauthRuler.CookiePath = route.Path - } - // Check if a RedirectPath is defined - if oauthRuler.RedirectPath == "" { - oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, middlewarePath) - } - if oauthRuler.Provider == "" { - oauthRuler.Provider = "custom" - } - secureRouter.Use(amw.AuthMiddleware) - secureRouter.Use(CORSHandler(route.Cors)) - secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler - // Callback route - r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET") + amw := middlewares.Oauth{ + Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), + ClientID: oauth.ClientID, + ClientSecret: oauth.ClientSecret, + RedirectURL: redirectURL, + Scopes: oauth.Scopes, + Endpoint: middlewares.OauthEndpoint{ + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + UserInfoURL: oauth.Endpoint.UserInfoURL, + }, + State: oauth.State, + Origins: gateway.Cors.Origins, + JWTSecret: oauth.JWTSecret, + Provider: oauth.Provider, } - default: - if !doesExist(routeMiddleware.Type) { - logger.Error("Unknown middlewares type %s", routeMiddleware.Type) + oauthRuler := oauthRulerMiddleware(amw) + // Check if a cookie path is defined + if oauthRuler.CookiePath == "" { + oauthRuler.CookiePath = route.Path } - + // Check if a RedirectPath is defined + if oauthRuler.RedirectPath == "" { + oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, routeMiddleware.Paths[0]) + } + if oauthRuler.Provider == "" { + oauthRuler.Provider = "custom" + } + r.Use(amw.AuthMiddleware) + r.Use(CORSHandler(route.Cors)) + r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET") + } + default: + if !doesExist(routeMiddleware.Type) { + logger.Error("Unknown middlewares type %s", routeMiddleware.Type) } } + } diff --git a/internal/server.go b/internal/server.go index 340b177..a9be5ea 100644 --- a/internal/server.go +++ b/internal/server.go @@ -30,7 +30,7 @@ import ( // Start / Start starts the server func (gatewayServer GatewayServer) Start() error { logger.Info("Initializing routes...") - route := gatewayServer.Initialize() + router := gatewayServer.Initialize() logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) gatewayServer.initRedis() defer gatewayServer.closeRedis() @@ -44,8 +44,8 @@ func (gatewayServer GatewayServer) Start() error { printRoute(dynamicRoutes) } - httpServer := gatewayServer.createServer(":8080", route, nil) - httpsServer := gatewayServer.createServer(":8443", route, tlsConfig) + httpServer := gatewayServer.createServer(":8080", router, nil) + httpsServer := gatewayServer.createServer(":8443", router, tlsConfig) // Start HTTP/HTTPS servers gatewayServer.startServers(httpServer, httpsServer, listenWithTLS) diff --git a/util/helpers.go b/util/helpers.go index 3248312..a3bc174 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -157,3 +157,11 @@ func Slug(text string) string { return text } + +func AddPrefixPath(prefix string, paths []string) []string { + for i := range paths { + paths[i] = ParseURLPath(prefix + paths[i]) + } + return paths + +} From f4e5bb3be251099a2e9c821ac80bc4c8be5945ee Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Sun, 24 Nov 2024 23:09:13 +0100 Subject: [PATCH 3/6] refactor: refactoring of rate limiting --- internal/middlewares/rate_limit.go | 4 +- internal/middlewares/types.go | 6 +++ internal/routes.go | 79 ++++++++++++++++++++---------- 3 files changed, 60 insertions(+), 29 deletions(-) diff --git a/internal/middlewares/rate_limit.go b/internal/middlewares/rate_limit.go index f2f3107..be627a4 100644 --- a/internal/middlewares/rate_limit.go +++ b/internal/middlewares/rate_limit.go @@ -53,13 +53,11 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clientIP := getRealIP(r) clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID - logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID) if rl.redisBased { err := redisRateLimiter(clientID, rl.unit, rl.requests) if err != nil { logger.Error("Redis Rate limiter error: %s", err.Error()) logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) - RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) return } } else { @@ -82,8 +80,10 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) + return } } + // Proceed to the next handler if the request limit is not exceeded next.ServeHTTP(w, r) }) diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index 59d78d8..c715d61 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -33,6 +33,8 @@ type RateLimiter struct { mu sync.Mutex origins []string redisBased bool + pathBased bool + paths []string } // Client stores request count and window expiration for each client. @@ -47,6 +49,8 @@ type RateLimit struct { Origins []string Hosts []string RedisBased bool + PathBased bool + Paths []string } // NewRateLimiterWindow creates a new RateLimiter. @@ -58,6 +62,8 @@ func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter { clientMap: make(map[string]*Client), origins: rateLimit.Origins, redisBased: rateLimit.RedisBased, + pathBased: rateLimit.PathBased, + paths: rateLimit.Paths, } } diff --git a/internal/routes.go b/internal/routes.go index 8f7defd..4b736e8 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -23,6 +23,7 @@ import ( "github.com/jkaninda/goma-gateway/util" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + "slices" ) // init initializes prometheus metrics @@ -127,6 +128,31 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } // Apply middlewares to the route for _, middleware := range route.Middlewares { + // Apply common exploits to the route + // Enable common exploits + if route.BlockCommonExploits { + logger.Info("Block common exploits enabled") + router.Use(middlewares.BlockExploitsMiddleware) + } + id := string(rune(rIndex)) + if len(route.Name) != 0 { + // Use route name as ID + id = util.Slug(route.Name) + } + // Apply route rate limit + if route.RateLimit != 0 { + rateLimit := middlewares.RateLimit{ + Unit: "minute", + Id: id, // Use route index as ID + Requests: route.RateLimit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + } if len(middleware) != 0 { // Get Access middlewares if it does exist accessMiddleware, err := getMiddleware([]string{middleware}, m) @@ -143,6 +169,31 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } + // Apply Rate limit middleware + if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) { + rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule) + if err != nil { + logger.Error("Error: %v", err.Error()) + } + if rateLimitMid.RequestsPerUnit != 0 && route.RateLimit == 0 { + rateLimit := middlewares.RateLimit{ + Unit: rateLimitMid.Unit, + Id: id, // Use route index as ID + Requests: rateLimitMid.RequestsPerUnit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + PathBased: true, + Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths), + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + + } + + } + } // Get route authentication middlewares if it does exist routeMiddleware, err := getMiddleware([]string{middleware}, m) @@ -158,31 +209,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } } - // Apply common exploits to the route - // Enable common exploits - if route.BlockCommonExploits { - logger.Info("Block common exploits enabled") - router.Use(middlewares.BlockExploitsMiddleware) - } - id := string(rune(rIndex)) - if len(route.Name) != 0 { - // Use route name as ID - id = util.Slug(route.Name) - } - // Apply route rate limit - if route.RateLimit != 0 { - rateLimit := middlewares.RateLimit{ - Unit: "minute", - Id: id, // Use route index as ID - Requests: route.RateLimit, - Origins: route.Cors.Origins, - Hosts: route.Hosts, - RedisBased: redisBased, - } - limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middlewares - router.Use(limiter.RateLimitMiddleware()) - } // Apply route Cors router.Use(CORSHandler(route.Cors)) if len(route.Hosts) > 0 { @@ -208,8 +234,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } router.Use(interceptErrors.ErrorInterceptor) } - //r.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler - //r.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler + } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) From dbd09743889f64e185a62da842afedb9bd5a4fcb Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Mon, 25 Nov 2024 07:38:49 +0100 Subject: [PATCH 4/6] refactor: refactoring of auth middlewares --- internal/middlewares/access_middleware.go | 4 ++-- internal/middlewares/middleware.go | 6 +++--- internal/middlewares/oauth_middleware.go | 2 +- internal/middlewares/types.go | 6 +++++- internal/routes.go | 25 +++++++++++++---------- 5 files changed, 25 insertions(+), 18 deletions(-) diff --git a/internal/middlewares/access_middleware.go b/internal/middlewares/access_middleware.go index 5b18181..d6132e2 100644 --- a/internal/middlewares/access_middleware.go +++ b/internal/middlewares/access_middleware.go @@ -53,9 +53,9 @@ func isPathBlocked(requestPath, blockedPath string) bool { } return false } -func isProtectedPath(urlPath string, paths []string) bool { +func isProtectedPath(urlPath, prefix string, paths []string) bool { for _, path := range paths { - return isPathBlocked(urlPath, util.ParseURLPath(path)) + return isPathBlocked(urlPath, util.ParseURLPath(prefix+path)) } return false } diff --git a/internal/middlewares/middleware.go b/internal/middlewares/middleware.go index da52614..5cb66e0 100644 --- a/internal/middlewares/middleware.go +++ b/internal/middlewares/middleware.go @@ -29,7 +29,7 @@ import ( // authorization based on the result of backend's response and continue the request when the client is authorized func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if isProtectedPath(r.URL.Path, jwtAuth.Paths) { + if isProtectedPath(r.URL.Path, jwtAuth.Path, jwtAuth.Paths) { for _, header := range jwtAuth.RequiredHeaders { if r.Header.Get(header) == "" { logger.Error("Proxy error, missing %s header", header) @@ -98,16 +98,16 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { } r.URL.RawQuery = query.Encode() } - next.ServeHTTP(w, r) }) + } // AuthMiddleware checks for the Authorization header and verifies the credentials func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Trace("Basic-Auth request headers: %v", r.Header) - if isProtectedPath(r.URL.Path, basicAuth.Paths) { + if isProtectedPath(r.URL.Path, basicAuth.Path, basicAuth.Paths) { // Get the Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { diff --git a/internal/middlewares/oauth_middleware.go b/internal/middlewares/oauth_middleware.go index 74b4089..3157ea5 100644 --- a/internal/middlewares/oauth_middleware.go +++ b/internal/middlewares/oauth_middleware.go @@ -26,7 +26,7 @@ import ( func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if isProtectedPath(r.URL.Path, oauth.Paths) { + if isProtectedPath(r.URL.Path, oauth.Path, oauth.Paths) { oauthConf := oauth2Config(oauth) // Check if the user is authenticated token, err := r.Cookie("goma.oauth") diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index c715d61..826d94a 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -85,7 +85,7 @@ type ProxyResponseError struct { // JwtAuth stores JWT configuration type JwtAuth struct { - RoutePath string + Path string Paths []string AuthURL string RequiredHeaders []string @@ -109,6 +109,8 @@ type AccessListMiddleware struct { // AuthBasic contains Basic auth configuration type AuthBasic struct { + // Route path + Path string Paths []string Username string Password string @@ -129,6 +131,8 @@ type responseRecorder struct { body *bytes.Buffer } type Oauth struct { + // Route path + Path string // Route protected path Paths []string // ClientID is the application's ID. diff --git a/internal/routes.go b/internal/routes.go index 4b736e8..8536ffb 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -201,7 +201,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Error: middlewares not found logger.Error("Error: %v", err.Error()) } else { - attachAuthMiddlewares(route, routeMiddleware, gateway, r) + attachAuthMiddlewares(route, routeMiddleware, gateway, router) } } else { logger.Error("Error, middlewares path is empty") @@ -211,13 +211,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Apply route Cors router.Use(CORSHandler(route.Cors)) - if len(route.Hosts) > 0 { - for _, host := range route.Hosts { - router.Host(host).PathPrefix("").Handler(proxyRoute.ProxyHandler()) - } - } else { - router.PathPrefix("").Handler(proxyRoute.ProxyHandler()) - } if gateway.EnableMetrics { pr := metrics.PrometheusRoute{ Name: route.Name, @@ -234,6 +227,13 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } router.Use(interceptErrors.ErrorInterceptor) } + if len(route.Hosts) != 0 { + for _, host := range route.Hosts { + router.Host(host).PathPrefix("").Handler(proxyRoute.ProxyHandler()) + } + } else { + router.PathPrefix("").Handler(proxyRoute.ProxyHandler()) + } } else { logger.Error("Error, path is empty in route %s", route.Name) @@ -266,7 +266,8 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate logger.Error("Error: %s", err.Error()) } else { authBasic := middlewares.AuthBasic{ - Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), + Path: route.Path, + Paths: routeMiddleware.Paths, Username: basicAuth.Username, Password: basicAuth.Password, Headers: nil, @@ -282,7 +283,8 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate logger.Error("Error: %s", err.Error()) } else { jwtAuth := middlewares.JwtAuth{ - Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), + Path: route.Path, + Paths: routeMiddleware.Paths, AuthURL: jwt.URL, RequiredHeaders: jwt.RequiredHeaders, Headers: jwt.Headers, @@ -304,7 +306,8 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate redirectURL = oauth.RedirectURL } amw := middlewares.Oauth{ - Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), + Path: route.Path, + Paths: routeMiddleware.Paths, ClientID: oauth.ClientID, ClientSecret: oauth.ClientSecret, RedirectURL: redirectURL, From 42292bb53dca7d11c2539e4c69bda945f1043d64 Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Mon, 25 Nov 2024 07:48:00 +0100 Subject: [PATCH 5/6] refactor: to meet all go lint requirement --- internal/routes.go | 171 +++++++++++++++++++++++---------------------- 1 file changed, 87 insertions(+), 84 deletions(-) diff --git a/internal/routes.go b/internal/routes.go index 8536ffb..aa3e4f0 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -61,7 +61,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { if err != nil { logger.Fatal("Error: %v", err) } - m := dynamicMiddlewares if len(gateway.Redis.Addr) != 0 { redisBased = true } @@ -126,89 +125,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { cors: route.Cors, insecureSkipVerify: route.InsecureSkipVerify, } - // Apply middlewares to the route - for _, middleware := range route.Middlewares { - // Apply common exploits to the route - // Enable common exploits - if route.BlockCommonExploits { - logger.Info("Block common exploits enabled") - router.Use(middlewares.BlockExploitsMiddleware) - } - id := string(rune(rIndex)) - if len(route.Name) != 0 { - // Use route name as ID - id = util.Slug(route.Name) - } - // Apply route rate limit - if route.RateLimit != 0 { - rateLimit := middlewares.RateLimit{ - Unit: "minute", - Id: id, // Use route index as ID - Requests: route.RateLimit, - Origins: route.Cors.Origins, - Hosts: route.Hosts, - RedisBased: redisBased, - } - limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middlewares - router.Use(limiter.RateLimitMiddleware()) - } - if len(middleware) != 0 { - // Get Access middlewares if it does exist - accessMiddleware, err := getMiddleware([]string{middleware}, m) - if err != nil { - logger.Error("Error: %v", err.Error()) - } else { - // Apply access middlewares - if accessMiddleware.Type == AccessMiddleware { - blM := middlewares.AccessListMiddleware{ - Path: route.Path, - List: accessMiddleware.Paths, - } - r.Use(blM.AccessMiddleware) - - } - - // Apply Rate limit middleware - if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) { - rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule) - if err != nil { - logger.Error("Error: %v", err.Error()) - } - if rateLimitMid.RequestsPerUnit != 0 && route.RateLimit == 0 { - rateLimit := middlewares.RateLimit{ - Unit: rateLimitMid.Unit, - Id: id, // Use route index as ID - Requests: rateLimitMid.RequestsPerUnit, - Origins: route.Cors.Origins, - Hosts: route.Hosts, - RedisBased: redisBased, - PathBased: true, - Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths), - } - limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middlewares - router.Use(limiter.RateLimitMiddleware()) - - } - - } - - } - // Get route authentication middlewares if it does exist - routeMiddleware, err := getMiddleware([]string{middleware}, m) - if err != nil { - // Error: middlewares not found - logger.Error("Error: %v", err.Error()) - } else { - attachAuthMiddlewares(route, routeMiddleware, gateway, router) - } - } else { - logger.Error("Error, middlewares path is empty") - logger.Error("Middleware ignored") - } - } - + attachMiddlewares(rIndex, route, gateway, router) // Apply route Cors router.Use(CORSHandler(route.Cors)) if gateway.EnableMetrics { @@ -257,6 +174,92 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } +// attachMiddlewares attach middlewares to the route +func attachMiddlewares(rIndex int, route Route, gateway Gateway, router *mux.Router) { + // Apply middlewares to the route + for _, middleware := range route.Middlewares { + // Apply common exploits to the route + // Enable common exploits + if route.BlockCommonExploits { + logger.Info("Block common exploits enabled") + router.Use(middlewares.BlockExploitsMiddleware) + } + id := string(rune(rIndex)) + if len(route.Name) != 0 { + // Use route name as ID + id = util.Slug(route.Name) + } + // Apply route rate limit + if route.RateLimit != 0 { + rateLimit := middlewares.RateLimit{ + Unit: "minute", + Id: id, // Use route index as ID + Requests: route.RateLimit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + } + if len(middleware) != 0 { + // Get Access middlewares if it does exist + accessMiddleware, err := getMiddleware([]string{middleware}, dynamicMiddlewares) + if err != nil { + logger.Error("Error: %v", err.Error()) + } else { + // Apply access middlewares + if accessMiddleware.Type == AccessMiddleware { + blM := middlewares.AccessListMiddleware{ + Path: route.Path, + List: accessMiddleware.Paths, + } + router.Use(blM.AccessMiddleware) + } + + // Apply Rate limit middleware + if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) { + rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule) + if err != nil { + logger.Error("Error: %v", err.Error()) + } + if rateLimitMid.RequestsPerUnit != 0 && route.RateLimit == 0 { + rateLimit := middlewares.RateLimit{ + Unit: rateLimitMid.Unit, + Id: id, // Use route index as ID + Requests: rateLimitMid.RequestsPerUnit, + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + PathBased: true, + Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths), + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middlewares + router.Use(limiter.RateLimitMiddleware()) + + } + + } + + } + // Get route authentication middlewares if it does exist + routeMiddleware, err := getMiddleware([]string{middleware}, dynamicMiddlewares) + if err != nil { + // Error: middlewares not found + logger.Error("Error: %v", err.Error()) + } else { + attachAuthMiddlewares(route, routeMiddleware, gateway, router) + } + } else { + logger.Error("Error, middlewares path is empty") + logger.Error("Middleware ignored") + } + } + +} + func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) { // Check Authentication middleware types switch routeMiddleware.Type { From f0f5dea2a3d4e8050150a083d12e9106956fac7d Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Mon, 25 Nov 2024 08:38:03 +0100 Subject: [PATCH 6/6] docs: update rate limiting --- docs/middleware/rate-limit.md | 28 +++++++++++++++++++++++++--- internal/route_type.go | 2 +- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/docs/middleware/rate-limit.md b/docs/middleware/rate-limit.md index 5bc3f67..4e368f1 100644 --- a/docs/middleware/rate-limit.md +++ b/docs/middleware/rate-limit.md @@ -10,13 +10,35 @@ nav_order: 6 The RateLimit middleware ensures that services will receive a fair number of requests, and allows one to define what fair is. -Example of global rateLimit middleware +Example of rate limiting middleware + +```yaml +middlewares: + - name: rate-limit + type: ratelimit #or rateLimit + paths: + - /* + rule: + unit: minute # or hour + requestsPerUnit: 10 +``` + +Example of route rate limiting middleware + +```yaml +version: 0.1.7 +gateway: + routes: + - name: Example + rateLimit: 60 # peer minute +``` + +Example of global rate limiting middleware ```yaml version: 0.1.7 gateway: - # Proxy rate limit, it's In-Memory IP based rateLimit: 60 # peer minute routes: - name: Example -``` +``` \ No newline at end of file diff --git a/internal/route_type.go b/internal/route_type.go index 2781037..d263c10 100644 --- a/internal/route_type.go +++ b/internal/route_type.go @@ -42,7 +42,7 @@ type Route struct { HealthCheck RouteHealthCheck `yaml:"healthCheck"` // Cors contains the route cors headers Cors Cors `yaml:"cors"` - RateLimit int `yaml:"rateLimit"` + RateLimit int `yaml:"rateLimit,omitempty"` // DisableHostFording Disable X-forwarded header. // // [X-Forwarded-Host, X-Forwarded-For, Host, Scheme ]