diff --git a/pkg/api/v1/workload_service.go b/pkg/api/v1/workload_service.go index 077c4c486f..e95277bb27 100644 --- a/pkg/api/v1/workload_service.go +++ b/pkg/api/v1/workload_service.go @@ -61,8 +61,8 @@ func NewWorkloadService( // CreateWorkloadFromRequest creates a workload from a request func (s *WorkloadService) CreateWorkloadFromRequest(ctx context.Context, req *createRequest) (*runner.RunConfig, error) { - // Build the full run config - runConfig, err := s.BuildFullRunConfig(ctx, req) + // Build the full run config (no existing port, so pass 0) + runConfig, err := s.BuildFullRunConfig(ctx, req, 0) if err != nil { return nil, err } @@ -83,9 +83,15 @@ func (s *WorkloadService) CreateWorkloadFromRequest(ctx context.Context, req *cr } // UpdateWorkloadFromRequest updates a workload from a request -func (s *WorkloadService) UpdateWorkloadFromRequest(ctx context.Context, name string, req *createRequest) (*runner.RunConfig, error) { //nolint:lll +func (s *WorkloadService) UpdateWorkloadFromRequest(ctx context.Context, name string, req *createRequest, existingPort int) (*runner.RunConfig, error) { //nolint:lll + // If ProxyPort is 0, reuse the existing port + if req.ProxyPort == 0 && existingPort > 0 { + req.ProxyPort = existingPort + logger.Debugf("Reusing existing port %d for workload %s", existingPort, name) + } + // Build the full run config - runConfig, err := s.BuildFullRunConfig(ctx, req) + runConfig, err := s.BuildFullRunConfig(ctx, req, existingPort) if err != nil { return nil, fmt.Errorf("failed to build workload config: %w", err) } @@ -101,7 +107,9 @@ func (s *WorkloadService) UpdateWorkloadFromRequest(ctx context.Context, name st // BuildFullRunConfig builds a complete RunConfig // //nolint:gocyclo // TODO: refactor this into shorter functions -func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createRequest) (*runner.RunConfig, error) { +func (s *WorkloadService) BuildFullRunConfig( + ctx context.Context, req *createRequest, existingPort int, +) (*runner.RunConfig, error) { // Default proxy mode to streamable-http if not specified (SSE is deprecated) if !types.IsValidProxyMode(req.ProxyMode) { if req.ProxyMode == "" { @@ -248,6 +256,11 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq runner.WithTelemetryConfigFromFlags("", false, false, false, "", 0.0, nil, false, nil), } + // Add existing port if provided (for update operations) + if existingPort > 0 { + options = append(options, runner.WithExistingPort(existingPort)) + } + // Determine transport type transportType := "streamable-http" if req.Transport != "" { diff --git a/pkg/api/v1/workloads.go b/pkg/api/v1/workloads.go index e5eabdec4c..8c7d40c922 100644 --- a/pkg/api/v1/workloads.go +++ b/pkg/api/v1/workloads.go @@ -338,8 +338,8 @@ func (s *WorkloadRoutes) updateWorkload(w http.ResponseWriter, r *http.Request) return } - // Check if workload exists - _, err := s.workloadManager.GetWorkload(ctx, name) + // Check if workload exists and get its current port + existingWorkload, err := s.workloadManager.GetWorkload(ctx, name) if err != nil { logger.Errorf("Failed to get workload: %v", err) http.Error(w, "Workload not found", http.StatusNotFound) @@ -352,7 +352,7 @@ func (s *WorkloadRoutes) updateWorkload(w http.ResponseWriter, r *http.Request) Name: name, // Use the name from URL path, not from request body } - runConfig, err := s.workloadService.UpdateWorkloadFromRequest(ctx, name, &createReq) + runConfig, err := s.workloadService.UpdateWorkloadFromRequest(ctx, name, &createReq, existingWorkload.Port) if err != nil { // Error messages already logged in UpdateWorkloadFromRequest if errors.Is(err, retriever.ErrImageNotFound) { diff --git a/pkg/api/v1/workloads_test.go b/pkg/api/v1/workloads_test.go index 38adf2d482..2548aa4f7c 100644 --- a/pkg/api/v1/workloads_test.go +++ b/pkg/api/v1/workloads_test.go @@ -435,6 +435,154 @@ func TestUpdateWorkload(t *testing.T) { } } +// TestUpdateWorkload_PortReuse tests the port reuse logic when editing workloads +func TestUpdateWorkload_PortReuse(t *testing.T) { + t.Parallel() + + logger.Initialize() + + tests := []struct { + name string + workloadName string + requestBody string + existingPort int + setupMock func(*testing.T, *workloadsmocks.MockManager, *runtimemocks.MockRuntime, *groupsmocks.MockManager) + expectedStatus int + expectedBody string + description string + }{ + { + name: "Edit with port=0 should reuse existing port", + workloadName: "test-workload", + requestBody: `{"image": "test-image", "proxy_port": 0}`, + existingPort: 8080, + setupMock: func(t *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) { + t.Helper() + wm.EXPECT().GetWorkload(gomock.Any(), "test-workload"). + Return(core.Workload{Name: "test-workload", Port: 8080}, nil) + gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil) + wm.EXPECT().UpdateWorkload(gomock.Any(), "test-workload", gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, runConfig *runner.RunConfig) (*errgroup.Group, error) { + assert.Equal(t, 8080, runConfig.Port, "Port should be reused from existing workload") + return &errgroup.Group{}, nil + }) + }, + expectedStatus: http.StatusOK, + expectedBody: "test-workload", + description: "When proxy_port is 0, the existing port should be reused", + }, + { + name: "Edit with explicit port should use that port", + workloadName: "test-workload", + requestBody: `{"image": "test-image", "proxy_port": 9090}`, + existingPort: 8080, + setupMock: func(t *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) { + t.Helper() + wm.EXPECT().GetWorkload(gomock.Any(), "test-workload"). + Return(core.Workload{Name: "test-workload", Port: 8080}, nil) + gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil) + wm.EXPECT().UpdateWorkload(gomock.Any(), "test-workload", gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, runConfig *runner.RunConfig) (*errgroup.Group, error) { + assert.Equal(t, 9090, runConfig.Port, "Port should be set to explicitly requested port") + return &errgroup.Group{}, nil + }) + }, + expectedStatus: http.StatusOK, + expectedBody: "test-workload", + description: "When an explicit port is provided, it should be used instead of reusing", + }, + { + name: "Edit with same port should skip validation", + workloadName: "test-workload", + requestBody: `{"image": "test-image", "proxy_port": 8080}`, + existingPort: 8080, + setupMock: func(t *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) { + t.Helper() + wm.EXPECT().GetWorkload(gomock.Any(), "test-workload"). + Return(core.Workload{Name: "test-workload", Port: 8080}, nil) + gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil) + wm.EXPECT().UpdateWorkload(gomock.Any(), "test-workload", gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, runConfig *runner.RunConfig) (*errgroup.Group, error) { + assert.Equal(t, 8080, runConfig.Port, "Port should remain the same") + return &errgroup.Group{}, nil + }) + }, + expectedStatus: http.StatusOK, + expectedBody: "test-workload", + description: "When reusing the same port, validation should be skipped", + }, + { + name: "Edit with no port specified should default to existing", + workloadName: "test-workload", + requestBody: `{"image": "test-image"}`, + existingPort: 8080, + setupMock: func(t *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) { + t.Helper() + wm.EXPECT().GetWorkload(gomock.Any(), "test-workload"). + Return(core.Workload{Name: "test-workload", Port: 8080}, nil) + gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil) + wm.EXPECT().UpdateWorkload(gomock.Any(), "test-workload", gomock.Any()). + DoAndReturn(func(_ context.Context, _ string, runConfig *runner.RunConfig) (*errgroup.Group, error) { + assert.Equal(t, 8080, runConfig.Port, "Port should default to existing port") + return &errgroup.Group{}, nil + }) + }, + expectedStatus: http.StatusOK, + expectedBody: "test-workload", + description: "When no port is specified in request, existing port should be reused", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloadManager := workloadsmocks.NewMockManager(ctrl) + mockRuntime := runtimemocks.NewMockRuntime(ctrl) + mockGroupManager := groupsmocks.NewMockManager(ctrl) + tt.setupMock(t, mockWorkloadManager, mockRuntime, mockGroupManager) + + mockRetriever := makeMockRetriever(t, + "test-image", + "test-image", + ®types.ImageMetadata{Image: "test-image"}, + nil, + ) + + routes := &WorkloadRoutes{ + workloadManager: mockWorkloadManager, + containerRuntime: mockRuntime, + groupManager: mockGroupManager, + debugMode: false, + workloadService: &WorkloadService{ + groupManager: mockGroupManager, + workloadManager: mockWorkloadManager, + containerRuntime: mockRuntime, + imageRetriever: mockRetriever, + appConfig: &config.Config{}, + }, + } + + req := httptest.NewRequest("POST", "/api/v1beta/workloads/"+tt.workloadName+"/edit", + strings.NewReader(tt.requestBody)) + req.Header.Set("Content-Type", "application/json") + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("name", tt.workloadName) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + routes.updateWorkload(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code, tt.description) + assert.Contains(t, w.Body.String(), tt.expectedBody, tt.description) + }) + } +} + func makeMockRetriever( t *testing.T, expectedServerOrImage string, diff --git a/pkg/runner/config.go b/pkg/runner/config.go index 248bbe40fa..10acba04d3 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -164,6 +164,10 @@ type RunConfig struct { // MiddlewareConfigs contains the list of middleware to apply to the transport // and the configuration for each middleware. MiddlewareConfigs []types.MiddlewareConfig `json:"middleware_configs,omitempty" yaml:"middleware_configs,omitempty"` + + // existingPort is the port from an existing workload being updated (not serialized) + // Used during port validation to allow reusing the same port + existingPort int } // WriteJSON serializes the RunConfig to JSON and writes it to the provided writer @@ -344,11 +348,16 @@ func (c *RunConfig) WithPorts(proxyPort, targetPort int) (*RunConfig, error) { // If not available - treat as an error, since picking a random port here // is going to lead to confusion. if proxyPort != 0 { - if !networking.IsAvailable(proxyPort) { + // Skip validation if reusing the same port from existing workload (during update) + if proxyPort == c.existingPort && c.existingPort > 0 { + logger.Debugf("Reusing existing port: %d", proxyPort) + selectedPort = proxyPort + } else if !networking.IsAvailable(proxyPort) { return c, fmt.Errorf("requested proxy port %d is not available", proxyPort) + } else { + logger.Debugf("Using requested port: %d", proxyPort) + selectedPort = proxyPort } - logger.Debugf("Using requested port: %d", proxyPort) - selectedPort = proxyPort } else { // Otherwise - pick a random available port. selectedPort, err = networking.FindOrUsePort(proxyPort) diff --git a/pkg/runner/config_builder.go b/pkg/runner/config_builder.go index 37a098c835..a4680b9abc 100644 --- a/pkg/runner/config_builder.go +++ b/pkg/runner/config_builder.go @@ -300,6 +300,15 @@ func WithTransportAndPorts(mcpTransport string, port, targetPort int) RunConfigB } } +// WithExistingPort sets the existing port for update operations +// This allows port reuse during workload updates by skipping validation for the same port +func WithExistingPort(port int) RunConfigBuilderOption { + return func(b *runConfigBuilder) error { + b.config.existingPort = port + return nil + } +} + // WithAuditEnabled configures audit settings func WithAuditEnabled(enableAudit bool, auditConfigPath string) RunConfigBuilderOption { return func(b *runConfigBuilder) error { diff --git a/pkg/runner/config_test.go b/pkg/runner/config_test.go index 98d9f093e5..8f80e0dc97 100644 --- a/pkg/runner/config_test.go +++ b/pkg/runner/config_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "net" "os" "strings" "testing" @@ -1698,3 +1699,110 @@ func TestConfigFileLoading(t *testing.T) { assert.Nil(t, file, "File handle should be nil when file does not exist") }) } + +// TestRunConfig_WithPorts_PortReuse tests the port reuse logic when updating workloads +// +//nolint:tparallel,paralleltest // Subtests intentionally run sequentially to share the same listener +func TestRunConfig_WithPorts_PortReuse(t *testing.T) { + t.Parallel() + + logger.Initialize() + + // Create a listener to occupy a port for the entire test + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "Should be able to create listener") + defer listener.Close() + + usedPort := listener.Addr().(*net.TCPAddr).Port + + t.Run("Reuse same port during update - should skip validation", func(t *testing.T) { + config := &RunConfig{ + Transport: types.TransportTypeStdio, + existingPort: usedPort, + } + result, err := config.WithPorts(usedPort, 0) + + assert.NoError(t, err, "When updating a workload and reusing the same port, validation should be skipped") + assert.Equal(t, config, result, "WithPorts should return the same config instance") + assert.Equal(t, usedPort, config.Port, "Port should be set to requested port") + }) + + t.Run("Different port during update - should validate", func(t *testing.T) { + config := &RunConfig{ + Transport: types.TransportTypeStdio, + existingPort: 8888, // Different from the port we're requesting + } + result, err := config.WithPorts(usedPort, 0) + + assert.Error(t, err, "When updating with a different port, validation should still occur and fail if port is in use") + assert.Contains(t, err.Error(), "not available", "Error should indicate port is not available") + assert.Equal(t, config, result, "WithPorts returns config even on error") + }) + + t.Run("No existing port - should validate normally", func(t *testing.T) { + config := &RunConfig{ + Transport: types.TransportTypeStdio, + existingPort: 0, + } + result, err := config.WithPorts(usedPort, 0) + + assert.Error(t, err, "When creating new workload (no existing port), validation should occur normally") + assert.Contains(t, err.Error(), "not available", "Error should indicate port is not available") + assert.Equal(t, config, result, "WithPorts returns config even on error") + }) + + t.Run("Reuse existing port with value 0 should still work", func(t *testing.T) { + config := &RunConfig{ + Transport: types.TransportTypeStdio, + existingPort: 0, + } + result, err := config.WithPorts(0, 0) + + assert.NoError(t, err, "Port 0 should still auto-select a port") + assert.Equal(t, config, result, "WithPorts should return the same config instance") + assert.Greater(t, config.Port, 0, "Port should be auto-selected") + }) +} + +// TestWithExistingPort tests the WithExistingPort builder option +func TestWithExistingPort(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + existingPort int + expected int + }{ + { + name: "Set existing port to valid value", + existingPort: 8080, + expected: 8080, + }, + { + name: "Set existing port to 0", + existingPort: 0, + expected: 0, + }, + { + name: "Set existing port to high port", + existingPort: 65535, + expected: 65535, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + builder := &runConfigBuilder{ + config: &RunConfig{}, + } + + option := WithExistingPort(tc.existingPort) + err := option(builder) + + assert.NoError(t, err, "WithExistingPort should not return an error") + assert.Equal(t, tc.expected, builder.config.existingPort, "existingPort should be set correctly") + }) + } +}