Skip to content

Commit

Permalink
ensure discovery completes on schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
jcodybaker committed Feb 18, 2024
1 parent 97f9a85 commit 69d4322
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 27 deletions.
2 changes: 1 addition & 1 deletion cmd/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func init() {
}
for _, childCmd := range c.Parent.Commands() {
childRun := childCmd.RunE
discoveryFlags(childCmd.Flags(), false, true)
discoveryFlags(childCmd.Flags(), discoveryFlagsOptions{interactive: true})
childCmd.RunE = func(cmd *cobra.Command, args []string) error {
if err := rootCmd.PersistentPreRunE(cmd, args); err != nil {
return err
Expand Down
37 changes: 21 additions & 16 deletions cmd/helper_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ import (

var addAll bool

func discoveryFlags(f *pflag.FlagSet, withTTL, interactive bool) {
type discoveryFlagsOptions struct {
withTTL bool
interactive bool
searchStrictTimeoutDefault bool
}

func discoveryFlags(f *pflag.FlagSet, opts discoveryFlagsOptions) {
f.String(
"auth",
"",
Expand Down Expand Up @@ -67,17 +73,23 @@ func discoveryFlags(f *pflag.FlagSet, withTTL, interactive bool) {
"timeout for devices to respond to the mDNS discovery query.",
)

f.Bool(
"search-strict-timeout",
opts.searchStrictTimeoutDefault,
"ignore devices which have been found but completed their initial query within the search-timeout",
)

// search-interactive and interactive cannot use the Bool() pattern as the default
// varies by command and the global be set to whatever the last value was.
f.Bool(
"search-interactive",
interactive,
opts.interactive,
"if true confirm devices discovered in search before proceeding with commands. Defers to --interactive if not explicitly set.",
)

f.Bool(
"interactive",
interactive,
opts.interactive,
"if true prompt for confirmation or passwords.",
)

Expand All @@ -98,7 +110,7 @@ func discoveryFlags(f *pflag.FlagSet, withTTL, interactive bool) {
"continue with other hosts in the face errors.",
)

if withTTL {
if opts.withTTL {
f.Duration(
"device-ttl",
discovery.DefaultDeviceTTL,
Expand Down Expand Up @@ -134,22 +146,14 @@ func discoveryOptionsFromFlags(flags *pflag.FlagSet) (opts []discovery.Discovere
default:
return nil, errors.New("invalid value for --prefer-ip-version; must be `4` or `6`")
}
searchInteractive, err := flags.GetBool("search-interactive")
if err != nil {
return nil, err
}
searchInteractive := viper.GetBool("search-interactive")
explictSearchInteractive := flags.Lookup("search-interactive").Changed
interactive, err := flags.GetBool("interactive")
if err != nil {
return nil, err
}
interactive := viper.GetBool("interactive")

if !explictSearchInteractive {
searchInteractive = interactive
}
auth, err := flags.GetString("auth")
if err != nil {
return nil, err
}
auth := viper.GetString("auth")
if auth != "" {
opts = append(opts, discovery.WithAuthCallback(func(_ context.Context, _ string) (passwd string, err error) {
return auth, nil
Expand All @@ -174,6 +178,7 @@ func discoveryOptionsFromFlags(flags *pflag.FlagSet) (opts []discovery.Discovere
discovery.WithMDNSZone(viper.GetString("mdns-zone")),
discovery.WithMDNSService(viper.GetString("mdns-service")),
discovery.WithSearchTimeout(viper.GetDuration("search-timeout")),
discovery.WithSearchStrictTimeout(viper.GetBool("search-strict-timeout")),
discovery.WithConcurrency(viper.GetInt("discovery-concurrency")),
discovery.WithDeviceTTL(viper.GetDuration("device-ttl")),
discovery.WithMDNSSearchEnabled(mdnsSearch),
Expand Down
6 changes: 5 additions & 1 deletion cmd/prometheus.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ func init() {
prometheusCmd.Flags().Int("probe-concurrency", promserver.DefaultConcurrency, "set the number of concurrent probes which will be made to service a metrics request.")
prometheusCmd.Flags().Duration("device-timeout", promserver.DefaultDeviceTimeout, "set the maximum time allowed for a device to respond to it probe.")
prometheusCmd.Flags().Duration("scrape-duration-warning", promserver.DefaultScrapeDurationWarning, "sets the value for scrape duration warning. Scrapes which exceed this duration will log a warning generate. Default value 8s is 80% of the 10s default prometheus scrape_timeout.")
discoveryFlags(prometheusCmd.Flags(), true, false)
discoveryFlags(prometheusCmd.Flags(), discoveryFlagsOptions{
withTTL: true,
interactive: false,
searchStrictTimeoutDefault: true,
})
rootCmd.AddCommand(prometheusCmd)
rootCmd.AddGroup(&cobra.Group{
ID: "servers",
Expand Down
2 changes: 1 addition & 1 deletion cmd/shelly.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func init() {
"password", "", "password to use for auth. If empty, the password will be cleared.",
)
shellyComponent.Parent.AddCommand(shellyAuthCmd)
discoveryFlags(shellyAuthCmd.Flags(), false, true)
discoveryFlags(shellyAuthCmd.Flags(), discoveryFlagsOptions{interactive: true})
shellyAuthCmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
ll := log.Ctx(ctx).With().Str("request", (&shelly.ShellySetAuthRequest{}).Method()).Logger()
Expand Down
2 changes: 1 addition & 1 deletion pkg/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewDiscoverer(opts ...DiscovererOption) *Discoverer {
mdnsService: DefaultMDNSService,
searchTimeout: DefaultMDNSSearchTimeout,
concurrency: DefaultConcurrency,
mdnsQueryFunc: mdns.Query,
mdnsQueryFunc: mdns.QueryContext,
},
}
for _, o := range opts {
Expand Down
7 changes: 6 additions & 1 deletion pkg/discovery/mdns.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ func (d *Discoverer) searchMDNS(ctx context.Context, stop chan struct{}) ([]*Dev
if !d.mdnsSearchEnabled {
return nil, nil
}
if d.searchStrictTimeout {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d.searchTimeout)
defer cancel()
}
c := make(chan *mdns.ServiceEntry, mdnsSearchBuffer)
params := &mdns.QueryParam{
Service: d.mdnsService,
Expand Down Expand Up @@ -93,7 +98,7 @@ func (d *Discoverer) searchMDNS(ctx context.Context, stop chan struct{}) ([]*Dev
}
}()

if err := d.mdnsQueryFunc(params); err != nil {
if err := d.mdnsQueryFunc(ctx, params); err != nil {
close(c)
return nil, fmt.Errorf("querying mdns for devices: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/discovery/mdns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestDiscovererMDNSSearch(t *testing.T) {
serviceEntryTemplate.Port, err = strconv.Atoi(port)
require.NoError(t, err)

queryFunc := func(params *mdns.QueryParam) error {
queryFunc := func(ctx context.Context, params *mdns.QueryParam) error {
se := serviceEntryTemplate
params.Entries <- &se
return nil
Expand Down
18 changes: 14 additions & 4 deletions pkg/discovery/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package discovery

import (
"context"
"net"
"sync"
"time"
Expand All @@ -24,17 +25,18 @@ type options struct {
mdnsService string
mdnsSearchEnabled bool

searchTimeout time.Duration
searchConfirm SearchConfirm
concurrency int
searchStrictTimeout bool
searchTimeout time.Duration
searchConfirm SearchConfirm
concurrency int

// deviceTTL is relevant for long-lived commands (like prometheus metrics server) when
// mixed with mDNS or other ephemeral discovery.
deviceTTL time.Duration

preferIPVersion string

mdnsQueryFunc func(*mdns.QueryParam) error
mdnsQueryFunc func(context.Context, *mdns.QueryParam) error
}

// DiscovererOption provides optional parameters for the Discoverer.
Expand Down Expand Up @@ -126,4 +128,12 @@ func WithAuthCallback(authCallback AuthCallback) DiscovererOption {
}
}

// WithSearchStrictTimeout will force devices which have been discovered, but not resolved and added
// to finish within the search timeout or be cancelled.
func WithSearchStrictTimeout(strictTimeoutMode bool) DiscovererOption {
return func(d *Discoverer) {
d.searchStrictTimeout = strictTimeoutMode
}
}

type DeviceOption func(*Device)
3 changes: 2 additions & 1 deletion pkg/discovery/test_harness.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package discovery

import (
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -38,7 +39,7 @@ func NewTestDiscoverer(t *testing.T, opts ...DiscovererOption) *TestDiscoverer {
}

// SetMDNSQueryFunc facilitates overriding the mDNS query function for testing.
func (td *TestDiscoverer) SetMDNSQueryFunc(q func(*mdns.QueryParam) error) {
func (td *TestDiscoverer) SetMDNSQueryFunc(q func(context.Context, *mdns.QueryParam) error) {
td.mdnsQueryFunc = q
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/promserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,11 @@ func (s *Server) Collect(ch chan<- prometheus.Metric) {
}
l.Debug().Dur("duration", duration).Msg("finished all collection")
}()
l.Debug().Msg("starting discovery")
if _, err := s.discoverer.Search(s.ctx); err != nil {
l.Err(err).Msg("finding new devices")
}
l.Debug().Dur("duration", time.Since(start)).Msg("finished discovery")
var wg sync.WaitGroup
defer wg.Wait()
concurrencyLimit := make(chan struct{}, s.concurrency)
Expand Down

0 comments on commit 69d4322

Please sign in to comment.