Merge pull request #4 from jkaninda/refactor

Refactor
This commit is contained in:
2024-10-28 02:45:01 +01:00
committed by GitHub
6 changed files with 39 additions and 21 deletions

3
.gitignore vendored
View File

@@ -10,4 +10,5 @@ goma
bin bin
Makefile Makefile
NOTES.md NOTES.md
tests tests
configs

View File

@@ -18,8 +18,11 @@ package cmd
import ( import (
"context" "context"
"fmt"
"github.com/common-nighthawk/go-figure"
"github.com/jkaninda/goma-gateway/internal/logger" "github.com/jkaninda/goma-gateway/internal/logger"
"github.com/jkaninda/goma-gateway/pkg" "github.com/jkaninda/goma-gateway/pkg"
"github.com/jkaninda/goma-gateway/util"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@@ -27,7 +30,7 @@ var ServerCmd = &cobra.Command{
Use: "server", Use: "server",
Short: "Start server", Short: "Start server",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
pkg.Intro() intro()
configFile, _ := cmd.Flags().GetString("config") configFile, _ := cmd.Flags().GetString("config")
if configFile == "" { if configFile == "" {
configFile = pkg.GetConfigPaths() configFile = pkg.GetConfigPaths()
@@ -49,3 +52,10 @@ var ServerCmd = &cobra.Command{
func init() { func init() {
ServerCmd.Flags().StringP("config", "", "", "Goma config file") ServerCmd.Flags().StringP("config", "", "", "Goma config file")
} }
func intro() {
nameFigure := figure.NewFigure("Goma", "", true)
nameFigure.Print()
fmt.Printf("Version: %s\n", util.FullVersion())
fmt.Println("Copyright (c) 2024 Jonas Kaninda")
fmt.Println("Starting Goma Gateway server...")
}

View File

@@ -88,7 +88,7 @@ func (heathRoute HealthCheckRoute) HealthCheckHandler(w http.ResponseWriter, r *
continue continue
} }
} else { } else {
logger.Error("Route %s's healthCheck is undefined", route.Name) logger.Warn("Route %s's healthCheck is undefined", route.Name)
routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "undefined", Error: ""}) routes = append(routes, HealthCheckRouteResponse{Name: route.Name, Status: "undefined", Error: ""})
continue continue

View File

@@ -11,18 +11,10 @@ You may get a copy of the License at
*/ */
import ( import (
"fmt" "fmt"
"github.com/common-nighthawk/go-figure"
"github.com/jedib0t/go-pretty/v6/table" "github.com/jedib0t/go-pretty/v6/table"
"github.com/jkaninda/goma-gateway/util" "net/http"
) )
func Intro() {
nameFigure := figure.NewFigure("Goma", "", true)
nameFigure.Print()
fmt.Printf("Version: %s\n", util.FullVersion())
fmt.Println("Copyright (c) 2024 Jonas Kaninda")
fmt.Println("Starting Goma Gateway server...")
}
func printRoute(routes []Route) { func printRoute(routes []Route) {
t := table.NewWriter() t := table.NewWriter()
t.AppendHeader(table.Row{"Name", "Route", "Rewrite", "Destination"}) t.AppendHeader(table.Row{"Name", "Route", "Rewrite", "Destination"})
@@ -31,3 +23,12 @@ func printRoute(routes []Route) {
} }
fmt.Println(t.Render()) fmt.Println(t.Render())
} }
func getRealIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
return r.RemoteAddr
}

View File

@@ -52,11 +52,7 @@ func (rl *TokenRateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc { func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientID := getRealIP(r)
//TODO:
clientID := r.RemoteAddr
logger.Info(clientID)
rl.mu.Lock() rl.mu.Lock()
client, exists := rl.ClientMap[clientID] client, exists := rl.ClientMap[clientID]
if !exists || time.Now().After(client.ExpiresAt) { if !exists || time.Now().After(client.ExpiresAt) {
@@ -70,6 +66,7 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
rl.mu.Unlock() rl.mu.Unlock()
if client.RequestCount > rl.Requests { if client.RequestCount > rl.Requests {
logger.Warn("Too many request from IP: %s %s %s", clientID, r.URL, r.UserAgent())
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests) w.WriteHeader(http.StatusTooManyRequests)
err := json.NewEncoder(w).Encode(ProxyResponseError{ err := json.NewEncoder(w).Encode(ProxyResponseError{
@@ -82,9 +79,17 @@ func (rl *RateLimiter) RateLimitMiddleware() mux.MiddlewareFunc {
} }
return return
} }
// Proceed to the next handler if rate limit is not exceeded // Proceed to the next handler if rate limit is not exceeded
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
} }
func getRealIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
return r.RemoteAddr
}

View File

@@ -36,7 +36,8 @@ type ProxyRoute struct {
// ProxyHandler proxies requests to the backend // ProxyHandler proxies requests to the backend
func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc { func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
logger.Info("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent()) realIP := getRealIP(r)
logger.Info("%s %s %s %s", r.Method, realIP, r.URL, r.UserAgent())
// Set CORS headers from the cors config // Set CORS headers from the cors config
//Update Cors Headers //Update Cors Headers
for k, v := range proxyRoute.cors.Headers { for k, v := range proxyRoute.cors.Headers {
@@ -76,8 +77,8 @@ func (proxyRoute ProxyRoute) ProxyHandler() http.HandlerFunc {
r.URL.Host = targetURL.Host r.URL.Host = targetURL.Host
r.URL.Scheme = targetURL.Scheme r.URL.Scheme = targetURL.Scheme
r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) r.Header.Set("X-Forwarded-Host", r.Header.Get("Host"))
r.Header.Set("X-Forwarded-For", r.RemoteAddr) r.Header.Set("X-Forwarded-For", realIP)
r.Header.Set("X-Real-IP", r.RemoteAddr) r.Header.Set("X-Real-IP", realIP)
r.Host = targetURL.Host r.Host = targetURL.Host
} }
// Create proxy // Create proxy