refactor: refactoring of code

Add graceful shutdown server
This commit is contained in:
Jonas Kaninda
2024-11-15 14:24:35 +01:00
parent 5665ee3dab
commit f1af5c3ce6
26 changed files with 267 additions and 181 deletions

View File

@@ -76,7 +76,7 @@ gateway:
Access-Control-Max-Age: 1728000 Access-Control-Max-Age: 1728000
##### Apply middlewares to the route ##### Apply middlewares to the route
## The name must be unique ## The name must be unique
## List of middleware name ## List of middlewares name
middlewares: middlewares:
- api-forbidden-paths - api-forbidden-paths
# Example of a route | 2 # Example of a route | 2
@@ -103,7 +103,7 @@ gateway:
- api-forbidden-paths - api-forbidden-paths
- basic-auth - basic-auth
#Defines proxy middlewares #Defines proxy middlewares
# middleware name must be unique # middlewares name must be unique
middlewares: middlewares:
# Enable Basic auth authorization based # Enable Basic auth authorization based
- name: basic-auth - name: basic-auth

View File

@@ -52,7 +52,7 @@ func CheckConfig(fileName string) error {
} }
} }
//Check middleware //Check middlewares
for index, mid := range c.Middlewares { for index, mid := range c.Middlewares {
if util.HasWhitespace(mid.Name) { if util.HasWhitespace(mid.Name) {
fmt.Printf("Warning: Middleware contains whitespace: %s | index: [%d], please remove whitespace characters\n", mid.Name, index) fmt.Printf("Warning: Middleware contains whitespace: %s | index: [%d], please remove whitespace characters\n", mid.Name, index)

View File

@@ -17,7 +17,7 @@ limitations under the License.
*/ */
import ( import (
"fmt" "fmt"
"github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util" "github.com/jkaninda/goma-gateway/util"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -330,7 +330,7 @@ func getJWTMiddleware(input interface{}) (JWTRuleMiddleware, error) {
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
} }
if jWTRuler.URL == "" { if jWTRuler.URL == "" {
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty url in jwt auth middleware") return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty url in jwt auth middlewares")
} }
return *jWTRuler, nil return *jWTRuler, nil
@@ -349,7 +349,7 @@ func getBasicAuthMiddleware(input interface{}) (BasicRuleMiddleware, error) {
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
} }
if basicAuth.Username == "" || basicAuth.Password == "" { if basicAuth.Username == "" || basicAuth.Password == "" {
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty username/password in %s middleware", basicAuth) return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty username/password in %s middlewares", basicAuth)
} }
return *basicAuth, nil return *basicAuth, nil
@@ -368,12 +368,12 @@ func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) {
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err) return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
} }
if oauthRuler.ClientID == "" || oauthRuler.ClientSecret == "" || oauthRuler.RedirectURL == "" { if oauthRuler.ClientID == "" || oauthRuler.ClientSecret == "" || oauthRuler.RedirectURL == "" {
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middleware", oauthRuler) return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middlewares", oauthRuler)
} }
return *oauthRuler, nil return *oauthRuler, nil
} }
func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware { func oauthRulerMiddleware(oauth middlewares.Oauth) *OauthRulerMiddleware {
return &OauthRulerMiddleware{ return &OauthRulerMiddleware{
ClientID: oauth.ClientID, ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret, ClientSecret: oauth.ClientSecret,

View File

@@ -14,7 +14,7 @@ func getMiddleware(rules []string, middlewares []Middleware) (Middleware, error)
continue continue
} }
return Middleware{}, errors.New("middleware not found with name: [" + strings.Join(rules, ";") + "]") return Middleware{}, errors.New("middlewares not found with name: [" + strings.Join(rules, ";") + "]")
} }
func doesExist(tyName string) bool { func doesExist(tyName string) bool {
@@ -30,5 +30,5 @@ func GetMiddleware(rule string, middlewares []Middleware) (Middleware, error) {
continue continue
} }
return Middleware{}, errors.New("no middleware found with name " + rule) return Middleware{}, errors.New("no middlewares found with name " + rule)
} }

View File

@@ -105,7 +105,7 @@ func TestReadMiddleware(t *testing.T) {
middlewares := getMiddlewares(t) middlewares := getMiddlewares(t)
m, err := getMiddleware(rules, middlewares) m, err := getMiddleware(rules, middlewares)
if err != nil { if err != nil {
t.Fatalf("Error searching middleware %s", err.Error()) t.Fatalf("Error searching middlewares %s", err.Error())
} }
log.Printf("Middleware: %v\n", m) log.Printf("Middleware: %v\n", m)
@@ -134,10 +134,10 @@ func TestReadMiddleware(t *testing.T) {
} }
log.Printf("OAuth authentification: provider %s\n", oauth.Provider) log.Printf("OAuth authentification: provider %s\n", oauth.Provider)
case AccessMiddleware: case AccessMiddleware:
log.Println("Access middleware") log.Println("Access middlewares")
log.Printf("Access middleware: paths: [%s]\n", middleware.Paths) log.Printf("Access middlewares: paths: [%s]\n", middleware.Paths)
default: default:
t.Errorf("Unknown middleware type %s", middleware.Type) t.Errorf("Unknown middlewares type %s", middleware.Type)
} }
} }
@@ -148,7 +148,7 @@ func TestFoundMiddleware(t *testing.T) {
middlewares := getMiddlewares(t) middlewares := getMiddlewares(t)
middleware, err := GetMiddleware("jwt", middlewares) middleware, err := GetMiddleware("jwt", middlewares)
if err != nil { if err != nil {
t.Errorf("Error getting middleware %v", err) t.Errorf("Error getting middlewares %v", err)
} }
fmt.Println(middleware.Type) fmt.Println(middleware.Type)
} }

View File

@@ -17,9 +17,9 @@
package pkg package pkg
// Middleware defined the route middleware // Middleware defined the route middlewares
type Middleware struct { type Middleware struct {
//Path contains the name of middleware and must be unique //Path contains the name of middlewares and must be unique
Name string `yaml:"name"` Name string `yaml:"name"`
// Type contains authentication types // Type contains authentication types
// //

View File

@@ -1,4 +1,4 @@
package middleware package middlewares
/* /*
Copyright 2024 Jonas Kaninda Copyright 2024 Jonas Kaninda

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"fmt" "fmt"

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"

View File

@@ -1,4 +1,4 @@
package middleware package middlewares
/* /*
* Copyright 2024 Jonas Kaninda * Copyright 2024 Jonas Kaninda

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"encoding/json" "encoding/json"

View File

@@ -1,4 +1,4 @@
package middleware package middlewares
/* /*
Copyright 2024 Jonas Kaninda Copyright 2024 Jonas Kaninda

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"fmt" "fmt"

View File

@@ -1,4 +1,4 @@
package middleware package middlewares
/* /*
Copyright 2024 Jonas Kaninda Copyright 2024 Jonas Kaninda
@@ -16,13 +16,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import ( import (
"errors"
"fmt" "fmt"
"github.com/go-redis/redis_rate/v10"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/redis/go-redis/v9"
"golang.org/x/net/context"
"net/http" "net/http"
"time" "time"
) )
@@ -91,23 +87,3 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
}) })
} }
} }
func redisRateLimiter(clientIP string, rate int) error {
ctx := context.Background()
res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate))
if err != nil {
return err
}
if res.Remaining == 0 {
return errors.New("requests limit exceeded")
}
return nil
}
func InitRedis(addr, password string) {
Rdb = redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
})
limiter = redis_rate.NewLimiter(Rdb)
}

View File

@@ -0,0 +1,46 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package middlewares
import (
"context"
"errors"
"github.com/go-redis/redis_rate/v10"
"github.com/redis/go-redis/v9"
)
func redisRateLimiter(clientIP string, rate int) error {
ctx := context.Background()
res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate))
if err != nil {
return err
}
if res.Remaining == 0 {
return errors.New("requests limit exceeded")
}
return nil
}
func InitRedis(addr, password string) {
Rdb = redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
})
limiter = redis_rate.NewLimiter(Rdb)
}

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"bytes" "bytes"

View File

@@ -15,7 +15,7 @@
* *
*/ */
package middleware package middlewares
import ( import (
"github.com/go-redis/redis_rate/v10" "github.com/go-redis/redis_rate/v10"

View File

@@ -18,7 +18,7 @@ limitations under the License.
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
@@ -38,7 +38,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
if len(proxyRoute.methods) > 0 { if len(proxyRoute.methods) > 0 {
if !slices.Contains(proxyRoute.methods, r.Method) { if !slices.Contains(proxyRoute.methods, r.Method) {
logger.Error("%s Method is not allowed", r.Method) logger.Error("%s Method is not allowed", r.Method)
middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method)) middlewares.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method))
return return
} }
} }
@@ -61,7 +61,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
targetURL, err := url.Parse(proxyRoute.destination) targetURL, err := url.Parse(proxyRoute.destination)
if err != nil { if err != nil {
logger.Error("Error parsing backend URL: %s", err) logger.Error("Error parsing backend URL: %s", err)
middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) middlewares.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
return return
} }
r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) r.Header.Set("X-Forwarded-Host", r.Header.Get("Host"))

40
internal/redis.go Normal file
View File

@@ -0,0 +1,40 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pkg
import (
"github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger"
)
func (gatewayServer GatewayServer) initRedis() error {
if gatewayServer.gateway.Redis.Addr == "" {
return nil
}
logger.Info("Initializing Redis...")
middlewares.InitRedis(gatewayServer.gateway.Redis.Addr, gatewayServer.gateway.Redis.Password)
return nil
}
func (gatewayServer GatewayServer) closeRedis() {
if middlewares.Rdb != nil {
if err := middlewares.Rdb.Close(); err != nil {
logger.Error("Error closing Redis: %v", err)
}
}
}

View File

@@ -17,7 +17,7 @@ limitations under the License.
*/ */
import ( import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/internal/middleware" "github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util" "github.com/jkaninda/goma-gateway/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@@ -34,7 +34,7 @@ func init() {
// Initialize the routes // Initialize the routes
func (gatewayServer GatewayServer) Initialize() *mux.Router { func (gatewayServer GatewayServer) Initialize() *mux.Router {
gateway := gatewayServer.gateway gateway := gatewayServer.gateway
middlewares := gatewayServer.middlewares m := gatewayServer.middlewares
redisBased := false redisBased := false
if len(gateway.Redis.Addr) != 0 { if len(gateway.Redis.Addr) != 0 {
redisBased = true redisBased = true
@@ -62,11 +62,11 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Enable common exploits // Enable common exploits
if gateway.BlockCommonExploits { if gateway.BlockCommonExploits {
logger.Info("Block common exploits enabled") logger.Info("Block common exploits enabled")
r.Use(middleware.BlockExploitsMiddleware) r.Use(middlewares.BlockExploitsMiddleware)
} }
if gateway.RateLimit > 0 { if gateway.RateLimit > 0 {
// Add rate limit middleware to all routes, if defined // Add rate limit middlewares to all routes, if defined
rateLimit := middleware.RateLimit{ rateLimit := middlewares.RateLimit{
Id: "global_rate", //Generate a unique ID for routes Id: "global_rate", //Generate a unique ID for routes
Requests: gateway.RateLimit, Requests: gateway.RateLimit,
Window: time.Minute, // requests per minute Window: time.Minute, // requests per minute
@@ -75,7 +75,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
RedisBased: redisBased, RedisBased: redisBased,
} }
limiter := rateLimit.NewRateLimiterWindow() limiter := rateLimit.NewRateLimiterWindow()
// Add rate limit middleware // Add rate limit middlewares
r.Use(limiter.RateLimitMiddleware()) r.Use(limiter.RateLimitMiddleware())
} }
for rIndex, route := range gateway.Routes { for rIndex, route := range gateway.Routes {
@@ -87,14 +87,14 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Apply middlewares to route // Apply middlewares to route
for _, mid := range route.Middlewares { for _, mid := range route.Middlewares {
if mid != "" { if mid != "" {
// Get Access middleware if it does exist // Get Access middlewares if it does exist
accessMiddleware, err := getMiddleware([]string{mid}, middlewares) accessMiddleware, err := getMiddleware([]string{mid}, m)
if err != nil { if err != nil {
logger.Error("Error: %v", err.Error()) logger.Error("Error: %v", err.Error())
} else { } else {
// Apply access middleware // Apply access middlewares
if accessMiddleware.Type == AccessMiddleware { if accessMiddleware.Type == AccessMiddleware {
blM := middleware.AccessListMiddleware{ blM := middlewares.AccessListMiddleware{
Path: route.Path, Path: route.Path,
List: accessMiddleware.Paths, List: accessMiddleware.Paths,
} }
@@ -103,10 +103,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
} }
// Get route authentication middleware if it does exist // Get route authentication middlewares if it does exist
rMiddleware, err := getMiddleware([]string{mid}, middlewares) rMiddleware, err := getMiddleware([]string{mid}, m)
if err != nil { if err != nil {
//Error: middleware not found //Error: middlewares not found
logger.Error("Error: %v", err.Error()) logger.Error("Error: %v", err.Error())
} else { } else {
for _, midPath := range rMiddleware.Paths { for _, midPath := range rMiddleware.Paths {
@@ -122,20 +122,20 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter() secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter()
//callBackRouter := r.PathPrefix(util.ParseRoutePath(route.Path, "/callback")).Subrouter() //callBackRouter := r.PathPrefix(util.ParseRoutePath(route.Path, "/callback")).Subrouter()
//Check Authentication middleware //Check Authentication middlewares
switch rMiddleware.Type { switch rMiddleware.Type {
case BasicAuth: case BasicAuth:
basicAuth, err := getBasicAuthMiddleware(rMiddleware.Rule) basicAuth, err := getBasicAuthMiddleware(rMiddleware.Rule)
if err != nil { if err != nil {
logger.Error("Error: %s", err.Error()) logger.Error("Error: %s", err.Error())
} else { } else {
amw := middleware.AuthBasic{ amw := middlewares.AuthBasic{
Username: basicAuth.Username, Username: basicAuth.Username,
Password: basicAuth.Password, Password: basicAuth.Password,
Headers: nil, Headers: nil,
Params: nil, Params: nil,
} }
// Apply JWT authentication middleware // Apply JWT authentication middlewares
secureRouter.Use(amw.AuthMiddleware) secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors)) secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
@@ -146,14 +146,14 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if err != nil { if err != nil {
logger.Error("Error: %s", err.Error()) logger.Error("Error: %s", err.Error())
} else { } else {
amw := middleware.JwtAuth{ amw := middlewares.JwtAuth{
AuthURL: jwt.URL, AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders, RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers, Headers: jwt.Headers,
Params: jwt.Params, Params: jwt.Params,
Origins: gateway.Cors.Origins, Origins: gateway.Cors.Origins,
} }
// Apply JWT authentication middleware // Apply JWT authentication middlewares
secureRouter.Use(amw.AuthMiddleware) secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors)) secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
@@ -169,12 +169,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if oauth.RedirectURL != "" { if oauth.RedirectURL != "" {
redirectURL = oauth.RedirectURL redirectURL = oauth.RedirectURL
} }
amw := middleware.Oauth{ amw := middlewares.Oauth{
ClientID: oauth.ClientID, ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret, ClientSecret: oauth.ClientSecret,
RedirectURL: redirectURL, RedirectURL: redirectURL,
Scopes: oauth.Scopes, Scopes: oauth.Scopes,
Endpoint: middleware.OauthEndpoint{ Endpoint: middlewares.OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL, AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL, TokenURL: oauth.Endpoint.TokenURL,
UserInfoURL: oauth.Endpoint.UserInfoURL, UserInfoURL: oauth.Endpoint.UserInfoURL,
@@ -205,7 +205,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
default: default:
if !doesExist(rMiddleware.Type) { if !doesExist(rMiddleware.Type) {
logger.Error("Unknown middleware type %s", rMiddleware.Type) logger.Error("Unknown middlewares type %s", rMiddleware.Type)
} }
} }
@@ -214,7 +214,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
} else { } else {
logger.Error("Error, middleware path is empty") logger.Error("Error, middlewares path is empty")
logger.Error("Middleware ignored") logger.Error("Middleware ignored")
} }
} }
@@ -234,7 +234,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Enable common exploits // Enable common exploits
if route.BlockCommonExploits { if route.BlockCommonExploits {
logger.Info("Block common exploits enabled") logger.Info("Block common exploits enabled")
router.Use(middleware.BlockExploitsMiddleware) router.Use(middlewares.BlockExploitsMiddleware)
} }
id := string(rune(rIndex)) id := string(rune(rIndex))
if len(route.Name) != 0 { if len(route.Name) != 0 {
@@ -243,7 +243,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
// Apply route rate limit // Apply route rate limit
if route.RateLimit > 0 { if route.RateLimit > 0 {
rateLimit := middleware.RateLimit{ rateLimit := middlewares.RateLimit{
Id: id, // Use route index as ID Id: id, // Use route index as ID
Requests: route.RateLimit, Requests: route.RateLimit,
Window: time.Minute, // requests per minute Window: time.Minute, // requests per minute
@@ -252,7 +252,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
RedisBased: redisBased, RedisBased: redisBased,
} }
limiter := rateLimit.NewRateLimiterWindow() limiter := rateLimit.NewRateLimiterWindow()
// Add rate limit middleware // Add rate limit middlewares
router.Use(limiter.RateLimitMiddleware()) router.Use(limiter.RateLimitMiddleware())
} }
// Apply route Cors // Apply route Cors
@@ -272,9 +272,9 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Prometheus endpoint // Prometheus endpoint
router.Use(pr.prometheusMiddleware) router.Use(pr.prometheusMiddleware)
} }
// Apply route Error interceptor middleware // Apply route Error interceptor middlewares
if len(route.InterceptErrors) != 0 { if len(route.InterceptErrors) != 0 {
interceptErrors := middleware.InterceptErrors{ interceptErrors := middlewares.InterceptErrors{
Origins: route.Cors.Origins, Origins: route.Cors.Origins,
Errors: route.InterceptErrors, Errors: route.InterceptErrors,
} }
@@ -286,10 +286,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
} }
} }
// Apply global Cors middlewares // Apply global Cors middlewares
r.Use(CORSHandler(gateway.Cors)) // Apply CORS middleware r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares
// Apply errorInterceptor middleware // Apply errorInterceptor middlewares
if len(gateway.InterceptErrors) != 0 { if len(gateway.InterceptErrors) != 0 {
interceptErrors := middleware.InterceptErrors{ interceptErrors := middlewares.InterceptErrors{
Errors: gateway.InterceptErrors, Errors: gateway.InterceptErrors,
Origins: gateway.Cors.Origins, Origins: gateway.Cors.Origins,
} }

View File

@@ -53,6 +53,6 @@ type Route struct {
InterceptErrors []int `yaml:"interceptErrors"` InterceptErrors []int `yaml:"interceptErrors"`
// BlockCommonExploits enable, disable block common exploits // BlockCommonExploits enable, disable block common exploits
BlockCommonExploits bool `yaml:"blockCommonExploits"` BlockCommonExploits bool `yaml:"blockCommonExploits"`
// Middlewares Defines route middleware from Middleware names // Middlewares Defines route middlewares from Middleware names
Middlewares []string `yaml:"middlewares"` Middlewares []string `yaml:"middlewares"`
} }

View File

@@ -18,111 +18,97 @@ limitations under the License.
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"github.com/jkaninda/goma-gateway/internal/middleware"
"github.com/jkaninda/goma-gateway/pkg/logger" "github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/redis/go-redis/v9"
"net/http" "net/http"
"os" "os"
"sync" "os/signal"
"syscall"
"time" "time"
) )
// Start starts the server // Start / Start starts the server
func (gatewayServer GatewayServer) Start(ctx context.Context) error { func (gatewayServer GatewayServer) Start(ctx context.Context) error {
logger.Info("Initializing routes...") logger.Info("Initializing routes...")
route := gatewayServer.Initialize() route := gatewayServer.Initialize()
gateway := gatewayServer.gateway logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares))
logger.Debug("Routes count=%d Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) if err := gatewayServer.initRedis(); err != nil {
logger.Info("Initializing routes...done") return fmt.Errorf("failed to initialize Redis: %w", err)
if len(gateway.Redis.Addr) != 0 { }
middleware.InitRedis(gateway.Redis.Addr, gateway.Redis.Password) defer gatewayServer.closeRedis()
defer func(Rdb *redis.Client) {
err := Rdb.Close() tlsConfig, listenWithTLS, err := gatewayServer.initTLS()
if err != nil { if err != nil {
logger.Error("Redis connection closed with error: %v", err) return err
}
}(middleware.Rdb)
} }
tlsConfig := &tls.Config{}
var listenWithTLS = false
if cert := gatewayServer.gateway.SSLCertFile; cert != "" && gatewayServer.gateway.SSLKeyFile != "" {
tlsConf, err := loadTLS(cert, gatewayServer.gateway.SSLKeyFile)
if err != nil {
return err
}
tlsConfig = tlsConf
listenWithTLS = true
}
// HTTP Server
httpServer := &http.Server{
Addr: ":8080",
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: route, // Pass our instance of gorilla/mux in.
}
// HTTPS Server
httpsServer := &http.Server{
Addr: ":8443",
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: route, // Pass our instance of gorilla/mux in.
TLSConfig: tlsConfig,
}
if !gatewayServer.gateway.DisableDisplayRouteOnStart { if !gatewayServer.gateway.DisableDisplayRouteOnStart {
printRoute(gatewayServer.gateway.Routes) printRoute(gatewayServer.gateway.Routes)
} }
// Set KeepAlive
httpServer.SetKeepAlivesEnabled(!gatewayServer.gateway.DisableKeepAlive)
go func() {
logger.Info("Starting HTTP server listen=0.0.0.0:8080")
if err := httpServer.ListenAndServe(); err != nil {
logger.Fatal("Error starting Goma Gateway HTTP server: %v", err)
}
}()
go func() {
if listenWithTLS {
logger.Info("Starting HTTPS server listen=0.0.0.0:8443")
if err := httpsServer.ListenAndServeTLS("", ""); err != nil {
logger.Fatal("Error starting Goma Gateway HTTPS server: %v", err)
}
}
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
_, err := fmt.Fprintf(os.Stderr, "error shutting down HTTP server: %s\n", err)
if err != nil {
return
}
}
}()
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if listenWithTLS {
if err := httpsServer.Shutdown(shutdownCtx); err != nil {
_, err := fmt.Fprintf(os.Stderr, "error shutting HTTPS server: %s\n", err)
if err != nil {
return
}
}
}
}()
wg.Wait()
return nil
httpServer := gatewayServer.createServer(":8080", route, nil)
httpsServer := gatewayServer.createServer(":8443", route, tlsConfig)
// Start HTTP/HTTPS servers
if err := gatewayServer.startServers(httpServer, httpsServer, listenWithTLS); err != nil {
return err
}
// Handle graceful shutdown
return gatewayServer.gracefulShutdown(ctx, httpServer, httpsServer, listenWithTLS)
}
func (gatewayServer GatewayServer) createServer(addr string, handler http.Handler, tlsConfig *tls.Config) *http.Server {
return &http.Server{
Addr: addr,
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: handler,
TLSConfig: tlsConfig,
}
}
func (gatewayServer GatewayServer) startServers(httpServer, httpsServer *http.Server, listenWithTLS bool) error {
go func() {
logger.Info("Starting HTTP server on 0.0.0.0:8080")
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Fatal("HTTP server error: %v", err)
}
}()
if listenWithTLS {
go func() {
logger.Info("Starting HTTPS server on 0.0.0.0:8443")
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Fatal("HTTPS server error: %v", err)
}
}()
}
return nil
}
func (gatewayServer GatewayServer) gracefulShutdown(ctx context.Context, httpServer, httpsServer *http.Server, listenWithTLS bool) error {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Info("Shutting down Goma Gateway...")
shutdownCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down HTTP server: %v", err)
}
if listenWithTLS {
if err := httpsServer.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down HTTPS server: %v", err)
}
}
logger.Info("Goma Gateway shut down successfully")
return nil
} }

36
internal/tls.go Normal file
View File

@@ -0,0 +1,36 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pkg
import (
"crypto/tls"
"fmt"
)
func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) {
cert, key := gatewayServer.gateway.SSLCertFile, gatewayServer.gateway.SSLKeyFile
if cert == "" || key == "" {
return nil, false, nil
}
tlsConfig, err := loadTLS(cert, key)
if err != nil {
return nil, false, fmt.Errorf("failed to load TLS config: %w", err)
}
return tlsConfig, true, nil
}

View File

@@ -4,10 +4,10 @@ const ConfigDir = "/etc/goma/" // Default config
const ConfigFile = "/etc/goma/goma.yml" // Default configuration file const ConfigFile = "/etc/goma/goma.yml" // Default configuration file
const accessControlAllowOrigin = "Access-Control-Allow-Origin" // Cors const accessControlAllowOrigin = "Access-Control-Allow-Origin" // Cors
const gatewayName = "Goma Gateway" const gatewayName = "Goma Gateway"
const AccessMiddleware = "access" // access middleware const AccessMiddleware = "access" // access middlewares
const BasicAuth = "basic" // basic authentication middleware const BasicAuth = "basic" // basic authentication middlewares
const JWTAuth = "jwt" // JWT authentication middleware const JWTAuth = "jwt" // JWT authentication middlewares
const OAuth = "oauth" // OAuth authentication middleware const OAuth = "oauth" // OAuth authentication middlewares
// Round-robin counter // Round-robin counter
var counter uint32 var counter uint32

View File

@@ -15,7 +15,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import "github.com/jkaninda/goma-gateway/cmd" import (
"github.com/jkaninda/goma-gateway/cmd"
)
func main() { func main() {