Skip to content

Commit 6ddab40

Browse files
committed
Add ZMQ connection retry configuration
Signed-off-by: zhengkezhou1 <madzhou1@gmail.com>
1 parent efa82a5 commit 6ddab40

File tree

5 files changed

+46
-26
lines changed

5 files changed

+46
-26
lines changed

pkg/common/config.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ type Configuration struct {
126126

127127
// ZMQEndpoint is the ZMQ address to publish events, the default value is tcp://localhost:5557
128128
ZMQEndpoint string `yaml:"zmq-endpoint"`
129+
// ZMQRetriesTimes defines the maximum number of retries when ZMQ connection fails
130+
ZMQRetriesTimes uint `yaml:"zmq-retries-times"`
129131
// EventBatchSize is the maximum number of kv-cache events to be sent together, defaults to 16
130132
EventBatchSize int `yaml:"event-batch-size"`
131133

@@ -354,6 +356,9 @@ func (c *Configuration) validate() error {
354356
if c.EventBatchSize < 1 {
355357
return errors.New("event batch size cannot less than 1")
356358
}
359+
if c.ZMQRetriesTimes > 10 {
360+
return errors.New("zmq retries times cannot be more than 10")
361+
}
357362

358363
if c.FakeMetrics != nil {
359364
if c.FakeMetrics.RunningRequests < 0 || c.FakeMetrics.WaitingRequests < 0 {
@@ -415,6 +420,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
415420
f.StringVar(&config.TokenizersCacheDir, "tokenizers-cache-dir", config.TokenizersCacheDir, "Directory for caching tokenizers")
416421
f.StringVar(&config.HashSeed, "hash-seed", config.HashSeed, "Seed for hash generation (if not set, is read from PYTHONHASHSEED environment variable)")
417422
f.StringVar(&config.ZMQEndpoint, "zmq-endpoint", config.ZMQEndpoint, "ZMQ address to publish events")
423+
f.UintVar(&config.ZMQRetriesTimes, "zmq-retries-times", config.ZMQRetriesTimes, "Number of times to retry ZMQ requests")
418424
f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together")
419425

420426
// These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help

pkg/common/publisher.go

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"errors"
2424
"fmt"
2525
"sync/atomic"
26+
"time"
2627

2728
zmq "github.com/pebbe/zmq4"
2829
"github.com/vmihailenco/msgpack/v5"
@@ -38,24 +39,33 @@ type Publisher struct {
3839

3940
// NewPublisher creates a new ZMQ publisher.
4041
// endpoint is the ZMQ address to bind to (e.g., "tcp://*:5557").
41-
func NewPublisher(endpoint string) (*Publisher, error) {
42+
func NewPublisher(endpoint string, retries uint) (*Publisher, error) {
4243
socket, err := zmq.NewSocket(zmq.PUB)
4344
if err != nil {
4445
return nil, fmt.Errorf("failed to create ZMQ PUB socket: %w", err)
4546
}
4647

47-
if err := socket.Connect(endpoint); err != nil {
48-
errClose := socket.Close()
49-
return nil, errors.Join(
50-
fmt.Errorf("failed to connect to %s: %w", endpoint, err),
51-
errClose,
52-
)
48+
// Retry connection with specified retry times and intervals
49+
for i := uint(0); i <= retries; i++ {
50+
err = socket.Connect(endpoint)
51+
if err == nil {
52+
return &Publisher{
53+
socket: socket,
54+
endpoint: endpoint,
55+
}, nil
56+
}
57+
58+
// If not the last attempt, wait before retrying
59+
if i < retries {
60+
time.Sleep(1 * time.Second)
61+
}
5362
}
5463

55-
return &Publisher{
56-
socket: socket,
57-
endpoint: endpoint,
58-
}, nil
64+
errClose := socket.Close()
65+
return nil, errors.Join(
66+
fmt.Errorf("failed to connect to %s after %d retries: %w", endpoint, retries+1, err),
67+
errClose,
68+
)
5969
}
6070

6171
// PublishEvent publishes a KV cache event batch to the ZMQ topic.

pkg/common/publisher_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ const (
3232
topic = "test-topic"
3333
endpoint = "tcp://localhost:5557"
3434
data = "Hello"
35+
retries = 0
3536
)
3637

3738
var _ = Describe("Publisher", func() {
@@ -49,7 +50,7 @@ var _ = Describe("Publisher", func() {
4950

5051
time.Sleep(100 * time.Millisecond)
5152

52-
pub, err := NewPublisher(endpoint)
53+
pub, err := NewPublisher(endpoint, retries)
5354
Expect(err).NotTo(HaveOccurred())
5455

5556
ctx, cancel := context.WithCancel(context.Background())

pkg/kv-cache/block_cache.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func newBlockCache(config *common.Configuration, logger logr.Logger) (*blockCach
4848
// TODO read size of channel from config
4949
eChan := make(chan EventData, 10000)
5050

51-
publisher, err := common.NewPublisher(config.ZMQEndpoint)
51+
publisher, err := common.NewPublisher(config.ZMQEndpoint, config.ZMQRetriesTimes)
5252
if err != nil {
5353
return nil, err
5454
}

pkg/kv-cache/kv_cache_test.go

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,12 @@ var _ = Describe("KV cache", Ordered, func() {
200200
time.Sleep(300 * time.Millisecond)
201201

202202
config := &common.Configuration{
203-
Port: 1234,
204-
Model: "model",
205-
KVCacheSize: test.cacheSize,
206-
ZMQEndpoint: endpoint,
207-
EventBatchSize: 1,
203+
Port: 1234,
204+
Model: "model",
205+
KVCacheSize: test.cacheSize,
206+
ZMQEndpoint: endpoint,
207+
ZMQRetriesTimes: 3,
208+
EventBatchSize: 1,
208209
}
209210

210211
sub, topic := createSub(config)
@@ -303,10 +304,11 @@ var _ = Describe("KV cache", Ordered, func() {
303304

304305
It("should send events correctly", func() {
305306
config := &common.Configuration{
306-
Port: 1234,
307-
Model: "model",
308-
KVCacheSize: 4,
309-
ZMQEndpoint: endpoint,
307+
Port: 1234,
308+
Model: "model",
309+
KVCacheSize: 4,
310+
ZMQEndpoint: endpoint,
311+
ZMQRetriesTimes: 3,
310312
}
311313

312314
sub, topic := createSub(config)
@@ -412,10 +414,11 @@ var _ = Describe("KV cache", Ordered, func() {
412414
for _, testCase := range testCases {
413415
It(testCase.name, func() {
414416
config := common.Configuration{
415-
Port: 1234,
416-
Model: "model",
417-
KVCacheSize: testCase.cacheSize,
418-
ZMQEndpoint: endpoint,
417+
Port: 1234,
418+
Model: "model",
419+
KVCacheSize: testCase.cacheSize,
420+
ZMQEndpoint: endpoint,
421+
ZMQRetriesTimes: 3,
419422
}
420423
blockCache, err := newBlockCache(&config, GinkgoLogr)
421424
Expect(err).NotTo(HaveOccurred())

0 commit comments

Comments
 (0)