fix: backend error interceptor

This commit is contained in:
Jonas Kaninda
2024-11-14 14:41:10 +01:00
parent 5951616153
commit 949667cc60
6 changed files with 12 additions and 26 deletions

View File

@@ -30,7 +30,7 @@ func (blockList AccessListMiddleware) AccessMiddleware(next http.Handler) http.H
for _, block := range blockList.List {
if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) {
logger.Error("%s: %s access forbidden", getRealIP(r), r.URL.Path)
RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource"))
RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource", http.StatusForbidden))
return
}
}

View File

@@ -19,18 +19,13 @@ package middleware
import (
"fmt"
errorinterceptor "github.com/jkaninda/goma-gateway/pkg/errorinterceptor"
"github.com/jkaninda/goma-gateway/pkg/logger"
"net/http"
"regexp"
)
type BlockCommon struct {
ErrorInterceptor errorinterceptor.ErrorInterceptor
}
// BlockExploitsMiddleware Middleware to block common exploits
func (blockCommon BlockCommon) BlockExploitsMiddleware(next http.Handler) http.Handler {
func BlockExploitsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Patterns to detect SQL injection attempts
sqlInjectionPattern := regexp.MustCompile(sqlPatterns)

View File

@@ -45,15 +45,12 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rec := newResponseRecorder(w)
next.ServeHTTP(rec, r)
w.Header().Set("Proxied-By", "Goma Gateway")
w.Header().Del("Server") //Delete server name
if canIntercept(rec.statusCode, intercept.Errors) {
logger.Debug("Backend error")
logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode)
//Update Origin Cors Headers
if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
}
logger.Debug("An error occurred in the backend, %d", rec.statusCode)
logger.Error("Backend error: %d", rec.statusCode)
RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode))
return
} else {
// No error: write buffered response to client
w.WriteHeader(rec.statusCode)
@@ -61,7 +58,6 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle
if err != nil {
return
}
return
}

View File

@@ -75,6 +75,5 @@ func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string)
if err != nil {
return
}
return
}

View File

@@ -27,7 +27,7 @@ import (
// RateLimiter defines requests limit properties.
type RateLimiter struct {
requests int
id int
id string
window time.Duration
clientMap map[string]*Client
mu sync.Mutex
@@ -42,7 +42,7 @@ type Client struct {
ExpiresAt time.Time
}
type RateLimit struct {
Id int
Id string
Requests int
Window time.Duration
Origins []string

View File

@@ -62,13 +62,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Enable common exploits
if gateway.BlockCommonExploits {
logger.Info("Block common exploits enabled")
blockCommon := middleware.BlockCommon{}
r.Use(blockCommon.BlockExploitsMiddleware)
r.Use(middleware.BlockExploitsMiddleware)
}
if gateway.RateLimit > 0 {
// Add rate limit middleware to all routes, if defined
rateLimit := middleware.RateLimit{
Id: 1,
Id: "global_rate", //Generate a unique ID for routes
Requests: gateway.RateLimit,
Window: time.Minute, // requests per minute
Origins: gateway.Cors.Origins,
@@ -232,16 +231,13 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Apply common exploits to the route
// Enable common exploits
if route.BlockCommonExploits {
blockCommon := middleware.BlockCommon{
ErrorInterceptor: route.ErrorInterceptor,
}
logger.Info("Block common exploits enabled")
router.Use(blockCommon.BlockExploitsMiddleware)
router.Use(middleware.BlockExploitsMiddleware)
}
// Apply route rate limit
if route.RateLimit > 0 {
rateLimit := middleware.RateLimit{
Id: rIndex,
Id: string(rune(rIndex)), // Use route index as ID
Requests: route.RateLimit,
Window: time.Minute, // requests per minute
Origins: route.Cors.Origins,