refactor: refactoring of auth middlewares

This commit is contained in:
Jonas Kaninda
2024-11-25 07:38:49 +01:00
parent f4e5bb3be2
commit dbd0974388
5 changed files with 25 additions and 18 deletions

View File

@@ -53,9 +53,9 @@ func isPathBlocked(requestPath, blockedPath string) bool {
} }
return false return false
} }
func isProtectedPath(urlPath string, paths []string) bool { func isProtectedPath(urlPath, prefix string, paths []string) bool {
for _, path := range paths { for _, path := range paths {
return isPathBlocked(urlPath, util.ParseURLPath(path)) return isPathBlocked(urlPath, util.ParseURLPath(prefix+path))
} }
return false return false
} }

View File

@@ -29,7 +29,7 @@ import (
// authorization based on the result of backend's response and continue the request when the client is authorized // authorization based on the result of backend's response and continue the request when the client is authorized
func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler { func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isProtectedPath(r.URL.Path, jwtAuth.Paths) { if isProtectedPath(r.URL.Path, jwtAuth.Path, jwtAuth.Paths) {
for _, header := range jwtAuth.RequiredHeaders { for _, header := range jwtAuth.RequiredHeaders {
if r.Header.Get(header) == "" { if r.Header.Get(header) == "" {
logger.Error("Proxy error, missing %s header", header) logger.Error("Proxy error, missing %s header", header)
@@ -98,16 +98,16 @@ func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
} }
r.URL.RawQuery = query.Encode() r.URL.RawQuery = query.Encode()
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
// AuthMiddleware checks for the Authorization header and verifies the credentials // AuthMiddleware checks for the Authorization header and verifies the credentials
func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler { func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Trace("Basic-Auth request headers: %v", r.Header) logger.Trace("Basic-Auth request headers: %v", r.Header)
if isProtectedPath(r.URL.Path, basicAuth.Paths) { if isProtectedPath(r.URL.Path, basicAuth.Path, basicAuth.Paths) {
// Get the Authorization header // Get the Authorization header
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
if authHeader == "" { if authHeader == "" {

View File

@@ -26,7 +26,7 @@ import (
func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isProtectedPath(r.URL.Path, oauth.Paths) { if isProtectedPath(r.URL.Path, oauth.Path, oauth.Paths) {
oauthConf := oauth2Config(oauth) oauthConf := oauth2Config(oauth)
// Check if the user is authenticated // Check if the user is authenticated
token, err := r.Cookie("goma.oauth") token, err := r.Cookie("goma.oauth")

View File

@@ -85,7 +85,7 @@ type ProxyResponseError struct {
// JwtAuth stores JWT configuration // JwtAuth stores JWT configuration
type JwtAuth struct { type JwtAuth struct {
RoutePath string Path string
Paths []string Paths []string
AuthURL string AuthURL string
RequiredHeaders []string RequiredHeaders []string
@@ -109,6 +109,8 @@ type AccessListMiddleware struct {
// AuthBasic contains Basic auth configuration // AuthBasic contains Basic auth configuration
type AuthBasic struct { type AuthBasic struct {
// Route path
Path string
Paths []string Paths []string
Username string Username string
Password string Password string
@@ -129,6 +131,8 @@ type responseRecorder struct {
body *bytes.Buffer body *bytes.Buffer
} }
type Oauth struct { type Oauth struct {
// Route path
Path string
// Route protected path // Route protected path
Paths []string Paths []string
// ClientID is the application's ID. // ClientID is the application's ID.

View File

@@ -201,7 +201,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Error: middlewares not found // Error: middlewares not found
logger.Error("Error: %v", err.Error()) logger.Error("Error: %v", err.Error())
} else { } else {
attachAuthMiddlewares(route, routeMiddleware, gateway, r) attachAuthMiddlewares(route, routeMiddleware, gateway, router)
} }
} else { } else {
logger.Error("Error, middlewares path is empty") logger.Error("Error, middlewares path is empty")
@@ -211,13 +211,6 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Apply route Cors // Apply route Cors
router.Use(CORSHandler(route.Cors)) router.Use(CORSHandler(route.Cors))
if len(route.Hosts) > 0 {
for _, host := range route.Hosts {
router.Host(host).PathPrefix("").Handler(proxyRoute.ProxyHandler())
}
} else {
router.PathPrefix("").Handler(proxyRoute.ProxyHandler())
}
if gateway.EnableMetrics { if gateway.EnableMetrics {
pr := metrics.PrometheusRoute{ pr := metrics.PrometheusRoute{
Name: route.Name, Name: route.Name,
@@ -234,6 +227,13 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
router.Use(interceptErrors.ErrorInterceptor) router.Use(interceptErrors.ErrorInterceptor)
} }
if len(route.Hosts) != 0 {
for _, host := range route.Hosts {
router.Host(host).PathPrefix("").Handler(proxyRoute.ProxyHandler())
}
} else {
router.PathPrefix("").Handler(proxyRoute.ProxyHandler())
}
} else { } else {
logger.Error("Error, path is empty in route %s", route.Name) logger.Error("Error, path is empty in route %s", route.Name)
@@ -266,7 +266,8 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate
logger.Error("Error: %s", err.Error()) logger.Error("Error: %s", err.Error())
} else { } else {
authBasic := middlewares.AuthBasic{ authBasic := middlewares.AuthBasic{
Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), Path: route.Path,
Paths: routeMiddleware.Paths,
Username: basicAuth.Username, Username: basicAuth.Username,
Password: basicAuth.Password, Password: basicAuth.Password,
Headers: nil, Headers: nil,
@@ -282,7 +283,8 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate
logger.Error("Error: %s", err.Error()) logger.Error("Error: %s", err.Error())
} else { } else {
jwtAuth := middlewares.JwtAuth{ jwtAuth := middlewares.JwtAuth{
Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), Path: route.Path,
Paths: routeMiddleware.Paths,
AuthURL: jwt.URL, AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders, RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers, Headers: jwt.Headers,
@@ -304,7 +306,8 @@ func attachAuthMiddlewares(route Route, routeMiddleware Middleware, gateway Gate
redirectURL = oauth.RedirectURL redirectURL = oauth.RedirectURL
} }
amw := middlewares.Oauth{ amw := middlewares.Oauth{
Paths: util.AddPrefixPath(route.Path, routeMiddleware.Paths), Path: route.Path,
Paths: routeMiddleware.Paths,
ClientID: oauth.ClientID, ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret, ClientSecret: oauth.ClientSecret,
RedirectURL: redirectURL, RedirectURL: redirectURL,