diff --git a/go.mod b/go.mod index 8387f1d..b188526 100644 --- a/go.mod +++ b/go.mod @@ -20,8 +20,10 @@ require ( github.com/go-redis/redis v6.15.9+incompatible // indirect github.com/go-redis/redis_rate v6.5.0+incompatible // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jinzhu/copier v0.4.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.7.0 // indirect diff --git a/go.sum b/go.sum index 957e728..15c56be 100644 --- a/go.sum +++ b/go.sum @@ -17,8 +17,12 @@ github.com/jedib0t/go-pretty v4.3.0+incompatible h1:CGs8AVhEKg/n9YbUenWmNStRW2PH github.com/jedib0t/go-pretty v4.3.0+incompatible/go.mod h1:XemHduiw8R651AF9Pt4FwCTKeG3oo7hrHJAoznj9nag= github.com/jedib0t/go-pretty/v6 v6.6.1 h1:iJ65Xjb680rHcikRj6DSIbzCex2huitmc7bDtxYVWyc= github.com/jedib0t/go-pretty/v6 v6.6.1/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -31,6 +35,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/internal/config.go b/internal/config.go index ffdbed9..9ba86e5 100644 --- a/internal/config.go +++ b/internal/config.go @@ -17,9 +17,11 @@ limitations under the License. */ import ( "fmt" + "github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/util" "github.com/spf13/cobra" + "golang.org/x/oauth2" "gopkg.in/yaml.v3" "os" ) @@ -175,6 +177,25 @@ func initConfig(configFile string) { "/actuator/*", }, }, + { + Name: "oauth", + Type: OAuth, + Paths: []string{ + "/protected", + "/example-of-oauth", + }, + Rule: OauthRulerMiddleware{ + ClientID: "", + ClientSecret: "", + RedirectURL: "", + Scopes: []string{"user"}, + Endpoint: OauthEndpoint{ + AuthURL: "https://accounts.google.com/o/oauth2/auth", + TokenURL: "https://oauth2.googleapis.com/token", + }, + State: "randomStateString", + }, + }, }, } yamlData, err := yaml.Marshal(&conf) @@ -250,3 +271,51 @@ func getBasicAuthMiddleware(input interface{}) (BasicRuleMiddleware, error) { } return *basicAuth, nil } + +// oAuthMiddleware returns OauthRulerMiddleware, error +func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) { + oauthRuler := new(OauthRulerMiddleware) + var bytes []byte + bytes, err := yaml.Marshal(input) + if err != nil { + return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + err = yaml.Unmarshal(bytes, oauthRuler) + if err != nil { + return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) + } + if oauthRuler.ClientID == "" || oauthRuler.ClientSecret == "" || oauthRuler.RedirectURL == "" { + return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middleware", oauthRuler) + + } + return *oauthRuler, nil +} + +func oauth2Config(oauth OauthRulerMiddleware) *oauth2.Config { + return &oauth2.Config{ + ClientID: oauth.ClientID, + ClientSecret: oauth.ClientSecret, + RedirectURL: oauth.RedirectURL, + Scopes: oauth.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, + }, + } +} + +func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware { + return &OauthRulerMiddleware{ + ClientID: oauth.ClientID, + ClientSecret: oauth.ClientSecret, + RedirectURL: oauth.RedirectURL, + State: oauth.State, + Scopes: oauth.Scopes, + Endpoint: OauthEndpoint{ + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, + }, + } +} diff --git a/internal/handler.go b/internal/handler.go index f18317c..84e690d 100644 --- a/internal/handler.go +++ b/internal/handler.go @@ -16,6 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( + "context" "encoding/json" "github.com/gorilla/mux" "github.com/jkaninda/goma-gateway/pkg/logger" @@ -130,3 +131,30 @@ func allowedOrigin(origins []string, origin string) bool { return false } +func (oauth OauthRulerMiddleware) callbackHandler(w http.ResponseWriter, r *http.Request) { + oauthConfig := oauth2Config(oauth) + logger.Info("URL State: %s", r.URL.Query().Get("state")) + // Verify the state to protect against CSRF + if r.URL.Query().Get("state") != oauth.State { + http.Error(w, "Invalid state", http.StatusBadRequest) + return + } + + // Exchange the authorization code for an access token + code := r.URL.Query().Get("code") + token, err := oauthConfig.Exchange(context.Background(), code) + if err != nil { + http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) + return + } + + // Save token to a cookie for simplicity + http.SetCookie(w, &http.Cookie{ + Name: "oauth-token", + Value: token.AccessToken, + Path: oauth.CookiePath, + }) + + // Redirect to the home page or another protected route + http.Redirect(w, r, oauth.RedirectPath, http.StatusSeeOther) +} diff --git a/internal/middleware/oauth-middleware.go b/internal/middleware/oauth-middleware.go new file mode 100644 index 0000000..16c977c --- /dev/null +++ b/internal/middleware/oauth-middleware.go @@ -0,0 +1,56 @@ +/* + * Copyright 2024 Jonas Kaninda + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package middleware + +import ( + "github.com/jkaninda/goma-gateway/pkg/logger" + "golang.org/x/oauth2" + "net/http" +) + +func oauth2Config(oauth Oauth) *oauth2.Config { + return &oauth2.Config{ + ClientID: oauth.ClientID, + ClientSecret: oauth.ClientSecret, + RedirectURL: oauth.RedirectURL, + Scopes: oauth.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, + }, + } +} +func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger.Info("%s: %s Oauth", getRealIP(r), r.URL.Path) + oauthConfig := oauth2Config(oauth) + // Check if the user is authenticated + _, err := r.Cookie("oauth-token") + if err != nil { + // If no token, redirect to OAuth provider + url := oauthConfig.AuthCodeURL(oauth.State) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + //TODO: Check if the token stored in the cookie is valid + + // Token exists, proceed with request + next.ServeHTTP(w, r) + }) +} diff --git a/internal/middleware/types.go b/internal/middleware/types.go index 86a5b29..94b1d71 100644 --- a/internal/middleware/types.go +++ b/internal/middleware/types.go @@ -107,3 +107,24 @@ type responseRecorder struct { statusCode int body *bytes.Buffer } +type Oauth struct { + // ClientID is the application's ID. + ClientID string + // ClientSecret is the application's secret. + ClientSecret string + // Endpoint contains the resource server's token endpoint + Endpoint OauthEndpoint + // RedirectURL is the URL to redirect users going through + // the OAuth flow, after the resource owner's URLs. + RedirectURL string + // Scope specifies optional requested permissions. + Scopes []string + // contains filtered or unexported fields + State string + Origins []string +} +type OauthEndpoint struct { + AuthURL string + TokenURL string + DeviceAuthURL string +} diff --git a/internal/route.go b/internal/route.go index 5f4a60f..7c70188 100644 --- a/internal/route.go +++ b/internal/route.go @@ -88,6 +88,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { cors: route.Cors, } secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter() + //callBackRouter := r.PathPrefix(util.ParseRoutePath(route.Path, "/callback")).Subrouter() //Check Authentication middleware switch rMiddleware.Type { case BasicAuth: @@ -126,9 +127,40 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler } - case "OAuth": - logger.Error("OAuth is not yet implemented") - logger.Info("Auth middleware ignored") + case OAuth, "openid": + oauth, err := oAuthMiddleware(rMiddleware.Rule) + if err != nil { + logger.Error("Error: %s", err.Error()) + } else { + amw := middleware.Oauth{ + ClientID: oauth.ClientID, + ClientSecret: oauth.ClientSecret, + RedirectURL: oauth.RedirectURL + route.Path, + Scopes: oauth.Scopes, + Endpoint: middleware.OauthEndpoint{ + AuthURL: oauth.Endpoint.AuthURL, + TokenURL: oauth.Endpoint.TokenURL, + DeviceAuthURL: oauth.Endpoint.DeviceAuthURL, + }, + State: oauth.State, + Origins: gateway.Cors.Origins, + } + oauthRuler := oauthRulerMiddleware(amw) + // Check if a cookie path is defined + if oauthRuler.CookiePath == "" { + oauthRuler.CookiePath = route.Path + } + // Check if a RedirectPath is defined + if oauthRuler.RedirectPath == "" { + oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, midPath) + } + secureRouter.Use(amw.AuthMiddleware) + secureRouter.Use(CORSHandler(route.Cors)) + secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler + secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler + // Callback route + r.HandleFunc("/callback"+route.Path, oauthRuler.callbackHandler).Methods("GET") + } default: if !doesExist(rMiddleware.Type) { logger.Error("Unknown middleware type %s", rMiddleware.Type) diff --git a/internal/types.go b/internal/types.go index 9564441..f36ea97 100644 --- a/internal/types.go +++ b/internal/types.go @@ -72,6 +72,34 @@ type JWTRuleMiddleware struct { //e.g: Header X-Auth-UserId to query userId Params map[string]string `yaml:"params"` } +type OauthRulerMiddleware struct { + // ClientID is the application's ID. + ClientID string `yaml:"clientId"` + + // ClientSecret is the application's secret. + ClientSecret string `yaml:"clientSecret"` + + // Endpoint contains the resource server's token endpoint + Endpoint OauthEndpoint `yaml:"endpoint"` + + // RedirectURL is the URL to redirect users going through + // the OAuth flow, after the resource owner's URLs. + RedirectURL string `yaml:"redirectUrl"` + // RedirectPath is the PATH to redirect users after authentication, e.g: /my-protected-path/dashboard + RedirectPath string `yaml:"redirectPath"` + //CookiePath e.g: /my-protected-path or / || by default is applied on a route path + CookiePath string `yaml:"cookiePath"` + + // Scope specifies optional requested permissions. + Scopes []string `yaml:"scopes"` + // contains filtered or unexported fields + State string `yaml:"state"` +} +type OauthEndpoint struct { + AuthURL string `yaml:"authUrl"` + TokenURL string `yaml:"tokenUrl"` + DeviceAuthURL string `yaml:"deviceAuthUrl"` +} type RateLimiter struct { // ipBased, tokenBased Type string `yaml:"type"` diff --git a/internal/var.go b/internal/var.go index 8eb1c1c..c696632 100644 --- a/internal/var.go +++ b/internal/var.go @@ -7,4 +7,4 @@ const gatewayName = "Goma Gateway" const AccessMiddleware = "access" // access middleware const BasicAuth = "basic" // basic authentication middleware const JWTAuth = "jwt" // JWT authentication middleware -const OAuth = "OAuth" // OAuth authentication middleware +const OAuth = "oauth" // OAuth authentication middleware