diff --git a/go.mod b/go.mod index 78ea970..dceeb00 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jinzhu/copier v0.4.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/kr/text v0.2.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect diff --git a/go.sum b/go.sum index 2fc8a68..4a5105a 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jedib0t/go-pretty/v6 v6.6.2 h1:27bLj3nRODzaiA7tPIxy9UVWHoPspFfME9XxgwiiNsM= github.com/jedib0t/go-pretty/v6 v6.6.2/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/internal/config.go b/internal/config.go index 6a88905..348824b 100644 --- a/internal/config.go +++ b/internal/config.go @@ -282,6 +282,41 @@ func getBasicAuthMiddleware(input interface{}) (BasicRuleMiddleware, error) { } return *basicAuth, nil } +func getAccessPoliciesMiddleware(input interface{}) (AccessPolicyRuleMiddleware, error) { + a := new(AccessPolicyRuleMiddleware) + var bytes []byte + bytes, err := yaml.Marshal(input) + if err != nil { + return AccessPolicyRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + err = yaml.Unmarshal(bytes, a) + if err != nil { + return AccessPolicyRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + if len(a.SourceRanges) == 0 { + return AccessPolicyRuleMiddleware{}, fmt.Errorf("empty sourceRanges") + + } + for _, ip := range a.SourceRanges { + isIP, isCIDR := isIPOrCIDR(ip) + if isIP { + if !validateIPAddress(ip) { + return AccessPolicyRuleMiddleware{}, fmt.Errorf("invalid ip address") + } + } + if isCIDR { + if !validateCIDR(ip) { + return AccessPolicyRuleMiddleware{}, fmt.Errorf("invalid cidr address") + } + if validateCIDR(ip) { + return AccessPolicyRuleMiddleware{}, fmt.Errorf("cidr is not yet supported") + + } + } + + } + return *a, nil +} // oAuthMiddleware returns OauthRulerMiddleware, error func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) { diff --git a/internal/helpers.go b/internal/helpers.go index 60e53b4..ed36dd0 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -1,14 +1,22 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + package pkg -/* -Copyright 2024 Jonas Kaninda. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may get a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 -*/ import ( "context" "encoding/json" @@ -16,6 +24,7 @@ import ( "github.com/jedib0t/go-pretty/v6/table" "golang.org/x/oauth2" "io" + "net" "net/http" ) @@ -45,6 +54,7 @@ func getRealIP(r *http.Request) string { return r.RemoteAddr } +// getUserInfo returns struct of UserInfo func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, error) { oauthConfig := oauth2Config(oauth) // Call the user info endpoint with the token @@ -68,3 +78,30 @@ func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, e return userInfo, nil } + +// validateIPAddress checks if the input is a valid IP address (IPv4 or IPv6) +func validateIPAddress(ip string) bool { + return net.ParseIP(ip) != nil +} + +// validateCIDR checks if the input is a valid CIDR notation +func validateCIDR(cidr string) bool { + _, _, err := net.ParseCIDR(cidr) + return err == nil +} + +// isIPOrCIDR determines whether the input is an IP address or a CIDR +func isIPOrCIDR(input string) (isIP bool, isCIDR bool) { + // Check if it's a valid IP address + if net.ParseIP(input) != nil { + return true, false + } + + // Check if it's a valid CIDR + if _, _, err := net.ParseCIDR(input); err == nil { + return false, true + } + + // Neither IP nor CIDR + return false, false +} diff --git a/internal/middleware.go b/internal/middleware.go index ddc55a8..a8c5c82 100644 --- a/internal/middleware.go +++ b/internal/middleware.go @@ -21,7 +21,7 @@ func getMiddleware(rules []string, middlewares []Middleware) (Middleware, error) } func doesExist(tyName string) bool { - middlewareList := []string{BasicAuth, JWTAuth, AccessMiddleware} + middlewareList := []string{BasicAuth, JWTAuth, AccessMiddleware, accessPolicy} middlewareList = append(middlewareList, RateLimitMiddleware...) return slices.Contains(middlewareList, tyName) } diff --git a/internal/middlewares/access_policy_middleware.go b/internal/middlewares/access_policy_middleware.go new file mode 100644 index 0000000..aaef055 --- /dev/null +++ b/internal/middlewares/access_policy_middleware.go @@ -0,0 +1,62 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package middlewares + +import ( + "github.com/jkaninda/goma-gateway/pkg/logger" + "net" + "net/http" +) + +type AccessPolicy struct { + Action string + SourceRanges []string +} + +func (access AccessPolicy) AccessPolicyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + iPs := make(map[string]struct{}) + for _, ip := range access.SourceRanges { + iPs[ip] = struct{}{} + } + // Get the client's IP address + ip, _, err := net.SplitHostPort(getRealIP(r)) + if err != nil { + logger.Error("Unable to parse IP address") + RespondWithError(w, http.StatusUnauthorized, "Unable to parse IP address") + return + } + // Check if the IP is in the blocklist + if access.Action == "DENY" { + if _, ok := iPs[ip]; ok { + logger.Error(" %s: IP address in the blocklist, access not allowed", getRealIP(r)) + RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) + return + } + } + // Check if the IP is in the allowlist + if _, ok := iPs[ip]; !ok { + logger.Error("%s: IP address not allowed ", getRealIP(r)) + RespondWithError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) + return + } + // Continue to the next handler if the authentication is successful + next.ServeHTTP(w, r) + }) + +} diff --git a/internal/routes.go b/internal/routes.go index 0562d96..1ceb709 100644 --- a/internal/routes.go +++ b/internal/routes.go @@ -210,22 +210,22 @@ func attachMiddlewares(rIndex int, route Route, gateway Gateway, router *mux.Rou } if len(middleware) != 0 { // Get Access middlewares if it does exist - accessMiddleware, err := getMiddleware([]string{middleware}, dynamicMiddlewares) + mid, err := getMiddleware([]string{middleware}, dynamicMiddlewares) if err != nil { logger.Error("Error: %v", err.Error()) } else { // Apply access middlewares - if accessMiddleware.Type == AccessMiddleware { + if mid.Type == AccessMiddleware { blM := middlewares.AccessListMiddleware{ Path: route.Path, - List: accessMiddleware.Paths, + List: mid.Paths, } router.Use(blM.AccessMiddleware) } // Apply Rate limit middleware - if slices.Contains(RateLimitMiddleware, accessMiddleware.Type) { - rateLimitMid, err := rateLimitMiddleware(accessMiddleware.Rule) + if slices.Contains(RateLimitMiddleware, mid.Type) { + rateLimitMid, err := rateLimitMiddleware(mid.Rule) if err != nil { logger.Error("Error: %v", err.Error()) } @@ -238,17 +238,33 @@ func attachMiddlewares(rIndex int, route Route, gateway Gateway, router *mux.Rou Hosts: route.Hosts, RedisBased: redisBased, PathBased: true, - Paths: util.AddPrefixPath(route.Path, accessMiddleware.Paths), + Paths: util.AddPrefixPath(route.Path, mid.Paths), } limiter := rateLimit.NewRateLimiterWindow() - // Add rate limit middlewares + // Apply rate limiter middlewares router.Use(limiter.RateLimitMiddleware()) } } + // AccessPolicy + if accessPolicy == mid.Type { + accessPolicy, err := getAccessPoliciesMiddleware(mid.Rule) + if err != nil { + logger.Error("Error: %v, middleware not applied", err.Error()) + } + if len(accessPolicy.SourceRanges) != 0 { + logger.Info("Ips: %v", accessPolicy.SourceRanges) + access := middlewares.AccessPolicy{ + SourceRanges: accessPolicy.SourceRanges, + Action: accessPolicy.Action, + } + router.Use(access.AccessPolicyMiddleware) + } + } } + // Get route authentication middlewares if it does exist routeMiddleware, err := getMiddleware([]string{middleware}, dynamicMiddlewares) if err != nil { diff --git a/internal/types.go b/internal/types.go index 2da2283..d4a2c38 100644 --- a/internal/types.go +++ b/internal/types.go @@ -167,7 +167,14 @@ type Redis struct { Password string `yaml:"password"` } +// ExtraRouteConfig contains additional routes and middlewares directory type ExtraRouteConfig struct { Directory string `yaml:"directory"` Watch bool `yaml:"watch"` } + +// AccessPolicyRuleMiddleware access policy +type AccessPolicyRuleMiddleware struct { + Action string `yaml:"action,omitempty"` // action, ALLOW or DENY + SourceRanges []string `yaml:"sourceRanges"` // list of Ips +} diff --git a/internal/var.go b/internal/var.go index 273dd6e..854295f 100644 --- a/internal/var.go +++ b/internal/var.go @@ -9,6 +9,7 @@ const AccessMiddleware = "access" // access middlewares const BasicAuth = "basic" // basic authentication middlewares const JWTAuth = "jwt" // JWT authentication middlewares const OAuth = "oauth" // OAuth authentication middlewares +const accessPolicy = "accessPolicy" var ( // Round-robin counter