refactor: improvement of access policy middleware

This commit is contained in:
2024-12-09 18:19:24 +01:00
parent 36fb317367
commit 7e3489e201

View File

@@ -38,65 +38,40 @@ func (access AccessPolicy) AccessPolicyMiddleware(next http.Handler) http.Handle
RespondWithError(w, http.StatusUnauthorized, "Unable to parse IP address") RespondWithError(w, http.StatusUnauthorized, "Unable to parse IP address")
return return
} }
for index, entry := range access.SourceRanges {
// Check if the IP is in the blocklist // Check IP against source ranges
if access.Action == "DENY" { isAllowed := access.Action != "DENY"
if strings.Contains(entry, "-") { for _, entry := range access.SourceRanges {
// Handle IP range if isIPAllowed(clientIP, entry) {
startIP, endIP, err := parseIPRange(entry) if isAllowed {
if err == nil && ipInRange(clientIP, startIP, endIP) {
logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
}
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return
}
continue
} else { } else {
// Handle single IP logger.Error("%s: IP address in the blocklist, access not allowed", clientIP)
if clientIP == entry {
logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
}
return return
} }
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r)
return
}
continue
} }
} else { // Final response for disallowed IPs
// Check if the IP is in the allowlist if isAllowed {
if strings.Contains(entry, "-") { logger.Error("%s: IP address not allowed", clientIP)
// Handle IP range
startIP, endIP, err := parseIPRange(entry)
if err == nil && ipInRange(clientIP, startIP, endIP) {
next.ServeHTTP(w, r)
return
}
continue
} else {
// Handle single IP
if clientIP == entry {
next.ServeHTTP(w, r)
return
}
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r)
return
}
continue
}
}
}
logger.Error("%s: IP address not allowed ", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return } else {
next.ServeHTTP(w, r)
}
}) })
}
// isIPAllowed checks if a client IP matches an entry (range or single IP).
func isIPAllowed(clientIP, entry string) bool {
if strings.Contains(entry, "-") {
// Handle IP range
startIP, endIP, err := parseIPRange(entry)
return err == nil && ipInRange(clientIP, startIP, endIP)
}
// Handle single IP
return clientIP == entry
} }
// / Parse a range string into start and end IPs // / Parse a range string into start and end IPs