refactor: Restructure project files for better organization, readability, and maintainability

This commit is contained in:
2024-11-04 08:48:38 +01:00
parent 096290bcb8
commit 4e49102433
22 changed files with 17 additions and 17 deletions

253
internal/config.go Normal file
View File

@@ -0,0 +1,253 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"fmt"
"github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
"os"
)
var cfg *Gateway
// Config reads config file and returns Gateway
func (GatewayServer) Config(configFile string) (*GatewayServer, error) {
if util.FileExists(configFile) {
buf, err := os.ReadFile(configFile)
if err != nil {
return nil, err
}
util.SetEnv("GOMA_CONFIG_FILE", configFile)
c := &GatewayConfig{}
err = yaml.Unmarshal(buf, c)
if err != nil {
return nil, fmt.Errorf("error parsing yaml %q: %w", configFile, err)
}
return &GatewayServer{
ctx: nil,
gateway: c.GatewayConfig,
middlewares: c.Middlewares,
}, nil
}
logger.Error("Configuration file not found: %v", configFile)
logger.Info("Generating new configuration file...")
initConfig(ConfigFile)
logger.Info("Server configuration file is available at %s", ConfigFile)
util.SetEnv("GOMA_CONFIG_FILE", ConfigFile)
buf, err := os.ReadFile(ConfigFile)
if err != nil {
return nil, err
}
c := &GatewayConfig{}
err = yaml.Unmarshal(buf, c)
if err != nil {
return nil, fmt.Errorf("in file %q: %w", ConfigFile, err)
}
logger.Info("Generating new configuration file...done")
logger.Info("Starting server with default configuration")
return &GatewayServer{
ctx: nil,
gateway: c.GatewayConfig,
middlewares: c.Middlewares,
}, nil
}
func GetConfigPaths() string {
return util.GetStringEnv("GOMAY_CONFIG_FILE", ConfigFile)
}
func InitConfig(cmd *cobra.Command) {
configFile, _ := cmd.Flags().GetString("output")
if configFile == "" {
configFile = GetConfigPaths()
}
initConfig(configFile)
return
}
func initConfig(configFile string) {
if configFile == "" {
configFile = GetConfigPaths()
}
conf := &GatewayConfig{
GatewayConfig: Gateway{
ListenAddr: "0.0.0.0:80",
WriteTimeout: 15,
ReadTimeout: 15,
IdleTimeout: 60,
AccessLog: "/dev/Stdout",
ErrorLog: "/dev/stderr",
DisableRouteHealthCheckError: false,
DisableDisplayRouteOnStart: false,
RateLimiter: 0,
InterceptErrors: []int{405, 500},
Cors: Cors{
Origins: []string{"http://localhost:8080", "https://example.com"},
Headers: map[string]string{
"Access-Control-Allow-Headers": "Origin, Authorization, Accept, Content-Type, Access-Control-Allow-Headers, X-Client-Id, X-Session-Id",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "1728000",
},
},
Routes: []Route{
{
Name: "Public",
Path: "/public",
Destination: "https://example.com",
Rewrite: "/",
HealthCheck: "",
Middlewares: []string{"api-forbidden-paths"},
},
{
Name: "Basic auth",
Path: "/protected",
Destination: "https://example.com",
Rewrite: "/",
HealthCheck: "",
Cors: Cors{
Origins: []string{"http://localhost:3000", "https://dev.example.com"},
Headers: map[string]string{
"Access-Control-Allow-Headers": "Origin, Authorization",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "1728000",
},
},
Middlewares: []string{"basic-auth", "api-forbidden-paths"},
},
{
Name: "Hostname example",
Host: "http://example.localhost",
Path: "/",
Destination: "https://example.com",
Rewrite: "/",
HealthCheck: "",
},
},
},
Middlewares: []Middleware{
{
Name: "basic-auth",
Type: BasicAuth,
Paths: []string{
"/*",
},
Rule: BasicRuleMiddleware{
Username: "admin",
Password: "admin",
},
}, {
Name: "jwt",
Type: JWTAuth,
Paths: []string{
"/protected-access",
"/example-of-jwt",
},
Rule: JWTRuleMiddleware{
URL: "https://www.googleapis.com/auth/userinfo.email",
RequiredHeaders: []string{
"Authorization",
},
Headers: map[string]string{},
Params: map[string]string{},
},
},
{
Name: "api-forbidden-paths",
Type: AccessMiddleware,
Paths: []string{
"/swagger-ui/*",
"/v2/swagger-ui/*",
"/api-docs/*",
"/internal/*",
"/actuator/*",
},
},
},
}
yamlData, err := yaml.Marshal(&conf)
if err != nil {
logger.Fatal("Error serializing configuration %v", err.Error())
}
err = os.WriteFile(configFile, yamlData, 0644)
if err != nil {
logger.Fatal("Unable to write config file %s", err)
}
logger.Info("Configuration file has been initialized successfully")
}
func Get() *Gateway {
if cfg == nil {
c := &Gateway{}
c.Setup(GetConfigPaths())
cfg = c
}
return cfg
}
func (Gateway) Setup(conf string) *Gateway {
if util.FileExists(conf) {
buf, err := os.ReadFile(conf)
if err != nil {
return &Gateway{}
}
util.SetEnv("GOMA_CONFIG_FILE", conf)
c := &GatewayConfig{}
err = yaml.Unmarshal(buf, c)
if err != nil {
logger.Fatal("Error loading configuration %v", err.Error())
}
return &c.GatewayConfig
}
return &Gateway{}
}
// getJWTMiddleware returns JWTRuleMiddleware,error
func getJWTMiddleware(input interface{}) (JWTRuleMiddleware, error) {
jWTRuler := new(JWTRuleMiddleware)
var bytes []byte
bytes, err := yaml.Marshal(input)
if err != nil {
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
}
err = yaml.Unmarshal(bytes, jWTRuler)
if err != nil {
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
}
if jWTRuler.URL == "" {
return JWTRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty url in jwt auth middleware")
}
return *jWTRuler, nil
}
// getBasicAuthMiddleware returns BasicRuleMiddleware,error
func getBasicAuthMiddleware(input interface{}) (BasicRuleMiddleware, error) {
basicAuth := new(BasicRuleMiddleware)
var bytes []byte
bytes, err := yaml.Marshal(input)
if err != nil {
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
}
err = yaml.Unmarshal(bytes, basicAuth)
if err != nil {
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: %v", err)
}
if basicAuth.Username == "" || basicAuth.Password == "" {
return BasicRuleMiddleware{}, fmt.Errorf("error parsing yaml: empty username/password in %s middleware", basicAuth)
}
return *basicAuth, nil
}

132
internal/handler.go Normal file
View File

@@ -0,0 +1,132 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/json"
"github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/pkg/logger"
"net/http"
"sync"
)
// CORSHandler handles CORS headers for incoming requests
//
// Adds CORS headers to the response dynamically based on the provided headers map[string]string
func CORSHandler(cors Cors) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers from the cors config
//Update Cors Headers
for k, v := range cors.Headers {
w.Header().Set(k, v)
}
//Update Origin Cors Headers
if allowedOrigin(cors.Origins, r.Header.Get("Origin")) {
// Handle preflight requests (OPTIONS)
if r.Method == "OPTIONS" {
w.Header().Set(accessControlAllowOrigin, r.Header.Get("Origin"))
w.WriteHeader(http.StatusNoContent)
return
} else {
w.Header().Set(accessControlAllowOrigin, r.Header.Get("Origin"))
}
}
// Pass the request to the next handler
next.ServeHTTP(w, r)
})
}
}
// ProxyErrorHandler catches backend errors and returns a custom response
func ProxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
logger.Error("Proxy error: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadGateway)
err = json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"code": http.StatusBadGateway,
"message": "The service is currently unavailable. Please try again later.",
})
if err != nil {
return
}
return
}
// HealthCheckHandler handles health check of routes
func (heathRoute HealthCheckRoute) HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
logger.Info("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent())
wg := sync.WaitGroup{}
wg.Add(len(heathRoute.Routes))
var routes []HealthCheckRouteResponse
for _, route := range heathRoute.Routes {
go func() {
if route.HealthCheck != "" {
err := HealthCheck(route.Destination + route.HealthCheck)
if err != nil {
if heathRoute.DisableRouteHealthCheckError {
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "unhealthy", Error: "Route healthcheck errors disabled"})
}
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "unhealthy", Error: "Error: " + err.Error()})
} else {
logger.Info("Route %s is healthy", route.Name)
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "healthy", Error: ""})
}
} else {
logger.Warn("Route %s's healthCheck is undefined", route.Name)
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "undefined", Error: ""})
}
defer wg.Done()
}()
}
wg.Wait() // Wait for all requests to complete
response := HealthCheckResponse{
Status: "healthy", //Goma proxy
Routes: routes, // Routes health check
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(response)
if err != nil {
return
}
}
func (heathRoute HealthCheckRoute) HealthReadyHandler(w http.ResponseWriter, r *http.Request) {
logger.Info("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent())
response := HealthCheckRouteResponse{
Name: "Goma Gateway",
Status: "healthy",
Error: "",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(response)
if err != nil {
return
}
}
func allowedOrigin(origins []string, origin string) bool {
for _, o := range origins {
if o == origin {
return true
}
continue
}
return false
}

54
internal/healthCheck.go Normal file
View File

@@ -0,0 +1,54 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"fmt"
"github.com/jkaninda/goma-gateway/pkg/logger"
"io"
"net/http"
"net/url"
)
func HealthCheck(healthURL string) error {
healthCheckURL, err := url.Parse(healthURL)
if err != nil {
return fmt.Errorf("error parsing HealthCheck URL: %v ", err)
}
// Create a new request for the route
healthReq, err := http.NewRequest("GET", healthCheckURL.String(), nil)
if err != nil {
return fmt.Errorf("error creating HealthCheck request: %v ", err)
}
// Perform the request to the route's healthcheck
client := &http.Client{}
healthResp, err := client.Do(healthReq)
if err != nil {
logger.Error("Error performing HealthCheck request: %v ", err)
return fmt.Errorf("error performing HealthCheck request: %v ", err)
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
}
}(healthResp.Body)
if healthResp.StatusCode >= 400 {
logger.Debug("Error performing HealthCheck request: %v ", err)
return fmt.Errorf("health check failed with status code %v", healthResp.StatusCode)
}
return nil
}

34
internal/helpers.go Normal file
View File

@@ -0,0 +1,34 @@
package pkg
/*
Copyright 2024 Jonas Kaninda.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may get a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
*/
import (
"fmt"
"github.com/jedib0t/go-pretty/v6/table"
"net/http"
)
func printRoute(routes []Route) {
t := table.NewWriter()
t.AppendHeader(table.Row{"Name", "Route", "Rewrite", "Destination"})
for _, route := range routes {
t.AppendRow(table.Row{route.Name, route.Path, route.Rewrite, route.Destination})
}
fmt.Println(t.Render())
}
func getRealIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
return r.RemoteAddr
}

View File

@@ -1,97 +0,0 @@
package logger
/*
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"fmt"
"github.com/jkaninda/goma-gateway/util"
"log"
"os"
)
type Logger struct {
msg string
args interface{}
}
// Info returns info log
func Info(msg string, args ...interface{}) {
log.SetOutput(getStd(util.GetStringEnv("GOMA_ACCESS_LOG", "/dev/stdout")))
formattedMessage := fmt.Sprintf(msg, args...)
if len(args) == 0 {
log.Printf("INFO: %s\n", msg)
} else {
log.Printf("INFO: %s\n", formattedMessage)
}
}
// Warn returns warning log
func Warn(msg string, args ...interface{}) {
log.SetOutput(getStd(util.GetStringEnv("GOMA_ACCESS_LOG", "/dev/stdout")))
formattedMessage := fmt.Sprintf(msg, args...)
if len(args) == 0 {
log.Printf("WARN: %s\n", msg)
} else {
log.Printf("WARN: %s\n", formattedMessage)
}
}
// Error error message
func Error(msg string, args ...interface{}) {
log.SetOutput(getStd(util.GetStringEnv("GOMA_ERROR_LOG", "/dev/stderr")))
formattedMessage := fmt.Sprintf(msg, args...)
if len(args) == 0 {
log.Printf("ERROR: %s\n", msg)
} else {
log.Printf("ERROR: %s\n", formattedMessage)
}
}
func Fatal(msg string, args ...interface{}) {
log.SetOutput(os.Stdout)
formattedMessage := fmt.Sprintf(msg, args...)
if len(args) == 0 {
log.Printf("ERROR: %s\n", msg)
} else {
log.Printf("ERROR: %s\n", formattedMessage)
}
os.Exit(1)
}
func Debug(msg string, args ...interface{}) {
log.SetOutput(getStd(util.GetStringEnv("GOMA_ACCESS_LOG", "/dev/stdout")))
formattedMessage := fmt.Sprintf(msg, args...)
if len(args) == 0 {
log.Printf("DEBUG: %s\n", msg)
} else {
log.Printf("DEBUG: %s\n", formattedMessage)
}
}
func getStd(out string) *os.File {
switch out {
case "/dev/stdout":
return os.Stdout
case "/dev/stderr":
return os.Stderr
case "/dev/stdin":
return os.Stdin
default:
return os.Stdout
}
}

38
internal/middleware.go Normal file
View File

@@ -0,0 +1,38 @@
package pkg
import (
"errors"
"slices"
"strings"
)
func getMiddleware(rules []string, middlewares []Middleware) (Middleware, error) {
for _, m := range middlewares {
if slices.Contains(rules, m.Name) {
return m, nil
}
continue
}
return Middleware{}, errors.New("middleware not found with name: [" + strings.Join(rules, ";") + "]")
}
func doesExist(tyName string) bool {
middlewareList := []string{BasicAuth, JWTAuth, AccessMiddleware}
if slices.Contains(middlewareList, tyName) {
return true
}
return false
}
func GetMiddleware(rule string, middlewares []Middleware) (Middleware, error) {
for _, m := range middlewares {
if strings.Contains(rule, m.Name) {
return m, nil
}
continue
}
return Middleware{}, errors.New("no middleware found with name " + rule)
}

View File

@@ -0,0 +1,99 @@
package middleware
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/json"
"fmt"
"github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util"
"net/http"
"strings"
"time"
)
// AccessMiddleware checks if the request path is forbidden and returns 403 Forbidden
func (blockList AccessListMiddleware) AccessMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, block := range blockList.List {
if isPathBlocked(r.URL.Path, util.ParseURLPath(blockList.Path+block)) {
logger.Error("%s: %s access forbidden", getRealIP(r), r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusForbidden,
Message: fmt.Sprintf("You do not have permission to access this resource"),
})
if err != nil {
return
}
return
}
}
next.ServeHTTP(w, r)
})
}
// Helper function to determine if the request path is blocked
func isPathBlocked(requestPath, blockedPath string) bool {
// Handle exact match
if requestPath == blockedPath {
return true
}
// Handle wildcard match (e.g., /admin/* should block /admin and any subpath)
if strings.HasSuffix(blockedPath, "/*") {
basePath := strings.TrimSuffix(blockedPath, "/*")
if strings.HasPrefix(requestPath, basePath) {
return true
}
}
return false
}
// NewRateLimiter creates a new rate limiter with the specified refill rate and token capacity
func NewRateLimiter(maxTokens int, refillRate time.Duration) *TokenRateLimiter {
return &TokenRateLimiter{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefill: time.Now(),
}
}
// Allow checks if a request is allowed based on the current token bucket
func (rl *TokenRateLimiter) Allow() bool {
rl.mu.Lock()
defer rl.mu.Unlock()
// Refill tokens based on the time elapsed
now := time.Now()
elapsed := now.Sub(rl.lastRefill)
tokensToAdd := int(elapsed / rl.refillRate)
if tokensToAdd > 0 {
rl.tokens = min(rl.maxTokens, rl.tokens+tokensToAdd)
rl.lastRefill = now
}
// Check if there are enough tokens to allow the request
if rl.tokens > 0 {
rl.tokens--
return true
}
// Reject request if no tokens are available
return false
}

View File

@@ -0,0 +1,82 @@
package middleware
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
import (
"bytes"
"encoding/json"
"github.com/jkaninda/goma-gateway/pkg/logger"
"io"
"net/http"
)
func newResponseRecorder(w http.ResponseWriter) *responseRecorder {
return &responseRecorder{
ResponseWriter: w,
statusCode: http.StatusOK,
body: &bytes.Buffer{},
}
}
func (rec *responseRecorder) WriteHeader(code int) {
rec.statusCode = code
}
func (rec *responseRecorder) Write(data []byte) (int, error) {
return rec.body.Write(data)
}
// ErrorInterceptor Middleware intercepts backend errors
func (intercept InterceptErrors) ErrorInterceptor(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rec := newResponseRecorder(w)
next.ServeHTTP(rec, r)
if canIntercept(rec.statusCode, intercept.Errors) {
logger.Error("Backend error")
logger.Error("An error occurred from the backend with the status code: %d", rec.statusCode)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(rec.statusCode)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: rec.statusCode,
Message: http.StatusText(rec.statusCode),
})
if err != nil {
return
}
} else {
// No error: write buffered response to client
w.WriteHeader(rec.statusCode)
_, err := io.Copy(w, rec.body)
if err != nil {
return
}
}
})
}
func canIntercept(code int, errors []int) bool {
for _, er := range errors {
if er == code {
return true
}
continue
}
return false
}

View File

@@ -0,0 +1,209 @@
package middleware
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/base64"
"encoding/json"
"github.com/jkaninda/goma-gateway/pkg/logger"
"io"
"net/http"
"net/url"
"strings"
)
// AuthMiddleware authenticate the client using JWT
//
// authorization based on the result of backend's response and continue the request when the client is authorized
func (jwtAuth JwtAuth) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, header := range jwtAuth.RequiredHeaders {
if r.Header.Get(header) == "" {
logger.Error("Proxy error, missing %s header", header)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Message: http.StatusText(http.StatusUnauthorized),
Code: http.StatusUnauthorized,
Success: false,
})
if err != nil {
return
}
return
}
}
//token := r.Header.Get("Authorization")
authURL, err := url.Parse(jwtAuth.AuthURL)
if err != nil {
logger.Error("Error parsing auth URL: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
err = json.NewEncoder(w).Encode(ProxyResponseError{
Message: "Internal Server Error",
Code: http.StatusInternalServerError,
Success: false,
})
if err != nil {
return
}
return
}
// Create a new request for /authentication
authReq, err := http.NewRequest("GET", authURL.String(), nil)
if err != nil {
logger.Error("Proxy error creating authentication request: %v", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
err = json.NewEncoder(w).Encode(ProxyResponseError{
Message: "Internal Server Error",
Code: http.StatusInternalServerError,
Success: false,
})
if err != nil {
return
}
return
}
// Copy headers from the original request to the new request
for name, values := range r.Header {
for _, value := range values {
authReq.Header.Set(name, value)
}
}
// Copy cookies from the original request to the new request
for _, cookie := range r.Cookies() {
authReq.AddCookie(cookie)
}
// Perform the request to the auth service
client := &http.Client{}
authResp, err := client.Do(authReq)
if err != nil || authResp.StatusCode != http.StatusOK {
logger.Info("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent())
logger.Warn("Proxy authentication error")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err = json.NewEncoder(w).Encode(ProxyResponseError{
Message: "Unauthorized",
Code: http.StatusUnauthorized,
Success: false,
})
if err != nil {
return
}
return
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
}
}(authResp.Body)
// Inject specific header tp the current request's header
// Add header to the next request from AuthRequest header, depending on your requirements
if jwtAuth.Headers != nil {
for k, v := range jwtAuth.Headers {
r.Header.Set(v, authResp.Header.Get(k))
}
}
query := r.URL.Query()
// Add query parameters to the next request from AuthRequest header, depending on your requirements
if jwtAuth.Params != nil {
for k, v := range jwtAuth.Params {
query.Set(v, authResp.Header.Get(k))
}
}
r.URL.RawQuery = query.Encode()
next.ServeHTTP(w, r)
})
}
// AuthMiddleware checks for the Authorization header and verifies the credentials
func (basicAuth AuthBasic) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
logger.Error("Proxy error, missing Authorization header")
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusUnauthorized,
Message: http.StatusText(http.StatusUnauthorized),
})
if err != nil {
return
}
return
}
// Check if the Authorization header contains "Basic" scheme
if !strings.HasPrefix(authHeader, "Basic ") {
logger.Error("Proxy error, missing Basic Authorization header")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusUnauthorized,
Message: http.StatusText(http.StatusUnauthorized),
})
if err != nil {
return
}
return
}
// Decode the base64 encoded username:password string
payload, err := base64.StdEncoding.DecodeString(authHeader[len("Basic "):])
if err != nil {
logger.Error("Proxy error, missing Basic Authorization header")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusUnauthorized,
Message: http.StatusText(http.StatusUnauthorized),
})
if err != nil {
return
}
return
}
// Split the payload into username and password
pair := strings.SplitN(string(payload), ":", 2)
if len(pair) != 2 || pair[0] != basicAuth.Username || pair[1] != basicAuth.Password {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusUnauthorized,
Message: http.StatusText(http.StatusUnauthorized),
})
if err != nil {
return
}
return
}
// Continue to the next handler if the authentication is successful
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,95 @@
package middleware
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/json"
"github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/pkg/logger"
"net/http"
"time"
)
// RateLimitMiddleware limits request based on the number of tokens peer minutes.
func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.Allow() {
// Rate limit exceeded, return a 429 Too Many Requests response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusTooManyRequests,
Message: "Too many requests, API rate limit exceeded. Please try again later.",
})
if err != nil {
return
}
return
}
// Proceed to the next handler if rate limit is not exceeded
next.ServeHTTP(w, r)
})
}
}
// RateLimitMiddleware limits request based on the number of requests peer minutes.
func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientID := getRealIP(r)
rl.mu.Lock()
client, exists := rl.ClientMap[clientID]
if !exists || time.Now().After(client.ExpiresAt) {
client = &Client{
RequestCount: 0,
ExpiresAt: time.Now().Add(rl.Window),
}
rl.ClientMap[clientID] = client
}
client.RequestCount++
rl.mu.Unlock()
if client.RequestCount > rl.Requests {
logger.Error("Too many requests from IP: %s %s %s", clientID, r.URL, r.UserAgent())
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
err := json.NewEncoder(w).Encode(ProxyResponseError{
Success: false,
Code: http.StatusTooManyRequests,
Message: "Too many requests, API rate limit exceeded. Please try again later.",
})
if err != nil {
return
}
return
}
// Proceed to the next handler if rate limit is not exceeded
next.ServeHTTP(w, r)
})
}
}
func getRealIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
return r.RemoteAddr
}

View File

@@ -0,0 +1,105 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package middleware
import (
"bytes"
"net/http"
"sync"
"time"
)
// RateLimiter defines rate limit properties.
type RateLimiter struct {
Requests int
Window time.Duration
ClientMap map[string]*Client
mu sync.Mutex
}
// Client stores request count and window expiration for each client.
type Client struct {
RequestCount int
ExpiresAt time.Time
}
// NewRateLimiterWindow creates a new RateLimiter.
func NewRateLimiterWindow(requests int, window time.Duration) *RateLimiter {
return &RateLimiter{
Requests: requests,
Window: window,
ClientMap: make(map[string]*Client),
}
}
// TokenRateLimiter stores tokenRate limit
type TokenRateLimiter struct {
tokens int
maxTokens int
refillRate time.Duration
lastRefill time.Time
mu sync.Mutex
}
// ProxyResponseError represents the structure of the JSON error response
type ProxyResponseError struct {
Success bool `json:"success"`
Code int `json:"code"`
Message string `json:"message"`
}
// JwtAuth stores JWT configuration
type JwtAuth struct {
AuthURL string
RequiredHeaders []string
Headers map[string]string
Params map[string]string
}
// AuthenticationMiddleware Define struct
type AuthenticationMiddleware struct {
AuthURL string
RequiredHeaders []string
Headers map[string]string
Params map[string]string
}
type AccessListMiddleware struct {
Path string
Destination string
List []string
}
// AuthBasic contains Basic auth configuration
type AuthBasic struct {
Username string
Password string
Headers map[string]string
Params map[string]string
}
// InterceptErrors contains backend status code errors to intercept
type InterceptErrors struct {
Errors []int
}
// responseRecorder intercepts the response body and status code
type responseRecorder struct {
http.ResponseWriter
statusCode int
body *bytes.Buffer
}

123
internal/middleware_test.go Normal file
View File

@@ -0,0 +1,123 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"fmt"
"gopkg.in/yaml.v3"
"log"
"os"
"testing"
)
const MidName = "google-jwt"
var rules = []string{"fake", "jwt", "google-jwt"}
func TestMiddleware(t *testing.T) {
TestInit(t)
middlewares := []Middleware{
{
Name: "basic-auth",
Type: "basic",
Paths: []string{"/", "/admin"},
Rule: BasicRuleMiddleware{
Username: "goma",
Password: "goma",
},
},
{
Name: "forbidden path access",
Type: "access",
Paths: []string{"/", "/admin"},
Rule: BasicRuleMiddleware{
Username: "goma",
Password: "goma",
},
},
{
Name: "jwt",
Type: "jwt",
Paths: []string{"/", "/admin"},
Rule: JWTRuleMiddleware{
URL: "https://www.googleapis.com/auth/userinfo.email",
Headers: map[string]string{},
Params: map[string]string{},
},
},
}
yamlData, err := yaml.Marshal(&middlewares)
if err != nil {
t.Fatalf("Error serializing configuration %v", err.Error())
}
err = os.WriteFile(configFile, yamlData, 0644)
if err != nil {
t.Fatalf("Unable to write config file %s", err)
}
log.Printf("Config file written to %s", configFile)
}
func TestReadMiddleware(t *testing.T) {
TestMiddleware(t)
middlewares := getMiddlewares(t)
middleware, err := getMiddleware(rules, middlewares)
if err != nil {
t.Fatalf("Error searching middleware %s", err.Error())
}
switch middleware.Type {
case "basic":
log.Println("Basic auth")
basicAuth, err := getBasicAuthMiddleware(middleware.Rule)
if err != nil {
log.Fatalln("error:", err)
}
log.Printf("Username: %s and password: %s\n", basicAuth.Username, basicAuth.Password)
case "jwt":
log.Println("JWT auth")
jwt, err := getJWTMiddleware(middleware.Rule)
if err != nil {
log.Fatalln("error:", err)
}
log.Printf("JWT authentification URL is %s\n", jwt.URL)
default:
t.Errorf("Unknown middleware type %s", middleware.Type)
}
}
func TestFoundMiddleware(t *testing.T) {
middlewares := getMiddlewares(t)
middleware, err := GetMiddleware("jwt", middlewares)
if err != nil {
t.Errorf("Error getting middleware %v", err)
}
fmt.Println(middleware.Type)
}
func getMiddlewares(t *testing.T) []Middleware {
buf, err := os.ReadFile(configFile)
if err != nil {
t.Fatalf("Unable to read config file %s", configFile)
}
c := &[]Middleware{}
err = yaml.Unmarshal(buf, c)
if err != nil {
t.Fatalf("Unable to parse config file %s", configFile)
}
return *c
}

88
internal/proxy.go Normal file
View File

@@ -0,0 +1,88 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"encoding/json"
"fmt"
"github.com/jkaninda/goma-gateway/pkg/logger"
"net/http"
"net/http/httputil"
"net/url"
"strings"
)
// ProxyHandler proxies requests to the backend
func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger.Info("%s %s %s %s", r.Method, getRealIP(r), r.URL, r.UserAgent())
// Set CORS headers from the cors config
//Update Cors Headers
for k, v := range proxyRoute.cors.Headers {
w.Header().Set(k, v)
}
if allowedOrigin(proxyRoute.cors.Origins, r.Header.Get("Origin")) {
// Handle preflight requests (OPTIONS)
if r.Method == "OPTIONS" {
w.Header().Set(accessControlAllowOrigin, r.Header.Get("Origin"))
w.WriteHeader(http.StatusNoContent)
return
} else {
w.Header().Set(accessControlAllowOrigin, r.Header.Get("Origin"))
}
}
// Parse the target backend URL
targetURL, err := url.Parse(proxyRoute.destination)
if err != nil {
logger.Error("Error parsing backend URL: %s", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
err := json.NewEncoder(w).Encode(ErrorResponse{
Message: "Internal server error",
Code: http.StatusInternalServerError,
Success: false,
})
if err != nil {
return
}
return
}
// Update the headers to allow for SSL redirection
if !proxyRoute.disableXForward {
r.URL.Host = targetURL.Host
r.URL.Scheme = targetURL.Scheme
r.Header.Set("X-Forwarded-Host", r.Header.Get("Host"))
r.Header.Set("X-Forwarded-For", getRealIP(r))
r.Header.Set("X-Real-IP", getRealIP(r))
r.Host = targetURL.Host
}
// Create proxy
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Rewrite
if proxyRoute.path != "" && proxyRoute.rewrite != "" {
// Rewrite the path
if strings.HasPrefix(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path)) {
r.URL.Path = strings.Replace(r.URL.Path, fmt.Sprintf("%s/", proxyRoute.path), proxyRoute.rewrite, 1)
}
}
w.Header().Set("Proxied-By", gatewayName) //Set Server name
w.Header().Set("Server", serverName)
// Custom error handler for proxy errors
proxy.ErrorHandler = ProxyErrorHandler
proxy.ServeHTTP(w, r)
}
}

169
internal/route.go Normal file
View File

@@ -0,0 +1,169 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"github.com/gorilla/mux"
"github.com/jkaninda/goma-gateway/internal/middleware"
"github.com/jkaninda/goma-gateway/pkg/logger"
"github.com/jkaninda/goma-gateway/util"
"time"
)
// Initialize the routes
func (gatewayServer GatewayServer) Initialize() *mux.Router {
gateway := gatewayServer.gateway
middlewares := gatewayServer.middlewares
r := mux.NewRouter()
heath := HealthCheckRoute{
DisableRouteHealthCheckError: gateway.DisableRouteHealthCheckError,
Routes: gateway.Routes,
}
// Routes health check
if !gateway.DisableHealthCheckStatus {
r.HandleFunc("/healthz", heath.HealthCheckHandler).Methods("GET")
}
// Readiness
r.HandleFunc("/readyz", heath.HealthReadyHandler).Methods("GET")
if gateway.RateLimiter != 0 {
//rateLimiter := middleware.NewRateLimiter(gateway.RateLimiter, time.Minute)
limiter := middleware.NewRateLimiterWindow(gateway.RateLimiter, time.Minute) // requests per minute
// Add rate limit middleware to all routes, if defined
r.Use(limiter.RateLimitMiddleware())
}
for _, route := range gateway.Routes {
if route.Path != "" {
// Apply middlewares to route
for _, mid := range route.Middlewares {
if mid != "" {
// Get Access middleware if it does exist
accessMiddleware, err := getMiddleware([]string{mid}, middlewares)
if err != nil {
logger.Error("Error: %v", err.Error())
} else {
// Apply access middleware
if accessMiddleware.Type == AccessMiddleware {
blM := middleware.AccessListMiddleware{
Path: route.Path,
List: accessMiddleware.Paths,
}
r.Use(blM.AccessMiddleware)
}
}
// Get route authentication middleware if it does exist
rMiddleware, err := getMiddleware([]string{mid}, middlewares)
if err != nil {
//Error: middleware not found
logger.Error("Error: %v", err.Error())
} else {
for _, midPath := range rMiddleware.Paths {
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
disableXForward: route.DisableHeaderXForward,
cors: route.Cors,
}
secureRouter := r.PathPrefix(util.ParseRoutePath(route.Path, midPath)).Subrouter()
//Check Authentication middleware
switch rMiddleware.Type {
case BasicAuth:
basicAuth, err := getBasicAuthMiddleware(rMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
amw := middleware.AuthBasic{
Username: basicAuth.Username,
Password: basicAuth.Password,
Headers: nil,
Params: nil,
}
// Apply JWT authentication middleware
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
}
case JWTAuth:
jwt, err := getJWTMiddleware(rMiddleware.Rule)
if err != nil {
logger.Error("Error: %s", err.Error())
} else {
amw := middleware.JwtAuth{
AuthURL: jwt.URL,
RequiredHeaders: jwt.RequiredHeaders,
Headers: jwt.Headers,
Params: jwt.Params,
}
// Apply JWT authentication middleware
secureRouter.Use(amw.AuthMiddleware)
secureRouter.Use(CORSHandler(route.Cors))
secureRouter.PathPrefix("/").Handler(proxyRoute.ProxyHandler()) // Proxy handler
secureRouter.PathPrefix("").Handler(proxyRoute.ProxyHandler()) // Proxy handler
}
case "OAuth":
logger.Error("OAuth is not yet implemented")
logger.Info("Auth middleware ignored")
default:
if !doesExist(rMiddleware.Type) {
logger.Error("Unknown middleware type %s", rMiddleware.Type)
}
}
}
}
} else {
logger.Error("Error, middleware path is empty")
logger.Error("Middleware ignored")
}
}
proxyRoute := ProxyRoute{
path: route.Path,
rewrite: route.Rewrite,
destination: route.Destination,
disableXForward: route.DisableHeaderXForward,
cors: route.Cors,
}
router := r.PathPrefix(route.Path).Subrouter()
// Apply route Cors
router.Use(CORSHandler(route.Cors))
if route.Host != "" {
router.Host(route.Host).PathPrefix("").Handler(proxyRoute.ProxyHandler())
} else {
router.PathPrefix("").Handler(proxyRoute.ProxyHandler())
}
} else {
logger.Error("Error, path is empty in route %s", route.Name)
logger.Debug("Route path ignored: %s", route.Path)
}
}
// Apply global Cors middlewares
r.Use(CORSHandler(gateway.Cors)) // Apply CORS middleware
// Apply errorInterceptor middleware
interceptErrors := middleware.InterceptErrors{
Errors: gateway.InterceptErrors,
}
r.Use(interceptErrors.ErrorInterceptor)
return r
}

70
internal/server.go Normal file
View File

@@ -0,0 +1,70 @@
package pkg
/*
Copyright 2024 Jonas Kaninda
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"context"
"fmt"
"github.com/jkaninda/goma-gateway/pkg/logger"
"net/http"
"os"
"sync"
"time"
)
func (gatewayServer GatewayServer) Start(ctx context.Context) error {
logger.Info("Initializing routes...")
route := gatewayServer.Initialize()
logger.Info("Initializing routes...done")
srv := &http.Server{
Addr: gatewayServer.gateway.ListenAddr,
WriteTimeout: time.Second * time.Duration(gatewayServer.gateway.WriteTimeout),
ReadTimeout: time.Second * time.Duration(gatewayServer.gateway.ReadTimeout),
IdleTimeout: time.Second * time.Duration(gatewayServer.gateway.IdleTimeout),
Handler: route, // Pass our instance of gorilla/mux in.
}
if !gatewayServer.gateway.DisableDisplayRouteOnStart {
printRoute(gatewayServer.gateway.Routes)
}
// Set KeepAlive
srv.SetKeepAlivesEnabled(!gatewayServer.gateway.DisableKeepAlive)
go func() {
logger.Info("Started Goma Gateway server on %v", gatewayServer.gateway.ListenAddr)
if err := srv.ListenAndServe(); err != nil {
logger.Error("Error starting Goma Gateway server: %v", err)
}
}()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
_, err := fmt.Fprintf(os.Stderr, "error shutting down Goma Gateway server: %s\n", err)
if err != nil {
return
}
}
}()
wg.Wait()
return nil
}

52
internal/server_test.go Normal file
View File

@@ -0,0 +1,52 @@
package pkg
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
const testPath = "./tests"
var configFile = filepath.Join(testPath, "goma.yml")
func TestInit(t *testing.T) {
err := os.MkdirAll(testPath, os.ModePerm)
if err != nil {
t.Error(err)
}
}
func TestStart(t *testing.T) {
TestInit(t)
initConfig(configFile)
g := GatewayServer{}
gatewayServer, err := g.Config(configFile)
if err != nil {
t.Error(err)
}
route := gatewayServer.Initialize()
route.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write([]byte("Hello Goma Proxy"))
if err != nil {
t.Fatalf("Failed writing HTTP response: %v", err)
}
})
assertResponseBody := func(t *testing.T, s *httptest.Server, expectedBody string) {
resp, err := s.Client().Get(s.URL)
if err != nil {
t.Fatalf("unexpected error getting from server: %v", err)
}
if resp.StatusCode != 200 {
t.Fatalf("expected a status code of 200, got %v", resp.StatusCode)
}
}
t.Run("httpServer", func(t *testing.T) {
s := httptest.NewServer(route)
defer s.Close()
assertResponseBody(t, s, "Hello Goma Proxy")
})
}

214
internal/types.go Normal file
View File

@@ -0,0 +1,214 @@
/*
* Copyright 2024 Jonas Kaninda
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pkg
import (
"context"
"github.com/gorilla/mux"
)
type Config struct {
file string
}
type BasicRuleMiddleware struct {
Username string `yaml:"username"`
Password string `yaml:"password"`
}
type Cors struct {
// Cors Allowed origins,
//e.g:
//
// - http://localhost:80
//
// - https://example.com
Origins []string `yaml:"origins"`
//
//e.g:
//
//Access-Control-Allow-Origin: '*'
//
// Access-Control-Allow-Methods: 'GET, POST, PUT, DELETE, OPTIONS'
//
// Access-Control-Allow-Cors: 'Content-Type, Authorization'
Headers map[string]string `yaml:"headers"`
}
// JWTRuleMiddleware authentication using HTTP GET method
//
// JWTRuleMiddleware contains the authentication details
type JWTRuleMiddleware struct {
// URL contains the authentication URL, it supports HTTP GET method only.
URL string `yaml:"url"`
// RequiredHeaders , contains required before sending request to the backend.
RequiredHeaders []string `yaml:"requiredHeaders"`
// Headers Add header to the backend from Authentication request's header, depending on your requirements.
// Key is Http's response header Key, and value is the backend Request's header Key.
// In case you want to get headers from Authentication service and inject them to backend request's headers.
Headers map[string]string `yaml:"headers"`
// Params same as Headers, contains the request params.
//
// Gets authentication headers from authentication request and inject them as request params to the backend.
//
// Key is Http's response header Key, and value is the backend Request's request param Key.
//
// In case you want to get headers from Authentication service and inject them to next request's params.
//
//e.g: Header X-Auth-UserId to query userId
Params map[string]string `yaml:"params"`
}
type RateLimiter struct {
// ipBased, tokenBased
Type string `yaml:"type"`
Rate float64 `yaml:"rate"`
Rule int `yaml:"rule"`
}
type AccessRuleMiddleware struct {
ResponseCode int `yaml:"responseCode"` // HTTP Response code
}
// Middleware defined the route middleware
type Middleware struct {
//Path contains the name of middleware and must be unique
Name string `yaml:"name"`
// Type contains authentication types
//
// basic, jwt, auth0, rateLimit, access
Type string `yaml:"type"` // Middleware type [basic, jwt, auth0, rateLimit, access]
Paths []string `yaml:"paths"` // Protected paths
// Rule contains rule type of
Rule interface{} `yaml:"rule"` // Middleware rule
}
type MiddlewareName struct {
name string `yaml:"name"`
}
// Route defines gateway route
type Route struct {
// Name defines route name
Name string `yaml:"name"`
//Host Domain/host based request routing
Host string `yaml:"host"`
// Path defines route path
Path string `yaml:"path"`
// Rewrite rewrites route path to desired path
//
// E.g. /cart to / => It will rewrite /cart path to /
Rewrite string `yaml:"rewrite"`
// Destination Defines backend URL
Destination string `yaml:"destination"`
// Cors contains the route cors headers
Cors Cors `yaml:"cors"`
// DisableHeaderXForward Disable X-forwarded header.
//
// [X-Forwarded-Host, X-Forwarded-For, Host, Scheme ]
//
// It will not match the backend route
DisableHeaderXForward bool `yaml:"disableHeaderXForward"`
// HealthCheck Defines the backend is health check PATH
HealthCheck string `yaml:"healthCheck"`
// InterceptErrors intercepts backend errors based on the status codes
//
// Eg: [ 403, 405, 500 ]
InterceptErrors []int `yaml:"interceptErrors"`
// Middlewares Defines route middleware from Middleware names
Middlewares []string `yaml:"middlewares"`
}
// Gateway contains Goma Proxy Gateway's configs
type Gateway struct {
// ListenAddr Defines the server listenAddr
//
//e.g: localhost:8080
ListenAddr string `yaml:"listenAddr" env:"GOMA_LISTEN_ADDR, overwrite"`
// WriteTimeout defines proxy write timeout
WriteTimeout int `yaml:"writeTimeout" env:"GOMA_WRITE_TIMEOUT, overwrite"`
// ReadTimeout defines proxy read timeout
ReadTimeout int `yaml:"readTimeout" env:"GOMA_READ_TIMEOUT, overwrite"`
// IdleTimeout defines proxy idle timeout
IdleTimeout int `yaml:"idleTimeout" env:"GOMA_IDLE_TIMEOUT, overwrite"`
// RateLimiter Defines number of request peer minute
RateLimiter int `yaml:"rateLimiter" env:"GOMA_RATE_LIMITER, overwrite"`
AccessLog string `yaml:"accessLog" env:"GOMA_ACCESS_LOG, overwrite"`
ErrorLog string `yaml:"errorLog" env:"GOMA_ERROR_LOG=, overwrite"`
// DisableHealthCheckStatus enable and disable routes health check
DisableHealthCheckStatus bool `yaml:"disableHealthCheckStatus"`
// DisableRouteHealthCheckError allows enabling and disabling backend healthcheck errors
DisableRouteHealthCheckError bool `yaml:"disableRouteHealthCheckError"`
//Disable allows enabling and disabling displaying routes on start
DisableDisplayRouteOnStart bool `yaml:"disableDisplayRouteOnStart"`
// DisableKeepAlive allows enabling and disabling KeepALive server
DisableKeepAlive bool `yaml:"disableKeepAlive"`
// InterceptErrors holds the status codes to intercept the error from backend
InterceptErrors []int `yaml:"interceptErrors"`
// Cors holds proxy global cors
Cors Cors `yaml:"cors"`
// Routes holds proxy routes
Routes []Route `yaml:"routes"`
}
type GatewayConfig struct {
// GatewayConfig holds Gateway config
GatewayConfig Gateway `yaml:"gateway"`
// Middlewares holds proxy middlewares
Middlewares []Middleware `yaml:"middlewares"`
}
// ErrorResponse represents the structure of the JSON error response
type ErrorResponse struct {
Success bool `json:"success"`
Code int `json:"code"`
Message string `json:"message"`
}
type GatewayServer struct {
ctx context.Context
gateway Gateway
middlewares []Middleware
}
type ProxyRoute struct {
path string
rewrite string
destination string
cors Cors
disableXForward bool
}
type RoutePath struct {
route Route
path string
rules []string
middlewares []Middleware
router *mux.Router
}
type HealthCheckRoute struct {
DisableRouteHealthCheckError bool
Routes []Route
}
// HealthCheckResponse represents the health check response structure
type HealthCheckResponse struct {
Status string `json:"status"`
Routes []HealthCheckRouteResponse `json:"routes"`
}
// HealthCheckRouteResponse represents the health check response for a route
type HealthCheckRouteResponse struct {
Name string `json:"name"`
Status string `json:"status"`
Error string `json:"error"`
}

10
internal/var.go Normal file
View File

@@ -0,0 +1,10 @@
package pkg
const ConfigFile = "/config/goma.yml" // Default configuration file
const accessControlAllowOrigin = "Access-Control-Allow-Origin" // Cors
const serverName = "Goma"
const gatewayName = "Goma Gateway"
const AccessMiddleware = "access" // access middleware
const BasicAuth = "basic" // basic authentication middleware
const JWTAuth = "jwt" // JWT authentication middleware
const OAuth = "OAuth" // OAuth authentication middleware