Initial commit

This commit is contained in:
Jonas Kaninda
2024-10-27 06:10:27 +01:00
commit 1923506e0a
35 changed files with 2592 additions and 0 deletions

362
pkg/config.go Normal file
View File

@@ -0,0 +1,362 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"context"
"fmt"
"github.com/jkaninda/goma-gateway/internal/logger"
"github.com/jkaninda/goma-gateway/util"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
"os"
)
var cfg *Gateway
type Config struct {
file string
}
type BasicRule struct {
Username string `yaml:"username"`
Password string `yaml:"password"`
}
type Cors struct {
// Cors Allowed origins,
//e.g:
//
// - http://localhost:80
//
// - https://example.com
Origins []string `yaml:"origins"`
//
//e.g:
//
//Access-Control-Allow-Origin: '*'
//
// Access-Control-Allow-Methods: 'GET, POST, PUT, DELETE, OPTIONS'
//
// Access-Control-Allow-Cors: 'Content-Type, Authorization'
Headers map[string]string `yaml:"headers"`
}
// JWTRuler authentication using HTTP GET method
//
// JWTRuler contains the authentication details
type JWTRuler struct {
// URL contains the authentication URL, it supports HTTP GET method only.
URL string `yaml:"url"`
// RequiredHeaders , contains required before sending request to the backend.
RequiredHeaders []string `yaml:"requiredHeaders"`
// Headers Add header to the backend from Authentication request's header, depending on your requirements.
// Key is Http's response header Key, and value is the backend Request's header Key.
// In case you want to get headers from Authentication service and inject them to backend request's headers.
Headers map[string]string `yaml:"headers"`
// Params same as Headers, contains the request params.
//
// Gets authentication headers from authentication request and inject them as request params to the backend.
//
// Key is Http's response header Key, and value is the backend Request's request param Key.
//
// In case you want to get headers from Authentication service and inject them to next request's params.
//
//e.g: Header X-Auth-UserId to query userId
Params map[string]string `yaml:"params"`
}
// Middleware defined the route middleware
type Middleware struct {
//Path contains the name of middleware and must be unique
Name string `yaml:"name"`
// Type contains authentication types
//
// basic, jwt, auth0, rateLimit
Type string `yaml:"type"`
// Rule contains rule type of
Rule interface{} `yaml:"rule"`
}
type MiddlewareName struct {
name string `yaml:"name"`
}
type RouteMiddleware struct {
//Path contains the path to protect
Path string `yaml:"path"`
//Rules defines which specific middleware applies to a route path
Rules []string `yaml:"rules"`
}
// Route defines gateway route
type Route struct {
// Name defines route name
Name string `yaml:"name"`
// Path defines route path
Path string `yaml:"path"`
// Rewrite rewrites route path to desired path
//
// E.g. /cart to / => It will rewrite /cart path to /
Rewrite string `yaml:"rewrite"`
// Destination Defines backend URL
Destination string `yaml:"destination"`
// Cors contains the route cors headers
Cors Cors `yaml:"cors"`
// DisableHeaderXForward Disable X-forwarded header.
//
// [X-Forwarded-Host, X-Forwarded-For, Host, Scheme ]
//
// It will not match the backend route
DisableHeaderXForward bool `yaml:"disableHeaderXForward"`
// HealthCheck Defines the backend is health check PATH
HealthCheck string `yaml:"healthCheck"`
// Blocklist Defines route blacklist
Blocklist []string `yaml:"blocklist"`
// Middlewares Defines route middleware from Middleware names
Middlewares []RouteMiddleware `yaml:"middlewares"`
}
// Gateway contains Goma Proxy Gateway's configs
type Gateway struct {
// ListenAddr Defines the server listenAddr
//
//e.g: localhost:8080
ListenAddr string `yaml:"listenAddr" env:"GOMA_LISTEN_ADDR, overwrite"`
// WriteTimeout defines proxy write timeout
WriteTimeout int `yaml:"writeTimeout" env:"GOMA_WRITE_TIMEOUT, overwrite"`
// ReadTimeout defines proxy read timeout
ReadTimeout int `yaml:"readTimeout" env:"GOMA_READ_TIMEOUT, overwrite"`
// IdleTimeout defines proxy idle timeout
IdleTimeout int `yaml:"idleTimeout" env:"GOMA_IDLE_TIMEOUT, overwrite"`
// RateLimiter Defines number of request peer minute
RateLimiter int `yaml:"rateLimiter" env:"GOMA_RATE_LIMITER, overwrite"`
AccessLog string `yaml:"accessLog" env:"GOMA_ACCESS_LOG, overwrite"`
ErrorLog string `yaml:"errorLog" env:"GOMA_ERROR_LOG=, overwrite"`
DisableRouteHealthCheckError bool `yaml:"disableRouteHealthCheckError"`
//Disable dispelling routes on start
DisableDisplayRouteOnStart bool `yaml:"disableDisplayRouteOnStart"`
// Cors contains the proxy global cors
Cors Cors `yaml:"cors"`
// Routes defines the proxy routes
Routes []Route `yaml:"routes"`
}
type GatewayConfig struct {
GatewayConfig Gateway `yaml:"gateway"`
Middlewares []Middleware `yaml:"middlewares"`
}
// ErrorResponse represents the structure of the JSON error response
type ErrorResponse struct {
Success bool `json:"success"`
Code int `json:"code"`
Message string `json:"message"`
}
type GatewayServer struct {
ctx context.Context
gateway Gateway
middlewares []Middleware
}
// Config reads config file and returns Gateway
func (GatewayServer) Config(configFile string) (*GatewayServer, error) {
if util.FileExists(configFile) {
buf, err := os.ReadFile(configFile)
if err != nil {
return nil, err
}
util.SetEnv("GOMA_CONFIG_FILE", configFile)
c := &GatewayConfig{}
err = yaml.Unmarshal(buf, c)
if err != nil {
return nil, fmt.Errorf("in file %q: %w", configFile, err)
}
return &GatewayServer{
ctx: nil,
gateway: c.GatewayConfig,
middlewares: c.Middlewares,
}, nil
}
logger.Error("Configuration file not found: %v", configFile)
logger.Info("Generating new configuration file...")
initConfig(ConfigFile)
logger.Info("Server configuration file is available at %s", ConfigFile)
util.SetEnv("GOMA_CONFIG_FILE", ConfigFile)
buf, err := os.ReadFile(ConfigFile)
if err != nil {
return nil, err
}
c := &GatewayConfig{}
err = yaml.Unmarshal(buf, c)
if err != nil {
return nil, fmt.Errorf("in file %q: %w", ConfigFile, err)
}
logger.Info("Generating new configuration file...done")
logger.Info("Starting server with default configuration")
return &GatewayServer{
ctx: nil,
gateway: c.GatewayConfig,
middlewares: c.Middlewares,
}, nil
}
func GetConfigPaths() string {
return util.GetStringEnv("GOMAY_CONFIG_FILE", ConfigFile)
}
func InitConfig(cmd *cobra.Command) {
configFile, _ := cmd.Flags().GetString("output")
if configFile == "" {
configFile = GetConfigPaths()
}
initConfig(configFile)
return
}
func initConfig(configFile string) {
if configFile == "" {
configFile = GetConfigPaths()
}
conf := &GatewayConfig{
GatewayConfig: Gateway{
ListenAddr: "0.0.0.0:80",
WriteTimeout: 15,
ReadTimeout: 15,
IdleTimeout: 60,
AccessLog: "/dev/Stdout",
ErrorLog: "/dev/stderr",
DisableRouteHealthCheckError: false,
DisableDisplayRouteOnStart: false,
RateLimiter: 0,
Cors: Cors{
Origins: []string{"http://localhost:8080", "https://example.com"},
Headers: map[string]string{
"Access-Control-Allow-Headers": "Origin, Authorization, Accept, Content-Type, Access-Control-Allow-Headers, X-Client-Id, X-Session-Id",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "1728000",
},
},
Routes: []Route{
{
Name: "HealthCheck",
Path: "/healthy",
Destination: "http://localhost:8080",
Rewrite: "/health",
HealthCheck: "",
Cors: Cors{
Headers: map[string]string{
"Access-Control-Allow-Headers": "Origin, Authorization, Accept, Content-Type, Access-Control-Allow-Headers, X-Client-Id, X-Session-Id",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "1728000",
},
},
},
{
Name: "Basic auth",
Path: "/basic",
Destination: "http://localhost:8080",
Rewrite: "/health",
HealthCheck: "",
Blocklist: []string{},
Cors: Cors{},
Middlewares: []RouteMiddleware{
{
Path: "/basic/auth",
Rules: []string{"basic-auth", "google-jwt"},
},
},
},
},
},
Middlewares: []Middleware{
{
Name: "basic-auth",
Type: "basic",
Rule: BasicRule{
Username: "goma",
Password: "goma",
},
}, {
Name: "google-jwt",
Type: "jwt",
Rule: JWTRuler{
URL: "https://www.googleapis.com/auth/userinfo.email",
Headers: map[string]string{},
Params: map[string]string{},
},
},
},
}
yamlData, err := yaml.Marshal(&conf)
if err != nil {
logger.Fatal("Error serializing configuration %v", err.Error())
}
err = os.WriteFile(configFile, yamlData, 0644)
if err != nil {
logger.Fatal("Unable to write config file %s", err)
}
logger.Info("Configuration file has been initialized successfully")
}
func Get() *Gateway {
if cfg == nil {
c := &Gateway{}
c.Setup(GetConfigPaths())
cfg = c
}
return cfg
}
func (Gateway) Setup(conf string) *Gateway {
if util.FileExists(conf) {
buf, err := os.ReadFile(conf)
if err != nil {
return &Gateway{}
}
util.SetEnv("GOMA_CONFIG_FILE", conf)
c := &GatewayConfig{}
err = yaml.Unmarshal(buf, c)
if err != nil {
logger.Fatal("Error loading configuration %v", err.Error())
}
return &c.GatewayConfig
}
return &Gateway{}
}
func (middleware Middleware) name() {
}
func ToJWTRuler(input interface{}) (JWTRuler, error) {
jWTRuler := new(JWTRuler)
var bytes []byte
bytes, err := yaml.Marshal(input)
if err != nil {
return JWTRuler{}, fmt.Errorf("error marshalling yaml: %v", err)
}
err = yaml.Unmarshal(bytes, jWTRuler)
if err != nil {
return JWTRuler{}, fmt.Errorf("error unmarshalling yaml: %v", err)
}
return *jWTRuler, nil
}
func ToBasicAuth(input interface{}) (BasicRule, error) {
basicAuth := new(BasicRule)
var bytes []byte
bytes, err := yaml.Marshal(input)
if err != nil {
return BasicRule{}, fmt.Errorf("error marshalling yaml: %v", err)
}
err = yaml.Unmarshal(bytes, basicAuth)
if err != nil {
return BasicRule{}, fmt.Errorf("error unmarshalling yaml: %v", err)
}
return *basicAuth, nil
}

107
pkg/handler.go Normal file
View File

@@ -0,0 +1,107 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/json"
"github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/internal/logger"
"net/http"
)
// CORSHandler handles CORS headers for incoming requests
//
// Adds CORS headers to the response dynamically based on the provided headers map[string]string
func CORSHandler(cors Cors) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers from the cors config
//Update Cors Headers
for k, v := range cors.Headers {
w.Header().Set(k, v)
}
//Update Origin Cors Headers
for _, origin := range cors.Origins {
if origin == r.Header.Get("Origin") {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
}
// Handle preflight requests (OPTIONS)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
// Pass the request to the next handler
next.ServeHTTP(w, r)
})
}
}
// ProxyErrorHandler catches backend errors and returns a custom response
func ProxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
logger.Error("Proxy error: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadGateway)
err = json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"code": http.StatusBadGateway,
"message": "The service is currently unavailable. Please try again later.",
})
if err != nil {
return
}
return
}
// HealthCheckHandler handles health check of routes
func (heathRoute HealthCheckRoute) HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
logger.Info("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent())
var routes []HealthCheckRouteResponse
for _, route := range heathRoute.Routes {
if route.HealthCheck != "" {
err := HealthCheck(route.Destination + route.HealthCheck)
if err != nil {
logger.Error("Route %s: %v", route.Name, err)
if heathRoute.DisableRouteHealthCheckError {
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "unhealthy", Error: "Route healthcheck errors disabled"})
continue
}
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "unhealthy", Error: err.Error()})
continue
} else {
logger.Info("Route %s is healthy", route.Name)
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "healthy", Error: ""})
continue
}
} else {
logger.Error("Route %s's healthCheck is undefined", route.Name)
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "undefined", Error: ""})
continue
}
}
response := HealthCheckResponse{
Status: "healthy",
Routes: routes,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(response)
if err != nil {
return
}
}

67
pkg/healthCheck.go Normal file
View File

@@ -0,0 +1,67 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"fmt"
"io"
"net/http"
"net/url"
)
type HealthCheckRoute struct {
DisableRouteHealthCheckError bool
Routes []Route
}
// HealthCheckResponse represents the health check response structure
type HealthCheckResponse struct {
Status string `json:"status"`
Routes []HealthCheckRouteResponse `json:"routes"`
}
type HealthCheckRouteResponse struct {
Name string `json:"name"`
Status string `json:"status"`
Error string `json:"error"`
}
func HealthCheck(healthURL string) error {
healthCheckURL, err := url.Parse(healthURL)
if err != nil {
return fmt.Errorf("error parsing HealthCheck URL: %v ", err)
}
// Create a new request for the route
healthReq, err := http.NewRequest("GET", healthCheckURL.String(), nil)
if err != nil {
return fmt.Errorf("error creating HealthCheck request: %v ", err)
}
// Perform the request to the route's healthcheck
client := &http.Client{}
healthResp, err := client.Do(healthReq)
if err != nil {
return fmt.Errorf("error performing HealthCheck request: %v ", err)
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
}
}(healthResp.Body)
if healthResp.StatusCode >= 400 {
return fmt.Errorf("health check failed with status code %v", healthResp.StatusCode)
}
return nil
}

24
pkg/helpers.go Normal file
View File

@@ -0,0 +1,24 @@
package pkg
/*
Copyright 2024 Jonas Kaninda.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may get a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
*/
import (
"fmt"
"github.com/common-nighthawk/go-figure"
"github.com/jkaninda/goma-gateway/util"
)
func Intro() {
nameFigure := figure.NewFigure("Goma", "", true)
nameFigure.Print()
fmt.Printf("Version: %s\n", util.FullVersion())
fmt.Println("Copyright (c) 2024 Jonas Kaninda")
fmt.Println("Starting Goma server...")
}

38
pkg/middleware.go Normal file
View File

@@ -0,0 +1,38 @@
package pkg
import (
"errors"
"github.com/gorilla/mux"
"slices"
"strings"
)
func searchMiddleware(rules []string, middlewares []Middleware) (Middleware, error) {
for _, m := range middlewares {
if slices.Contains(rules, m.Name) {
return m, nil
}
continue
}
return Middleware{}, errors.New("no middleware found with name " + strings.Join(rules, ";"))
}
func getMiddleware(rule string, middlewares []Middleware) (Middleware, error) {
for _, m := range middlewares {
if strings.Contains(rule, m.Name) {
return m, nil
}
continue
}
return Middleware{}, errors.New("no middleware found with name " + rule)
}
type RoutePath struct {
route Route
path string
rules []string
middlewares []Middleware
router *mux.Router
}

View File

@@ -0,0 +1,99 @@
package middleware
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/json"
"fmt"
"github.com/jkaninda/goma-gateway/internal/logger"
"github.com/jkaninda/goma-gateway/util"
"net/http"
"strings"
"time"
)
// BlocklistMiddleware checks if the request path is forbidden and returns 403 Forbidden
func (blockList BlockListMiddleware) BlocklistMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, block := range blockList.List {
if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) {
logger.Error("Access to %s is forbidden", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusNotFound,
Message: fmt.Sprintf("Not found: %s", r.URL.Path),
})
if err != nil {
return
}
return
}
}
next.ServeHTTP(w, r)
})
}
// Helper function to determine if the request path is blocked
func isPathBlocked(requestPath, blockedPath string) bool {
// Handle exact match
if requestPath == blockedPath {
return true
}
// Handle wildcard match (e.g., /admin/* should block /admin and any subpath)
if strings.HasSuffix(blockedPath, "/*") {
basePath := strings.TrimSuffix(blockedPath, "/*")
if strings.HasPrefix(requestPath, basePath) {
return true
}
}
return false
}
// NewRateLimiter creates a new rate limiter with the specified refill rate and token capacity
func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter {
return &TokenRateLimiter{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefill: time.Now(),
}
}
// Allow checks if a request is allowed based on the current token bucket
func (rl *TokenRateLimiter) Allow() bool {
rl.mu.Lock()
defer rl.mu.Unlock()
// Refill tokens based on the time elapsed
now := time.Now()
elapsed := now.Sub(rl.lastRefill)
tokensToAdd := int(elapsed / rl.refillRate)
if tokensToAdd > 0 {
rl.tokens = min(rl.maxTokens, rl.tokens+tokensToAdd)
rl.lastRefill = now
}
// Check if there are enough tokens to allow the request
if rl.tokens > 0 {
rl.tokens--
return true
}
// Reject request if no tokens are available
return false
}

View File

@@ -0,0 +1,276 @@
package middleware
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/base64"
"encoding/json"
"github.com/jkaninda/goma-gateway/internal/logger"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// RateLimiter defines rate limit properties.
type RateLimiter struct {
Requests int
Window time.Duration
ClientMap map[string]*Client
mu sync.Mutex
}
// Client stores request count and window expiration for each client.
type Client struct {
RequestCount int
ExpiresAt time.Time
}
// NewRateLimiterWindow creates a new RateLimiter.
func NewRateLimiterWindow(requests int, window time.Duration) *RateLimiter {
return &RateLimiter{
Requests: requests,
Window: window,
ClientMap: make(map[string]*Client),
}
}
type TokenRateLimiter struct {
tokens int
maxTokens int
refillRate time.Duration
lastRefill time.Time
mu sync.Mutex
}
// ProxyResponseError represents the structure of the JSON error response
type ProxyResponseError struct {
Success bool `json:"success"`
Code int `json:"code"`
Message string `json:"message"`
}
// AuthJWT Define struct
type AuthJWT struct {
AuthURL string
RequiredHeaders []string
Headers map[string]string
Params map[string]string
}
// AuthenticationMiddleware Define struct
type AuthenticationMiddleware struct {
AuthURL string
RequiredHeaders []string
Headers map[string]string
Params map[string]string
}
type BlockListMiddleware struct {
Path string
Destination string
List []string
}
// AuthBasic Define Basic auth
type AuthBasic struct {
Username string
Password string
Headers map[string]string
Params map[string]string
}
// AuthMiddleware function, which will be called for each request
func (amw AuthJWT) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, header := range amw.RequiredHeaders {
if r.Header.Get(header) == "" {
logger.Error("Proxy error, missing %s header", header)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Message: "Missing Authorization header",
Code: http.StatusForbidden,
Success: false,
})
if err != nil {
return
}
return
}
}
//token := r.Header.Get("Authorization")
authURL, err := url.Parse(amw.AuthURL)
if err != nil {
logger.Error("Error parsing auth URL: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
err = json.NewEncoder(w).Encode(ProxyResponseError{
Message: "Internal Server Error",
Code: http.StatusInternalServerError,
Success: false,
})
if err != nil {
return
}
return
}
// Create a new request for /authentication
authReq, err := http.NewRequest("GET", authURL.String(), nil)
if err != nil {
logger.Error("Proxy error creating authentication request: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
err = json.NewEncoder(w).Encode(ProxyResponseError{
Message: "Internal Server Error",
Code: http.StatusInternalServerError,
Success: false,
})
if err != nil {
return
}
return
}
// Copy headers from the original request to the new request
for name, values := range r.Header {
for _, value := range values {
authReq.Header.Set(name, value)
}
}
// Copy cookies from the original request to the new request
for _, cookie := range r.Cookies() {
authReq.AddCookie(cookie)
}
// Perform the request to the auth service
client := &http.Client{}
authResp, err := client.Do(authReq)
if err != nil || authResp.StatusCode != http.StatusOK {
logger.Info("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent())
logger.Error("Proxy authentication error")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err = json.NewEncoder(w).Encode(ProxyResponseError{
Message: "Unauthorized",
Code: http.StatusUnauthorized,
Success: false,
})
if err != nil {
return
}
return
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
}
}(authResp.Body)
// Inject specific header tp the current request's header
// Add header to the next request from AuthRequest header, depending on your requirements
if amw.Headers != nil {
for k, v := range amw.Headers {
r.Header.Set(v, authResp.Header.Get(k))
}
}
query := r.URL.Query()
// Add query parameters to the next request from AuthRequest header, depending on your requirements
if amw.Params != nil {
for k, v := range amw.Params {
query.Set(v, authResp.Header.Get(k))
}
}
r.URL.RawQuery = query.Encode()
next.ServeHTTP(w, r)
})
}
// AuthMiddleware checks for the Authorization header and verifies the credentials
func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
logger.Error("Proxy error, missing Authorization header")
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusUnauthorized,
Message: "Unauthorized",
})
if err != nil {
return
}
return
}
// Check if the Authorization header contains "Basic" scheme
if !strings.HasPrefix(authHeader, "Basic ") {
logger.Error("Proxy error, missing Basic Authorization header")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusUnauthorized,
Message: "Unauthorized",
})
if err != nil {
return
}
return
}
// Decode the base64 encoded username:password string
payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):])
if err != nil {
logger.Error("Proxy error, missing Basic Authorization header")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusUnauthorized,
Message: "Unauthorized",
})
if err != nil {
return
}
return
}
// Split the payload into username and password
pair := strings.SplitN(string(payload), ":", 2)
if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusUnauthorized,
Message: "Unauthorized",
})
if err != nil {
return
}
return
}
// Continue to the next handler if the authentication is successful
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,90 @@
package middleware
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/json"
"github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/internal/logger"
"net/http"
"time"
)
// RateLimitMiddleware limits request based on the number of tokens peer minutes.
func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.Allow() {
// Rate limit exceeded, return a 429 Too Many Requests response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusTooManyRequests,
Message: "Too many requests. Please try again later.",
})
if err != nil {
return
}
return
}
// Proceed to the next handler if rate limit is not exceeded
next.ServeHTTP(w, r)
})
}
}
// RateLimitMiddleware limits request based on the number of requests peer minutes.
func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//TODO:
clientID := r.RemoteAddr
logger.Info(clientID)
rl.mu.Lock()
client, exists := rl.ClientMap[clientID]
if !exists || time.Now().After(client.ExpiresAt) {
client = &Client{
RequestCount: 0,
ExpiresAt: time.Now().Add(rl.Window),
}
rl.ClientMap[clientID] = client
}
client.RequestCount++
rl.mu.Unlock()
if client.RequestCount > rl.Requests {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusTooManyRequests,
Message: "Too many requests. Please try again later.",
})
if err != nil {
return
}
return
}
// Proceed to the next handler if rate limit is not exceeded
next.ServeHTTP(w, r)
})
}
}

110
pkg/middleware_test.go Normal file
View File

@@ -0,0 +1,110 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"fmt"
"gopkg.in/yaml.v3"
"log"
"os"
"testing"
)
const MidName = "google-jwt"
var rules = []string{"fake", "jwt", "google-jwt"}
func TestMiddleware(t *testing.T) {
TestInit(t)
middlewares := []Middleware{
{
Name: "basic-auth",
Type: "basic",
Rule: BasicRule{
Username: "goma",
Password: "goma",
},
}, {
Name: MidName,
Type: "jwt",
Rule: JWTRuler{
URL: "https://www.googleapis.com/auth/userinfo.email",
Headers: map[string]string{},
Params: map[string]string{},
},
},
}
yamlData, err := yaml.Marshal(&middlewares)
if err != nil {
t.Fatalf("Error serializing configuration %v", err.Error())
}
err = os.WriteFile(configFile, yamlData, 0644)
if err != nil {
t.Fatalf("Unable to write config file %s", err)
}
log.Printf("Config file written to %s", configFile)
}
func TestReadMiddleware(t *testing.T) {
TestMiddleware(t)
middlewares := getMiddlewares(t)
middleware, err := searchMiddleware(rules, middlewares)
if err != nil {
t.Fatalf("Error searching middleware %s", err.Error())
}
switch middleware.Type {
case "basic":
log.Println("Basic auth")
basicAuth, err := ToBasicAuth(middleware.Rule)
if err != nil {
log.Fatalln("error:", err)
}
log.Printf("Username: %s and password: %s\n", basicAuth.Username, basicAuth.Password)
case "jwt":
log.Println("JWT auth")
jwt, err := ToJWTRuler(middleware.Rule)
if err != nil {
log.Fatalln("error:", err)
}
log.Printf("JWT authentification URL is %s\n", jwt.URL)
default:
t.Errorf("Unknown middleware type %s", middleware.Type)
}
}
func TestFoundMiddleware(t *testing.T) {
middlewares := getMiddlewares(t)
middleware, err := searchMiddleware(rules, middlewares)
if err != nil {
t.Errorf("Error getting middleware %v", err)
}
fmt.Println(middleware.Type)
}
func getMiddlewares(t *testing.T) []Middleware {
buf, err := os.ReadFile(configFile)
if err != nil {
t.Fatalf("Unable to read config file %s", configFile)
}
c := &[]Middleware{}
err = yaml.Unmarshal(buf, c)
if err != nil {
t.Fatalf("Unable to parse config file %s", configFile)
}
return *c
}

113
pkg/proxy.go Normal file
View File

@@ -0,0 +1,113 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/json"
"fmt"
"github.com/jkaninda/goma-gateway/internal/logger"
"net/http"
"net/http/httputil"
"net/url"
"strings"
)
type ProxyRoute struct {
path string
rewrite string
destination string
cors Cors
disableXForward bool
}
// ProxyHandler proxies requests to the backend
func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger.Info("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent())
// Set CORS headers from the cors config
//Update Cors Headers
for k, v := range proxyRoute.cors.Headers {
w.Header().Set(k, v)
}
//Update Origin Cors Headers
for _, origin := range proxyRoute.cors.Origins {
if origin == r.Header.Get("Origin") {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
}
// Handle preflight requests (OPTIONS)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
// Parse the target backend URL
targetURL, err := url.Parse(proxyRoute.destination)
if err != nil {
logger.Error("Error parsing backend URL: %s", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
err := json.NewEncoder(w).Encode(ErrorResponse{
Message: "Internal server error",
Code: http.StatusInternalServerError,
Success: false,
})
if err != nil {
return
}
return
}
// Update the headers to allow for SSL redirection
if !proxyRoute.disableXForward {
r.URL.Host = targetURL.Host
r.URL.Scheme = targetURL.Scheme
r.Header.Set("X-Forwarded-Host", r.Header.Get("Host"))
r.Header.Set("X-Forwarded-For", r.RemoteAddr)
r.Header.Set("X-Real-IP", r.RemoteAddr)
r.Host = targetURL.Host
}
// Create proxy
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Rewrite
if proxyRoute.path != "" && proxyRoute.rewrite != "" {
// Rewrite the path
if strings.HasPrefix(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path)) {
r.URL.Path = strings.Replace(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path), proxyRoute.rewrite, 1)
}
}
proxy.ModifyResponse = func(response *http.Response) error {
if response.StatusCode < 200 || response.StatusCode >= 300 {
//TODO || Add override backend errors | user can enable or disable it
}
return nil
}
// Custom error handler for proxy errors
proxy.ErrorHandler = ProxyErrorHandler
proxy.ServeHTTP(w, r)
}
}
func isAllowed(cors []string, r *http.Request) bool {
for _, origin := range cors {
if origin == r.Header.Get("Origin") {
return true
}
continue
}
return false
}

133
pkg/route.go Normal file
View File

@@ -0,0 +1,133 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"fmt"
"github.com/gorilla/mux"
"github.com/jedib0t/go-pretty/v6/table"
"github.com/jkaninda/goma-gateway/internal/logger"
"github.com/jkaninda/goma-gateway/pkg/middleware"
"github.com/jkaninda/goma-gateway/util"
"time"
)
func (gatewayServer GatewayServer) Initialize() *mux.Router {
gateway := gatewayServer.gateway
middlewares := gatewayServer.middlewares
r := mux.NewRouter()
heath := HealthCheckRoute{
DisableRouteHealthCheckError: gateway.DisableRouteHealthCheckError,
Routes: gateway.Routes,
}
// Define the health check route
r.HandleFunc("/health", heath.HealthCheckHandler).Methods("GET")
r.HandleFunc("/healthz", heath.HealthCheckHandler).Methods("GET")
// Apply global Cors middlewares
r.Use(CORSHandler(gateway.Cors)) // Apply CORS middleware
if gateway.RateLimiter != 0 {
//rateLimiter := middleware.NewRateLimiter(gateway.RateLimiter, time.Minute)
limiter := middleware.NewRateLimiterWindow(gateway.RateLimiter, time.Minute) // requests per minute
// Add rate limit middleware to all routes, if defined
r.Use(limiter.RateLimitMiddleware())
}
for _, route := range gateway.Routes {
blM := middleware.BlockListMiddleware{
Path: route.Path,
List: route.Blocklist,
}
// Add block access middleware to all route, if defined
r.Use(blM.BlocklistMiddleware)
//if route.Middlewares != nil {
for _, mid := range route.Middlewares {
secureRouter := r.PathPrefix(util.ParseURLPath(route.Path + mid.Path)).Subrouter()
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
disableXForward: route.DisableHeaderXForward,
cors: route.Cors,
}
rMiddleware, err := searchMiddleware(mid.Rules, middlewares)
if err != nil {
logger.Error("Middleware name not found")
} else {
//Check Authentication middleware
switch rMiddleware.Type {
case "basic":
basicAuth, err := ToBasicAuth(rMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
amw := middleware.AuthBasic{
Username: basicAuth.Username,
Password: basicAuth.Password,
Headers: nil,
Params: nil,
}
// Apply JWT authentication middleware
secureRouter.Use(amw.AuthMiddleware)
}
case "jwt":
jwt, err := ToJWTRuler(rMiddleware.Rule)
if err != nil {
} else {
amw := middleware.AuthJWT{
AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers,
Params: jwt.Params,
}
// Apply JWT authentication middleware
secureRouter.Use(amw.AuthMiddleware)
}
default:
logger.Error("Unknown middleware type %s", rMiddleware.Type)
}
}
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
}
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
disableXForward: route.DisableHeaderXForward,
cors: route.Cors,
}
router := r.PathPrefix(route.Path).Subrouter()
router.Use(CORSHandler(route.Cors))
router.PathPrefix("/").Handler(proxyRoute.ProxyHandler())
}
return r
}
func printRoute(routes []Route) {
t := table.NewWriter()
t.AppendHeader(table.Row{"Name", "Route", "Rewrite", "Destination"})
for _, route := range routes {
t.AppendRow(table.Row{route.Name, route.Path, route.Rewrite, route.Destination})
}
fmt.Println(t.Render())
}

68
pkg/server.go Normal file
View File

@@ -0,0 +1,68 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"context"
"fmt"
"github.com/jkaninda/goma-gateway/internal/logger"
"net/http"
"os"
"sync"
"time"
)
func (gatewayServer GatewayServer) Start(ctx context.Context) error {
logger.Info("Initializing routes...")
route := gatewayServer.Initialize()
logger.Info("Initializing routes...done")
srv := &http.Server{
Addr: gatewayServer.gateway.ListenAddr,
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: route, // Pass our instance of gorilla/mux in.
}
if !gatewayServer.gateway.DisableDisplayRouteOnStart {
printRoute(gatewayServer.gateway.Routes)
}
go func() {
logger.Info("Started Goma Gateway server on %v", gatewayServer.gateway.ListenAddr)
if err := srv.ListenAndServe(); err != nil {
logger.Error("Error starting Goma Gateway server: %v", err)
}
}()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
_, err := fmt.Fprintf(os.Stderr, "error shutting down Goma Gateway server: %s\n", err)
if err != nil {
return
}
}
}()
wg.Wait()
return nil
}

52
pkg/server_test.go Normal file
View File

@@ -0,0 +1,52 @@
package pkg
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
const testPath = "./tests"
var configFile = filepath.Join(testPath, "goma.yml")
func TestInit(t *testing.T) {
err := os.MkdirAll(testPath, os.ModePerm)
if err != nil {
t.Error(err)
}
}
func TestStart(t *testing.T) {
TestInit(t)
initConfig(configFile)
g := GatewayServer{}
gatewayServer, err := g.Config(configFile)
if err != nil {
t.Error(err)
}
route := gatewayServer.Initialize()
route.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write([]byte("Hello Goma Proxy"))
if err != nil {
t.Fatalf("Failed writing HTTP response: %v", err)
}
})
assertResponseBody := func(t *testing.T, s *httptest.Server, expectedBody string) {
resp, err := s.Client().Get(s.URL)
if err != nil {
t.Fatalf("unexpected error getting from server: %v", err)
}
if resp.StatusCode != 200 {
t.Fatalf("expected a status code of 200, got %v", resp.StatusCode)
}
}
t.Run("httpServer", func(t *testing.T) {
s := httptest.NewServer(route)
defer s.Close()
assertResponseBody(t, s, "Hello Goma Proxy")
})
}

3
pkg/var.go Normal file
View File

@@ -0,0 +1,3 @@
package pkg
const ConfigFile = "/config/goma.yml"