feat: add access middleware support ip range

This commit is contained in:
2024-12-09 17:39:51 +01:00
parent 0fc5ef52ff
commit af2b0cbce1
2 changed files with 100 additions and 23 deletions

View File

@@ -21,6 +21,7 @@ import (
"github.com/jkaninda/goma-gateway/pkg/logger"
"net"
"net/http"
"strings"
)
type AccessPolicy struct {
@@ -30,33 +31,109 @@ type AccessPolicy struct {
func (access AccessPolicy) AccessPolicyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
iPs := make(map[string]struct{})
for _, ip := range access.SourceRanges {
iPs[ip] = struct{}{}
}
// Get the client's IP address
ip, _, err := net.SplitHostPort(getRealIP(r))
clientIP, _, err := net.SplitHostPort(getRealIP(r))
if err != nil {
logger.Error("Unable to parse IP address")
RespondWithError(w, http.StatusUnauthorized, "Unable to parse IP address")
return
}
for index, entry := range access.SourceRanges {
// Check if the IP is in the blocklist
if access.Action == "DENY" {
if _, ok := iPs[ip]; ok {
if strings.Contains(entry, "-") {
// Handle IP range
startIP, endIP, err := parseIPRange(entry)
if err == nil && ipInRange(clientIP, startIP, endIP) {
logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
}
}
// Check if the IP is in the allowlist
if _, ok := iPs[ip]; !ok {
logger.Error("%s: IP address not allowed ", getRealIP(r))
continue
} else {
// Handle single IP
if clientIP == entry {
logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
}
// Continue to the next handler if the authentication is successful
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r)
return
}
continue
}
} else {
// Check if the IP is in the allowlist
if strings.Contains(entry, "-") {
// Handle IP range
startIP, endIP, err := parseIPRange(entry)
if err == nil && ipInRange(clientIP, startIP, endIP) {
next.ServeHTTP(w, r)
return
}
continue
} else {
// Handle single IP
if clientIP == entry {
next.ServeHTTP(w, r)
return
}
if index == len(access.SourceRanges)-1 {
next.ServeHTTP(w, r)
return
}
continue
}
}
}
logger.Error("%s: IP address not allowed ", getRealIP(r))
RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
return
})
}
// / Parse a range string into start and end IPs
func parseIPRange(rangeStr string) (string, string, error) {
parts := strings.Split(rangeStr, "-")
if len(parts) != 2 {
return "", "", http.ErrAbortHandler
}
startIP := strings.TrimSpace(parts[0])
endIP := strings.TrimSpace(parts[1])
if net.ParseIP(startIP) == nil || net.ParseIP(endIP) == nil {
return "", "", http.ErrAbortHandler
}
return startIP, endIP, nil
}
// Check if an IP is in range
func ipInRange(ipStr, startIP, endIP string) bool {
ip := net.ParseIP(ipStr)
start := net.ParseIP(startIP)
end := net.ParseIP(endIP)
if ip == nil || start == nil || end == nil {
return false
}
ipBytes := ip.To4()
startBytes := start.To4()
endBytes := end.To4()
if ipBytes == nil || startBytes == nil || endBytes == nil {
return false
}
for i := 0; i < 4; i++ {
if ipBytes[i] < startBytes[i] || ipBytes[i] > endBytes[i] {
return false
}
}
return true
}

View File

@@ -250,14 +250,14 @@ func attachMiddlewares(rIndex int, route Route, gateway Gateway, router *mux.Rou
}
// AccessPolicy
if accessPolicy == mid.Type {
accessPolicy, err := getAccessPoliciesMiddleware(mid.Rule)
a, err := getAccessPoliciesMiddleware(mid.Rule)
if err != nil {
logger.Error("Error: %v, middleware not applied", err.Error())
}
if len(accessPolicy.SourceRanges) != 0 {
if len(a.SourceRanges) != 0 {
access := middlewares.AccessPolicy{
SourceRanges: accessPolicy.SourceRanges,
Action: accessPolicy.Action,
SourceRanges: a.SourceRanges,
Action: a.Action,
}
router.Use(access.AccessPolicyMiddleware)
}