fix: Cors when returns errors
This commit is contained in:
@@ -49,6 +49,10 @@ func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handle
|
|||||||
logger.Error("Backend error")
|
logger.Error("Backend error")
|
||||||
logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode)
|
logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode)
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
//Update Origin Cors Headers
|
||||||
|
if allowedOrigin(intercept.Origins, r.Header.Get("Origin")) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||||
|
}
|
||||||
w.WriteHeader(rec.statusCode)
|
w.WriteHeader(rec.statusCode)
|
||||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||||
Success: false,
|
Success: false,
|
||||||
|
|||||||
40
internal/middleware/helpers.go
Normal file
40
internal/middleware/helpers.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024 Jonas Kaninda
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
func getRealIP(r *http.Request) string {
|
||||||
|
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
func allowedOrigin(origins []string, origin string) bool {
|
||||||
|
for _, o := range origins {
|
||||||
|
if o == origin {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
|
||||||
|
}
|
||||||
@@ -34,6 +34,10 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
|
|||||||
if r.Header.Get(header) == "" {
|
if r.Header.Get(header) == "" {
|
||||||
logger.Error("Proxy error, missing %s header", header)
|
logger.Error("Proxy error, missing %s header", header)
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
//Update Origin Cors Headers
|
||||||
|
if allowedOrigin(jwtAuth.Origins, r.Header.Get("Origin")) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||||
|
}
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||||
Message: http.StatusText(http.StatusUnauthorized),
|
Message: http.StatusText(http.StatusUnauthorized),
|
||||||
|
|||||||
@@ -69,6 +69,10 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
|
|||||||
logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent())
|
logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent())
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusTooManyRequests)
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
//Update Origin Cors Headers
|
||||||
|
if allowedOrigin(rl.Origins, r.Header.Get("Origin")) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||||
|
}
|
||||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||||
Success: false,
|
Success: false,
|
||||||
Code: http.StatusTooManyRequests,
|
Code: http.StatusTooManyRequests,
|
||||||
@@ -84,12 +88,3 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func getRealIP(r *http.Request) string {
|
|
||||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
return r.RemoteAddr
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ type RateLimiter struct {
|
|||||||
Window time.Duration
|
Window time.Duration
|
||||||
ClientMap map[string]*Client
|
ClientMap map[string]*Client
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
Origins []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client stores request count and window expiration for each client.
|
// Client stores request count and window expiration for each client.
|
||||||
@@ -39,11 +40,12 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewRateLimiterWindow creates a new RateLimiter.
|
// NewRateLimiterWindow creates a new RateLimiter.
|
||||||
func NewRateLimiterWindow(requests int, window time.Duration) *RateLimiter {
|
func NewRateLimiterWindow(requests int, window time.Duration, origin []string) *RateLimiter {
|
||||||
return &RateLimiter{
|
return &RateLimiter{
|
||||||
Requests: requests,
|
Requests: requests,
|
||||||
Window: window,
|
Window: window,
|
||||||
ClientMap: make(map[string]*Client),
|
ClientMap: make(map[string]*Client),
|
||||||
|
Origins: origin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,6 +71,7 @@ type JwtAuth struct {
|
|||||||
RequiredHeaders []string
|
RequiredHeaders []string
|
||||||
Headers map[string]string
|
Headers map[string]string
|
||||||
Params map[string]string
|
Params map[string]string
|
||||||
|
Origins []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthenticationMiddleware Define struct
|
// AuthenticationMiddleware Define struct
|
||||||
@@ -95,6 +98,7 @@ type AuthBasic struct {
|
|||||||
// InterceptErrors contains backend status code errors to intercept
|
// InterceptErrors contains backend status code errors to intercept
|
||||||
type InterceptErrors struct {
|
type InterceptErrors struct {
|
||||||
Errors []int
|
Errors []int
|
||||||
|
Origins []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// responseRecorder intercepts the response body and status code
|
// responseRecorder intercepts the response body and status code
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
|
|||||||
|
|
||||||
if gateway.RateLimiter != 0 {
|
if gateway.RateLimiter != 0 {
|
||||||
//rateLimiter := middleware.NewRateLimiter(gateway.RateLimiter, time.Minute)
|
//rateLimiter := middleware.NewRateLimiter(gateway.RateLimiter, time.Minute)
|
||||||
limiter := middleware.NewRateLimiterWindow(gateway.RateLimiter, time.Minute) // requests per minute
|
limiter := middleware.NewRateLimiterWindow(gateway.RateLimiter, time.Minute, gateway.Cors.Origins) // requests per minute
|
||||||
// Add rate limit middleware to all routes, if defined
|
// Add rate limit middleware to all routes, if defined
|
||||||
r.Use(limiter.RateLimitMiddleware())
|
r.Use(limiter.RateLimitMiddleware())
|
||||||
}
|
}
|
||||||
@@ -113,6 +113,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
|
|||||||
RequiredHeaders: jwt.RequiredHeaders,
|
RequiredHeaders: jwt.RequiredHeaders,
|
||||||
Headers: jwt.Headers,
|
Headers: jwt.Headers,
|
||||||
Params: jwt.Params,
|
Params: jwt.Params,
|
||||||
|
Origins: gateway.Cors.Origins,
|
||||||
}
|
}
|
||||||
// Apply JWT authentication middleware
|
// Apply JWT authentication middleware
|
||||||
secureRouter.Use(amw.AuthMiddleware)
|
secureRouter.Use(amw.AuthMiddleware)
|
||||||
@@ -164,6 +165,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
|
|||||||
// Apply errorInterceptor middleware
|
// Apply errorInterceptor middleware
|
||||||
interceptErrors := middleware.InterceptErrors{
|
interceptErrors := middleware.InterceptErrors{
|
||||||
Errors: gateway.InterceptErrors,
|
Errors: gateway.InterceptErrors,
|
||||||
|
Origins: gateway.Cors.Origins,
|
||||||
}
|
}
|
||||||
r.Use(interceptErrors.ErrorInterceptor)
|
r.Use(interceptErrors.ErrorInterceptor)
|
||||||
return r
|
return r
|
||||||
|
|||||||
Reference in New Issue
Block a user