From af2b0cbce1747db48f247a0c617dabd781961c9a Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Mon, 9 Dec 2024 17:39:51 +0100 Subject: [PATCH] feat: add access middleware support ip range --- .../middlewares/access_policy_middleware.go | 115 +++++++++++++++--- internal/routes.go | 8 +- 2 files changed, 100 insertions(+), 23 deletions(-) diff --git a/internal/middlewares/access_policy_middleware.go b/internal/middlewares/access_policy_middleware.go index aaef055..48d6df2 100644 --- a/internal/middlewares/access_policy_middleware.go +++ b/internal/middlewares/access_policy_middleware.go @@ -21,6 +21,7 @@ import ( "github.com/jkaninda/goma-gateway/pkg/logger" "net" "net/http" + "strings" ) type AccessPolicy struct { @@ -30,33 +31,109 @@ type AccessPolicy struct { func (access AccessPolicy) AccessPolicyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - iPs := make(map[string]struct{}) - for _, ip := range access.SourceRanges { - iPs[ip] = struct{}{} - } // Get the client's IP address - ip, _, err := net.SplitHostPort(getRealIP(r)) + clientIP, _, err := net.SplitHostPort(getRealIP(r)) if err != nil { logger.Error("Unable to parse IP address") RespondWithError(w, http.StatusUnauthorized, "Unable to parse IP address") return } - // Check if the IP is in the blocklist - if access.Action == "DENY" { - if _, ok := iPs[ip]; ok { - logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r)) - RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) - return + for index, entry := range access.SourceRanges { + // Check if the IP is in the blocklist + if access.Action == "DENY" { + if strings.Contains(entry, "-") { + // Handle IP range + startIP, endIP, err := parseIPRange(entry) + if err == nil && ipInRange(clientIP, startIP, endIP) { + logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r)) + RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) + return + } + continue + } else { + // Handle single IP + if clientIP == entry { + logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r)) + RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) + return + } + if index == len(access.SourceRanges)-1 { + next.ServeHTTP(w, r) + return + } + continue + } + + } else { + // Check if the IP is in the allowlist + if strings.Contains(entry, "-") { + // Handle IP range + startIP, endIP, err := parseIPRange(entry) + if err == nil && ipInRange(clientIP, startIP, endIP) { + next.ServeHTTP(w, r) + return + } + continue + } else { + // Handle single IP + if clientIP == entry { + next.ServeHTTP(w, r) + return + } + if index == len(access.SourceRanges)-1 { + next.ServeHTTP(w, r) + return + } + continue + } } } - // Check if the IP is in the allowlist - if _, ok := iPs[ip]; !ok { - logger.Error("%s: IP address not allowed ", getRealIP(r)) - RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) - return - } - // Continue to the next handler if the authentication is successful - next.ServeHTTP(w, r) + logger.Error("%s: IP address not allowed ", getRealIP(r)) + RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) + return }) } + +// / Parse a range string into start and end IPs +func parseIPRange(rangeStr string) (string, string, error) { + parts := strings.Split(rangeStr, "-") + if len(parts) != 2 { + return "", "", http.ErrAbortHandler + } + + startIP := strings.TrimSpace(parts[0]) + endIP := strings.TrimSpace(parts[1]) + + if net.ParseIP(startIP) == nil || net.ParseIP(endIP) == nil { + return "", "", http.ErrAbortHandler + } + + return startIP, endIP, nil +} + +// Check if an IP is in range +func ipInRange(ipStr, startIP, endIP string) bool { + ip := net.ParseIP(ipStr) + start := net.ParseIP(startIP) + end := net.ParseIP(endIP) + + if ip == nil || start == nil || end == nil { + return false + } + + ipBytes := ip.To4() + startBytes := start.To4() + endBytes := end.To4() + + if ipBytes == nil || startBytes == nil || endBytes == nil { + return false + } + + for i := 0; i < 4; i++ { + if ipBytes[i] < startBytes[i] || ipBytes[i] > endBytes[i] { + return false + } + } + return true +} diff --git a/internal/routes.go b/internal/routes.go index 0fc73f0..f83b31c 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -250,14 +250,14 @@ func attachMiddlewares(rIndex int, route Route, gateway Gateway, router *mux.Rou } // AccessPolicy if accessPolicy == mid.Type { - accessPolicy, err := getAccessPoliciesMiddleware(mid.Rule) + a, err := getAccessPoliciesMiddleware(mid.Rule) if err != nil { logger.Error("Error: %v, middleware not applied", err.Error()) } - if len(accessPolicy.SourceRanges) != 0 { + if len(a.SourceRanges) != 0 { access := middlewares.AccessPolicy{ - SourceRanges: accessPolicy.SourceRanges, - Action: accessPolicy.Action, + SourceRanges: a.SourceRanges, + Action: a.Action, } router.Use(access.AccessPolicyMiddleware) }