Initial commit
This commit is contained in:
362
pkg/config.go
Normal file
362
pkg/config.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package pkg
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/jkaninda/goma-gateway/internal/logger"
|
||||
"github.com/jkaninda/goma-gateway/util"
|
||||
"github.com/spf13/cobra"
|
||||
"gopkg.in/yaml.v3"
|
||||
"os"
|
||||
)
|
||||
|
||||
var cfg *Gateway
|
||||
|
||||
type Config struct {
|
||||
file string
|
||||
}
|
||||
type BasicRule struct {
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
}
|
||||
|
||||
type Cors struct {
|
||||
// Cors Allowed origins,
|
||||
//e.g:
|
||||
//
|
||||
// - http://localhost:80
|
||||
//
|
||||
// - https://example.com
|
||||
Origins []string `yaml:"origins"`
|
||||
//
|
||||
//e.g:
|
||||
//
|
||||
//Access-Control-Allow-Origin: '*'
|
||||
//
|
||||
// Access-Control-Allow-Methods: 'GET, POST, PUT, DELETE, OPTIONS'
|
||||
//
|
||||
// Access-Control-Allow-Cors: 'Content-Type, Authorization'
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
}
|
||||
|
||||
// JWTRuler authentication using HTTP GET method
|
||||
//
|
||||
// JWTRuler contains the authentication details
|
||||
type JWTRuler struct {
|
||||
// URL contains the authentication URL, it supports HTTP GET method only.
|
||||
URL string `yaml:"url"`
|
||||
// RequiredHeaders , contains required before sending request to the backend.
|
||||
RequiredHeaders []string `yaml:"requiredHeaders"`
|
||||
// Headers Add header to the backend from Authentication request's header, depending on your requirements.
|
||||
// Key is Http's response header Key, and value is the backend Request's header Key.
|
||||
// In case you want to get headers from Authentication service and inject them to backend request's headers.
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
// Params same as Headers, contains the request params.
|
||||
//
|
||||
// Gets authentication headers from authentication request and inject them as request params to the backend.
|
||||
//
|
||||
// Key is Http's response header Key, and value is the backend Request's request param Key.
|
||||
//
|
||||
// In case you want to get headers from Authentication service and inject them to next request's params.
|
||||
//
|
||||
//e.g: Header X-Auth-UserId to query userId
|
||||
Params map[string]string `yaml:"params"`
|
||||
}
|
||||
|
||||
// Middleware defined the route middleware
|
||||
type Middleware struct {
|
||||
//Path contains the name of middleware and must be unique
|
||||
Name string `yaml:"name"`
|
||||
// Type contains authentication types
|
||||
//
|
||||
// basic, jwt, auth0, rateLimit
|
||||
Type string `yaml:"type"`
|
||||
// Rule contains rule type of
|
||||
Rule interface{} `yaml:"rule"`
|
||||
}
|
||||
type MiddlewareName struct {
|
||||
name string `yaml:"name"`
|
||||
}
|
||||
type RouteMiddleware struct {
|
||||
//Path contains the path to protect
|
||||
Path string `yaml:"path"`
|
||||
//Rules defines which specific middleware applies to a route path
|
||||
Rules []string `yaml:"rules"`
|
||||
}
|
||||
|
||||
// Route defines gateway route
|
||||
type Route struct {
|
||||
// Name defines route name
|
||||
Name string `yaml:"name"`
|
||||
// Path defines route path
|
||||
Path string `yaml:"path"`
|
||||
// Rewrite rewrites route path to desired path
|
||||
//
|
||||
// E.g. /cart to / => It will rewrite /cart path to /
|
||||
Rewrite string `yaml:"rewrite"`
|
||||
// Destination Defines backend URL
|
||||
Destination string `yaml:"destination"`
|
||||
// Cors contains the route cors headers
|
||||
Cors Cors `yaml:"cors"`
|
||||
// DisableHeaderXForward Disable X-forwarded header.
|
||||
//
|
||||
// [X-Forwarded-Host, X-Forwarded-For, Host, Scheme ]
|
||||
//
|
||||
// It will not match the backend route
|
||||
DisableHeaderXForward bool `yaml:"disableHeaderXForward"`
|
||||
// HealthCheck Defines the backend is health check PATH
|
||||
HealthCheck string `yaml:"healthCheck"`
|
||||
// Blocklist Defines route blacklist
|
||||
Blocklist []string `yaml:"blocklist"`
|
||||
// Middlewares Defines route middleware from Middleware names
|
||||
Middlewares []RouteMiddleware `yaml:"middlewares"`
|
||||
}
|
||||
|
||||
// Gateway contains Goma Proxy Gateway's configs
|
||||
type Gateway struct {
|
||||
// ListenAddr Defines the server listenAddr
|
||||
//
|
||||
//e.g: localhost:8080
|
||||
ListenAddr string `yaml:"listenAddr" env:"GOMA_LISTEN_ADDR, overwrite"`
|
||||
// WriteTimeout defines proxy write timeout
|
||||
WriteTimeout int `yaml:"writeTimeout" env:"GOMA_WRITE_TIMEOUT, overwrite"`
|
||||
// ReadTimeout defines proxy read timeout
|
||||
ReadTimeout int `yaml:"readTimeout" env:"GOMA_READ_TIMEOUT, overwrite"`
|
||||
// IdleTimeout defines proxy idle timeout
|
||||
IdleTimeout int `yaml:"idleTimeout" env:"GOMA_IDLE_TIMEOUT, overwrite"`
|
||||
// RateLimiter Defines number of request peer minute
|
||||
RateLimiter int `yaml:"rateLimiter" env:"GOMA_RATE_LIMITER, overwrite"`
|
||||
AccessLog string `yaml:"accessLog" env:"GOMA_ACCESS_LOG, overwrite"`
|
||||
ErrorLog string `yaml:"errorLog" env:"GOMA_ERROR_LOG=, overwrite"`
|
||||
DisableRouteHealthCheckError bool `yaml:"disableRouteHealthCheckError"`
|
||||
//Disable dispelling routes on start
|
||||
DisableDisplayRouteOnStart bool `yaml:"disableDisplayRouteOnStart"`
|
||||
// Cors contains the proxy global cors
|
||||
Cors Cors `yaml:"cors"`
|
||||
// Routes defines the proxy routes
|
||||
Routes []Route `yaml:"routes"`
|
||||
}
|
||||
type GatewayConfig struct {
|
||||
GatewayConfig Gateway `yaml:"gateway"`
|
||||
Middlewares []Middleware `yaml:"middlewares"`
|
||||
}
|
||||
|
||||
// ErrorResponse represents the structure of the JSON error response
|
||||
type ErrorResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
type GatewayServer struct {
|
||||
ctx context.Context
|
||||
gateway Gateway
|
||||
middlewares []Middleware
|
||||
}
|
||||
|
||||
// Config reads config file and returns Gateway
|
||||
func (GatewayServer) Config(configFile string) (*GatewayServer, error) {
|
||||
if util.FileExists(configFile) {
|
||||
buf, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
util.SetEnv("GOMA_CONFIG_FILE", configFile)
|
||||
c := &GatewayConfig{}
|
||||
err = yaml.Unmarshal(buf, c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("in file %q: %w", configFile, err)
|
||||
}
|
||||
return &GatewayServer{
|
||||
ctx: nil,
|
||||
gateway: c.GatewayConfig,
|
||||
middlewares: c.Middlewares,
|
||||
}, nil
|
||||
}
|
||||
logger.Error("Configuration file not found: %v", configFile)
|
||||
logger.Info("Generating new configuration file...")
|
||||
initConfig(ConfigFile)
|
||||
logger.Info("Server configuration file is available at %s", ConfigFile)
|
||||
util.SetEnv("GOMA_CONFIG_FILE", ConfigFile)
|
||||
buf, err := os.ReadFile(ConfigFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &GatewayConfig{}
|
||||
err = yaml.Unmarshal(buf, c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("in file %q: %w", ConfigFile, err)
|
||||
}
|
||||
logger.Info("Generating new configuration file...done")
|
||||
logger.Info("Starting server with default configuration")
|
||||
return &GatewayServer{
|
||||
ctx: nil,
|
||||
gateway: c.GatewayConfig,
|
||||
middlewares: c.Middlewares,
|
||||
}, nil
|
||||
}
|
||||
func GetConfigPaths() string {
|
||||
return util.GetStringEnv("GOMAY_CONFIG_FILE", ConfigFile)
|
||||
}
|
||||
func InitConfig(cmd *cobra.Command) {
|
||||
configFile, _ := cmd.Flags().GetString("output")
|
||||
if configFile == "" {
|
||||
configFile = GetConfigPaths()
|
||||
}
|
||||
initConfig(configFile)
|
||||
return
|
||||
|
||||
}
|
||||
func initConfig(configFile string) {
|
||||
if configFile == "" {
|
||||
configFile = GetConfigPaths()
|
||||
}
|
||||
conf := &GatewayConfig{
|
||||
GatewayConfig: Gateway{
|
||||
ListenAddr: "0.0.0.0:80",
|
||||
WriteTimeout: 15,
|
||||
ReadTimeout: 15,
|
||||
IdleTimeout: 60,
|
||||
AccessLog: "/dev/Stdout",
|
||||
ErrorLog: "/dev/stderr",
|
||||
DisableRouteHealthCheckError: false,
|
||||
DisableDisplayRouteOnStart: false,
|
||||
RateLimiter: 0,
|
||||
Cors: Cors{
|
||||
Origins: []string{"http://localhost:8080", "https://example.com"},
|
||||
Headers: map[string]string{
|
||||
"Access-Control-Allow-Headers": "Origin, Authorization, Accept, Content-Type, Access-Control-Allow-Headers, X-Client-Id, X-Session-Id",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "1728000",
|
||||
},
|
||||
},
|
||||
Routes: []Route{
|
||||
{
|
||||
Name: "HealthCheck",
|
||||
Path: "/healthy",
|
||||
Destination: "http://localhost:8080",
|
||||
Rewrite: "/health",
|
||||
HealthCheck: "",
|
||||
Cors: Cors{
|
||||
Headers: map[string]string{
|
||||
"Access-Control-Allow-Headers": "Origin, Authorization, Accept, Content-Type, Access-Control-Allow-Headers, X-Client-Id, X-Session-Id",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "1728000",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Basic auth",
|
||||
Path: "/basic",
|
||||
Destination: "http://localhost:8080",
|
||||
Rewrite: "/health",
|
||||
HealthCheck: "",
|
||||
Blocklist: []string{},
|
||||
Cors: Cors{},
|
||||
Middlewares: []RouteMiddleware{
|
||||
{
|
||||
Path: "/basic/auth",
|
||||
Rules: []string{"basic-auth", "google-jwt"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Middlewares: []Middleware{
|
||||
{
|
||||
Name: "basic-auth",
|
||||
Type: "basic",
|
||||
Rule: BasicRule{
|
||||
Username: "goma",
|
||||
Password: "goma",
|
||||
},
|
||||
}, {
|
||||
Name: "google-jwt",
|
||||
Type: "jwt",
|
||||
Rule: JWTRuler{
|
||||
URL: "https://www.googleapis.com/auth/userinfo.email",
|
||||
Headers: map[string]string{},
|
||||
Params: map[string]string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
yamlData, err := yaml.Marshal(&conf)
|
||||
if err != nil {
|
||||
logger.Fatal("Error serializing configuration %v", err.Error())
|
||||
}
|
||||
err = os.WriteFile(configFile, yamlData, 0644)
|
||||
if err != nil {
|
||||
logger.Fatal("Unable to write config file %s", err)
|
||||
}
|
||||
logger.Info("Configuration file has been initialized successfully")
|
||||
}
|
||||
func Get() *Gateway {
|
||||
if cfg == nil {
|
||||
c := &Gateway{}
|
||||
c.Setup(GetConfigPaths())
|
||||
cfg = c
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
func (Gateway) Setup(conf string) *Gateway {
|
||||
if util.FileExists(conf) {
|
||||
buf, err := os.ReadFile(conf)
|
||||
if err != nil {
|
||||
return &Gateway{}
|
||||
}
|
||||
util.SetEnv("GOMA_CONFIG_FILE", conf)
|
||||
c := &GatewayConfig{}
|
||||
err = yaml.Unmarshal(buf, c)
|
||||
if err != nil {
|
||||
logger.Fatal("Error loading configuration %v", err.Error())
|
||||
}
|
||||
return &c.GatewayConfig
|
||||
}
|
||||
return &Gateway{}
|
||||
|
||||
}
|
||||
func (middleware Middleware) name() {
|
||||
|
||||
}
|
||||
func ToJWTRuler(input interface{}) (JWTRuler, error) {
|
||||
jWTRuler := new(JWTRuler)
|
||||
var bytes []byte
|
||||
bytes, err := yaml.Marshal(input)
|
||||
if err != nil {
|
||||
return JWTRuler{}, fmt.Errorf("error marshalling yaml: %v", err)
|
||||
}
|
||||
err = yaml.Unmarshal(bytes, jWTRuler)
|
||||
if err != nil {
|
||||
return JWTRuler{}, fmt.Errorf("error unmarshalling yaml: %v", err)
|
||||
}
|
||||
return *jWTRuler, nil
|
||||
}
|
||||
|
||||
func ToBasicAuth(input interface{}) (BasicRule, error) {
|
||||
basicAuth := new(BasicRule)
|
||||
var bytes []byte
|
||||
bytes, err := yaml.Marshal(input)
|
||||
if err != nil {
|
||||
return BasicRule{}, fmt.Errorf("error marshalling yaml: %v", err)
|
||||
}
|
||||
err = yaml.Unmarshal(bytes, basicAuth)
|
||||
if err != nil {
|
||||
return BasicRule{}, fmt.Errorf("error unmarshalling yaml: %v", err)
|
||||
}
|
||||
return *basicAuth, nil
|
||||
}
|
||||
107
pkg/handler.go
Normal file
107
pkg/handler.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package pkg
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jkaninda/goma-gateway/internal/logger"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// CORSHandler handles CORS headers for incoming requests
|
||||
//
|
||||
// Adds CORS headers to the response dynamically based on the provided headers map[string]string
|
||||
func CORSHandler(cors Cors) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set CORS headers from the cors config
|
||||
//Update Cors Headers
|
||||
for k, v := range cors.Headers {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
//Update Origin Cors Headers
|
||||
for _, origin := range cors.Origins {
|
||||
if origin == r.Header.Get("Origin") {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
|
||||
}
|
||||
}
|
||||
// Handle preflight requests (OPTIONS)
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
// Pass the request to the next handler
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyErrorHandler catches backend errors and returns a custom response
|
||||
func ProxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
logger.Error("Proxy error: %v", err)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
err = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"success": false,
|
||||
"code": http.StatusBadGateway,
|
||||
"message": "The service is currently unavailable. Please try again later.",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// HealthCheckHandler handles health check of routes
|
||||
func (heathRoute HealthCheckRoute) HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Info("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent())
|
||||
var routes []HealthCheckRouteResponse
|
||||
for _, route := range heathRoute.Routes {
|
||||
if route.HealthCheck != "" {
|
||||
err := HealthCheck(route.Destination + route.HealthCheck)
|
||||
if err != nil {
|
||||
logger.Error("Route %s: %v", route.Name, err)
|
||||
if heathRoute.DisableRouteHealthCheckError {
|
||||
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "unhealthy", Error: "Route healthcheck errors disabled"})
|
||||
continue
|
||||
}
|
||||
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "unhealthy", Error: err.Error()})
|
||||
continue
|
||||
} else {
|
||||
logger.Info("Route %s is healthy", route.Name)
|
||||
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "healthy", Error: ""})
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
logger.Error("Route %s's healthCheck is undefined", route.Name)
|
||||
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "undefined", Error: ""})
|
||||
continue
|
||||
|
||||
}
|
||||
}
|
||||
response := HealthCheckResponse{
|
||||
Status: "healthy",
|
||||
Routes: routes,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
err := json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
67
pkg/healthCheck.go
Normal file
67
pkg/healthCheck.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package pkg
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type HealthCheckRoute struct {
|
||||
DisableRouteHealthCheckError bool
|
||||
Routes []Route
|
||||
}
|
||||
|
||||
// HealthCheckResponse represents the health check response structure
|
||||
type HealthCheckResponse struct {
|
||||
Status string `json:"status"`
|
||||
Routes []HealthCheckRouteResponse `json:"routes"`
|
||||
}
|
||||
type HealthCheckRouteResponse struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func HealthCheck(healthURL string) error {
|
||||
healthCheckURL, err := url.Parse(healthURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing HealthCheck URL: %v ", err)
|
||||
}
|
||||
// Create a new request for the route
|
||||
healthReq, err := http.NewRequest("GET", healthCheckURL.String(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating HealthCheck request: %v ", err)
|
||||
}
|
||||
// Perform the request to the route's healthcheck
|
||||
client := &http.Client{}
|
||||
healthResp, err := client.Do(healthReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error performing HealthCheck request: %v ", err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
}
|
||||
}(healthResp.Body)
|
||||
|
||||
if healthResp.StatusCode >= 400 {
|
||||
return fmt.Errorf("health check failed with status code %v", healthResp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
24
pkg/helpers.go
Normal file
24
pkg/helpers.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package pkg
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may get a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
*/
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/common-nighthawk/go-figure"
|
||||
"github.com/jkaninda/goma-gateway/util"
|
||||
)
|
||||
|
||||
func Intro() {
|
||||
nameFigure := figure.NewFigure("Goma", "", true)
|
||||
nameFigure.Print()
|
||||
fmt.Printf("Version: %s\n", util.FullVersion())
|
||||
fmt.Println("Copyright (c) 2024 Jonas Kaninda")
|
||||
fmt.Println("Starting Goma server...")
|
||||
}
|
||||
38
pkg/middleware.go
Normal file
38
pkg/middleware.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gorilla/mux"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func searchMiddleware(rules []string, middlewares []Middleware) (Middleware, error) {
|
||||
for _, m := range middlewares {
|
||||
if slices.Contains(rules, m.Name) {
|
||||
return m, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return Middleware{}, errors.New("no middleware found with name " + strings.Join(rules, ";"))
|
||||
}
|
||||
func getMiddleware(rule string, middlewares []Middleware) (Middleware, error) {
|
||||
for _, m := range middlewares {
|
||||
if strings.Contains(rule, m.Name) {
|
||||
|
||||
return m, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return Middleware{}, errors.New("no middleware found with name " + rule)
|
||||
}
|
||||
|
||||
type RoutePath struct {
|
||||
route Route
|
||||
path string
|
||||
rules []string
|
||||
middlewares []Middleware
|
||||
router *mux.Router
|
||||
}
|
||||
99
pkg/middleware/bloclist.go
Normal file
99
pkg/middleware/bloclist.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package middleware
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/jkaninda/goma-gateway/internal/logger"
|
||||
"github.com/jkaninda/goma-gateway/util"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BlocklistMiddleware checks if the request path is forbidden and returns 403 Forbidden
|
||||
func (blockList BlockListMiddleware) BlocklistMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, block := range blockList.List {
|
||||
if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) {
|
||||
logger.Error("Access to %s is forbidden", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Success: false,
|
||||
Code: http.StatusNotFound,
|
||||
Message: fmt.Sprintf("Not found: %s", r.URL.Path),
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to determine if the request path is blocked
|
||||
func isPathBlocked(requestPath, blockedPath string) bool {
|
||||
// Handle exact match
|
||||
if requestPath == blockedPath {
|
||||
return true
|
||||
}
|
||||
// Handle wildcard match (e.g., /admin/* should block /admin and any subpath)
|
||||
if strings.HasSuffix(blockedPath, "/*") {
|
||||
basePath := strings.TrimSuffix(blockedPath, "/*")
|
||||
if strings.HasPrefix(requestPath, basePath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter with the specified refill rate and token capacity
|
||||
func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter {
|
||||
return &TokenRateLimiter{
|
||||
tokens: maxTokens,
|
||||
maxTokens: maxTokens,
|
||||
refillRate: refillRate,
|
||||
lastRefill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request is allowed based on the current token bucket
|
||||
func (rl *TokenRateLimiter) Allow() bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
// Refill tokens based on the time elapsed
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(rl.lastRefill)
|
||||
tokensToAdd := int(elapsed / rl.refillRate)
|
||||
if tokensToAdd > 0 {
|
||||
rl.tokens = min(rl.maxTokens, rl.tokens+tokensToAdd)
|
||||
rl.lastRefill = now
|
||||
}
|
||||
|
||||
// Check if there are enough tokens to allow the request
|
||||
if rl.tokens > 0 {
|
||||
rl.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// Reject request if no tokens are available
|
||||
return false
|
||||
}
|
||||
276
pkg/middleware/middleware.go
Normal file
276
pkg/middleware/middleware.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package middleware
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"github.com/jkaninda/goma-gateway/internal/logger"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter defines rate limit properties.
|
||||
type RateLimiter struct {
|
||||
Requests int
|
||||
Window time.Duration
|
||||
ClientMap map[string]*Client
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Client stores request count and window expiration for each client.
|
||||
type Client struct {
|
||||
RequestCount int
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// NewRateLimiterWindow creates a new RateLimiter.
|
||||
func NewRateLimiterWindow(requests int, window time.Duration) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
Requests: requests,
|
||||
Window: window,
|
||||
ClientMap: make(map[string]*Client),
|
||||
}
|
||||
}
|
||||
|
||||
type TokenRateLimiter struct {
|
||||
tokens int
|
||||
maxTokens int
|
||||
refillRate time.Duration
|
||||
lastRefill time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ProxyResponseError represents the structure of the JSON error response
|
||||
type ProxyResponseError struct {
|
||||
Success bool `json:"success"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// AuthJWT Define struct
|
||||
type AuthJWT struct {
|
||||
AuthURL string
|
||||
RequiredHeaders []string
|
||||
Headers map[string]string
|
||||
Params map[string]string
|
||||
}
|
||||
|
||||
// AuthenticationMiddleware Define struct
|
||||
type AuthenticationMiddleware struct {
|
||||
AuthURL string
|
||||
RequiredHeaders []string
|
||||
Headers map[string]string
|
||||
Params map[string]string
|
||||
}
|
||||
type BlockListMiddleware struct {
|
||||
Path string
|
||||
Destination string
|
||||
List []string
|
||||
}
|
||||
|
||||
// AuthBasic Define Basic auth
|
||||
type AuthBasic struct {
|
||||
Username string
|
||||
Password string
|
||||
Headers map[string]string
|
||||
Params map[string]string
|
||||
}
|
||||
|
||||
// AuthMiddleware function, which will be called for each request
|
||||
func (amw AuthJWT) AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, header := range amw.RequiredHeaders {
|
||||
if r.Header.Get(header) == "" {
|
||||
logger.Error("Proxy error, missing %s header", header)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Message: "Missing Authorization header",
|
||||
Code: http.StatusForbidden,
|
||||
Success: false,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
//token := r.Header.Get("Authorization")
|
||||
authURL, err := url.Parse(amw.AuthURL)
|
||||
if err != nil {
|
||||
logger.Error("Error parsing auth URL: %v", err)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
err = json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Message: "Internal Server Error",
|
||||
Code: http.StatusInternalServerError,
|
||||
Success: false,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
// Create a new request for /authentication
|
||||
authReq, err := http.NewRequest("GET", authURL.String(), nil)
|
||||
if err != nil {
|
||||
logger.Error("Proxy error creating authentication request: %v", err)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
err = json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Message: "Internal Server Error",
|
||||
Code: http.StatusInternalServerError,
|
||||
Success: false,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
// Copy headers from the original request to the new request
|
||||
for name, values := range r.Header {
|
||||
for _, value := range values {
|
||||
authReq.Header.Set(name, value)
|
||||
}
|
||||
}
|
||||
// Copy cookies from the original request to the new request
|
||||
for _, cookie := range r.Cookies() {
|
||||
authReq.AddCookie(cookie)
|
||||
}
|
||||
// Perform the request to the auth service
|
||||
client := &http.Client{}
|
||||
authResp, err := client.Do(authReq)
|
||||
if err != nil || authResp.StatusCode != http.StatusOK {
|
||||
logger.Info("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent())
|
||||
logger.Error("Proxy authentication error")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err = json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Message: "Unauthorized",
|
||||
Code: http.StatusUnauthorized,
|
||||
Success: false,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
}(authResp.Body)
|
||||
// Inject specific header tp the current request's header
|
||||
// Add header to the next request from AuthRequest header, depending on your requirements
|
||||
if amw.Headers != nil {
|
||||
for k, v := range amw.Headers {
|
||||
r.Header.Set(v, authResp.Header.Get(k))
|
||||
}
|
||||
}
|
||||
query := r.URL.Query()
|
||||
// Add query parameters to the next request from AuthRequest header, depending on your requirements
|
||||
if amw.Params != nil {
|
||||
for k, v := range amw.Params {
|
||||
query.Set(v, authResp.Header.Get(k))
|
||||
}
|
||||
}
|
||||
r.URL.RawQuery = query.Encode()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// AuthMiddleware checks for the Authorization header and verifies the credentials
|
||||
func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
logger.Error("Proxy error, missing Authorization header")
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Success: false,
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "Unauthorized",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
// Check if the Authorization header contains "Basic" scheme
|
||||
if !strings.HasPrefix(authHeader, "Basic ") {
|
||||
logger.Error("Proxy error, missing Basic Authorization header")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Success: false,
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "Unauthorized",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Decode the base64 encoded username:password string
|
||||
payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):])
|
||||
if err != nil {
|
||||
logger.Error("Proxy error, missing Basic Authorization header")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Success: false,
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "Unauthorized",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Split the payload into username and password
|
||||
pair := strings.SplitN(string(payload), ":", 2)
|
||||
if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Success: false,
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "Unauthorized",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Continue to the next handler if the authentication is successful
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
}
|
||||
90
pkg/middleware/rate_limiter.go
Normal file
90
pkg/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package middleware
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jkaninda/goma-gateway/internal/logger"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware limits request based on the number of tokens peer minutes.
|
||||
func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !rl.Allow() {
|
||||
// Rate limit exceeded, return a 429 Too Many Requests response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Success: false,
|
||||
Code: http.StatusTooManyRequests,
|
||||
Message: "Too many requests. Please try again later.",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Proceed to the next handler if rate limit is not exceeded
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitMiddleware limits request based on the number of requests peer minutes.
|
||||
func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
//TODO:
|
||||
clientID := r.RemoteAddr
|
||||
logger.Info(clientID)
|
||||
|
||||
rl.mu.Lock()
|
||||
client, exists := rl.ClientMap[clientID]
|
||||
if !exists || time.Now().After(client.ExpiresAt) {
|
||||
client = &Client{
|
||||
RequestCount: 0,
|
||||
ExpiresAt: time.Now().Add(rl.Window),
|
||||
}
|
||||
rl.ClientMap[clientID] = client
|
||||
}
|
||||
client.RequestCount++
|
||||
rl.mu.Unlock()
|
||||
|
||||
if client.RequestCount > rl.Requests {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
err := json.NewEncoder(w).Encode(ProxyResponseError{
|
||||
Success: false,
|
||||
Code: http.StatusTooManyRequests,
|
||||
Message: "Too many requests. Please try again later.",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Proceed to the next handler if rate limit is not exceeded
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
110
pkg/middleware_test.go
Normal file
110
pkg/middleware_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package pkg
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"fmt"
|
||||
"gopkg.in/yaml.v3"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const MidName = "google-jwt"
|
||||
|
||||
var rules = []string{"fake", "jwt", "google-jwt"}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
TestInit(t)
|
||||
middlewares := []Middleware{
|
||||
{
|
||||
Name: "basic-auth",
|
||||
Type: "basic",
|
||||
Rule: BasicRule{
|
||||
Username: "goma",
|
||||
Password: "goma",
|
||||
},
|
||||
}, {
|
||||
Name: MidName,
|
||||
Type: "jwt",
|
||||
Rule: JWTRuler{
|
||||
URL: "https://www.googleapis.com/auth/userinfo.email",
|
||||
Headers: map[string]string{},
|
||||
Params: map[string]string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
yamlData, err := yaml.Marshal(&middlewares)
|
||||
if err != nil {
|
||||
t.Fatalf("Error serializing configuration %v", err.Error())
|
||||
}
|
||||
err = os.WriteFile(configFile, yamlData, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to write config file %s", err)
|
||||
}
|
||||
log.Printf("Config file written to %s", configFile)
|
||||
}
|
||||
|
||||
func TestReadMiddleware(t *testing.T) {
|
||||
TestMiddleware(t)
|
||||
middlewares := getMiddlewares(t)
|
||||
middleware, err := searchMiddleware(rules, middlewares)
|
||||
if err != nil {
|
||||
t.Fatalf("Error searching middleware %s", err.Error())
|
||||
}
|
||||
switch middleware.Type {
|
||||
case "basic":
|
||||
log.Println("Basic auth")
|
||||
basicAuth, err := ToBasicAuth(middleware.Rule)
|
||||
if err != nil {
|
||||
log.Fatalln("error:", err)
|
||||
}
|
||||
log.Printf("Username: %s and password: %s\n", basicAuth.Username, basicAuth.Password)
|
||||
case "jwt":
|
||||
log.Println("JWT auth")
|
||||
jwt, err := ToJWTRuler(middleware.Rule)
|
||||
if err != nil {
|
||||
log.Fatalln("error:", err)
|
||||
}
|
||||
log.Printf("JWT authentification URL is %s\n", jwt.URL)
|
||||
default:
|
||||
t.Errorf("Unknown middleware type %s", middleware.Type)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFoundMiddleware(t *testing.T) {
|
||||
middlewares := getMiddlewares(t)
|
||||
middleware, err := searchMiddleware(rules, middlewares)
|
||||
if err != nil {
|
||||
t.Errorf("Error getting middleware %v", err)
|
||||
}
|
||||
fmt.Println(middleware.Type)
|
||||
}
|
||||
|
||||
func getMiddlewares(t *testing.T) []Middleware {
|
||||
buf, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to read config file %s", configFile)
|
||||
}
|
||||
c := &[]Middleware{}
|
||||
err = yaml.Unmarshal(buf, c)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to parse config file %s", configFile)
|
||||
}
|
||||
return *c
|
||||
}
|
||||
113
pkg/proxy.go
Normal file
113
pkg/proxy.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package pkg
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/jkaninda/goma-gateway/internal/logger"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ProxyRoute struct {
|
||||
path string
|
||||
rewrite string
|
||||
destination string
|
||||
cors Cors
|
||||
disableXForward bool
|
||||
}
|
||||
|
||||
// ProxyHandler proxies requests to the backend
|
||||
func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Info("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent())
|
||||
// Set CORS headers from the cors config
|
||||
//Update Cors Headers
|
||||
for k, v := range proxyRoute.cors.Headers {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
|
||||
//Update Origin Cors Headers
|
||||
for _, origin := range proxyRoute.cors.Origins {
|
||||
if origin == r.Header.Get("Origin") {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
|
||||
}
|
||||
}
|
||||
// Handle preflight requests (OPTIONS)
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
// Parse the target backend URL
|
||||
targetURL, err := url.Parse(proxyRoute.destination)
|
||||
if err != nil {
|
||||
logger.Error("Error parsing backend URL: %s", err)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
err := json.NewEncoder(w).Encode(ErrorResponse{
|
||||
Message: "Internal server error",
|
||||
Code: http.StatusInternalServerError,
|
||||
Success: false,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
// Update the headers to allow for SSL redirection
|
||||
if !proxyRoute.disableXForward {
|
||||
r.URL.Host = targetURL.Host
|
||||
r.URL.Scheme = targetURL.Scheme
|
||||
r.Header.Set("X-Forwarded-Host", r.Header.Get("Host"))
|
||||
r.Header.Set("X-Forwarded-For", r.RemoteAddr)
|
||||
r.Header.Set("X-Real-IP", r.RemoteAddr)
|
||||
r.Host = targetURL.Host
|
||||
}
|
||||
// Create proxy
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
// Rewrite
|
||||
if proxyRoute.path != "" && proxyRoute.rewrite != "" {
|
||||
// Rewrite the path
|
||||
if strings.HasPrefix(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path)) {
|
||||
r.URL.Path = strings.Replace(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path), proxyRoute.rewrite, 1)
|
||||
}
|
||||
}
|
||||
proxy.ModifyResponse = func(response *http.Response) error {
|
||||
if response.StatusCode < 200 || response.StatusCode >= 300 {
|
||||
//TODO || Add override backend errors | user can enable or disable it
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Custom error handler for proxy errors
|
||||
proxy.ErrorHandler = ProxyErrorHandler
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func isAllowed(cors []string, r *http.Request) bool {
|
||||
for _, origin := range cors {
|
||||
if origin == r.Header.Get("Origin") {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
return false
|
||||
|
||||
}
|
||||
133
pkg/route.go
Normal file
133
pkg/route.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package pkg
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jedib0t/go-pretty/v6/table"
|
||||
"github.com/jkaninda/goma-gateway/internal/logger"
|
||||
"github.com/jkaninda/goma-gateway/pkg/middleware"
|
||||
"github.com/jkaninda/goma-gateway/util"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (gatewayServer GatewayServer) Initialize() *mux.Router {
|
||||
gateway := gatewayServer.gateway
|
||||
middlewares := gatewayServer.middlewares
|
||||
r := mux.NewRouter()
|
||||
heath := HealthCheckRoute{
|
||||
DisableRouteHealthCheckError: gateway.DisableRouteHealthCheckError,
|
||||
Routes: gateway.Routes,
|
||||
}
|
||||
// Define the health check route
|
||||
r.HandleFunc("/health", heath.HealthCheckHandler).Methods("GET")
|
||||
r.HandleFunc("/healthz", heath.HealthCheckHandler).Methods("GET")
|
||||
// Apply global Cors middlewares
|
||||
r.Use(CORSHandler(gateway.Cors)) // Apply CORS middleware
|
||||
if gateway.RateLimiter != 0 {
|
||||
//rateLimiter := middleware.NewRateLimiter(gateway.RateLimiter, time.Minute)
|
||||
limiter := middleware.NewRateLimiterWindow(gateway.RateLimiter, time.Minute) // requests per minute
|
||||
// Add rate limit middleware to all routes, if defined
|
||||
r.Use(limiter.RateLimitMiddleware())
|
||||
}
|
||||
for _, route := range gateway.Routes {
|
||||
blM := middleware.BlockListMiddleware{
|
||||
Path: route.Path,
|
||||
List: route.Blocklist,
|
||||
}
|
||||
// Add block access middleware to all route, if defined
|
||||
r.Use(blM.BlocklistMiddleware)
|
||||
//if route.Middlewares != nil {
|
||||
for _, mid := range route.Middlewares {
|
||||
secureRouter := r.PathPrefix(util.ParseURLPath(route.Path + mid.Path)).Subrouter()
|
||||
proxyRoute := ProxyRoute{
|
||||
path: route.Path,
|
||||
rewrite: route.Rewrite,
|
||||
destination: route.Destination,
|
||||
disableXForward: route.DisableHeaderXForward,
|
||||
cors: route.Cors,
|
||||
}
|
||||
rMiddleware, err := searchMiddleware(mid.Rules, middlewares)
|
||||
if err != nil {
|
||||
logger.Error("Middleware name not found")
|
||||
} else {
|
||||
//Check Authentication middleware
|
||||
switch rMiddleware.Type {
|
||||
case "basic":
|
||||
basicAuth, err := ToBasicAuth(rMiddleware.Rule)
|
||||
if err != nil {
|
||||
|
||||
logger.Error("Error: %s", err.Error())
|
||||
} else {
|
||||
amw := middleware.AuthBasic{
|
||||
Username: basicAuth.Username,
|
||||
Password: basicAuth.Password,
|
||||
Headers: nil,
|
||||
Params: nil,
|
||||
}
|
||||
// Apply JWT authentication middleware
|
||||
secureRouter.Use(amw.AuthMiddleware)
|
||||
}
|
||||
case "jwt":
|
||||
jwt, err := ToJWTRuler(rMiddleware.Rule)
|
||||
if err != nil {
|
||||
|
||||
} else {
|
||||
amw := middleware.AuthJWT{
|
||||
AuthURL: jwt.URL,
|
||||
RequiredHeaders: jwt.RequiredHeaders,
|
||||
Headers: jwt.Headers,
|
||||
Params: jwt.Params,
|
||||
}
|
||||
// Apply JWT authentication middleware
|
||||
secureRouter.Use(amw.AuthMiddleware)
|
||||
|
||||
}
|
||||
default:
|
||||
logger.Error("Unknown middleware type %s", rMiddleware.Type)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
secureRouter.Use(CORSHandler(route.Cors))
|
||||
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
|
||||
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
|
||||
}
|
||||
proxyRoute := ProxyRoute{
|
||||
path: route.Path,
|
||||
rewrite: route.Rewrite,
|
||||
destination: route.Destination,
|
||||
disableXForward: route.DisableHeaderXForward,
|
||||
cors: route.Cors,
|
||||
}
|
||||
|
||||
router := r.PathPrefix(route.Path).Subrouter()
|
||||
router.Use(CORSHandler(route.Cors))
|
||||
router.PathPrefix("/").Handler(proxyRoute.ProxyHandler())
|
||||
}
|
||||
return r
|
||||
|
||||
}
|
||||
|
||||
func printRoute(routes []Route) {
|
||||
t := table.NewWriter()
|
||||
t.AppendHeader(table.Row{"Name", "Route", "Rewrite", "Destination"})
|
||||
for _, route := range routes {
|
||||
t.AppendRow(table.Row{route.Name, route.Path, route.Rewrite, route.Destination})
|
||||
}
|
||||
fmt.Println(t.Render())
|
||||
}
|
||||
68
pkg/server.go
Normal file
68
pkg/server.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package pkg
|
||||
|
||||
/*
|
||||
Copyright 2024 Jonas Kaninda
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/jkaninda/goma-gateway/internal/logger"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (gatewayServer GatewayServer) Start(ctx context.Context) error {
|
||||
logger.Info("Initializing routes...")
|
||||
route := gatewayServer.Initialize()
|
||||
logger.Info("Initializing routes...done")
|
||||
srv := &http.Server{
|
||||
Addr: gatewayServer.gateway.ListenAddr,
|
||||
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
|
||||
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
|
||||
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
|
||||
Handler: route, // Pass our instance of gorilla/mux in.
|
||||
}
|
||||
if !gatewayServer.gateway.DisableDisplayRouteOnStart {
|
||||
printRoute(gatewayServer.gateway.Routes)
|
||||
}
|
||||
go func() {
|
||||
|
||||
logger.Info("Started Goma Gateway server on %v", gatewayServer.gateway.ListenAddr)
|
||||
if err := srv.ListenAndServe(); err != nil {
|
||||
logger.Error("Error starting Goma Gateway server: %v", err)
|
||||
}
|
||||
}()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-ctx.Done()
|
||||
shutdownCtx := context.Background()
|
||||
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
_, err := fmt.Fprintf(os.Stderr, "error shutting down Goma Gateway server: %s\n", err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
return nil
|
||||
|
||||
}
|
||||
52
pkg/server_test.go
Normal file
52
pkg/server_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const testPath = "./tests"
|
||||
|
||||
var configFile = filepath.Join(testPath, "goma.yml")
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
err := os.MkdirAll(testPath, os.ModePerm)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
TestInit(t)
|
||||
initConfig(configFile)
|
||||
g := GatewayServer{}
|
||||
gatewayServer, err := g.Config(configFile)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
route := gatewayServer.Initialize()
|
||||
route.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write([]byte("Hello Goma Proxy"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed writing HTTP response: %v", err)
|
||||
}
|
||||
})
|
||||
assertResponseBody := func(t *testing.T, s *httptest.Server, expectedBody string) {
|
||||
resp, err := s.Client().Get(s.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error getting from server: %v", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("expected a status code of 200, got %v", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
t.Run("httpServer", func(t *testing.T) {
|
||||
s := httptest.NewServer(route)
|
||||
defer s.Close()
|
||||
assertResponseBody(t, s, "Hello Goma Proxy")
|
||||
})
|
||||
|
||||
}
|
||||
3
pkg/var.go
Normal file
3
pkg/var.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package pkg
|
||||
|
||||
const ConfigFile = "/config/goma.yml"
|
||||
Reference in New Issue
Block a user