refactor: refactoring of code

Add graceful shutdown server
This commit is contained in:
Jonas Kaninda
2024-11-15 14:24:35 +01:00
parent 5665ee3dab
commit f1af5c3ce6
26 changed files with 267 additions and 181 deletions

View File

@@ -76,7 +76,7 @@ gateway:
Access-Control-Max-Age: 1728000
##### Apply middlewares to the route
## The name must be unique
## List of middleware name
## List of middlewares name
middlewares:
- api-forbidden-paths
# Example of a route | 2
@@ -103,7 +103,7 @@ gateway:
- api-forbidden-paths
- basic-auth
#Defines proxy middlewares
# middleware name must be unique
# middlewares name must be unique
middlewares:
# Enable Basic auth authorization based
- name: basic-auth

View File

@@ -52,7 +52,7 @@ func CheckConfig(fileName string) error {
}
}
//Check middleware
//Check middlewares
for index, mid := range c.Middlewares {
if util.HasWhitespace(mid.Name) {
fmt.Printf("Warning: Middleware contains whitespace: %s | index: [%d], please remove whitespace characters\n", mid.Name, index)

View File

@@ -17,7 +17,7 @@ limitations under the License.
*/
import (
"fmt"
"github.com/jkaninda/goma-gateway/internal/middleware"
"github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util"
"golang.org/x/oauth2"
@@ -330,7 +330,7 @@ func getJWTMiddleware(input interface{}) (JWTRuleMiddleware, error) {
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
}
if jWTRuler.URL == "" {
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty url in jwt auth middleware")
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty url in jwt auth middlewares")
}
return *jWTRuler, nil
@@ -349,7 +349,7 @@ func getBasicAuthMiddleware(input interface{}) (BasicRuleMiddleware, error) {
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
}
if basicAuth.Username == "" || basicAuth.Password == "" {
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty username/password in %s middleware", basicAuth)
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty username/password in %s middlewares", basicAuth)
}
return *basicAuth, nil
@@ -368,12 +368,12 @@ func oAuthMiddleware(input interface{}) (OauthRulerMiddleware, error) {
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
}
if oauthRuler.ClientID == "" || oauthRuler.ClientSecret == "" || oauthRuler.RedirectURL == "" {
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middleware", oauthRuler)
return OauthRulerMiddleware{}, fmt.Errorf("error parsing yaml: empty clientId/secretId in %s middlewares", oauthRuler)
}
return *oauthRuler, nil
}
func oauthRulerMiddleware(oauth middleware.Oauth) *OauthRulerMiddleware {
func oauthRulerMiddleware(oauth middlewares.Oauth) *OauthRulerMiddleware {
return &OauthRulerMiddleware{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,

View File

@@ -14,7 +14,7 @@ func getMiddleware(rules []string, middlewares []Middleware) (Middleware, error)
continue
}
return Middleware{}, errors.New("middleware not found with name: [" + strings.Join(rules, ";") + "]")
return Middleware{}, errors.New("middlewares not found with name: [" + strings.Join(rules, ";") + "]")
}
func doesExist(tyName string) bool {
@@ -30,5 +30,5 @@ func GetMiddleware(rule string, middlewares []Middleware) (Middleware, error) {
continue
}
return Middleware{}, errors.New("no middleware found with name " + rule)
return Middleware{}, errors.New("no middlewares found with name " + rule)
}

View File

@@ -105,7 +105,7 @@ func TestReadMiddleware(t *testing.T) {
middlewares := getMiddlewares(t)
m, err := getMiddleware(rules, middlewares)
if err != nil {
t.Fatalf("Error searching middleware %s", err.Error())
t.Fatalf("Error searching middlewares %s", err.Error())
}
log.Printf("Middleware: %v\n", m)
@@ -134,10 +134,10 @@ func TestReadMiddleware(t *testing.T) {
}
log.Printf("OAuth authentification: provider %s\n", oauth.Provider)
case AccessMiddleware:
log.Println("Access middleware")
log.Printf("Access middleware: paths: [%s]\n", middleware.Paths)
log.Println("Access middlewares")
log.Printf("Access middlewares: paths: [%s]\n", middleware.Paths)
default:
t.Errorf("Unknown middleware type %s", middleware.Type)
t.Errorf("Unknown middlewares type %s", middleware.Type)
}
}
@@ -148,7 +148,7 @@ func TestFoundMiddleware(t *testing.T) {
middlewares := getMiddlewares(t)
middleware, err := GetMiddleware("jwt", middlewares)
if err != nil {
t.Errorf("Error getting middleware %v", err)
t.Errorf("Error getting middlewares %v", err)
}
fmt.Println(middleware.Type)
}

View File

@@ -17,9 +17,9 @@
package pkg
// Middleware defined the route middleware
// Middleware defined the route middlewares
type Middleware struct {
//Path contains the name of middleware and must be unique
//Path contains the name of middlewares and must be unique
Name string `yaml:"name"`
// Type contains authentication types
//

View File

@@ -1,4 +1,4 @@
package middleware
package middlewares
/*
Copyright 2024 Jonas Kaninda

View File

@@ -15,7 +15,7 @@
*
*/
package middleware
package middlewares
import (
"fmt"

View File

@@ -15,7 +15,7 @@
*
*/
package middleware
package middlewares
import (
"github.com/jkaninda/goma-gateway/pkg/logger"

View File

@@ -1,4 +1,4 @@
package middleware
package middlewares
/*
* Copyright 2024 Jonas Kaninda

View File

@@ -15,7 +15,7 @@
*
*/
package middleware
package middlewares
import (
"encoding/json"

View File

@@ -1,4 +1,4 @@
package middleware
package middlewares
/*
Copyright 2024 Jonas Kaninda

View File

@@ -15,7 +15,7 @@
*
*/
package middleware
package middlewares
import (
"fmt"

View File

@@ -1,4 +1,4 @@
package middleware
package middlewares
/*
Copyright 2024 Jonas Kaninda
@@ -16,13 +16,9 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"errors"
"fmt"
"github.com/go-redis/redis_rate/v10"
"github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/redis/go-redis/v9"
"golang.org/x/net/context"
"net/http"
"time"
)
@@ -91,23 +87,3 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
})
}
}
func redisRateLimiter(clientIP string, rate int) error {
ctx := context.Background()
res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate))
if err != nil {
return err
}
if res.Remaining == 0 {
return errors.New("requests limit exceeded")
}
return nil
}
func InitRedis(addr, password string) {
Rdb = redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
})
limiter = redis_rate.NewLimiter(Rdb)
}

View File

@@ -0,0 +1,46 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package middlewares
import (
"context"
"errors"
"github.com/go-redis/redis_rate/v10"
"github.com/redis/go-redis/v9"
)
func redisRateLimiter(clientIP string, rate int) error {
ctx := context.Background()
res, err := limiter.Allow(ctx, clientIP, redis_rate.PerMinute(rate))
if err != nil {
return err
}
if res.Remaining == 0 {
return errors.New("requests limit exceeded")
}
return nil
}
func InitRedis(addr, password string) {
Rdb = redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
})
limiter = redis_rate.NewLimiter(Rdb)
}

View File

@@ -15,7 +15,7 @@
*
*/
package middleware
package middlewares
import (
"bytes"

View File

@@ -15,7 +15,7 @@
*
*/
package middleware
package middlewares
import (
"github.com/go-redis/redis_rate/v10"

View File

@@ -18,7 +18,7 @@ limitations under the License.
import (
"crypto/tls"
"fmt"
"github.com/jkaninda/goma-gateway/internal/middleware"
"github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger"
"net/http"
"net/http/httputil"
@@ -38,7 +38,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
if len(proxyRoute.methods) > 0 {
if !slices.Contains(proxyRoute.methods, r.Method) {
logger.Error("%s Method is not allowed", r.Method)
middleware.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method))
middlewares.RespondWithError(w, http.StatusMethodNotAllowed, fmt.Sprintf("%d %s method is not allowed", http.StatusMethodNotAllowed, r.Method))
return
}
}
@@ -61,7 +61,7 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
targetURL, err := url.Parse(proxyRoute.destination)
if err != nil {
logger.Error("Error parsing backend URL: %s", err)
middleware.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
middlewares.RespondWithError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
return
}
r.Header.Set("X-Forwarded-Host", r.Header.Get("Host"))

40
internal/redis.go Normal file
View File

@@ -0,0 +1,40 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pkg
import (
"github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger"
)
func (gatewayServer GatewayServer) initRedis() error {
if gatewayServer.gateway.Redis.Addr == "" {
return nil
}
logger.Info("Initializing Redis...")
middlewares.InitRedis(gatewayServer.gateway.Redis.Addr, gatewayServer.gateway.Redis.Password)
return nil
}
func (gatewayServer GatewayServer) closeRedis() {
if middlewares.Rdb != nil {
if err := middlewares.Rdb.Close(); err != nil {
logger.Error("Error closing Redis: %v", err)
}
}
}

View File

@@ -17,7 +17,7 @@ limitations under the License.
*/
import (
"github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/internal/middleware"
"github.com/jkaninda/goma-gateway/internal/middlewares"
"github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util"
"github.com/prometheus/client_golang/prometheus"
@@ -34,7 +34,7 @@ func init() {
// Initialize the routes
func (gatewayServer GatewayServer) Initialize() *mux.Router {
gateway := gatewayServer.gateway
middlewares := gatewayServer.middlewares
m := gatewayServer.middlewares
redisBased := false
if len(gateway.Redis.Addr) != 0 {
redisBased = true
@@ -62,11 +62,11 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Enable common exploits
if gateway.BlockCommonExploits {
logger.Info("Block common exploits enabled")
r.Use(middleware.BlockExploitsMiddleware)
r.Use(middlewares.BlockExploitsMiddleware)
}
if gateway.RateLimit > 0 {
// Add rate limit middleware to all routes, if defined
rateLimit := middleware.RateLimit{
// Add rate limit middlewares to all routes, if defined
rateLimit := middlewares.RateLimit{
Id: "global_rate", //Generate a unique ID for routes
Requests: gateway.RateLimit,
Window: time.Minute, // requests per minute
@@ -75,7 +75,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
RedisBased: redisBased,
}
limiter := rateLimit.NewRateLimiterWindow()
// Add rate limit middleware
// Add rate limit middlewares
r.Use(limiter.RateLimitMiddleware())
}
for rIndex, route := range gateway.Routes {
@@ -87,14 +87,14 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Apply middlewares to route
for _, mid := range route.Middlewares {
if mid != "" {
// Get Access middleware if it does exist
accessMiddleware, err := getMiddleware([]string{mid}, middlewares)
// Get Access middlewares if it does exist
accessMiddleware, err := getMiddleware([]string{mid}, m)
if err != nil {
logger.Error("Error: %v", err.Error())
} else {
// Apply access middleware
// Apply access middlewares
if accessMiddleware.Type == AccessMiddleware {
blM := middleware.AccessListMiddleware{
blM := middlewares.AccessListMiddleware{
Path: route.Path,
List: accessMiddleware.Paths,
}
@@ -103,10 +103,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
}
// Get route authentication middleware if it does exist
rMiddleware, err := getMiddleware([]string{mid}, middlewares)
// Get route authentication middlewares if it does exist
rMiddleware, err := getMiddleware([]string{mid}, m)
if err != nil {
//Error: middleware not found
//Error: middlewares not found
logger.Error("Error: %v", err.Error())
} else {
for _, midPath := range rMiddleware.Paths {
@@ -122,20 +122,20 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter()
//callBackRouter := r.PathPrefix(util.ParseRoutePath(route.Path, "/callback")).Subrouter()
//Check Authentication middleware
//Check Authentication middlewares
switch rMiddleware.Type {
case BasicAuth:
basicAuth, err := getBasicAuthMiddleware(rMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
amw := middleware.AuthBasic{
amw := middlewares.AuthBasic{
Username: basicAuth.Username,
Password: basicAuth.Password,
Headers: nil,
Params: nil,
}
// Apply JWT authentication middleware
// Apply JWT authentication middlewares
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
@@ -146,14 +146,14 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
amw := middleware.JwtAuth{
amw := middlewares.JwtAuth{
AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers,
Params: jwt.Params,
Origins: gateway.Cors.Origins,
}
// Apply JWT authentication middleware
// Apply JWT authentication middlewares
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
@@ -169,12 +169,12 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
if oauth.RedirectURL != "" {
redirectURL = oauth.RedirectURL
}
amw := middleware.Oauth{
amw := middlewares.Oauth{
ClientID: oauth.ClientID,
ClientSecret: oauth.ClientSecret,
RedirectURL: redirectURL,
Scopes: oauth.Scopes,
Endpoint: middleware.OauthEndpoint{
Endpoint: middlewares.OauthEndpoint{
AuthURL: oauth.Endpoint.AuthURL,
TokenURL: oauth.Endpoint.TokenURL,
UserInfoURL: oauth.Endpoint.UserInfoURL,
@@ -205,7 +205,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
default:
if !doesExist(rMiddleware.Type) {
logger.Error("Unknown middleware type %s", rMiddleware.Type)
logger.Error("Unknown middlewares type %s", rMiddleware.Type)
}
}
@@ -214,7 +214,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
} else {
logger.Error("Error, middleware path is empty")
logger.Error("Error, middlewares path is empty")
logger.Error("Middleware ignored")
}
}
@@ -234,7 +234,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Enable common exploits
if route.BlockCommonExploits {
logger.Info("Block common exploits enabled")
router.Use(middleware.BlockExploitsMiddleware)
router.Use(middlewares.BlockExploitsMiddleware)
}
id := string(rune(rIndex))
if len(route.Name) != 0 {
@@ -243,7 +243,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
// Apply route rate limit
if route.RateLimit > 0 {
rateLimit := middleware.RateLimit{
rateLimit := middlewares.RateLimit{
Id: id, // Use route index as ID
Requests: route.RateLimit,
Window: time.Minute, // requests per minute
@@ -252,7 +252,7 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
RedisBased: redisBased,
}
limiter := rateLimit.NewRateLimiterWindow()
// Add rate limit middleware
// Add rate limit middlewares
router.Use(limiter.RateLimitMiddleware())
}
// Apply route Cors
@@ -272,9 +272,9 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
// Prometheus endpoint
router.Use(pr.prometheusMiddleware)
}
// Apply route Error interceptor middleware
// Apply route Error interceptor middlewares
if len(route.InterceptErrors) != 0 {
interceptErrors := middleware.InterceptErrors{
interceptErrors := middlewares.InterceptErrors{
Origins: route.Cors.Origins,
Errors: route.InterceptErrors,
}
@@ -286,10 +286,10 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router {
}
}
// Apply global Cors middlewares
r.Use(CORSHandler(gateway.Cors)) // Apply CORS middleware
// Apply errorInterceptor middleware
r.Use(CORSHandler(gateway.Cors)) // Apply CORS middlewares
// Apply errorInterceptor middlewares
if len(gateway.InterceptErrors) != 0 {
interceptErrors := middleware.InterceptErrors{
interceptErrors := middlewares.InterceptErrors{
Errors: gateway.InterceptErrors,
Origins: gateway.Cors.Origins,
}

View File

@@ -53,6 +53,6 @@ type Route struct {
InterceptErrors []int `yaml:"interceptErrors"`
// BlockCommonExploits enable, disable block common exploits
BlockCommonExploits bool `yaml:"blockCommonExploits"`
// Middlewares Defines route middleware from Middleware names
// Middlewares Defines route middlewares from Middleware names
Middlewares []string `yaml:"middlewares"`
}

View File

@@ -18,111 +18,97 @@ limitations under the License.
import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/jkaninda/goma-gateway/internal/middleware"
"github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/redis/go-redis/v9"
"net/http"
"os"
"sync"
"os/signal"
"syscall"
"time"
)
// Start starts the server
// Start / Start starts the server
func (gatewayServer GatewayServer) Start(ctx context.Context) error {
logger.Info("Initializing routes...")
route := gatewayServer.Initialize()
gateway := gatewayServer.gateway
logger.Debug("Routes count=%d Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares))
logger.Info("Initializing routes...done")
if len(gateway.Redis.Addr) != 0 {
middleware.InitRedis(gateway.Redis.Addr, gateway.Redis.Password)
defer func(Rdb *redis.Client) {
err := Rdb.Close()
if err != nil {
logger.Error("Redis connection closed with error: %v", err)
}
}(middleware.Rdb)
logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares))
if err := gatewayServer.initRedis(); err != nil {
return fmt.Errorf("failed to initialize Redis: %w", err)
}
defer gatewayServer.closeRedis()
tlsConfig, listenWithTLS, err := gatewayServer.initTLS()
if err != nil {
return err
}
tlsConfig := &tls.Config{}
var listenWithTLS = false
if cert := gatewayServer.gateway.SSLCertFile; cert != "" && gatewayServer.gateway.SSLKeyFile != "" {
tlsConf, err := loadTLS(cert, gatewayServer.gateway.SSLKeyFile)
if err != nil {
return err
}
tlsConfig = tlsConf
listenWithTLS = true
}
// HTTP Server
httpServer := &http.Server{
Addr: ":8080",
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: route, // Pass our instance of gorilla/mux in.
}
// HTTPS Server
httpsServer := &http.Server{
Addr: ":8443",
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: route, // Pass our instance of gorilla/mux in.
TLSConfig: tlsConfig,
}
if !gatewayServer.gateway.DisableDisplayRouteOnStart {
printRoute(gatewayServer.gateway.Routes)
}
// Set KeepAlive
httpServer.SetKeepAlivesEnabled(!gatewayServer.gateway.DisableKeepAlive)
go func() {
logger.Info("Starting HTTP server listen=0.0.0.0:8080")
if err := httpServer.ListenAndServe(); err != nil {
logger.Fatal("Error starting Goma Gateway HTTP server: %v", err)
}
}()
go func() {
if listenWithTLS {
logger.Info("Starting HTTPS server listen=0.0.0.0:8443")
if err := httpsServer.ListenAndServeTLS("", ""); err != nil {
logger.Fatal("Error starting Goma Gateway HTTPS server: %v", err)
}
}
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
_, err := fmt.Fprintf(os.Stderr, "error shutting down HTTP server: %s\n", err)
if err != nil {
return
}
}
}()
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if listenWithTLS {
if err := httpsServer.Shutdown(shutdownCtx); err != nil {
_, err := fmt.Fprintf(os.Stderr, "error shutting HTTPS server: %s\n", err)
if err != nil {
return
}
}
}
}()
wg.Wait()
return nil
httpServer := gatewayServer.createServer(":8080", route, nil)
httpsServer := gatewayServer.createServer(":8443", route, tlsConfig)
// Start HTTP/HTTPS servers
if err := gatewayServer.startServers(httpServer, httpsServer, listenWithTLS); err != nil {
return err
}
// Handle graceful shutdown
return gatewayServer.gracefulShutdown(ctx, httpServer, httpsServer, listenWithTLS)
}
func (gatewayServer GatewayServer) createServer(addr string, handler http.Handler, tlsConfig *tls.Config) *http.Server {
return &http.Server{
Addr: addr,
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: handler,
TLSConfig: tlsConfig,
}
}
func (gatewayServer GatewayServer) startServers(httpServer, httpsServer *http.Server, listenWithTLS bool) error {
go func() {
logger.Info("Starting HTTP server on 0.0.0.0:8080")
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Fatal("HTTP server error: %v", err)
}
}()
if listenWithTLS {
go func() {
logger.Info("Starting HTTPS server on 0.0.0.0:8443")
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Fatal("HTTPS server error: %v", err)
}
}()
}
return nil
}
func (gatewayServer GatewayServer) gracefulShutdown(ctx context.Context, httpServer, httpsServer *http.Server, listenWithTLS bool) error {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Info("Shutting down Goma Gateway...")
shutdownCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down HTTP server: %v", err)
}
if listenWithTLS {
if err := httpsServer.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down HTTPS server: %v", err)
}
}
logger.Info("Goma Gateway shut down successfully")
return nil
}

36
internal/tls.go Normal file
View File

@@ -0,0 +1,36 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pkg
import (
"crypto/tls"
"fmt"
)
func (gatewayServer GatewayServer) initTLS() (*tls.Config, bool, error) {
cert, key := gatewayServer.gateway.SSLCertFile, gatewayServer.gateway.SSLKeyFile
if cert == "" || key == "" {
return nil, false, nil
}
tlsConfig, err := loadTLS(cert, key)
if err != nil {
return nil, false, fmt.Errorf("failed to load TLS config: %w", err)
}
return tlsConfig, true, nil
}

View File

@@ -4,10 +4,10 @@ const ConfigDir = "/etc/goma/" // Default config
const ConfigFile = "/etc/goma/goma.yml" // Default configuration file
const accessControlAllowOrigin = "Access-Control-Allow-Origin" // Cors
const gatewayName = "Goma Gateway"
const AccessMiddleware = "access" // access middleware
const BasicAuth = "basic" // basic authentication middleware
const JWTAuth = "jwt" // JWT authentication middleware
const OAuth = "oauth" // OAuth authentication middleware
const AccessMiddleware = "access" // access middlewares
const BasicAuth = "basic" // basic authentication middlewares
const JWTAuth = "jwt" // JWT authentication middlewares
const OAuth = "oauth" // OAuth authentication middlewares
// Round-robin counter
var counter uint32

View File

@@ -15,7 +15,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import "github.com/jkaninda/goma-gateway/cmd"
import (
"github.com/jkaninda/goma-gateway/cmd"
)
func main() {