feat: add oauth token validity verification

This commit is contained in:
2024-11-08 12:03:52 +01:00
parent d6e7791cb4
commit bd20895306
12 changed files with 326 additions and 119 deletions

View File

@@ -22,6 +22,11 @@ import (
"github.com/jkaninda/goma-gateway/util"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
"golang.org/x/oauth2/amazon"
"golang.org/x/oauth2/facebook"
"golang.org/x/oauth2/github"
"golang.org/x/oauth2/gitlab"
"golang.org/x/oauth2/google"
"gopkg.in/yaml.v3"
"os"
)
@@ -179,7 +184,7 @@ func initConfig(configFile string) {
"/example-of-jwt",
},
Rule: JWTRuleMiddleware{
URL: "https://www.googleapis.com/auth/userinfo.email",
URL: "https://example.com/auth/userinfo",
RequiredHeaders: []string{
"Authorization",
},
@@ -199,20 +204,41 @@ func initConfig(configFile string) {
},
},
{
Name: "oauth",
Name: "oauth-google",
Type: OAuth,
Paths: []string{
"/protected",
"/example-of-oauth",
},
Rule: OauthRulerMiddleware{
ClientID: "",
ClientSecret: "",
RedirectURL: "",
Scopes: []string{"user"},
ClientID: "xxx",
ClientSecret: "xxx",
Provider: "google",
JWTSecret: "your-strong-jwt-secret | It's optional",
RedirectURL: "http://localhost:8080/callback",
Scopes: []string{"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile"},
Endpoint: OauthEndpoint{},
State: "randomStateString",
},
},
{
Name: "oauth-authentik",
Type: OAuth,
Paths: []string{
"/protected",
"/example-of-oauth",
},
Rule: OauthRulerMiddleware{
ClientID: "xxx",
ClientSecret: "xxx",
RedirectURL: "http://localhost:8080/callback",
Scopes: []string{"email", "openid"},
JWTSecret: "your-strong-jwt-secret | It's optional",
Endpoint: OauthEndpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
AuthURL: "https://authentik.example.com/application/o/authorize/",
TokenURL: "https://authentik.example.com/application/o/token/",
UserInfoURL: "https://authentik.example.com/application/o/userinfo/",
},
State: "randomStateString",
},
@@ -311,21 +337,6 @@ func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) {
}
return *oauthRuler, nil
}
func oauth2Config(oauth OauthRulerMiddleware) *oauth2.Config {
return &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
},
}
}
func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware {
return &OauthRulerMiddleware{
ClientID: oauth.ClientID,
@@ -333,10 +344,51 @@ func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware {
RedirectURL: oauth.RedirectURL,
State: oauth.State,
Scopes: oauth.Scopes,
JWTSecret: oauth.JWTSecret,
Provider: oauth.Provider,
Endpoint: OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
UserInfoURL: oauth.Endpoint.UserInfoURL,
},
}
}
func oauth2Config(oauth *OauthRulerMiddleware) *oauth2.Config {
conf := &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
},
}
switch oauth.Provider {
case "google":
conf.Endpoint = google.Endpoint
if oauth.Endpoint.UserInfoURL == "" {
oauth.Endpoint.UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
}
case "amazon":
conf.Endpoint = amazon.Endpoint
case "facebook":
conf.Endpoint = facebook.Endpoint
if oauth.Endpoint.UserInfoURL == "" {
oauth.Endpoint.UserInfoURL = "https://graph.facebook.com/me"
}
case "github":
conf.Endpoint = github.Endpoint
if oauth.Endpoint.UserInfoURL == "" {
oauth.Endpoint.UserInfoURL = "https://api.github.com/user/repo"
}
case "gitlab":
conf.Endpoint = gitlab.Endpoint
default:
if oauth.Provider != "custom" {
logger.Error("Unknown provider: %s", oauth.Provider)
}
}
return conf
}

View File

@@ -55,13 +55,8 @@ func CORSHandler(cors Cors) mux.MiddlewareFunc {
// ProxyErrorHandler catches backend errors and returns a custom response
func ProxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
logger.Error("Proxy error: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadGateway)
err = json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"code": http.StatusBadGateway,
"message": "The service is currently unavailable. Please try again later.",
})
_, err = w.Write([]byte("Bad Gateway"))
if err != nil {
return
}
@@ -131,27 +126,42 @@ func allowedOrigin(origins []string, origin string) bool {
return false
}
func (oauth OauthRulerMiddleware) callbackHandler(w http.ResponseWriter, r *http.Request) {
// callbackHandler handles oauth callback
func (oauth *OauthRulerMiddleware) callbackHandler(w http.ResponseWriter, r *http.Request) {
oauthConfig := oauth2Config(oauth)
logger.Info("URL State: %s", r.URL.Query().Get("state"))
// Verify the state to protect against CSRF
if r.URL.Query().Get("state") != oauth.State {
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
// Exchange the authorization code for an access token
code := r.URL.Query().Get("code")
token, err := oauthConfig.Exchange(context.Background(), code)
if err != nil {
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
logger.Error("Failed to exchange token: %v", err.Error())
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
return
}
// Get user info from the token
userInfo, err := oauth.getUserInfo(token)
if err != nil {
logger.Error("Error getting user info: %v", err)
http.Error(w, "Error getting user info: ", http.StatusInternalServerError)
return
}
// Generate JWT with user's email
jwtToken, err := createJWT(userInfo.Email, oauth.JWTSecret)
if err != nil {
logger.Error("Error creating JWT: %v", err)
http.Error(w, "Error creating JWT ", http.StatusInternalServerError)
return
}
// Save token to a cookie for simplicity
http.SetCookie(w, &http.Cookie{
Name: "oauth-token",
Value: token.AccessToken,
Name: "goma.JWT",
Value: jwtToken,
Path: oauth.CookiePath,
})

View File

@@ -10,11 +10,16 @@ You may get a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
*/
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"github.com/golang-jwt/jwt"
"github.com/jedib0t/go-pretty/v6/table"
"github.com/jkaninda/goma-gateway/pkg/logger"
"golang.org/x/oauth2"
"net/http"
"time"
)
// printRoute prints routes
@@ -53,3 +58,40 @@ func loadTLS(cert, key string) (*tls.Config, error) {
}
return tlsConfig, nil
}
func (oauth *OauthRulerMiddleware) getUserInfo(token *oauth2.Token) (UserInfo, error) {
oauthConfig := oauth2Config(oauth)
// Call the user info endpoint with the token
client := oauthConfig.Client(context.Background(), token)
resp, err := client.Get(oauth.Endpoint.UserInfoURL)
if err != nil {
return UserInfo{}, err
}
defer resp.Body.Close()
// Parse the user info
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return UserInfo{}, err
}
return userInfo, nil
}
func createJWT(email, jwtSecret string) (string, error) {
// Define JWT claims
claims := jwt.MapClaims{
"email": email,
"exp": jwt.TimeFunc().Add(time.Hour * 24).Unix(), // Token expiration
"iss": "Goma-Gateway", // Issuer claim
}
// Create a new token with HS256 signing method
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Sign the token with a secret
signedToken, err := token.SignedString([]byte(jwtSecret))
if err != nil {
return "", err
}
return signedToken, nil
}

View File

@@ -0,0 +1,59 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package middleware
import (
"github.com/jkaninda/goma-gateway/pkg/logger"
"golang.org/x/oauth2"
"golang.org/x/oauth2/amazon"
"golang.org/x/oauth2/facebook"
"golang.org/x/oauth2/github"
"golang.org/x/oauth2/gitlab"
"golang.org/x/oauth2/google"
)
func oauth2Config(oauth Oauth) *oauth2.Config {
config := &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
},
}
switch oauth.Provider {
case "google":
config.Endpoint = google.Endpoint
case "amazon":
config.Endpoint = amazon.Endpoint
case "facebook":
config.Endpoint = facebook.Endpoint
case "github":
config.Endpoint = github.Endpoint
case "gitlab":
config.Endpoint = gitlab.Endpoint
default:
if oauth.Provider != "custom" {
logger.Error("Unknown provider: %s", oauth.Provider)
}
}
return config
}

View File

@@ -18,39 +18,71 @@
package middleware
import (
"fmt"
"github.com/golang-jwt/jwt"
"github.com/jkaninda/goma-gateway/pkg/logger"
"golang.org/x/oauth2"
"net/http"
"time"
)
func oauth2Config(oauth Oauth) *oauth2.Config {
return &oauth2.Config{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL,
Scopes: oauth.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
},
}
}
func (oauth Oauth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Info("%s: %s Oauth", getRealIP(r), r.URL.Path)
oauthConfig := oauth2Config(oauth)
oauthConf := oauth2Config(oauth)
// Check if the user is authenticated
_, err := r.Cookie("oauth-token")
token, err := r.Cookie("goma.JWT")
if err != nil {
// If no token, redirect to OAuth provider
url := oauthConfig.AuthCodeURL(oauth.State)
url := oauthConf.AuthCodeURL(oauth.State)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return
}
ok, err := validateJWT(token.Value, oauth)
if err != nil {
// If no token, redirect to OAuth provider
url := oauthConf.AuthCodeURL(oauth.State)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return
}
if !ok {
// If no token, redirect to OAuth provider
url := oauthConf.AuthCodeURL(oauth.State)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return
}
//TODO: Check if the token stored in the cookie is valid
// Token exists, proceed with request
next.ServeHTTP(w, r)
})
}
func validateJWT(signedToken string, oauth Oauth) (bool, error) {
// Parse the JWT token and provide the key function
token, err := jwt.Parse(signedToken, func(token *jwt.Token) (interface{}, error) {
// Ensure the signing method is HMAC and specifically HS256
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// Return the shared secret key for validation
return []byte(oauth.JWTSecret), nil
})
// If there's an error or token is invalid, return false
if err != nil || !token.Valid {
return false, fmt.Errorf("token is invalid: %v", err)
}
// Check if token claims are valid
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// Optional: Check token expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Unix(int64(exp), 0).Before(time.Now()) {
return false, fmt.Errorf("token has expired")
}
}
// Token is valid and not expired
return true, nil
}
return false, fmt.Errorf("token is invalid or missing claims")
}

View File

@@ -120,11 +120,13 @@ type Oauth struct {
// Scope specifies optional requested permissions.
Scopes []string
// contains filtered or unexported fields
State string
Origins []string
State string
Origins []string
JWTSecret string
Provider string
}
type OauthEndpoint struct {
AuthURL string
TokenURL string
DeviceAuthURL string
AuthURL string
TokenURL string
UserInfoURL string
}

View File

@@ -16,7 +16,6 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/json"
"fmt"
"github.com/jkaninda/goma-gateway/pkg/logger"
"net/http"
@@ -44,18 +43,12 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
w.Header().Set(accessControlAllowOrigin, r.Header.Get("Origin"))
}
}
// Parse the target backend URL
targetURL, err := url.Parse(proxyRoute.destination)
if err != nil {
logger.Error("Error parsing backend URL: %s", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
err := json.NewEncoder(w).Encode(ErrorResponse{
Message: "Internal server error",
Code: http.StatusInternalServerError,
Success: false,
})
_, err := w.Write([]byte("Internal Server Error"))
if err != nil {
return
}

View File

@@ -132,18 +132,24 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
redirectURL := "/callback" + route.Path
if oauth.RedirectURL != "" {
redirectURL = oauth.RedirectURL
}
amw := middleware.Oauth{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: oauth.RedirectURL + route.Path,
RedirectURL: redirectURL,
Scopes: oauth.Scopes,
Endpoint: middleware.OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
DeviceAuthURL: oauth.Endpoint.DeviceAuthURL,
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
UserInfoURL: oauth.Endpoint.UserInfoURL,
},
State: oauth.State,
Origins: gateway.Cors.Origins,
State: oauth.State,
Origins: gateway.Cors.Origins,
JWTSecret: oauth.JWTSecret,
Provider: oauth.Provider,
}
oauthRuler := oauthRulerMiddleware(amw)
// Check if a cookie path is defined
@@ -154,12 +160,15 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if oauthRuler.RedirectPath == "" {
oauthRuler.RedirectPath = util.ParseRoutePath(route.Path, midPath)
}
if oauthRuler.Provider == "" {
oauthRuler.Provider = "custom"
}
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
// Callback route
r.HandleFunc("/callback"+route.Path, oauthRuler.callbackHandler).Methods("GET")
r.HandleFunc(util.UrlParsePath(redirectURL), oauthRuler.callbackHandler).Methods("GET")
}
default:
if !doesExist(rMiddleware.Type) {

View File

@@ -78,7 +78,8 @@ type OauthRulerMiddleware struct {
// ClientSecret is the application's secret.
ClientSecret string `yaml:"clientSecret"`
// oauth provider google, gitlab, github, amazon, facebook, custom
Provider string `yaml:"provider"`
// Endpoint contains the resource server's token endpoint
Endpoint OauthEndpoint `yaml:"endpoint"`
@@ -93,12 +94,13 @@ type OauthRulerMiddleware struct {
// Scope specifies optional requested permissions.
Scopes []string `yaml:"scopes"`
// contains filtered or unexported fields
State string `yaml:"state"`
State string `yaml:"state"`
JWTSecret string `yaml:"jwtSecret"`
}
type OauthEndpoint struct {
AuthURL string `yaml:"authUrl"`
TokenURL string `yaml:"tokenUrl"`
DeviceAuthURL string `yaml:"deviceAuthUrl"`
AuthURL string `yaml:"authUrl"`
TokenURL string `yaml:"tokenUrl"`
UserInfoURL string `yaml:"userInfoUrl"`
}
type RateLimiter struct {
// ipBased, tokenBased
@@ -242,3 +244,11 @@ type HealthCheckRouteResponse struct {
Status string `json:"status"`
Error string `json:"error"`
}
type UserInfo struct {
Email string `json:"email"`
}
type JWTSecret struct {
ISS string `yaml:"iss"`
Secret string `yaml:"secret"`
}