feat: add access middleware support ip range

This commit is contained in:
2024-12-09 17:39:51 +01:00
parent 0fc5ef52ff
commit af2b0cbce1
2 changed files with 100 additions and 23 deletions

View File

@@ -21,6 +21,7 @@ import (
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"net" "net"
"net/http" "net/http"
"strings"
) )
type AccessPolicy struct { type AccessPolicy struct {
@@ -30,33 +31,109 @@ type AccessPolicy struct {
func (access AccessPolicy) AccessPolicyMiddleware(next http.Handler) http.Handler { func (access AccessPolicy) AccessPolicyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
iPs := make(map[string]struct{})
for _, ip := range access.SourceRanges {
iPs[ip] = struct{}{}
}
// Get the client's IP address // Get the client's IP address
ip, _, err := net.SplitHostPort(getRealIP(r)) clientIP, _, err := net.SplitHostPort(getRealIP(r))
if err != nil { if err != nil {
logger.Error("Unable to parse IP address") logger.Error("Unable to parse IP address")
RespondWithError(w, http.StatusUnauthorized, "Unable to parse IP address") RespondWithError(w, http.StatusUnauthorized, "Unable to parse IP address")
return return
} }
// Check if the IP is in the blocklist for index, entry := range access.SourceRanges {
if access.Action == "DENY" { // Check if the IP is in the blocklist
if _, ok := iPs[ip]; ok { if access.Action == "DENY" {
logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r)) if strings.Contains(entry, "-") {
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) // Handle IP range
return startIP, endIP, err := parseIPRange(entry)
if err == nil && ipInRange(clientIP, startIP, endIP) {
logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
}
continue
} else {
// Handle single IP
if clientIP == entry {
logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
}
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r)
return
}
continue
}
} else {
// Check if the IP is in the allowlist
if strings.Contains(entry, "-") {
// Handle IP range
startIP, endIP, err := parseIPRange(entry)
if err == nil && ipInRange(clientIP, startIP, endIP) {
next.ServeHTTP(w, r)
return
}
continue
} else {
// Handle single IP
if clientIP == entry {
next.ServeHTTP(w, r)
return
}
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r)
return
}
continue
}
} }
} }
// Check if the IP is in the allowlist logger.Error("%s: IP address not allowed ", getRealIP(r))
if _, ok := iPs[ip]; !ok { RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
logger.Error("%s: IP address not allowed ", getRealIP(r)) return
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
}
// Continue to the next handler if the authentication is successful
next.ServeHTTP(w, r)
}) })
} }
// / Parse a range string into start and end IPs
func parseIPRange(rangeStr string) (string, string, error) {
parts := strings.Split(rangeStr, "-")
if len(parts) != 2 {
return "", "", http.ErrAbortHandler
}
startIP := strings.TrimSpace(parts[0])
endIP := strings.TrimSpace(parts[1])
if net.ParseIP(startIP) == nil || net.ParseIP(endIP) == nil {
return "", "", http.ErrAbortHandler
}
return startIP, endIP, nil
}
// Check if an IP is in range
func ipInRange(ipStr, startIP, endIP string) bool {
ip := net.ParseIP(ipStr)
start := net.ParseIP(startIP)
end := net.ParseIP(endIP)
if ip == nil || start == nil || end == nil {
return false
}
ipBytes := ip.To4()
startBytes := start.To4()
endBytes := end.To4()
if ipBytes == nil || startBytes == nil || endBytes == nil {
return false
}
for i := 0; i < 4; i++ {
if ipBytes[i] < startBytes[i] || ipBytes[i] > endBytes[i] {
return false
}
}
return true
}

View File

@@ -250,14 +250,14 @@ func attachMiddlewares(rIndex int, route Route, gateway Gateway, router *mux.Rou
} }
// AccessPolicy // AccessPolicy
if accessPolicy == mid.Type { if accessPolicy == mid.Type {
accessPolicy, err := getAccessPoliciesMiddleware(mid.Rule) a, err := getAccessPoliciesMiddleware(mid.Rule)
if err != nil { if err != nil {
logger.Error("Error: %v, middleware not applied", err.Error()) logger.Error("Error: %v, middleware not applied", err.Error())
} }
if len(accessPolicy.SourceRanges) != 0 { if len(a.SourceRanges) != 0 {
access := middlewares.AccessPolicy{ access := middlewares.AccessPolicy{
SourceRanges: accessPolicy.SourceRanges, SourceRanges: a.SourceRanges,
Action: accessPolicy.Action, Action: a.Action,
} }
router.Use(access.AccessPolicyMiddleware) router.Use(access.AccessPolicyMiddleware)
} }