fix: Cors when returns errors

This commit is contained in:
2024-11-05 20:44:06 +01:00
parent 28931ca306
commit 453508688e
6 changed files with 62 additions and 13 deletions

View File

@@ -49,6 +49,10 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle
logger.Error("Backend error") logger.Error("Backend error")
logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode) logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
//Update Origin Cors Headers
if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
}
w.WriteHeader(rec.statusCode) w.WriteHeader(rec.statusCode)
err := json.NewEncoder(w).Encode(ProxyResponseError{ err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false, Success: false,

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package middleware
import "net/http"
func getRealIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
return r.RemoteAddr
}
func allowedOrigin(origins []string, origin string) bool {
for _, o := range origins {
if o == origin {
return true
}
continue
}
return false
}

View File

@@ -34,6 +34,10 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
if r.Header.Get(header) == "" { if r.Header.Get(header) == "" {
logger.Error("Proxy error, missing %s header", header) logger.Error("Proxy error, missing %s header", header)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
//Update Origin Cors Headers
if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
}
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{ err := json.NewEncoder(w).Encode(ProxyResponseError{
Message: http.StatusText(http.StatusUnauthorized), Message: http.StatusText(http.StatusUnauthorized),

View File

@@ -69,6 +69,10 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent()) logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent())
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests) w.WriteHeader(http.StatusTooManyRequests)
//Update Origin Cors Headers
if allowedOrigin(rl.Origins, r.Header.Get("Origin")) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
}
err := json.NewEncoder(w).Encode(ProxyResponseError{ err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false, Success: false,
Code: http.StatusTooManyRequests, Code: http.StatusTooManyRequests,
@@ -84,12 +88,3 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
}) })
} }
} }
func getRealIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
return r.RemoteAddr
}

View File

@@ -30,6 +30,7 @@ type RateLimiter struct {
Window time.Duration Window time.Duration
ClientMap map[string]*Client ClientMap map[string]*Client
mu sync.Mutex mu sync.Mutex
Origins []string
} }
// Client stores request count and window expiration for each client. // Client stores request count and window expiration for each client.
@@ -39,11 +40,12 @@ type Client struct {
} }
// NewRateLimiterWindow creates a new RateLimiter. // NewRateLimiterWindow creates a new RateLimiter.
func NewRateLimiterWindow(requests int, window time.Duration) *RateLimiter { func NewRateLimiterWindow(requests int, window time.Duration, origin []string) *RateLimiter {
return &RateLimiter{ return &RateLimiter{
Requests: requests, Requests: requests,
Window: window, Window: window,
ClientMap: make(map[string]*Client), ClientMap: make(map[string]*Client),
Origins: origin,
} }
} }
@@ -69,6 +71,7 @@ type JwtAuth struct {
RequiredHeaders []string RequiredHeaders []string
Headers map[string]string Headers map[string]string
Params map[string]string Params map[string]string
Origins []string
} }
// AuthenticationMiddleware Define struct // AuthenticationMiddleware Define struct
@@ -95,6 +98,7 @@ type AuthBasic struct {
// InterceptErrors contains backend status code errors to intercept // InterceptErrors contains backend status code errors to intercept
type InterceptErrors struct { type InterceptErrors struct {
Errors []int Errors []int
Origins []string
} }
// responseRecorder intercepts the response body and status code // responseRecorder intercepts the response body and status code

View File

@@ -43,7 +43,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if gateway.RateLimiter != 0 { if gateway.RateLimiter != 0 {
//rateLimiter := middleware.NewRateLimiter(gateway.RateLimiter, time.Minute) //rateLimiter := middleware.NewRateLimiter(gateway.RateLimiter, time.Minute)
limiter := middleware.NewRateLimiterWindow(gateway.RateLimiter, time.Minute) // requests per minute limiter := middleware.NewRateLimiterWindow(gateway.RateLimiter, time.Minute, gateway.Cors.Origins) // requests per minute
// Add rate limit middleware to all routes, if defined // Add rate limit middleware to all routes, if defined
r.Use(limiter.RateLimitMiddleware()) r.Use(limiter.RateLimitMiddleware())
} }
@@ -113,6 +113,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
RequiredHeaders: jwt.RequiredHeaders, RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers, Headers: jwt.Headers,
Params: jwt.Params, Params: jwt.Params,
Origins: gateway.Cors.Origins,
} }
// Apply JWT authentication middleware // Apply JWT authentication middleware
secureRouter.Use(amw.AuthMiddleware) secureRouter.Use(amw.AuthMiddleware)
@@ -164,6 +165,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Apply errorInterceptor middleware // Apply errorInterceptor middleware
interceptErrors := middleware.InterceptErrors{ interceptErrors := middleware.InterceptErrors{
Errors: gateway.InterceptErrors, Errors: gateway.InterceptErrors,
Origins: gateway.Cors.Origins,
} }
r.Use(interceptErrors.ErrorInterceptor) r.Use(interceptErrors.ErrorInterceptor)
return r return r