From cb31faf65fb0f0c6e8ddf50a0e8fa03214606cff Mon Sep 17 00:00:00 2001 From: Jonas Kaninda Date: Fri, 15 Nov 2024 08:19:22 +0100 Subject: [PATCH] fix: routes health check --- internal/handler.go | 7 ++++--- internal/route.go | 7 ++++++- internal/server_test.go | 10 ++++++++++ util/helpers.go | 14 ++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/internal/handler.go b/internal/handler.go index 1bbe058..107b4bf 100644 --- a/internal/handler.go +++ b/internal/handler.go @@ -67,12 +67,12 @@ func ProxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { // HealthCheckHandler handles health check of routes func (heathRoute HealthCheckRoute) HealthCheckHandler(w http.ResponseWriter, r *http.Request) { logger.Debug("%s %s %s %s", r.Method, r.RemoteAddr, r.URL, r.UserAgent()) + healthRoutes := healthCheckRoutes(heathRoute.Routes) wg := sync.WaitGroup{} - wg.Add(len(heathRoute.Routes)) + wg.Add(len(healthRoutes)) var routes []HealthCheckRouteResponse - for _, health := range healthCheckRoutes(heathRoute.Routes) { + for _, health := range healthRoutes { go func() { - defer wg.Done() err := health.Check() if err != nil { if heathRoute.DisableRouteHealthCheckError { @@ -83,6 +83,7 @@ func (heathRoute HealthCheckRoute) HealthCheckHandler(w http.ResponseWriter, r * logger.Debug("Route %s is healthy", health.Name) routes = append(routes, HealthCheckRouteResponse{Name: health.Name, Status: "healthy", Error: ""}) } + defer wg.Done() }() diff --git a/internal/route.go b/internal/route.go index 3c3655b..dd9c077 100644 --- a/internal/route.go +++ b/internal/route.go @@ -236,10 +236,15 @@ func (gatewayServer GatewayServer) Initialize() *mux.Router { logger.Info("Block common exploits enabled") router.Use(middleware.BlockExploitsMiddleware) } + id := string(rune(rIndex)) + if len(route.Name) != 0 { + // Use route name as ID + id = util.Slug(route.Name) + } // Apply route rate limit if route.RateLimit > 0 { rateLimit := middleware.RateLimit{ - Id: string(rune(rIndex)), // Use route index as ID + Id: id, // Use route index as ID Requests: route.RateLimit, Window: time.Minute, // requests per minute Origins: route.Cors.Origins, diff --git a/internal/server_test.go b/internal/server_test.go index d91e8fd..2c57f48 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -54,6 +54,15 @@ func TestStart(t *testing.T) { t.Fatalf("expected a status code of 200, got %v", resp.StatusCode) } } + assertRoutesResponseBody := func(t *testing.T, s *httptest.Server) { + resp, err := s.Client().Get(s.URL + "/health/routes") + if err != nil { + t.Fatalf("unexpected error getting from server: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("expected a status code of 200, got %v", resp.StatusCode) + } + } ctx := context.Background() go func() { err = gatewayServer.Start(ctx) @@ -67,6 +76,7 @@ func TestStart(t *testing.T) { s := httptest.NewServer(route) defer s.Close() assertResponseBody(t, s) + assertRoutesResponseBody(t, s) }) ctx.Done() } diff --git a/util/helpers.go b/util/helpers.go index c896cc5..f3aaf51 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -142,3 +142,17 @@ func ParseDuration(durationStr string) (time.Duration, error) { } return duration, nil } + +func Slug(text string) string { + // Convert to lowercase + text = strings.ToLower(text) + + // Replace spaces and special characters with hyphens + re := regexp.MustCompile(`[^\w]+`) + text = re.ReplaceAllString(text, "-") + + // Remove leading and trailing hyphens + text = strings.Trim(text, "-") + + return text +}