/* * Copyright 2024 Jonas Kaninda * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package middlewares import ( "fmt" "github.com/jkaninda/goma-gateway/pkg/logger" "net/http" "regexp" ) // BlockExploitsMiddleware Middleware to block common exploits func BlockExploitsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Patterns to detect SQL injection attempts sqlInjectionPattern := regexp.MustCompile(sqlPatterns) // Pattern to detect path traversal attempts pathTraversalPattern := regexp.MustCompile(traversalPatterns) // Pattern to detect simple XSS attempts xssPattern := regexp.MustCompile(xssPatterns) // Check query strings if sqlInjectionPattern.MatchString(r.URL.RawQuery) || pathTraversalPattern.MatchString(r.URL.Path) || xssPattern.MatchString(r.URL.RawQuery) { logger.Error("%s: %s Forbidden - Potential exploit detected", getRealIP(r), r.URL.Path) RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden", http.StatusForbidden)) return } // Check form data (for POST requests) if r.Method == http.MethodPost { if err := r.ParseForm(); err == nil { for _, values := range r.Form { for _, value := range values { if sqlInjectionPattern.MatchString(value) || xssPattern.MatchString(value) { logger.Error("%s: %s %s Forbidden - Potential exploit detected", getRealIP(r), r.Method, r.URL.Path) RespondWithError(w, http.StatusForbidden, fmt.Sprintf("%d Forbidden", http.StatusForbidden)) return } } } } } // Pass to the next handler if no exploit patterns were detected next.ServeHTTP(w, r) }) }