diff --git a/cmd/server.go b/cmd/server.go index ec2c36b..7a731d0 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -37,13 +37,13 @@ var ServerCmd = &cobra.Command{ } ctx := context.Background() g := pkg.GatewayServer{} - gs, err := g.Config(configFile) + gs, err := g.Config(configFile, ctx) if err != nil { fmt.Printf("Could not load configuration: %v\n", err) os.Exit(1) } gs.SetEnv() - if err := gs.Start(ctx); err != nil { + if err := gs.Start(); err != nil { fmt.Printf("Could not start server: %v\n", err) os.Exit(1) diff --git a/internal/config.go b/internal/config.go index 7becabd..708a4b4 100644 --- a/internal/config.go +++ b/internal/config.go @@ -16,6 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. */ import ( + "context" "fmt" "github.com/jkaninda/goma-gateway/internal/middlewares" "github.com/jkaninda/goma-gateway/pkg/logger" @@ -31,7 +32,7 @@ import ( ) // Config reads config file and returns Gateway -func (GatewayServer) Config(configFile string) (*GatewayServer, error) { +func (GatewayServer) Config(configFile string, ctx context.Context) (*GatewayServer, error) { if util.FileExists(configFile) { buf, err := os.ReadFile(configFile) if err != nil { @@ -44,7 +45,8 @@ func (GatewayServer) Config(configFile string) (*GatewayServer, error) { return nil, fmt.Errorf("parsing the configuration file %q: %w", configFile, err) } return &GatewayServer{ - ctx: nil, + ctx: ctx, + configFile: configFile, version: c.Version, gateway: c.GatewayConfig, middlewares: c.Middlewares, @@ -59,14 +61,15 @@ func (GatewayServer) Config(configFile string) (*GatewayServer, error) { } logger.Info("Using configuration file: %s", ConfigFile) - util.SetEnv("GOMA_CONFIG_FILE", configFile) + util.SetEnv("GOMA_CONFIG_FILE", ConfigFile) c := &GatewayConfig{} err = yaml.Unmarshal(buf, c) if err != nil { return nil, fmt.Errorf("parsing the configuration file %q: %w", ConfigFile, err) } return &GatewayServer{ - ctx: nil, + ctx: ctx, + configFile: ConfigFile, gateway: c.GatewayConfig, middlewares: c.Middlewares, }, nil @@ -98,7 +101,8 @@ func (GatewayServer) Config(configFile string) (*GatewayServer, error) { } logger.Info("Generating new configuration file...done") return &GatewayServer{ - ctx: nil, + ctx: ctx, + configFile: ConfigFile, gateway: c.GatewayConfig, middlewares: c.Middlewares, }, nil diff --git a/internal/server.go b/internal/server.go index 9f46c1d..5c65378 100644 --- a/internal/server.go +++ b/internal/server.go @@ -29,7 +29,7 @@ import ( ) // Start / Start starts the server -func (gatewayServer GatewayServer) Start(ctx context.Context) error { +func (gatewayServer GatewayServer) Start() error { logger.Info("Initializing routes...") route := gatewayServer.Initialize() logger.Debug("Routes count=%d, Middlewares count=%d", len(gatewayServer.gateway.Routes), len(gatewayServer.middlewares)) @@ -56,7 +56,7 @@ func (gatewayServer GatewayServer) Start(ctx context.Context) error { } // Handle graceful shutdown - return gatewayServer.shutdown(ctx, httpServer, httpsServer, listenWithTLS) + return gatewayServer.shutdown(httpServer, httpsServer, listenWithTLS) } func (gatewayServer GatewayServer) createServer(addr string, handler http.Handler, tlsConfig *tls.Config) *http.Server { @@ -90,13 +90,13 @@ func (gatewayServer GatewayServer) startServers(httpServer, httpsServer *http.Se return nil } -func (gatewayServer GatewayServer) shutdown(ctx context.Context, httpServer, httpsServer *http.Server, listenWithTLS bool) error { +func (gatewayServer GatewayServer) shutdown(httpServer, httpsServer *http.Server, listenWithTLS bool) error { quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit logger.Info("Shutting down Goma Gateway...") - shutdownCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + shutdownCtx, cancel := context.WithTimeout(gatewayServer.ctx, 10*time.Second) defer cancel() if err := httpServer.Shutdown(shutdownCtx); err != nil { diff --git a/internal/server_test.go b/internal/server_test.go index da298e9..23d752f 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -39,8 +39,9 @@ func TestStart(t *testing.T) { if err != nil { t.Fatalf("Error initializing config: %s", err.Error()) } + ctx := context.Background() g := GatewayServer{} - gatewayServer, err := g.Config(configFile) + gatewayServer, err := g.Config(configFile, ctx) if err != nil { t.Error(err) } @@ -54,9 +55,8 @@ func TestStart(t *testing.T) { t.Fatalf("expected a status code of 200, got %v", resp.StatusCode) } } - ctx := context.Background() go func() { - err = gatewayServer.Start(ctx) + err = gatewayServer.Start() if err != nil { t.Error(err) return diff --git a/internal/types.go b/internal/types.go index e358156..7180e48 100644 --- a/internal/types.go +++ b/internal/types.go @@ -113,6 +113,7 @@ type ErrorResponse struct { } type GatewayServer struct { ctx context.Context + configFile string version string gateway Gateway middlewares []Middleware