refactor: improvement of access policy middleware

This commit is contained in:
2024-12-09 18:19:24 +01:00
parent 36fb317367
commit 7e3489e201

View File

@@ -38,65 +38,40 @@ func (access AccessPolicy) AccessPolicyMiddleware(next http.Handler) http.Handle
RespondWithError(w, http.StatusUnauthorized, "Unable to parse IP address")
return
}
for index, entry := range access.SourceRanges {
// Check if the IP is in the blocklist
if access.Action == "DENY" {
if strings.Contains(entry, "-") {
// Handle IP range
startIP, endIP, err := parseIPRange(entry)
if err == nil && ipInRange(clientIP, startIP, endIP) {
logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
}
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r)
return
}
continue
} else {
// Handle single IP
if clientIP == entry {
logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
}
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r)
return
}
continue
}
} else {
// Check if the IP is in the allowlist
if strings.Contains(entry, "-") {
// Handle IP range
startIP, endIP, err := parseIPRange(entry)
if err == nil && ipInRange(clientIP, startIP, endIP) {
next.ServeHTTP(w, r)
return
}
continue
// Check IP against source ranges
isAllowed := access.Action != "DENY"
for _, entry := range access.SourceRanges {
if isIPAllowed(clientIP, entry) {
if isAllowed {
next.ServeHTTP(w, r)
} else {
// Handle single IP
if clientIP == entry {
next.ServeHTTP(w, r)
return
}
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r)
return
}
continue
logger.Error("%s: IP address in the blocklist, access not allowed", clientIP)
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
}
return
}
}
logger.Error("%s: IP address not allowed ", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
})
// Final response for disallowed IPs
if isAllowed {
logger.Error("%s: IP address not allowed", clientIP)
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
} else {
next.ServeHTTP(w, r)
}
})
}
// isIPAllowed checks if a client IP matches an entry (range or single IP).
func isIPAllowed(clientIP, entry string) bool {
if strings.Contains(entry, "-") {
// Handle IP range
startIP, endIP, err := parseIPRange(entry)
return err == nil && ipInRange(clientIP, startIP, endIP)
}
// Handle single IP
return clientIP == entry
}
// / Parse a range string into start and end IPs