From 7e3489e201bcc457177ddc751c517bb72015b8f4 Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Mon, 9 Dec 2024 18:19:24 +0100 Subject: [PATCH] refactor: improvement of access policy middleware --- .../middlewares/access_policy_middleware.go | 81 +++++++------------ 1 file changed, 28 insertions(+), 53 deletions(-) diff --git a/internal/middlewares/access_policy_middleware.go b/internal/middlewares/access_policy_middleware.go index d979ab0..7858b97 100644 --- a/internal/middlewares/access_policy_middleware.go +++ b/internal/middlewares/access_policy_middleware.go @@ -38,65 +38,40 @@ func (access AccessPolicy) AccessPolicyMiddleware(next http.Handler) http.Handle RespondWithError(w, http.StatusUnauthorized, "Unable to parse IP address") return } - for index, entry := range access.SourceRanges { - // Check if the IP is in the blocklist - if access.Action == "DENY" { - if strings.Contains(entry, "-") { - // Handle IP range - startIP, endIP, err := parseIPRange(entry) - if err == nil && ipInRange(clientIP, startIP, endIP) { - logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r)) - RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) - return - } - if index == len(access.SourceRanges)-1 { - next.ServeHTTP(w, r) - return - } - continue - } else { - // Handle single IP - if clientIP == entry { - logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r)) - RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) - return - } - if index == len(access.SourceRanges)-1 { - next.ServeHTTP(w, r) - return - } - continue - } - } else { - // Check if the IP is in the allowlist - if strings.Contains(entry, "-") { - // Handle IP range - startIP, endIP, err := parseIPRange(entry) - if err == nil && ipInRange(clientIP, startIP, endIP) { - next.ServeHTTP(w, r) - return - } - continue + // Check IP against source ranges + isAllowed := access.Action != "DENY" + for _, entry := range access.SourceRanges { + if isIPAllowed(clientIP, entry) { + if isAllowed { + next.ServeHTTP(w, r) } else { - // Handle single IP - if clientIP == entry { - next.ServeHTTP(w, r) - return - } - if index == len(access.SourceRanges)-1 { - next.ServeHTTP(w, r) - return - } - continue + logger.Error("%s: IP address in the blocklist, access not allowed", clientIP) + RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) } + return } } - logger.Error("%s: IP address not allowed ", getRealIP(r)) - RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) - return - }) + // Final response for disallowed IPs + if isAllowed { + logger.Error("%s: IP address not allowed", clientIP) + RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) + } else { + next.ServeHTTP(w, r) + } + }) +} + +// isIPAllowed checks if a client IP matches an entry (range or single IP). +func isIPAllowed(clientIP, entry string) bool { + if strings.Contains(entry, "-") { + // Handle IP range + startIP, endIP, err := parseIPRange(entry) + return err == nil && ipInRange(clientIP, startIP, endIP) + } + // Handle single IP + return clientIP == entry } // / Parse a range string into start and end IPs