diff --git a/internal/middleware/error-interceptor.go b/internal/middleware/error-interceptor.go index 70a1dbe..d83f37b 100644 --- a/internal/middleware/error-interceptor.go +++ b/internal/middleware/error-interceptor.go @@ -49,6 +49,10 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle logger.Error("Backend error") logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) w.Header().Set("Content-Type", "application/json") + //Update Origin Cors Headers + if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } w.WriteHeader(rec.statusCode) err := json.NewEncoder(w).Encode(ProxyResponseError{ Success: false, diff --git a/internal/middleware/helpers.go b/internal/middleware/helpers.go new file mode 100644 index 0000000..65b95c3 --- /dev/null +++ b/internal/middleware/helpers.go @@ -0,0 +1,40 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package middleware + +import "net/http" + +func getRealIP(r *http.Request) string { + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + return ip + } + return r.RemoteAddr +} +func allowedOrigin(origins []string, origin string) bool { + for _, o := range origins { + if o == origin { + return true + } + continue + } + return false + +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index ff68ebe..a6c44f8 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -34,6 +34,10 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { if r.Header.Get(header) == "" { logger.Error("Proxy error, missing %s header", header) w.Header().Set("Content-Type", "application/json") + //Update Origin Cors Headers + if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } w.WriteHeader(http.StatusUnauthorized) err := json.NewEncoder(w).Encode(ProxyResponseError{ Message: http.StatusText(http.StatusUnauthorized), diff --git a/internal/middleware/rate-limit.go b/internal/middleware/rate-limit.go index c4d7fad..200b21b 100644 --- a/internal/middleware/rate-limit.go +++ b/internal/middleware/rate-limit.go @@ -69,6 +69,10 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent()) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) + //Update Origin Cors Headers + if allowedOrigin(rl.Origins, r.Header.Get("Origin")) { + w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + } err := json.NewEncoder(w).Encode(ProxyResponseError{ Success: false, Code: http.StatusTooManyRequests, @@ -84,12 +88,3 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { }) } } -func getRealIP(r *http.Request) string { - if ip := r.Header.Get("X-Real-IP"); ip != "" { - return ip - } - if ip := r.Header.Get("X-Forwarded-For"); ip != "" { - return ip - } - return r.RemoteAddr -} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index 7d2695d..86a5b29 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -30,6 +30,7 @@ type RateLimiter struct { Window time.Duration ClientMap map[string]*Client mu sync.Mutex + Origins []string } // Client stores request count and window expiration for each client. @@ -39,11 +40,12 @@ type Client struct { } // NewRateLimiterWindow creates a new RateLimiter. -func NewRateLimiterWindow(requests int, window time.Duration) *RateLimiter { +func NewRateLimiterWindow(requests int, window time.Duration, origin []string) *RateLimiter { return &RateLimiter{ Requests: requests, Window: window, ClientMap: make(map[string]*Client), + Origins: origin, } } @@ -69,6 +71,7 @@ type JwtAuth struct { RequiredHeaders []string Headers map[string]string Params map[string]string + Origins []string } // AuthenticationMiddleware Define struct @@ -94,7 +97,8 @@ type AuthBasic struct { // InterceptErrors contains backend status code errors to intercept type InterceptErrors struct { - Errors []int + Errors []int + Origins []string } // responseRecorder intercepts the response body and status code diff --git a/internal/route.go b/internal/route.go index e810eeb..1ac5912 100644 --- a/internal/route.go +++ b/internal/route.go @@ -43,7 +43,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { if gateway.RateLimiter != 0 { //rateLimiter := middleware.NewRateLimiter(gateway.RateLimiter, time.Minute) - limiter := middleware.NewRateLimiterWindow(gateway.RateLimiter, time.Minute) // requests per minute + limiter := middleware.NewRateLimiterWindow(gateway.RateLimiter, time.Minute, gateway.Cors.Origins) // requests per minute // Add rate limit middleware to all routes, if defined r.Use(limiter.RateLimitMiddleware()) } @@ -113,6 +113,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { RequiredHeaders: jwt.RequiredHeaders, Headers: jwt.Headers, Params: jwt.Params, + Origins: gateway.Cors.Origins, } // Apply JWT authentication middleware secureRouter.Use(amw.AuthMiddleware) @@ -163,7 +164,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { r.Use(CORSHandler(gateway.Cors)) // Apply CORS middleware // Apply errorInterceptor middleware interceptErrors := middleware.InterceptErrors{ - Errors: gateway.InterceptErrors, + Errors: gateway.InterceptErrors, + Origins: gateway.Cors.Origins, } r.Use(interceptErrors.ErrorInterceptor) return r