fix: fix authentification middlewares

This commit is contained in:
Jonas Kaninda
2024-11-24 22:13:26 +01:00
parent 6258b07c82
commit 3df8dce59b
7 changed files with 228 additions and 216 deletions

View File

@@ -53,6 +53,12 @@ func isPathBlocked(requestPath, blockedPath string) bool {
} }
return false return false
} }
func isProtectedPath(urlPath string, paths []string) bool {
for _, path := range paths {
return isPathBlocked(urlPath, util.ParseURLPath(path))
}
return false
}
// NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity // NewRateLimiter creates a new requests limiter with the specified refill requests and token capacity
func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter { func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter {

View File

@@ -29,6 +29,7 @@ import (
// authorization based on the result of backend's response and continue the request when the client is authorized // authorization based on the result of backend's response and continue the request when the client is authorized
func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isProtectedPath(r.URL.Path, jwtAuth.Paths) {
for _, header := range jwtAuth.RequiredHeaders { for _, header := range jwtAuth.RequiredHeaders {
if r.Header.Get(header) == "" { if r.Header.Get(header) == "" {
logger.Error("Proxy error, missing %s header", header) logger.Error("Proxy error, missing %s header", header)
@@ -96,6 +97,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
} }
} }
r.URL.RawQuery = query.Encode() r.URL.RawQuery = query.Encode()
}
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
@@ -105,6 +107,7 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Trace("Basic-Auth request headers: %v", r.Header) logger.Trace("Basic-Auth request headers: %v", r.Header)
if isProtectedPath(r.URL.Path, basicAuth.Paths) {
// Get the Authorization header // Get the Authorization header
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
if authHeader == "" { if authHeader == "" {
@@ -120,7 +123,6 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
return return
} }
// Decode the base64 encoded username:password string // Decode the base64 encoded username:password string
payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):]) payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):])
if err != nil { if err != nil {
@@ -128,7 +130,6 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) RespondWithError(w, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
return return
} }
// Split the payload into username and password // Split the payload into username and password
pair := strings.SplitN(string(payload), ":", 2) pair := strings.SplitN(string(payload), ":", 2)
if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password { if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password {
@@ -137,6 +138,8 @@ func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
return return
} }
}
// Continue to the next handler if the authentication is successful // Continue to the next handler if the authentication is successful
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })

View File

@@ -26,6 +26,7 @@ import (
func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isProtectedPath(r.URL.Path, oauth.Paths) {
oauthConf := oauth2Config(oauth) oauthConf := oauth2Config(oauth)
// Check if the user is authenticated // Check if the user is authenticated
token, err := r.Cookie("goma.oauth") token, err := r.Cookie("goma.oauth")
@@ -48,6 +49,7 @@ func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler {
http.Redirect(w, r, url, http.StatusTemporaryRedirect) http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return return
} }
}
// Token exists, proceed with request // Token exists, proceed with request
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })

View File

@@ -79,6 +79,8 @@ type ProxyResponseError struct {
// JwtAuth stores JWT configuration // JwtAuth stores JWT configuration
type JwtAuth struct { type JwtAuth struct {
RoutePath string
Paths []string
AuthURL string AuthURL string
RequiredHeaders []string RequiredHeaders []string
Headers map[string]string Headers map[string]string
@@ -101,6 +103,7 @@ type AccessListMiddleware struct {
// AuthBasic contains Basic auth configuration // AuthBasic contains Basic auth configuration
type AuthBasic struct { type AuthBasic struct {
Paths []string
Username string Username string
Password string Password string
Headers map[string]string Headers map[string]string
@@ -120,6 +123,8 @@ type responseRecorder struct {
body *bytes.Buffer body *bytes.Buffer
} }
type Oauth struct { type Oauth struct {
// Route protected path
Paths []string
// ClientID is the application's ID. // ClientID is the application's ID.
ClientID string ClientID string
// ClientSecret is the application's secret. // ClientSecret is the application's secret.

View File

@@ -106,12 +106,25 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
r.Use(limiter.RateLimitMiddleware()) r.Use(limiter.RateLimitMiddleware())
} }
for rIndex, route := range dynamicRoutes { for rIndex, route := range dynamicRoutes {
// create route
router := r.PathPrefix(route.Path).Subrouter()
if len(route.Path) != 0 { if len(route.Path) != 0 {
// Checks if route destination and backend are empty // Checks if route destination and backend are empty
if len(route.Destination) == 0 && len(route.Backends) == 0 { if len(route.Destination) == 0 && len(route.Backends) == 0 {
logger.Fatal("Route %s : destination or backends should not be empty", route.Name) logger.Fatal("Route %s : destination or backends should not be empty", route.Name)
} }
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
backends: route.Backends,
methods: route.Methods,
disableHostFording: route.DisableHostFording,
cors: route.Cors,
insecureSkipVerify: route.InsecureSkipVerify,
}
// Apply middlewares to the route // Apply middlewares to the route
for _, middleware := range route.Middlewares { for _, middleware := range route.Middlewares {
if len(middleware) != 0 { if len(middleware) != 0 {
@@ -144,18 +157,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
logger.Error("Middleware ignored") logger.Error("Middleware ignored")
} }
} }
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
backends: route.Backends,
methods: route.Methods,
disableHostFording: route.DisableHostFording,
cors: route.Cors,
insecureSkipVerify: route.InsecureSkipVerify,
}
// create route
router := r.PathPrefix(route.Path).Subrouter()
// Apply common exploits to the route // Apply common exploits to the route
// Enable common exploits // Enable common exploits
if route.BlockCommonExploits { if route.BlockCommonExploits {
@@ -206,6 +208,8 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
router.Use(interceptErrors.ErrorInterceptor) router.Use(interceptErrors.ErrorInterceptor)
} }
//r.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
//r.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
} else { } else {
logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Error, path is empty in route %s", route.Name)
logger.Error("Route path ignored: %s", route.Path) logger.Error("Route path ignored: %s", route.Path)
@@ -221,6 +225,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
r.Use(interceptErrors.ErrorInterceptor) r.Use(interceptErrors.ErrorInterceptor)
} }
} }
return r return r
@@ -228,18 +233,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) { func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gateway, r *mux.Router) {
for _, middlewarePath := range routeMiddleware.Paths {
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
backends: route.Backends,
disableHostFording: route.DisableHostFording,
methods: route.Methods,
cors: route.Cors,
insecureSkipVerify: route.InsecureSkipVerify,
}
secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, middlewarePath)).Subrouter()
// Check Authentication middleware types // Check Authentication middleware types
switch routeMiddleware.Type { switch routeMiddleware.Type {
case BasicAuth: case BasicAuth:
@@ -248,16 +241,15 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate
logger.Error("Error: %s", err.Error()) logger.Error("Error: %s", err.Error())
} else { } else {
authBasic := middlewares.AuthBasic{ authBasic := middlewares.AuthBasic{
Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths),
Username: basicAuth.Username, Username: basicAuth.Username,
Password: basicAuth.Password, Password: basicAuth.Password,
Headers: nil, Headers: nil,
Params: nil, Params: nil,
} }
// Apply JWT authentication middlewares // Apply JWT authentication middlewares
secureRouter.Use(authBasic.AuthMiddleware) r.Use(authBasic.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors)) r.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
} }
case JWTAuth: case JWTAuth:
jwt, err := getJWTMiddleware(routeMiddleware.Rule) jwt, err := getJWTMiddleware(routeMiddleware.Rule)
@@ -265,6 +257,7 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate
logger.Error("Error: %s", err.Error()) logger.Error("Error: %s", err.Error())
} else { } else {
jwtAuth := middlewares.JwtAuth{ jwtAuth := middlewares.JwtAuth{
Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths),
AuthURL: jwt.URL, AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders, RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers, Headers: jwt.Headers,
@@ -272,10 +265,8 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate
Origins: gateway.Cors.Origins, Origins: gateway.Cors.Origins,
} }
// Apply JWT authentication middlewares // Apply JWT authentication middlewares
secureRouter.Use(jwtAuth.AuthMiddleware) r.Use(jwtAuth.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors)) r.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
} }
case OAuth: case OAuth:
@@ -288,6 +279,7 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate
redirectURL = oauth.RedirectURL redirectURL = oauth.RedirectURL
} }
amw := middlewares.Oauth{ amw := middlewares.Oauth{
Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths),
ClientID: oauth.ClientID, ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret, ClientSecret: oauth.ClientSecret,
RedirectURL: redirectURL, RedirectURL: redirectURL,
@@ -309,16 +301,13 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate
} }
// Check if a RedirectPath is defined // Check if a RedirectPath is defined
if oauthRuler.RedirectPath == "" { if oauthRuler.RedirectPath == "" {
oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, middlewarePath) oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, routeMiddleware.Paths[0])
} }
if oauthRuler.Provider == "" { if oauthRuler.Provider == "" {
oauthRuler.Provider = "custom" oauthRuler.Provider = "custom"
} }
secureRouter.Use(amw.AuthMiddleware) r.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors)) r.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
// Callback route
r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET") r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET")
} }
default: default:
@@ -328,5 +317,4 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate
} }
}
} }

View File

@@ -30,7 +30,7 @@ import (
// Start / Start starts the server // Start / Start starts the server
func (gatewayServer GatewayServer) Start() error { func (gatewayServer GatewayServer) Start() error {
logger.Info("Initializing routes...") logger.Info("Initializing routes...")
route := gatewayServer.Initialize() router := gatewayServer.Initialize()
logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares))
gatewayServer.initRedis() gatewayServer.initRedis()
defer gatewayServer.closeRedis() defer gatewayServer.closeRedis()
@@ -44,8 +44,8 @@ func (gatewayServer GatewayServer) Start() error {
printRoute(dynamicRoutes) printRoute(dynamicRoutes)
} }
httpServer := gatewayServer.createServer(":8080", route, nil) httpServer := gatewayServer.createServer(":8080", router, nil)
httpsServer := gatewayServer.createServer(":8443", route, tlsConfig) httpsServer := gatewayServer.createServer(":8443", router, tlsConfig)
// Start HTTP/HTTPS servers // Start HTTP/HTTPS servers
gatewayServer.startServers(httpServer, httpsServer, listenWithTLS) gatewayServer.startServers(httpServer, httpsServer, listenWithTLS)

View File

@@ -157,3 +157,11 @@ func Slug(text string) string {
return text return text
} }
func AddPrefixPath(prefix string, paths []string) []string {
for i := range paths {
paths[i] = ParseURLPath(prefix + paths[i])
}
return paths
}