Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions pkg/api/v1/workload_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func NewWorkloadService(

// CreateWorkloadFromRequest creates a workload from a request
func (s *WorkloadService) CreateWorkloadFromRequest(ctx context.Context, req *createRequest) (*runner.RunConfig, error) {
// Build the full run config
runConfig, err := s.BuildFullRunConfig(ctx, req)
// Build the full run config (no existing port, so pass 0)
runConfig, err := s.BuildFullRunConfig(ctx, req, 0)
if err != nil {
return nil, err
}
Expand All @@ -83,9 +83,15 @@ func (s *WorkloadService) CreateWorkloadFromRequest(ctx context.Context, req *cr
}

// UpdateWorkloadFromRequest updates a workload from a request
func (s *WorkloadService) UpdateWorkloadFromRequest(ctx context.Context, name string, req *createRequest) (*runner.RunConfig, error) { //nolint:lll
func (s *WorkloadService) UpdateWorkloadFromRequest(ctx context.Context, name string, req *createRequest, existingPort int) (*runner.RunConfig, error) { //nolint:lll
// If ProxyPort is 0, reuse the existing port
if req.ProxyPort == 0 && existingPort > 0 {
req.ProxyPort = existingPort
logger.Debugf("Reusing existing port %d for workload %s", existingPort, name)
}

// Build the full run config
runConfig, err := s.BuildFullRunConfig(ctx, req)
runConfig, err := s.BuildFullRunConfig(ctx, req, existingPort)
if err != nil {
return nil, fmt.Errorf("failed to build workload config: %w", err)
}
Expand All @@ -101,7 +107,9 @@ func (s *WorkloadService) UpdateWorkloadFromRequest(ctx context.Context, name st
// BuildFullRunConfig builds a complete RunConfig
//
//nolint:gocyclo // TODO: refactor this into shorter functions
func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createRequest) (*runner.RunConfig, error) {
func (s *WorkloadService) BuildFullRunConfig(
ctx context.Context, req *createRequest, existingPort int,
) (*runner.RunConfig, error) {
// Default proxy mode to streamable-http if not specified (SSE is deprecated)
if !types.IsValidProxyMode(req.ProxyMode) {
if req.ProxyMode == "" {
Expand Down Expand Up @@ -248,6 +256,11 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq
runner.WithTelemetryConfigFromFlags("", false, false, false, "", 0.0, nil, false, nil),
}

// Add existing port if provided (for update operations)
if existingPort > 0 {
options = append(options, runner.WithExistingPort(existingPort))
}

// Determine transport type
transportType := "streamable-http"
if req.Transport != "" {
Expand Down
6 changes: 3 additions & 3 deletions pkg/api/v1/workloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ func (s *WorkloadRoutes) updateWorkload(w http.ResponseWriter, r *http.Request)
return
}

// Check if workload exists
_, err := s.workloadManager.GetWorkload(ctx, name)
// Check if workload exists and get its current port
existingWorkload, err := s.workloadManager.GetWorkload(ctx, name)
if err != nil {
logger.Errorf("Failed to get workload: %v", err)
http.Error(w, "Workload not found", http.StatusNotFound)
Expand All @@ -352,7 +352,7 @@ func (s *WorkloadRoutes) updateWorkload(w http.ResponseWriter, r *http.Request)
Name: name, // Use the name from URL path, not from request body
}

runConfig, err := s.workloadService.UpdateWorkloadFromRequest(ctx, name, &createReq)
runConfig, err := s.workloadService.UpdateWorkloadFromRequest(ctx, name, &createReq, existingWorkload.Port)
if err != nil {
// Error messages already logged in UpdateWorkloadFromRequest
if errors.Is(err, retriever.ErrImageNotFound) {
Expand Down
148 changes: 148 additions & 0 deletions pkg/api/v1/workloads_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,154 @@ func TestUpdateWorkload(t *testing.T) {
}
}

// TestUpdateWorkload_PortReuse tests the port reuse logic when editing workloads
func TestUpdateWorkload_PortReuse(t *testing.T) {
t.Parallel()

logger.Initialize()

tests := []struct {
name string
workloadName string
requestBody string
existingPort int
setupMock func(*testing.T, *workloadsmocks.MockManager, *runtimemocks.MockRuntime, *groupsmocks.MockManager)
expectedStatus int
expectedBody string
description string
}{
{
name: "Edit with port=0 should reuse existing port",
workloadName: "test-workload",
requestBody: `{"image": "test-image", "proxy_port": 0}`,
existingPort: 8080,
setupMock: func(t *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) {
t.Helper()
wm.EXPECT().GetWorkload(gomock.Any(), "test-workload").
Return(core.Workload{Name: "test-workload", Port: 8080}, nil)
gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil)
wm.EXPECT().UpdateWorkload(gomock.Any(), "test-workload", gomock.Any()).
DoAndReturn(func(_ context.Context, _ string, runConfig *runner.RunConfig) (*errgroup.Group, error) {
assert.Equal(t, 8080, runConfig.Port, "Port should be reused from existing workload")
return &errgroup.Group{}, nil
})
},
expectedStatus: http.StatusOK,
expectedBody: "test-workload",
description: "When proxy_port is 0, the existing port should be reused",
},
{
name: "Edit with explicit port should use that port",
workloadName: "test-workload",
requestBody: `{"image": "test-image", "proxy_port": 9090}`,
existingPort: 8080,
setupMock: func(t *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) {
t.Helper()
wm.EXPECT().GetWorkload(gomock.Any(), "test-workload").
Return(core.Workload{Name: "test-workload", Port: 8080}, nil)
gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil)
wm.EXPECT().UpdateWorkload(gomock.Any(), "test-workload", gomock.Any()).
DoAndReturn(func(_ context.Context, _ string, runConfig *runner.RunConfig) (*errgroup.Group, error) {
assert.Equal(t, 9090, runConfig.Port, "Port should be set to explicitly requested port")
return &errgroup.Group{}, nil
})
},
expectedStatus: http.StatusOK,
expectedBody: "test-workload",
description: "When an explicit port is provided, it should be used instead of reusing",
},
{
name: "Edit with same port should skip validation",
workloadName: "test-workload",
requestBody: `{"image": "test-image", "proxy_port": 8080}`,
existingPort: 8080,
setupMock: func(t *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) {
t.Helper()
wm.EXPECT().GetWorkload(gomock.Any(), "test-workload").
Return(core.Workload{Name: "test-workload", Port: 8080}, nil)
gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil)
wm.EXPECT().UpdateWorkload(gomock.Any(), "test-workload", gomock.Any()).
DoAndReturn(func(_ context.Context, _ string, runConfig *runner.RunConfig) (*errgroup.Group, error) {
assert.Equal(t, 8080, runConfig.Port, "Port should remain the same")
return &errgroup.Group{}, nil
})
},
expectedStatus: http.StatusOK,
expectedBody: "test-workload",
description: "When reusing the same port, validation should be skipped",
},
{
name: "Edit with no port specified should default to existing",
workloadName: "test-workload",
requestBody: `{"image": "test-image"}`,
existingPort: 8080,
setupMock: func(t *testing.T, wm *workloadsmocks.MockManager, _ *runtimemocks.MockRuntime, gm *groupsmocks.MockManager) {
t.Helper()
wm.EXPECT().GetWorkload(gomock.Any(), "test-workload").
Return(core.Workload{Name: "test-workload", Port: 8080}, nil)
gm.EXPECT().Exists(gomock.Any(), "default").Return(true, nil)
wm.EXPECT().UpdateWorkload(gomock.Any(), "test-workload", gomock.Any()).
DoAndReturn(func(_ context.Context, _ string, runConfig *runner.RunConfig) (*errgroup.Group, error) {
assert.Equal(t, 8080, runConfig.Port, "Port should default to existing port")
return &errgroup.Group{}, nil
})
},
expectedStatus: http.StatusOK,
expectedBody: "test-workload",
description: "When no port is specified in request, existing port should be reused",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockWorkloadManager := workloadsmocks.NewMockManager(ctrl)
mockRuntime := runtimemocks.NewMockRuntime(ctrl)
mockGroupManager := groupsmocks.NewMockManager(ctrl)
tt.setupMock(t, mockWorkloadManager, mockRuntime, mockGroupManager)

mockRetriever := makeMockRetriever(t,
"test-image",
"test-image",
&regtypes.ImageMetadata{Image: "test-image"},
nil,
)

routes := &WorkloadRoutes{
workloadManager: mockWorkloadManager,
containerRuntime: mockRuntime,
groupManager: mockGroupManager,
debugMode: false,
workloadService: &WorkloadService{
groupManager: mockGroupManager,
workloadManager: mockWorkloadManager,
containerRuntime: mockRuntime,
imageRetriever: mockRetriever,
appConfig: &config.Config{},
},
}

req := httptest.NewRequest("POST", "/api/v1beta/workloads/"+tt.workloadName+"/edit",
strings.NewReader(tt.requestBody))
req.Header.Set("Content-Type", "application/json")

rctx := chi.NewRouteContext()
rctx.URLParams.Add("name", tt.workloadName)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))

w := httptest.NewRecorder()
routes.updateWorkload(w, req)

assert.Equal(t, tt.expectedStatus, w.Code, tt.description)
assert.Contains(t, w.Body.String(), tt.expectedBody, tt.description)
})
}
}

func makeMockRetriever(
t *testing.T,
expectedServerOrImage string,
Expand Down
15 changes: 12 additions & 3 deletions pkg/runner/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ type RunConfig struct {
// MiddlewareConfigs contains the list of middleware to apply to the transport
// and the configuration for each middleware.
MiddlewareConfigs []types.MiddlewareConfig `json:"middleware_configs,omitempty" yaml:"middleware_configs,omitempty"`

// existingPort is the port from an existing workload being updated (not serialized)
// Used during port validation to allow reusing the same port
existingPort int
}

// WriteJSON serializes the RunConfig to JSON and writes it to the provided writer
Expand Down Expand Up @@ -344,11 +348,16 @@ func (c *RunConfig) WithPorts(proxyPort, targetPort int) (*RunConfig, error) {
// If not available - treat as an error, since picking a random port here
// is going to lead to confusion.
if proxyPort != 0 {
if !networking.IsAvailable(proxyPort) {
// Skip validation if reusing the same port from existing workload (during update)
if proxyPort == c.existingPort && c.existingPort > 0 {
logger.Debugf("Reusing existing port: %d", proxyPort)
selectedPort = proxyPort
} else if !networking.IsAvailable(proxyPort) {
return c, fmt.Errorf("requested proxy port %d is not available", proxyPort)
} else {
logger.Debugf("Using requested port: %d", proxyPort)
selectedPort = proxyPort
}
logger.Debugf("Using requested port: %d", proxyPort)
selectedPort = proxyPort
} else {
// Otherwise - pick a random available port.
selectedPort, err = networking.FindOrUsePort(proxyPort)
Expand Down
9 changes: 9 additions & 0 deletions pkg/runner/config_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,15 @@ func WithTransportAndPorts(mcpTransport string, port, targetPort int) RunConfigB
}
}

// WithExistingPort sets the existing port for update operations
// This allows port reuse during workload updates by skipping validation for the same port
func WithExistingPort(port int) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
b.config.existingPort = port
return nil
}
}

// WithAuditEnabled configures audit settings
func WithAuditEnabled(enableAudit bool, auditConfigPath string) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
Expand Down
108 changes: 108 additions & 0 deletions pkg/runner/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"net"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -1698,3 +1699,110 @@ func TestConfigFileLoading(t *testing.T) {
assert.Nil(t, file, "File handle should be nil when file does not exist")
})
}

// TestRunConfig_WithPorts_PortReuse tests the port reuse logic when updating workloads
//
//nolint:tparallel,paralleltest // Subtests intentionally run sequentially to share the same listener
func TestRunConfig_WithPorts_PortReuse(t *testing.T) {
t.Parallel()

logger.Initialize()

// Create a listener to occupy a port for the entire test
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "Should be able to create listener")
defer listener.Close()

usedPort := listener.Addr().(*net.TCPAddr).Port

t.Run("Reuse same port during update - should skip validation", func(t *testing.T) {
config := &RunConfig{
Transport: types.TransportTypeStdio,
existingPort: usedPort,
}
result, err := config.WithPorts(usedPort, 0)

assert.NoError(t, err, "When updating a workload and reusing the same port, validation should be skipped")
assert.Equal(t, config, result, "WithPorts should return the same config instance")
assert.Equal(t, usedPort, config.Port, "Port should be set to requested port")
})

t.Run("Different port during update - should validate", func(t *testing.T) {
config := &RunConfig{
Transport: types.TransportTypeStdio,
existingPort: 8888, // Different from the port we're requesting
}
result, err := config.WithPorts(usedPort, 0)

assert.Error(t, err, "When updating with a different port, validation should still occur and fail if port is in use")
assert.Contains(t, err.Error(), "not available", "Error should indicate port is not available")
assert.Equal(t, config, result, "WithPorts returns config even on error")
})

t.Run("No existing port - should validate normally", func(t *testing.T) {
config := &RunConfig{
Transport: types.TransportTypeStdio,
existingPort: 0,
}
result, err := config.WithPorts(usedPort, 0)

assert.Error(t, err, "When creating new workload (no existing port), validation should occur normally")
assert.Contains(t, err.Error(), "not available", "Error should indicate port is not available")
assert.Equal(t, config, result, "WithPorts returns config even on error")
})

t.Run("Reuse existing port with value 0 should still work", func(t *testing.T) {
config := &RunConfig{
Transport: types.TransportTypeStdio,
existingPort: 0,
}
result, err := config.WithPorts(0, 0)

assert.NoError(t, err, "Port 0 should still auto-select a port")
assert.Equal(t, config, result, "WithPorts should return the same config instance")
assert.Greater(t, config.Port, 0, "Port should be auto-selected")
})
}

// TestWithExistingPort tests the WithExistingPort builder option
func TestWithExistingPort(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
existingPort int
expected int
}{
{
name: "Set existing port to valid value",
existingPort: 8080,
expected: 8080,
},
{
name: "Set existing port to 0",
existingPort: 0,
expected: 0,
},
{
name: "Set existing port to high port",
existingPort: 65535,
expected: 65535,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

builder := &runConfigBuilder{
config: &RunConfig{},
}

option := WithExistingPort(tc.existingPort)
err := option(builder)

assert.NoError(t, err, "WithExistingPort should not return an error")
assert.Equal(t, tc.expected, builder.config.existingPort, "existingPort should be set correctly")
})
}
}
Loading