From 42abf56473a8b833be1bb69fdbf287c037a1350a Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Thu, 14 Nov 2024 00:26:21 +0100 Subject: [PATCH] refactor: improve error interceptor --- internal/config.go | 25 ++++-- internal/middleware/access-middleware.go | 12 +-- internal/middleware/block-common-exploits.go | 34 +++---- internal/middleware/error-interceptor.go | 12 +-- internal/middleware/helpers.go | 57 +++++++++++- internal/middleware/middleware.go | 90 +++---------------- internal/middleware/rate-limit.go | 26 ++---- .../middleware/route_error_interceptor.go | 58 ++++++++++++ internal/middleware/types.go | 48 +++++----- internal/proxy.go | 13 +-- internal/route.go | 14 ++- internal/types.go | 23 ++++- internal/var.go | 4 + pkg/error-interceptor/types.go | 31 +++++++ pkg/error-interceptor/var.go | 22 +++++ 15 files changed, 284 insertions(+), 185 deletions(-) create mode 100644 internal/middleware/route_error_interceptor.go create mode 100644 pkg/error-interceptor/types.go create mode 100644 pkg/error-interceptor/var.go diff --git a/internal/config.go b/internal/config.go index faf15e5..72b3896 100644 --- a/internal/config.go +++ b/internal/config.go @@ -18,6 +18,7 @@ limitations under the License. import ( "fmt" "github.com/jkaninda/goma-gateway/internal/middleware" + error_interceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" "golang.org/x/oauth2" @@ -27,6 +28,7 @@ import ( "golang.org/x/oauth2/gitlab" "golang.org/x/oauth2/google" "gopkg.in/yaml.v3" + "net/http" "os" ) @@ -180,11 +182,24 @@ func initConfig(configFile string) error { Middlewares: []string{"basic-auth", "api-forbidden-paths"}, }, { - Path: "/", - Name: "Hostname and load balancing example", - Hosts: []string{"example.com", "example.localhost"}, - InterceptErrors: []int{404, 405, 500}, - RateLimit: 60, + Path: "/", + Name: "Hostname and load balancing example", + Hosts: []string{"example.com", "example.localhost"}, + //InterceptErrors: []int{404, 405, 500}, + ErrorInterceptor: error_interceptor.ErrorInterceptor{ + ContentType: applicationJson, + Errors: []error_interceptor.Error{ + { + Code: http.StatusUnauthorized, + Message: http.StatusText(http.StatusUnauthorized), + }, + { + Code: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), + }, + }, + }, + RateLimit: 60, Backends: []string{ "https://example.com", "https://example2.com", diff --git a/internal/middleware/access-middleware.go b/internal/middleware/access-middleware.go index b581c73..4451c4b 100644 --- a/internal/middleware/access-middleware.go +++ b/internal/middleware/access-middleware.go @@ -16,7 +16,6 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( - "encoding/json" "fmt" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" @@ -31,16 +30,7 @@ func (blockList AccessListMiddleware) AccessMiddleware(next http.Handler) http.H for _, block := range blockList.List { if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) { logger.Error("%s: %s access forbidden", getRealIP(r), r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusForbidden, - Message: fmt.Sprintf("You do not have permission to access this resource"), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource", http.StatusForbidden), blockList.ErrorInterceptor) return } } diff --git a/internal/middleware/block-common-exploits.go b/internal/middleware/block-common-exploits.go index 8a82534..9d87221 100644 --- a/internal/middleware/block-common-exploits.go +++ b/internal/middleware/block-common-exploits.go @@ -18,15 +18,19 @@ package middleware import ( - "encoding/json" "fmt" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" "regexp" ) +type BlockCommon struct { + ErrorInterceptor errorinterceptor.ErrorInterceptor +} + // BlockExploitsMiddleware Middleware to block common exploits -func BlockExploitsMiddleware(next http.Handler) http.Handler { +func (blockCommon BlockCommon) BlockExploitsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Patterns to detect SQL injection attempts sqlInjectionPattern := regexp.MustCompile(sqlPatterns) @@ -42,36 +46,18 @@ func BlockExploitsMiddleware(next http.Handler) http.Handler { pathTraversalPattern.MatchString(r.URL.Path) || xssPattern.MatchString(r.URL.RawQuery) { logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusForbidden, - Message: fmt.Sprintf("Forbidden - Potential exploit detected"), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden), blockCommon.ErrorInterceptor) return - } + } // Check form data (for POST requests) if r.Method == http.MethodPost { if err := r.ParseForm(); err == nil { for _, values := range r.Form { for _, value := range values { if sqlInjectionPattern.MatchString(value) || xssPattern.MatchString(value) { - logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusForbidden, - Message: fmt.Sprintf("Forbidden - Potential exploit detected"), - }) - if err != nil { - return - } + logger.Error("%s: %s %s Forbidden - Potential exploit detected", getRealIP(r), r.Method, r.URL.Path) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden), blockCommon.ErrorInterceptor) return } } diff --git a/internal/middleware/error-interceptor.go b/internal/middleware/error-interceptor.go index 76235c4..c71d20e 100644 --- a/internal/middleware/error-interceptor.go +++ b/internal/middleware/error-interceptor.go @@ -22,6 +22,7 @@ import ( "github.com/jkaninda/goma-gateway/pkg/logger" "io" "net/http" + "slices" ) func newResponseRecorder(w http.ResponseWriter) *responseRecorder { @@ -62,6 +63,7 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle if err != nil { return } + return } else { // No error: write buffered response to client w.WriteHeader(rec.statusCode) @@ -69,18 +71,12 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle if err != nil { return } + return } }) } func canIntercept(code int, errors []int) bool { - for _, er := range errors { - if er == code { - return true - } - continue - - } - return false + return slices.Contains(errors, code) } diff --git a/internal/middleware/helpers.go b/internal/middleware/helpers.go index 65b95c3..fcd0f35 100644 --- a/internal/middleware/helpers.go +++ b/internal/middleware/helpers.go @@ -17,7 +17,12 @@ package middleware -import "net/http" +import ( + "encoding/json" + "fmt" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + "net/http" +) func getRealIP(r *http.Request) string { if ip := r.Header.Get("X-Real-IP"); ip != "" { @@ -38,3 +43,53 @@ func allowedOrigin(origins []string, origin string) bool { return false } +func canInterceptError(code int, errors []errorinterceptor.Error) bool { + for _, er := range errors { + if er.Code == code { + return true + } + continue + + } + return false +} +func errMessage(code int, errors []errorinterceptor.Error) (string, error) { + for _, er := range errors { + if er.Code == code { + if len(er.Message) != 0 { + return er.Message, nil + } + continue + } + } + return "", fmt.Errorf("%d errors occurred", code) +} + +// RespondWithError is a helper function to handle error responses with flexible content type +func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string, errorIntercept errorinterceptor.ErrorInterceptor) { + message, err := errMessage(statusCode, errorIntercept.Errors) + if err != nil { + message = logMessage + } + if errorIntercept.ContentType == errorinterceptor.ApplicationJson { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + err := json.NewEncoder(w).Encode(ProxyResponseError{ + Success: false, + Code: statusCode, + Message: message, + }) + if err != nil { + return + } + return + } else { + w.Header().Set("Content-Type", "plain/text;charset=utf-8") + w.WriteHeader(statusCode) + _, err2 := w.Write([]byte(message)) + if err2 != nil { + return + } + return + } +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index c8fa848..35c2c67 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -17,7 +17,6 @@ limitations under the License. */ import ( "encoding/base64" - "encoding/json" "github.com/jkaninda/goma-gateway/pkg/logger" "io" "net/http" @@ -38,48 +37,23 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Message: http.StatusText(http.StatusUnauthorized), - Code: http.StatusUnauthorized, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), jwtAuth.ErrorInterceptor) return + } } //token := r.Header.Get("Authorization") authURL, err := url.Parse(jwtAuth.AuthURL) if err != nil { logger.Error("Error parsing auth URL: %v", err) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - err = json.NewEncoder(w).Encode(ProxyResponseError{ - Message: "Internal Server Error", - Code: http.StatusInternalServerError, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), jwtAuth.ErrorInterceptor) return } // Create a new request for /authentication authReq, err := http.NewRequest("GET", authURL.String(), nil) if err != nil { logger.Error("Proxy error creating authentication request: %v", err) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - err = json.NewEncoder(w).Encode(ProxyResponseError{ - Message: "Internal Server Error", - Code: http.StatusInternalServerError, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), jwtAuth.ErrorInterceptor) return } logger.Trace("JWT Auth response headers: %v", authReq.Header) @@ -99,16 +73,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if err != nil || authResp.StatusCode != http.StatusOK { logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) logger.Debug("Proxy authentication error") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err = json.NewEncoder(w).Encode(ProxyResponseError{ - Message: "Unauthorized", - Code: http.StatusUnauthorized, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), jwtAuth.ErrorInterceptor) return } defer func(Body io.ReadCloser) { @@ -146,31 +111,14 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { if authHeader == "" { logger.Debug("Proxy error, missing Authorization header") w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) return } // Check if the Authorization header contains "Basic" scheme if !strings.HasPrefix(authHeader, "Basic ") { logger.Error("Proxy error, missing Basic Authorization header") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + return } @@ -178,16 +126,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) if err != nil { logger.Debug("Proxy error, missing Basic Authorization header") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) return } @@ -195,16 +134,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { pair := strings.SplitN(string(payload), ":", 2) if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) return } diff --git a/internal/middleware/rate-limit.go b/internal/middleware/rate-limit.go index 125a10c..050ae76 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middleware/rate-limit.go @@ -16,7 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( - "encoding/json" + "fmt" "github.com/gorilla/mux" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" @@ -28,20 +28,17 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !rl.Allow() { + logger.Error("Too many requests from IP: %s %s %s", getRealIP(r), r.URL, r.UserAgent()) + //RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + // Rate limit exceeded, return a 429 Too Many Requests response - w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusTooManyRequests, - Message: "Too many requests, API rate limit exceeded. Please try again later.", - }) + _, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests))) if err != nil { return } return } - // Proceed to the next handler if rate limit is not exceeded next.ServeHTTP(w, r) }) @@ -66,21 +63,12 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { rl.mu.Unlock() if client.RequestCount > rl.Requests { - logger.Debug("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent()) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) + logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent()) //Update Origin Cors Headers if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusTooManyRequests, - Message: "Too many requests, API rate limit exceeded. Please try again later.", - }) - if err != nil { - return - } + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API rate limit exceeded. Please try again later", http.StatusTooManyRequests), rl.ErrorInterceptor) return } // Proceed to the next handler if rate limit is not exceeded diff --git a/internal/middleware/route_error_interceptor.go b/internal/middleware/route_error_interceptor.go new file mode 100644 index 0000000..ec8c1e2 --- /dev/null +++ b/internal/middleware/route_error_interceptor.go @@ -0,0 +1,58 @@ +package middleware + +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +import ( + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" + "github.com/jkaninda/goma-gateway/pkg/logger" + "io" + "net/http" +) + +// RouteErrorInterceptor contains backend status code errors to intercept +type RouteErrorInterceptor struct { + Origins []string + ErrorInterceptor errorinterceptor.ErrorInterceptor +} + +// RouteErrorInterceptor Middleware intercepts backend route errors +func (intercept RouteErrorInterceptor) RouteErrorInterceptor(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rec := newResponseRecorder(w) + next.ServeHTTP(rec, r) + if canInterceptError(rec.statusCode, intercept.ErrorInterceptor.Errors) { + logger.Debug("Backend error") + logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) + //Update Origin Cors Headers + if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } + RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode), intercept.ErrorInterceptor) + return + } else { + // No error: write buffered response to client + w.WriteHeader(rec.statusCode) + _, err := io.Copy(w, rec.body) + if err != nil { + return + } + return + + } + + }) +} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index 54bebee..15f695b 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -19,6 +19,7 @@ package middleware import ( "bytes" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" "net/http" "sync" "time" @@ -26,11 +27,12 @@ import ( // RateLimiter defines rate limit properties. type RateLimiter struct { - Requests int - Window time.Duration - ClientMap map[string]*Client - mu sync.Mutex - Origins []string + Requests int + Window time.Duration + ClientMap map[string]*Client + mu sync.Mutex + Origins []string + ErrorInterceptor errorinterceptor.ErrorInterceptor } // Client stores request count and window expiration for each client. @@ -67,11 +69,12 @@ type ProxyResponseError struct { // JwtAuth stores JWT configuration type JwtAuth struct { - AuthURL string - RequiredHeaders []string - Headers map[string]string - Params map[string]string - Origins []string + AuthURL string + RequiredHeaders []string + Headers map[string]string + Params map[string]string + Origins []string + ErrorInterceptor errorinterceptor.ErrorInterceptor } // AuthenticationMiddleware Define struct @@ -82,17 +85,19 @@ type AuthenticationMiddleware struct { Params map[string]string } type AccessListMiddleware struct { - Path string - Destination string - List []string + Path string + Destination string + List []string + ErrorInterceptor errorinterceptor.ErrorInterceptor } // AuthBasic contains Basic auth configuration type AuthBasic struct { - Username string - Password string - Headers map[string]string - Params map[string]string + Username string + Password string + Headers map[string]string + Params map[string]string + ErrorInterceptor errorinterceptor.ErrorInterceptor } // InterceptErrors contains backend status code errors to intercept @@ -120,10 +125,11 @@ type Oauth struct { // Scope specifies optional requested permissions. Scopes []string // contains filtered or unexported fields - State string - Origins []string - JWTSecret string - Provider string + State string + Origins []string + JWTSecret string + Provider string + ErrorInterceptor errorinterceptor.ErrorInterceptor } type OauthEndpoint struct { AuthURL string diff --git a/internal/proxy.go b/internal/proxy.go index 6478078..bc6b90e 100644 --- a/internal/proxy.go +++ b/internal/proxy.go @@ -17,6 +17,7 @@ limitations under the License. */ import ( "fmt" + "github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" "net/http/httputil" @@ -36,11 +37,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { if len(proxyRoute.methods) > 0 { if !slices.Contains(proxyRoute.methods, r.Method) { logger.Error("%s Method is not allowed", r.Method) - w.WriteHeader(http.StatusMethodNotAllowed) - _, err := w.Write([]byte(fmt.Sprintf("%s method is not allowed", r.Method))) - if err != nil { - return - } + middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method), proxyRoute.ErrorInterceptor) return } } @@ -63,11 +60,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { targetURL, err := url.Parse(proxyRoute.destination) if err != nil { logger.Error("Error parsing backend URL: %s", err) - w.WriteHeader(http.StatusInternalServerError) - _, err := w.Write([]byte("Internal Server Error")) - if err != nil { - return - } + middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), proxyRoute.ErrorInterceptor) return } r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) diff --git a/internal/route.go b/internal/route.go index c735aca..a0299b9 100644 --- a/internal/route.go +++ b/internal/route.go @@ -58,7 +58,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Enable common exploits if gateway.BlockCommonExploits { logger.Info("Block common exploits enabled") - r.Use(middleware.BlockExploitsMiddleware) + blockCommon := middleware.BlockCommon{} + r.Use(blockCommon.BlockExploitsMiddleware) } if gateway.RateLimit != 0 { //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) @@ -219,8 +220,11 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Apply common exploits to the route // Enable common exploits if route.BlockCommonExploits { + blockCommon := middleware.BlockCommon{ + ErrorInterceptor: route.ErrorInterceptor, + } logger.Info("Block common exploits enabled") - router.Use(middleware.BlockExploitsMiddleware) + router.Use(blockCommon.BlockExploitsMiddleware) } // Apply route rate limit if route.RateLimit > 0 { @@ -246,6 +250,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Prometheus endpoint router.Use(pr.prometheusMiddleware) } + // Apply route Error interceptor middleware + interceptErrors := middleware.RouteErrorInterceptor{ + Origins: gateway.Cors.Origins, + ErrorInterceptor: route.ErrorInterceptor, + } + r.Use(interceptErrors.RouteErrorInterceptor) } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) diff --git a/internal/types.go b/internal/types.go index 143b1ad..b1b600f 100644 --- a/internal/types.go +++ b/internal/types.go @@ -20,6 +20,7 @@ package pkg import ( "context" "github.com/gorilla/mux" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/error-interceptor" "time" ) @@ -161,12 +162,12 @@ type Route struct { // // It will not match the backend route DisableHostFording bool `yaml:"disableHostFording"` - // InterceptErrors intercepts backend errors based on the status codes - // - // Eg: [ 403, 405, 500 ] - InterceptErrors []int `yaml:"interceptErrors"` + // BlockCommonExploits enable, disable block common exploits BlockCommonExploits bool `yaml:"blockCommonExploits"` + // ErrorInterceptor intercepts backend errors based on the status codes and custom message + // + ErrorInterceptor errorinterceptor.ErrorInterceptor `yaml:"errorInterceptor"` // Middlewares Defines route middleware from Middleware names Middlewares []string `yaml:"middlewares"` } @@ -242,6 +243,7 @@ type ProxyRoute struct { methods []string cors Cors disableHostFording bool + ErrorInterceptor errorinterceptor.ErrorInterceptor } type RoutePath struct { route Route @@ -285,3 +287,16 @@ type Health struct { Interval string HealthyStatuses []int } + +//type ErrorInterceptor struct { +// // ContentType error response content type, application/json, plain/text +// ContentType string `yaml:"contentType"` +// //Errors contains error status code and custom message +// Errors []ErrorInterceptor `yaml:"errors"` +//} +//type ErrorInterceptor struct { +// // Code HTTP status code +// Code int `yaml:"code"` +// // Message custom message +// Message string `yaml:"message"` +//} diff --git a/internal/var.go b/internal/var.go index baf9516..b57bb0d 100644 --- a/internal/var.go +++ b/internal/var.go @@ -9,6 +9,10 @@ const AccessMiddleware = "access" // access middleware const BasicAuth = "basic" // basic authentication middleware const JWTAuth = "jwt" // JWT authentication middleware const OAuth = "oauth" // OAuth authentication middleware +const applicationJson = "application/json" +const textPlain = "text/plain" +const applicationXml = "application/xml" + // Round-robin counter var counter uint32 diff --git a/pkg/error-interceptor/types.go b/pkg/error-interceptor/types.go new file mode 100644 index 0000000..95d3bf4 --- /dev/null +++ b/pkg/error-interceptor/types.go @@ -0,0 +1,31 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package error_interceptor + +type ErrorInterceptor struct { + // ContentType error response content type, application/json, plain/text + ContentType string `yaml:"contentType"` + //Errors contains error status code and custom message + Errors []Error `yaml:"errors"` +} +type Error struct { + // Code HTTP status code + Code int `yaml:"code"` + // Message custom message + Message string `yaml:"message"` +} diff --git a/pkg/error-interceptor/var.go b/pkg/error-interceptor/var.go new file mode 100644 index 0000000..267abe7 --- /dev/null +++ b/pkg/error-interceptor/var.go @@ -0,0 +1,22 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package error_interceptor + +const TextPlain = "text/plain" +const ApplicationXml = "application/xml" +const ApplicationJson = "application/json"