diff --git a/go.mod b/go.mod index 13ce334..8f22857 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,14 @@ go 1.23.2 require ( github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be + github.com/go-redis/redis_rate/v10 v10.0.1 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/gorilla/mux v1.8.1 github.com/prometheus/client_golang v1.20.5 + github.com/redis/go-redis/v9 v9.7.0 + github.com/robfig/cron/v3 v3.0.1 github.com/spf13/cobra v1.8.1 + golang.org/x/net v0.26.0 golang.org/x/oauth2 v0.24.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -23,13 +27,13 @@ require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/robfig/cron/v3 v3.0.1 // indirect github.com/spf13/pflag v1.0.5 // indirect google.golang.org/protobuf v1.34.2 // indirect diff --git a/go.sum b/go.sum index 5c8caf9..4126899 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,10 @@ github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be/go.mod github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo= +github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -37,6 +41,8 @@ github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= +github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= @@ -50,6 +56,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= diff --git a/internal/config.go b/internal/config.go index faf15e5..6e9fa2c 100644 --- a/internal/config.go +++ b/internal/config.go @@ -18,6 +18,7 @@ limitations under the License. import ( "fmt" "github.com/jkaninda/goma-gateway/internal/middleware" + "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" "golang.org/x/oauth2" @@ -27,6 +28,7 @@ import ( "golang.org/x/oauth2/gitlab" "golang.org/x/oauth2/google" "gopkg.in/yaml.v3" + "net/http" "os" ) @@ -180,11 +182,23 @@ func initConfig(configFile string) error { Middlewares: []string{"basic-auth", "api-forbidden-paths"}, }, { - Path: "/", - Name: "Hostname and load balancing example", - Hosts: []string{"example.com", "example.localhost"}, - InterceptErrors: []int{404, 405, 500}, - RateLimit: 60, + Path: "/", + Name: "Hostname and load balancing example", + Hosts: []string{"example.com", "example.localhost"}, + //ErrorIntercept: []int{404, 405, 500}, + ErrorInterceptor: errorinterceptor.ErrorInterceptor{ + Errors: []errorinterceptor.Error{ + { + Code: http.StatusUnauthorized, + Message: http.StatusText(http.StatusUnauthorized), + }, + { + Code: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), + }, + }, + }, + RateLimit: 60, Backends: []string{ "https://example.com", "https://example2.com", diff --git a/internal/middleware/access-middleware.go b/internal/middleware/access-middleware.go index b581c73..11ddd5f 100644 --- a/internal/middleware/access-middleware.go +++ b/internal/middleware/access-middleware.go @@ -16,7 +16,6 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( - "encoding/json" "fmt" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" @@ -31,16 +30,7 @@ func (blockList AccessListMiddleware) AccessMiddleware(next http.Handler) http.H for _, block := range blockList.List { if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) { logger.Error("%s: %s access forbidden", getRealIP(r), r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusForbidden, - Message: fmt.Sprintf("You do not have permission to access this resource"), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d you do not have permission to access this resource", http.StatusForbidden)) return } } @@ -64,7 +54,7 @@ func isPathBlocked(requestPath, blockedPath string) bool { return false } -// NewRateLimiter creates a new rate limiter with the specified refill rate and token capacity +// NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter { return &TokenRateLimiter{ tokens: maxTokens, diff --git a/internal/middleware/block-common-exploits.go b/internal/middleware/block-common-exploits.go index 8a82534..b3f8f2b 100644 --- a/internal/middleware/block-common-exploits.go +++ b/internal/middleware/block-common-exploits.go @@ -18,7 +18,6 @@ package middleware import ( - "encoding/json" "fmt" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" @@ -42,36 +41,18 @@ func BlockExploitsMiddleware(next http.Handler) http.Handler { pathTraversalPattern.MatchString(r.URL.Path) || xssPattern.MatchString(r.URL.RawQuery) { logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusForbidden, - Message: fmt.Sprintf("Forbidden - Potential exploit detected"), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden)) return - } + } // Check form data (for POST requests) if r.Method == http.MethodPost { if err := r.ParseForm(); err == nil { for _, values := range r.Form { for _, value := range values { if sqlInjectionPattern.MatchString(value) || xssPattern.MatchString(value) { - logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusForbidden, - Message: fmt.Sprintf("Forbidden - Potential exploit detected"), - }) - if err != nil { - return - } + logger.Error("%s: %s %s Forbidden - Potential exploit detected", getRealIP(r), r.Method, r.URL.Path) + RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden - Potential exploit detected", http.StatusForbidden)) return } } diff --git a/internal/middleware/error-interceptor.go b/internal/middleware/error-interceptor.go index 76235c4..9c25127 100644 --- a/internal/middleware/error-interceptor.go +++ b/internal/middleware/error-interceptor.go @@ -18,10 +18,10 @@ package middleware */ import ( "bytes" - "encoding/json" "github.com/jkaninda/goma-gateway/pkg/logger" "io" "net/http" + "slices" ) func newResponseRecorder(w http.ResponseWriter) *responseRecorder { @@ -45,23 +45,12 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rec := newResponseRecorder(w) next.ServeHTTP(rec, r) + w.Header().Set("Proxied-By", "Goma Gateway") + w.Header().Del("Server") //Delete server name if canIntercept(rec.statusCode, intercept.Errors) { - logger.Debug("Backend error") - logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) - w.Header().Set("Content-Type", "application/json") - //Update Origin Cors Headers - if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - } - w.WriteHeader(rec.statusCode) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: rec.statusCode, - Message: http.StatusText(rec.statusCode), - }) - if err != nil { - return - } + logger.Debug("An error occurred in the backend, %d", rec.statusCode) + logger.Error("Backend error: %d", rec.statusCode) + RespondWithError(w, rec.statusCode, http.StatusText(rec.statusCode)) } else { // No error: write buffered response to client w.WriteHeader(rec.statusCode) @@ -75,12 +64,5 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle }) } func canIntercept(code int, errors []int) bool { - for _, er := range errors { - if er == code { - return true - } - continue - - } - return false + return slices.Contains(errors, code) } diff --git a/internal/middleware/helpers.go b/internal/middleware/helpers.go index 65b95c3..1594165 100644 --- a/internal/middleware/helpers.go +++ b/internal/middleware/helpers.go @@ -17,7 +17,13 @@ package middleware -import "net/http" +import ( + "encoding/json" + "fmt" + "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" + "net/http" + "slices" +) func getRealIP(r *http.Request) string { if ip := r.Header.Get("X-Real-IP"); ip != "" { @@ -29,12 +35,45 @@ func getRealIP(r *http.Request) string { return r.RemoteAddr } func allowedOrigin(origins []string, origin string) bool { - for _, o := range origins { - if o == origin { + return slices.Contains(origins, origin) +} +func canInterceptError(code int, errors []errorinterceptor.Error) bool { + for _, er := range errors { + if er.Code == code { return true } continue + } return false +} +func errMessage(code int, errors []errorinterceptor.Error) (string, error) { + for _, er := range errors { + if er.Code == code { + if len(er.Message) != 0 { + return er.Message, nil + } + continue + } + } + return "", fmt.Errorf("%d errors occurred", code) +} + +// RespondWithError is a helper function to handle error responses with flexible content type +func RespondWithError(w http.ResponseWriter, statusCode int, logMessage string) { + message := http.StatusText(statusCode) + if len(logMessage) != 0 { + message = logMessage + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + err := json.NewEncoder(w).Encode(ProxyResponseError{ + Success: false, + Code: statusCode, + Message: message, + }) + if err != nil { + return + } } diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index c8fa848..c2e6461 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -17,7 +17,6 @@ limitations under the License. */ import ( "encoding/base64" - "encoding/json" "github.com/jkaninda/goma-gateway/pkg/logger" "io" "net/http" @@ -38,48 +37,23 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) } - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Message: http.StatusText(http.StatusUnauthorized), - Code: http.StatusUnauthorized, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return + } } //token := r.Header.Get("Authorization") authURL, err := url.Parse(jwtAuth.AuthURL) if err != nil { logger.Error("Error parsing auth URL: %v", err) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - err = json.NewEncoder(w).Encode(ProxyResponseError{ - Message: "Internal Server Error", - Code: http.StatusInternalServerError, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } // Create a new request for /authentication authReq, err := http.NewRequest("GET", authURL.String(), nil) if err != nil { logger.Error("Proxy error creating authentication request: %v", err) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - err = json.NewEncoder(w).Encode(ProxyResponseError{ - Message: "Internal Server Error", - Code: http.StatusInternalServerError, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } logger.Trace("JWT Auth response headers: %v", authReq.Header) @@ -99,16 +73,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if err != nil || authResp.StatusCode != http.StatusOK { logger.Debug("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent()) logger.Debug("Proxy authentication error") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err = json.NewEncoder(w).Encode(ProxyResponseError{ - Message: "Unauthorized", - Code: http.StatusUnauthorized, - Success: false, - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } defer func(Body io.ReadCloser) { @@ -146,31 +111,14 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { if authHeader == "" { logger.Debug("Proxy error, missing Authorization header") w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } // Check if the Authorization header contains "Basic" scheme if !strings.HasPrefix(authHeader, "Basic ") { logger.Error("Proxy error, missing Basic Authorization header") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + return } @@ -178,16 +126,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) if err != nil { logger.Debug("Proxy error, missing Basic Authorization header") - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } @@ -195,16 +134,7 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { pair := strings.SplitN(string(payload), ":", 2) if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusUnauthorized, - Message: http.StatusText(http.StatusUnauthorized), - }) - if err != nil { - return - } + RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) return } diff --git a/internal/middleware/rate-limit.go b/internal/middleware/rate-limit.go index 125a10c..2d2942a 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middleware/rate-limit.go @@ -16,9 +16,13 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( - "encoding/json" + "errors" + "fmt" + "github.com/go-redis/redis_rate/v10" "github.com/gorilla/mux" "github.com/jkaninda/goma-gateway/pkg/logger" + "github.com/redis/go-redis/v9" + "golang.org/x/net/context" "net/http" "time" ) @@ -28,21 +32,18 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !rl.Allow() { + logger.Error("Too many requests from IP: %s %s %s", getRealIP(r), r.URL, r.UserAgent()) + //RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), basicAuth.ErrorInterceptor) + // Rate limit exceeded, return a 429 Too Many Requests response - w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusTooManyRequests, - Message: "Too many requests, API rate limit exceeded. Please try again later.", - }) + _, err := w.Write([]byte(fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests))) if err != nil { return } return } - - // Proceed to the next handler if rate limit is not exceeded + // Proceed to the next handler if requests limit is not exceeded next.ServeHTTP(w, r) }) } @@ -52,39 +53,61 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - clientID := getRealIP(r) - rl.mu.Lock() - client, exists := rl.ClientMap[clientID] - if !exists || time.Now().After(client.ExpiresAt) { - client = &Client{ - RequestCount: 0, - ExpiresAt: time.Now().Add(rl.Window), - } - rl.ClientMap[clientID] = client - } - client.RequestCount++ - rl.mu.Unlock() - - if client.RequestCount > rl.Requests { - logger.Debug("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent()) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - //Update Origin Cors Headers - if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) - } - err := json.NewEncoder(w).Encode(ProxyResponseError{ - Success: false, - Code: http.StatusTooManyRequests, - Message: "Too many requests, API rate limit exceeded. Please try again later.", - }) + clientIP := getRealIP(r) + clientID := fmt.Sprintf("%s-%s", rl.id, clientIP) // Generate client Id, ID+ route ID + logger.Debug("requests limiter: clientIP: %s, clientID: %s", clientIP, clientID) + if rl.redisBased { + err := redisRateLimiter(clientID, rl.requests) if err != nil { + logger.Error("Redis Rate limiter error: %s", err.Error()) + logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) return } - return + } else { + rl.mu.Lock() + client, exists := rl.clientMap[clientID] + if !exists || time.Now().After(client.ExpiresAt) { + client = &Client{ + RequestCount: 0, + ExpiresAt: time.Now().Add(rl.window), + } + rl.clientMap[clientID] = client + } + client.RequestCount++ + rl.mu.Unlock() + + if client.RequestCount > rl.requests { + logger.Error("Too many requests from IP: %s %s %s", clientIP, r.URL, r.UserAgent()) + //Update Origin Cors Headers + if allowedOrigin(rl.origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } + RespondWithError(w, http.StatusTooManyRequests, fmt.Sprintf("%d Too many requests, API requests limit exceeded. Please try again later", http.StatusTooManyRequests)) + } } - // Proceed to the next handler if rate limit is not exceeded + // Proceed to the next handler if requests limit is not exceeded next.ServeHTTP(w, r) }) } } +func redisRateLimiter(clientIP string, rate int) error { + ctx := context.Background() + + res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate)) + if err != nil { + return err + } + if res.Remaining == 0 { + return errors.New("requests limit exceeded") + } + + return nil +} +func InitRedis(addr, password string) { + Rdb = redis.NewClient(&redis.Options{ + Addr: addr, + Password: password, + }) + limiter = redis_rate.NewLimiter(Rdb) +} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index 54bebee..ad3d613 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -24,13 +24,16 @@ import ( "time" ) -// RateLimiter defines rate limit properties. +// RateLimiter defines requests limit properties. type RateLimiter struct { - Requests int - Window time.Duration - ClientMap map[string]*Client - mu sync.Mutex - Origins []string + requests int + id string + window time.Duration + clientMap map[string]*Client + mu sync.Mutex + origins []string + hosts []string + redisBased bool } // Client stores request count and window expiration for each client. @@ -38,14 +41,24 @@ type Client struct { RequestCount int ExpiresAt time.Time } +type RateLimit struct { + Id string + Requests int + Window time.Duration + Origins []string + Hosts []string + RedisBased bool +} // NewRateLimiterWindow creates a new RateLimiter. -func NewRateLimiterWindow(requests int, window time.Duration, origin []string) *RateLimiter { +func (rateLimit RateLimit) NewRateLimiterWindow() *RateLimiter { return &RateLimiter{ - Requests: requests, - Window: window, - ClientMap: make(map[string]*Client), - Origins: origin, + id: rateLimit.Id, + requests: rateLimit.Requests, + window: rateLimit.Window, + clientMap: make(map[string]*Client), + origins: rateLimit.Origins, + redisBased: rateLimit.RedisBased, } } diff --git a/internal/middleware/var.go b/internal/middleware/var.go index 4e8502c..5eb3266 100644 --- a/internal/middleware/var.go +++ b/internal/middleware/var.go @@ -17,7 +17,17 @@ package middleware +import ( + "github.com/go-redis/redis_rate/v10" + "github.com/redis/go-redis/v9" +) + // sqlPatterns contains SQL injections patters const sqlPatterns = `(?i)(union|select|drop|insert|delete|update|create|alter|exec|;|--)` const traversalPatterns = `\.\./` const xssPatterns = `(?i) 0 { if !slices.Contains(proxyRoute.methods, r.Method) { logger.Error("%s Method is not allowed", r.Method) - w.WriteHeader(http.StatusMethodNotAllowed) - _, err := w.Write([]byte(fmt.Sprintf("%s method is not allowed", r.Method))) - if err != nil { - return - } + middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method)) return } } @@ -63,11 +60,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { targetURL, err := url.Parse(proxyRoute.destination) if err != nil { logger.Error("Error parsing backend URL: %s", err) - w.WriteHeader(http.StatusInternalServerError) - _, err := w.Write([]byte("Internal Server Error")) - if err != nil { - return - } + middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) return } r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) diff --git a/internal/route.go b/internal/route.go index c735aca..a6fcbad 100644 --- a/internal/route.go +++ b/internal/route.go @@ -35,6 +35,10 @@ func init() { func (gatewayServer GatewayServer) Initialize() *mux.Router { gateway := gatewayServer.gateway middlewares := gatewayServer.middlewares + redisBased := false + if len(gateway.Redis.Addr) != 0 { + redisBased = true + } //Routes background healthcheck routesHealthCheck(gateway.Routes) r := mux.NewRouter() @@ -60,13 +64,21 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { logger.Info("Block common exploits enabled") r.Use(middleware.BlockExploitsMiddleware) } - if gateway.RateLimit != 0 { - //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(gateway.RateLimit, time.Minute, gateway.Cors.Origins) // requests per minute + if gateway.RateLimit > 0 { // Add rate limit middleware to all routes, if defined + rateLimit := middleware.RateLimit{ + Id: "global_rate", //Generate a unique ID for routes + Requests: gateway.RateLimit, + Window: time.Minute, // requests per minute + Origins: gateway.Cors.Origins, + Hosts: []string{}, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middleware r.Use(limiter.RateLimitMiddleware()) } - for _, route := range gateway.Routes { + for rIndex, route := range gateway.Routes { if route.Path != "" { if route.Destination == "" && len(route.Backends) == 0 { logger.Fatal("Route %s : destination or backends should not be empty", route.Name) @@ -224,9 +236,16 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { } // Apply route rate limit if route.RateLimit > 0 { - //rateLimiter := middleware.NewRateLimiter(gateway.RateLimit, time.Minute) - limiter := middleware.NewRateLimiterWindow(route.RateLimit, time.Minute, route.Cors.Origins) // requests per minute - // Add rate limit middleware to all routes, if defined + rateLimit := middleware.RateLimit{ + Id: string(rune(rIndex)), // Use route index as ID + Requests: route.RateLimit, + Window: time.Minute, // requests per minute + Origins: route.Cors.Origins, + Hosts: route.Hosts, + RedisBased: redisBased, + } + limiter := rateLimit.NewRateLimiterWindow() + // Add rate limit middleware router.Use(limiter.RateLimitMiddleware()) } // Apply route Cors @@ -246,6 +265,11 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { // Prometheus endpoint router.Use(pr.prometheusMiddleware) } + // Apply route Error interceptor middleware + interceptErrors := middleware.InterceptErrors{ + Origins: gateway.Cors.Origins, + } + router.Use(interceptErrors.ErrorInterceptor) } else { logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Route path ignored: %s", route.Path) diff --git a/internal/server.go b/internal/server.go index 6b43c7b..e71d2ca 100644 --- a/internal/server.go +++ b/internal/server.go @@ -19,7 +19,9 @@ import ( "context" "crypto/tls" "fmt" + "github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/pkg/logger" + "github.com/redis/go-redis/v9" "net/http" "os" "sync" @@ -30,8 +32,19 @@ import ( func (gatewayServer GatewayServer) Start(ctx context.Context) error { logger.Info("Initializing routes...") route := gatewayServer.Initialize() + gateway := gatewayServer.gateway logger.Debug("Routes count=%d Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) logger.Info("Initializing routes...done") + if len(gateway.Redis.Addr) != 0 { + middleware.InitRedis(gateway.Redis.Addr, gateway.Redis.Password) + defer func(Rdb *redis.Client) { + err := Rdb.Close() + if err != nil { + logger.Error("Redis connection closed with error: %v", err) + } + }(middleware.Rdb) + } + tlsConfig := &tls.Config{} var listenWithTLS = false if cert := gatewayServer.gateway.SSLCertFile; cert != "" && gatewayServer.gateway.SSLKeyFile != "" { diff --git a/internal/types.go b/internal/types.go index 143b1ad..c3c85e2 100644 --- a/internal/types.go +++ b/internal/types.go @@ -20,6 +20,7 @@ package pkg import ( "context" "github.com/gorilla/mux" + errorinterceptor "github.com/jkaninda/goma-gateway/pkg/errorinterceptor" "time" ) @@ -161,12 +162,12 @@ type Route struct { // // It will not match the backend route DisableHostFording bool `yaml:"disableHostFording"` - // InterceptErrors intercepts backend errors based on the status codes - // - // Eg: [ 403, 405, 500 ] - InterceptErrors []int `yaml:"interceptErrors"` + // BlockCommonExploits enable, disable block common exploits BlockCommonExploits bool `yaml:"blockCommonExploits"` + // ErrorInterceptor intercepts backend errors based on the status codes and custom message + // + ErrorInterceptor errorinterceptor.ErrorInterceptor `yaml:"errorInterceptor"` // Middlewares Defines route middleware from Middleware names Middlewares []string `yaml:"middlewares"` } @@ -177,6 +178,8 @@ type Gateway struct { SSLCertFile string `yaml:"sslCertFile" env:"GOMA_SSL_CERT_FILE, overwrite"` // SSLKeyFile SSL Private key file SSLKeyFile string `yaml:"sslKeyFile" env:"GOMA_SSL_KEY_FILE, overwrite"` + // Redis contains redis database details + Redis Redis `yaml:"redis"` // WriteTimeout defines proxy write timeout WriteTimeout int `yaml:"writeTimeout" env:"GOMA_WRITE_TIMEOUT, overwrite"` // ReadTimeout defines proxy read timeout @@ -203,6 +206,7 @@ type Gateway struct { InterceptErrors []int `yaml:"interceptErrors"` // Cors holds proxy global cors Cors Cors `yaml:"cors"` + // Routes holds proxy routes Routes []Route `yaml:"routes"` } @@ -285,3 +289,8 @@ type Health struct { Interval string HealthyStatuses []int } +type Redis struct { + // Addr redis hostname and post number : + Addr string `yaml:"addr"` + Password string `yaml:"password"` +} diff --git a/internal/var.go b/internal/var.go index baf9516..b57bb0d 100644 --- a/internal/var.go +++ b/internal/var.go @@ -9,6 +9,10 @@ const AccessMiddleware = "access" // access middleware const BasicAuth = "basic" // basic authentication middleware const JWTAuth = "jwt" // JWT authentication middleware const OAuth = "oauth" // OAuth authentication middleware +const applicationJson = "application/json" +const textPlain = "text/plain" +const applicationXml = "application/xml" + // Round-robin counter var counter uint32 diff --git a/pkg/errorinterceptor/types.go b/pkg/errorinterceptor/types.go new file mode 100644 index 0000000..23b0ffc --- /dev/null +++ b/pkg/errorinterceptor/types.go @@ -0,0 +1,31 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package errorinterceptor + +type ErrorInterceptor struct { + // ContentType error response content type, application/json, plain/text + //ContentType string `yaml:"contentType"` + //Errors contains error status code and custom message + Errors []Error `yaml:"errors"` +} +type Error struct { + // Code HTTP status code + Code int `yaml:"code"` + // Message Error custom response message + Message string `yaml:"message"` +} diff --git a/pkg/errorinterceptor/var.go b/pkg/errorinterceptor/var.go new file mode 100644 index 0000000..d924846 --- /dev/null +++ b/pkg/errorinterceptor/var.go @@ -0,0 +1,22 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package errorinterceptor + +const TextPlain = "text/plain" +const ApplicationXml = "application/xml" +const ApplicationJson = "application/json"