Skip to content

Commit

Permalink
feat: adding proper canceling of token source
Browse files Browse the repository at this point in the history
  • Loading branch information
vlastahajek committed Sep 24, 2024
1 parent 7059db7 commit 2977c54
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 11 deletions.
20 changes: 15 additions & 5 deletions services/kafka/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,26 @@ func (c *Config) ApplyConditionalDefaults() {
}
}

type Closer interface {
Close()
}

type WriterConfig struct {
Closers []Closer
Config *kafka.Config
}

type WriteTarget struct {
Topic string
PartitionById bool
PartitionAlgorithm string
}

func (c Config) writerConfig(target WriteTarget) (*kafka.Config, error) {
func (c Config) writerConfig(target WriteTarget) (*WriterConfig, error) {
cfg := kafka.NewConfig()

if target.Topic == "" {
return cfg, errors.New("topic must not be empty")
return &WriterConfig{nil, cfg}, errors.New("topic must not be empty")
}
var partitioner kafka.PartitionerConstructor
if target.PartitionById {
Expand All @@ -104,7 +113,7 @@ func (c Config) writerConfig(target WriteTarget) (*kafka.Config, error) {
case "fnv-1a":
partitioner = kafka.NewHashPartitioner
default:
return cfg, fmt.Errorf("invalid partition algorithm: %q", target.PartitionAlgorithm)
return &WriterConfig{nil, cfg}, fmt.Errorf("invalid partition algorithm: %q", target.PartitionAlgorithm)
}
cfg.Producer.Partitioner = partitioner
}
Expand Down Expand Up @@ -135,10 +144,11 @@ func (c Config) writerConfig(target WriteTarget) (*kafka.Config, error) {
cfg.Producer.Flush.Frequency = time.Duration(c.BatchTimeout)

// SASL
if err := c.SASLAuth.SetSASLConfig(cfg); err != nil {
if o, err := c.SASLAuth.SetSASLConfig(cfg); err != nil {
return nil, err
} else {
return &WriterConfig{[]Closer{o}, cfg}, cfg.Validate()
}
return cfg, cfg.Validate()
}

type Configs []Config
Expand Down
11 changes: 7 additions & 4 deletions services/kafka/sasl.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ func (k *SASLAuth) Validate() error {
// SetSASLConfig configures SASL for kafka (sarama)
// We mutate instead of returning the appropriate struct, because kafka.NewConfig() already populates certain defaults
// that we do not want to disrupt.
func (k *SASLAuth) SetSASLConfig(config *kafka.Config) error {
func (k *SASLAuth) SetSASLConfig(config *kafka.Config) (Closer, error) {

config.Net.SASL.User = k.SASLUsername
config.Net.SASL.Password = k.SASLPassword
var c Closer

if k.SASLMechanism != "" {
config.Net.SASL.Mechanism = kafka.SASLMechanism(k.SASLMechanism)
Expand Down Expand Up @@ -139,7 +140,9 @@ func (k *SASLAuth) SetSASLConfig(config *kafka.Config) error {
ctx, cancel := context.WithCancel(context.Background())
src := cfg.TokenSource(ctx)
source := oauth2.ReuseTokenSourceWithExpiry(nil, src, k.SASLOAUTHExpiryMargin)
config.Net.SASL.TokenProvider = NewRefreshingToken(source, cancel, k.SASLExtensions)
r := NewRefreshingToken(source, cancel, k.SASLExtensions)
config.Net.SASL.TokenProvider = r
c = r

case kafka.SASLTypeGSSAPI:
config.Net.SASL.GSSAPI.ServiceName = k.SASLGSSAPIServiceName
Expand All @@ -161,11 +164,11 @@ func (k *SASLAuth) SetSASLConfig(config *kafka.Config) error {

version, err := SASLVersion(config.Version, k.SASLVersion)
if err != nil {
return err
return nil, err
}
config.Net.SASL.Version = version
}
return nil
return c, nil
}

func SASLVersion(kafkaVersion kafka.KafkaVersion, saslVersion *int) (int16, error) {
Expand Down
10 changes: 8 additions & 2 deletions services/kafka/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ type writer struct {
statsKey string

done chan struct{}

closers []Closer
}

func (w *writer) Open() {
Expand Down Expand Up @@ -83,6 +85,9 @@ func (w *writer) Close() {

close(w.done)
vars.DeleteStatistic(w.statsKey)
for _, c := range w.closers {
c.Close()
}
err := w.kafka.Close()

if err != nil {
Expand Down Expand Up @@ -166,19 +171,20 @@ func (c *Cluster) writer(target WriteTarget, diagnostic Diagnostic) (*writer, er
if err != nil {
return nil, err
}
kp, err := kafka.NewAsyncProducer(c.cfg.Brokers, wc)
kp, err := kafka.NewAsyncProducer(c.cfg.Brokers, wc.Config)

if err != nil {
return nil, err
}

// Create new writer
w = &writer{
requestsInFlightMetric: metrics.GetOrRegisterCounter("requests-in-flight", wc.MetricRegistry),
requestsInFlightMetric: metrics.GetOrRegisterCounter("requests-in-flight", wc.Config.MetricRegistry),
kafka: kp,
cluster: c.cfg.ID,
topic: topic,
diagnostic: diagnostic,
closers: wc.Closers,
}
w.Open()
c.writers[topic] = w
Expand Down
5 changes: 5 additions & 0 deletions services/kafka/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ func (k *RefreshingToken) Token() (*kafka.AccessToken, error) {
}, nil
}

func (k *RefreshingToken) Close() {
// canceling the token refresh
k.cancel()
}

func NewStaticToken(token string, extensions map[string]string) *StaticToken {
return &StaticToken{
token: token,
Expand Down

0 comments on commit 2977c54

Please sign in to comment.