diff --git a/README.md b/README.md index 328bc608..275f8e13 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ For more details see the 10 { + return errors.New("zmq retries times cannot be more than 10") + } if c.FakeMetrics != nil { if c.FakeMetrics.RunningRequests < 0 || c.FakeMetrics.WaitingRequests < 0 { @@ -415,6 +420,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.TokenizersCacheDir, "tokenizers-cache-dir", config.TokenizersCacheDir, "Directory for caching tokenizers") f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)") f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events") + f.UintVar(&config.ZMQMaxConnectAttempts, "zmq-max-connect-attempts", config.ZMQMaxConnectAttempts, "Maximum number of times to retry ZMQ requests") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index d4ec9677..f50c40a9 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -103,12 +103,14 @@ var _ = Describe("Simulator configuration", func() { "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", } c.EventBatchSize = 5 + c.ZMQMaxConnectAttempts = 1 test = testCase{ name: "config file with command line args", args: []string{"cmd", "--model", model, "--config", "../../manifests/config.yaml", "--port", "8002", "--served-model-name", "alias1", "alias2", "--seed", "100", "--lora-modules", "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}", "--event-batch-size", "5", + "--zmq-max-connect-attempts", "1", }, expectedConfig: c, } @@ -121,6 +123,7 @@ var _ = Describe("Simulator configuration", func() { c.LoraModulesString = []string{ "{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}", } + c.ZMQMaxConnectAttempts = 0 test = testCase{ name: "config file with command line args with different format", args: []string{"cmd", "--model", model, "--config", "../../manifests/config.yaml", "--port", "8002", @@ -377,6 +380,14 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--fake-metrics", "{\"running-requests\":10,\"waiting-requests\":30,\"kv-cache-usage\":40}", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid (negative) zmq-max-connect-attempts for argument", + args: []string{"cmd", "zmq-max-connect-attempts", "-1", "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid (negative) zmq-max-connect-attempts for config file", + args: []string{"cmd", "--config", "../../manifests/invalid-config.yaml"}, + }, } for _, test := range invalidTests { diff --git a/pkg/common/publisher.go b/pkg/common/publisher.go index d7d6e325..883c05a2 100644 --- a/pkg/common/publisher.go +++ b/pkg/common/publisher.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "sync/atomic" + "time" zmq "github.com/pebbe/zmq4" "github.com/vmihailenco/msgpack/v5" @@ -38,24 +39,34 @@ type Publisher struct { // NewPublisher creates a new ZMQ publisher. // endpoint is the ZMQ address to bind to (e.g., "tcp://*:5557"). -func NewPublisher(endpoint string) (*Publisher, error) { +// retries is the maximum number of connection attempts. +func NewPublisher(endpoint string, retries uint) (*Publisher, error) { socket, err := zmq.NewSocket(zmq.PUB) if err != nil { return nil, fmt.Errorf("failed to create ZMQ PUB socket: %w", err) } - if err := socket.Connect(endpoint); err != nil { - errClose := socket.Close() - return nil, errors.Join( - fmt.Errorf("failed to connect to %s: %w", endpoint, err), - errClose, - ) + // Retry connection with specified retry times and intervals + for i := uint(0); i <= retries; i++ { + err = socket.Connect(endpoint) + if err == nil { + return &Publisher{ + socket: socket, + endpoint: endpoint, + }, nil + } + + // If not the last attempt, wait before retrying + if i < retries { + time.Sleep(1 * time.Second) + } } - return &Publisher{ - socket: socket, - endpoint: endpoint, - }, nil + errClose := socket.Close() + return nil, errors.Join( + fmt.Errorf("failed to connect to %s after %d retries: %w", endpoint, retries+1, err), + errClose, + ) } // PublishEvent publishes a KV cache event batch to the ZMQ topic. diff --git a/pkg/common/publisher_test.go b/pkg/common/publisher_test.go index 8f4609d5..a9d6582b 100644 --- a/pkg/common/publisher_test.go +++ b/pkg/common/publisher_test.go @@ -33,6 +33,7 @@ const ( subEndpoint = "tcp://*:5557" pubEndpoint = "tcp://localhost:5557" data = "Hello" + retries = 0 ) var _ = Describe("Publisher", func() { @@ -50,7 +51,7 @@ var _ = Describe("Publisher", func() { time.Sleep(100 * time.Millisecond) - pub, err := NewPublisher(pubEndpoint) + pub, err := NewPublisher(pubEndpoint, retries) Expect(err).NotTo(HaveOccurred()) ctx, cancel := context.WithCancel(context.Background()) @@ -78,4 +79,40 @@ var _ = Describe("Publisher", func() { Expect(err).NotTo(HaveOccurred()) Expect(payload).To(Equal(data)) }) + It("should fail when connection attempts exceed maximum retries", func() { + // Use invalid address format, which will cause connection to fail + invalidEndpoint := "invalid-address-format" + + pub, err := NewPublisher(invalidEndpoint, 2) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to connect")) + Expect(err.Error()).To(ContainSubstring("after 3 retries")) // 2 retries = 3 total attempts + + if pub != nil { + //nolint + pub.Close() + } + }) + It("should retry connection successfully", func() { + // Step 1: Try to connect to a temporarily non-existent service + // This will trigger the retry mechanism + go func() { + // Delay starting the server to simulate service recovery + time.Sleep(2 * time.Second) + + // Start subscriber as server + sub, err := zmq.NewSocket(zmq.SUB) + Expect(err).NotTo(HaveOccurred()) + //nolint + defer sub.Close() + err = sub.Bind(subEndpoint) + Expect(err).NotTo(HaveOccurred()) + }() + + // Step 2: Publisher will retry connection and eventually succeed + pub, err := NewPublisher(pubEndpoint, 5) // 5 retries + Expect(err).NotTo(HaveOccurred()) // Should eventually succeed + //nolint + defer pub.Close() + }) }) diff --git a/pkg/kv-cache/block_cache.go b/pkg/kv-cache/block_cache.go index e66c7224..56d2253b 100644 --- a/pkg/kv-cache/block_cache.go +++ b/pkg/kv-cache/block_cache.go @@ -48,7 +48,7 @@ func newBlockCache(config *common.Configuration, logger logr.Logger) (*blockCach // TODO read size of channel from config eChan := make(chan EventData, 10000) - publisher, err := common.NewPublisher(config.ZMQEndpoint) + publisher, err := common.NewPublisher(config.ZMQEndpoint, config.ZMQMaxConnectAttempts) if err != nil { return nil, err } diff --git a/pkg/kv-cache/kv_cache_test.go b/pkg/kv-cache/kv_cache_test.go index 8f57c516..7731196e 100644 --- a/pkg/kv-cache/kv_cache_test.go +++ b/pkg/kv-cache/kv_cache_test.go @@ -201,11 +201,12 @@ var _ = Describe("KV cache", Ordered, func() { time.Sleep(300 * time.Millisecond) config := &common.Configuration{ - Port: 1234, - Model: "model", - KVCacheSize: test.cacheSize, - ZMQEndpoint: pubEndpoint, - EventBatchSize: 1, + Port: 1234, + Model: "model", + KVCacheSize: test.cacheSize, + ZMQEndpoint: pubEndpoint, + ZMQMaxConnectAttempts: 3, + EventBatchSize: 1, } sub, topic := createSub(config) @@ -304,10 +305,11 @@ var _ = Describe("KV cache", Ordered, func() { It("should send events correctly", func() { config := &common.Configuration{ - Port: 1234, - Model: "model", - KVCacheSize: 4, - ZMQEndpoint: pubEndpoint, + Port: 1234, + Model: "model", + KVCacheSize: 4, + ZMQEndpoint: pubEndpoint, + ZMQMaxConnectAttempts: 3, } sub, topic := createSub(config) @@ -413,10 +415,11 @@ var _ = Describe("KV cache", Ordered, func() { for _, testCase := range testCases { It(testCase.name, func() { config := common.Configuration{ - Port: 1234, - Model: "model", - KVCacheSize: testCase.cacheSize, - ZMQEndpoint: pubEndpoint, + Port: 1234, + Model: "model", + KVCacheSize: testCase.cacheSize, + ZMQEndpoint: pubEndpoint, + ZMQMaxConnectAttempts: 3, } blockCache, err := newBlockCache(&config, GinkgoLogr) Expect(err).NotTo(HaveOccurred())