From c01332c19ceedf588d9a99f31555dc055dd62c14 Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Fri, 19 Apr 2024 11:37:05 +0300 Subject: [PATCH 1/6] [#58] Request device owner's approval - add config before which phases a owner consent will be needed - owner consent API - add owner consent client, used internally in UM - add owner consent agent client, to be used by the app getting the owner approval - modify orchestrator to wait for a owner consent - unit test Signed-off-by: Dimitar Dimitrov --- api/client.go | 26 ++ api/types/owner_consent.go | 29 +++ api/update_orchestrator.go | 1 + cmd/update-manager/main.go | 40 ++- config/config_internal.go | 5 +- config/config_test.go | 1 + config/flags_internal.go | 33 ++- config/flags_test.go | 101 ++++++++ config/testdata/config.json | 1 + go.mod | 2 +- mqtt/desired_state_client.go | 11 +- mqtt/desired_state_client_test.go | 4 +- mqtt/owner_consent_agent_client .go | 124 ++++++++++ mqtt/owner_consent_agent_client_test.go | 173 +++++++++++++ mqtt/owner_consent_client _test.go | 201 +++++++++++++++ mqtt/owner_consent_client.go | 114 +++++++++ mqtt/update_agent_client.go | 10 +- mqtt/util.go | 30 +++ test/mocks/client_mock.go | 232 ++++++++++++++++++ test/mocks/update_orchestrator_mock.go | 14 ++ updatem/orchestration/update_operation.go | 5 +- updatem/orchestration/update_orchestrator.go | 23 +- .../update_orchestrator_apply.go | 51 ++++ .../update_orchestrator_apply_test.go | 137 ++++++++++- .../orchestration/update_orchestrator_test.go | 38 ++- updatem/orchestration/update_phase.go | 9 + 26 files changed, 1373 insertions(+), 42 deletions(-) create mode 100644 api/types/owner_consent.go create mode 100755 mqtt/owner_consent_agent_client .go create mode 100644 mqtt/owner_consent_agent_client_test.go create mode 100755 mqtt/owner_consent_client _test.go create mode 100755 mqtt/owner_consent_client.go create mode 100644 mqtt/util.go diff --git a/api/client.go b/api/client.go index b6a40ab..b5782ba 100755 --- a/api/client.go +++ b/api/client.go @@ -54,3 +54,29 @@ type DesiredStateClient interface { SendDesiredStateCommand(string, *types.DesiredStateCommand) error SendCurrentStateGet(string) error } + +// OwnerConsentAgentHandler defines functions for handling the owner consent requests +type OwnerConsentAgentHandler interface { + HandleOwnerConsentGet(string, int64, *types.OwnerConsent) error +} + +// OwnerConsentAgentClient defines an interface for handling for owner consent requests +type OwnerConsentAgentClient interface { + BaseClient + + Start(OwnerConsentAgentHandler) error + SendOwnerConsent(string, *types.OwnerConsent) error +} + +// OwnerConsentHandler defines functions for handling the owner consent +type OwnerConsentHandler interface { + HandleOwnerConsent(string, int64, *types.OwnerConsent) error +} + +// OwnerConsentClient defines an interface for triggering requests for owner consent +type OwnerConsentClient interface { + BaseClient + + Start(OwnerConsentHandler) error + SendOwnerConsentGet(string, *types.DesiredState) error +} diff --git a/api/types/owner_consent.go b/api/types/owner_consent.go new file mode 100644 index 0000000..c05497d --- /dev/null +++ b/api/types/owner_consent.go @@ -0,0 +1,29 @@ +// Copyright (c) 2024 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package types + +// ConsentStatusType defines values for status within the owner consent +type ConsentStatusType string + +const ( + // StatusApproved denotes that the owner has consented. + StatusApproved ConsentStatusType = "APPROVED" + // StatusDenied denotes that the owner has not consented. + StatusDenied ConsentStatusType = "DENIED" +) + +// OwnerConsent defines the payload for Owner Consent response. +type OwnerConsent struct { + Status ConsentStatusType `json:"status,omitempty"` + // time field for scheduling could be added here +} diff --git a/api/update_orchestrator.go b/api/update_orchestrator.go index 8b894c9..b2eaf6c 100644 --- a/api/update_orchestrator.go +++ b/api/update_orchestrator.go @@ -23,4 +23,5 @@ type UpdateOrchestrator interface { Apply(context.Context, map[string]UpdateManager, string, *types.DesiredState, DesiredStateFeedbackHandler) bool DesiredStateFeedbackHandler + OwnerConsentHandler } diff --git a/cmd/update-manager/main.go b/cmd/update-manager/main.go index 66d57bc..6006abb 100755 --- a/cmd/update-manager/main.go +++ b/cmd/update-manager/main.go @@ -41,17 +41,9 @@ func main() { } defer loggerOut.Close() - var client api.UpdateAgentClient - if cfg.ThingsEnabled { - client, err = mqtt.NewUpdateAgentThingsClient(cfg.Domain, cfg.MQTT) - } else { - client, err = mqtt.NewUpdateAgentClient(cfg.Domain, cfg.MQTT) - } + uac, um, err := initUpdateManager(cfg) if err == nil { - updateManager, err := orchestration.NewUpdateManager(version, cfg, client, orchestration.NewUpdateOrchestrator(cfg)) - if err == nil { - err = app.Launch(cfg, client, updateManager) - } + err = app.Launch(cfg, uac, um) } if err != nil { @@ -60,3 +52,31 @@ func main() { os.Exit(1) } } + +func initUpdateManager(cfg *config.Config) (api.UpdateAgentClient, api.UpdateManager, error) { + var ( + uac api.UpdateAgentClient + occ api.OwnerConsentClient + um api.UpdateManager + err error + ) + + if cfg.ThingsEnabled { + uac, err = mqtt.NewUpdateAgentThingsClient(cfg.Domain, cfg.MQTT) + } else { + uac, err = mqtt.NewUpdateAgentClient(cfg.Domain, cfg.MQTT) + } + if err != nil { + return nil, nil, err + } + + if len(cfg.OwnerConsentPhases) != 0 { + if occ, err = mqtt.NewOwnerConsentClient(cfg.Domain, uac); err != nil { + return nil, nil, err + } + } + if um, err = orchestration.NewUpdateManager(version, cfg, uac, orchestration.NewUpdateOrchestrator(cfg, occ)); err != nil { + return nil, nil, err + } + return uac, um, nil +} diff --git a/config/config_internal.go b/config/config_internal.go index 2cb0ce3..32c44ad 100755 --- a/config/config_internal.go +++ b/config/config_internal.go @@ -12,7 +12,9 @@ package config -import "github.com/eclipse-kanto/update-manager/api" +import ( + "github.com/eclipse-kanto/update-manager/api" +) const ( // default log config @@ -42,6 +44,7 @@ type Config struct { ReportFeedbackInterval string `json:"reportFeedbackInterval"` CurrentStateDelay string `json:"currentStateDelay"` PhaseTimeout string `json:"phaseTimeout"` + OwnerConsentPhases []string `json:"ownerConsentPhases"` } func newDefaultConfig() *Config { diff --git a/config/config_test.go b/config/config_test.go index e74ffdf..6f907e4 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -136,6 +136,7 @@ func TestLoadConfigFromFile(t *testing.T) { ReportFeedbackInterval: "2m", CurrentStateDelay: "1m", PhaseTimeout: "2m", + OwnerConsentPhases: []string{"download"}, } assert.True(t, reflect.DeepEqual(*cfg, expectedConfigValues)) }) diff --git a/config/flags_internal.go b/config/flags_internal.go index 4d01620..d737fd9 100755 --- a/config/flags_internal.go +++ b/config/flags_internal.go @@ -24,14 +24,17 @@ import ( const ( // domains flag - domainsFlagID = "domains" + domainsFlagID = "domains" + domainsDesc = "Specify a comma-separated list of domains handled by the update manager" + ownerConsentPhasesFlagID = "owner-consent-phases" + ownerConsentPhasesDesc = "Specify a comma-separated list of update phase, before which an owner consent should be granted. Possible values are: 'download', 'update', 'activation'" ) // SetupAllUpdateManagerFlags adds all flags for the configuration of the update manager func SetupAllUpdateManagerFlags(flagSet *flag.FlagSet, cfg *Config) { SetupFlags(flagSet, cfg.BaseConfig) - flagSet.String(domainsFlagID, "", "Specify a comma-separated list of domains handled by the update manager") + flagSet.String(domainsFlagID, "", domainsDesc) flagSet.BoolVar(&cfg.RebootEnabled, "reboot-enabled", EnvToBool("REBOOT_ENABLED", cfg.RebootEnabled), "Specify a flag that controls the enabling/disabling of the reboot process after successful update operation") flagSet.StringVar(&cfg.RebootAfter, "reboot-after", EnvToString("REBOOT_AFTER", cfg.RebootAfter), "Specify the timeout in cron format to wait before a reboot process is initiated after successful update operation. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") @@ -39,13 +42,16 @@ func SetupAllUpdateManagerFlags(flagSet *flag.FlagSet, cfg *Config) { flagSet.StringVar(&cfg.PhaseTimeout, "phase-timeout", EnvToString("PHASE_TIMEOUT", cfg.PhaseTimeout), "Specify the timeout for completing an Update Orchestration phase. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") flagSet.StringVar(&cfg.ReportFeedbackInterval, "report-feedback-interval", EnvToString("REPORT_FEEDBACK_INTERVAL", cfg.ReportFeedbackInterval), "Specify the time interval for reporting intermediate desired state feedback messages during an active update operation. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") flagSet.StringVar(&cfg.CurrentStateDelay, "current-state-delay", EnvToString("CURRENT_STATE_DELAY", cfg.CurrentStateDelay), "Specify the time delay for reporting current state messages. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") - + flagSet.String(ownerConsentPhasesFlagID, "", "Specify a comma-separated list of update phase, before which an owner consent should be granted. Possible values are: 'download', 'update', 'activation'") setupAgentsConfigFlags(flagSet, cfg) } func parseFlags(cfg *Config, version string) { domains := parseDomainsFlag() prepareAgentsConfig(cfg, domains) + if ownerConsentPhases := parseOwnerConsentPhasesFlag(); len(ownerConsentPhases) > 0 { + cfg.OwnerConsentPhases = ownerConsentPhases + } flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) flagSet := flag.CommandLine @@ -63,11 +69,30 @@ func parseFlags(cfg *Config, version string) { } } +func parseOwnerConsentPhasesFlag() []string { + var listPhases string + flagSet := flag.NewFlagSet("", flag.ContinueOnError) + flagSet.SetOutput(io.Discard) + flagSet.StringVar(&listPhases, ownerConsentPhasesFlagID, EnvToString("OWNER_CONSENT_PHASES", ""), ownerConsentPhasesDesc) + if err := flagSet.Parse(getFlagArgs(ownerConsentPhasesFlagID)); err != nil { + logger.ErrorErr(err, "Cannot parse %s flag", ownerConsentPhasesFlagID) + } + + var result []string + for _, phase := range strings.Split(listPhases, ",") { + p := strings.TrimSpace(phase) + if len(p) > 0 { + result = append(result, p) + } + } + return result +} + func parseDomainsFlag() map[string]bool { var listDomains string flagSet := flag.NewFlagSet("", flag.ContinueOnError) flagSet.SetOutput(io.Discard) - flagSet.StringVar(&listDomains, domainsFlagID, EnvToString("DOMAINS", ""), "Specify a comma-separated list of domains handled by the update manager") + flagSet.StringVar(&listDomains, domainsFlagID, EnvToString("DOMAINS", ""), domainsDesc) if err := flagSet.Parse(getFlagArgs(domainsFlagID)); err != nil { logger.ErrorErr(err, "Cannot parse domain flag") } diff --git a/config/flags_test.go b/config/flags_test.go index ea91cc6..e668c41 100644 --- a/config/flags_test.go +++ b/config/flags_test.go @@ -17,6 +17,8 @@ import ( "fmt" "os" "reflect" + "slices" + "strings" "testing" "github.com/eclipse-kanto/update-manager/api" @@ -129,6 +131,10 @@ func TestSetupFlags(t *testing.T) { flag: "phase-timeout", expectedType: reflect.String.String(), }, + "test_flags_owner_consent_phases": { + flag: "owner-consent-phases", + expectedType: reflect.String.String(), + }, } for testName, testCase := range tests { t.Run(testName, func(t *testing.T) { @@ -230,6 +236,68 @@ func TestParseDomainsFlag(t *testing.T) { } }) } +func TestParseOwnerConsentPhasesFlag(t *testing.T) { + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + testPhases := "download" + + t.Run("test_parse_consent_phases_flag_1", func(t *testing.T) { + os.Args = []string{oldArgs[0], fmt.Sprintf("-%s=%s", ownerConsentPhasesFlagID, testPhases)} + actualConsentPhases := parseOwnerConsentPhasesFlag() + if len(actualConsentPhases) != 1 && !slices.Contains(actualConsentPhases, testPhases) { + t.Error("consent phase not set") + } + }) + + t.Run("test_parse_consent_phases_flag_2", func(t *testing.T) { + os.Args = []string{oldArgs[0], fmt.Sprintf("--%s=%s", ownerConsentPhasesFlagID, testPhases)} + actualConsentPhases := parseOwnerConsentPhasesFlag() + if len(actualConsentPhases) != 1 && !slices.Contains(actualConsentPhases, testPhases) { + t.Error("consent phase not set") + } + }) + + t.Run("test_parse_consent_phases_flag_3", func(t *testing.T) { + os.Args = []string{oldArgs[0], fmt.Sprintf("-%s", ownerConsentPhasesFlagID), testPhases} + actualConsentPhases := parseOwnerConsentPhasesFlag() + if len(actualConsentPhases) != 1 && !slices.Contains(actualConsentPhases, testPhases) { + t.Error("consent phase not set") + } + }) + + t.Run("test_parse_consent_phases_flag_4", func(t *testing.T) { + os.Args = []string{oldArgs[0], fmt.Sprintf("-%s", ownerConsentPhasesFlagID), testPhases} + actualConsentPhases := parseOwnerConsentPhasesFlag() + if len(actualConsentPhases) != 1 && !slices.Contains(actualConsentPhases, testPhases) { + t.Error("consent phase not set") + } + }) + + t.Run("test_parse_consent_phase_flag_err", func(t *testing.T) { + invalidConsentPhasesFlagID := "invalid" + os.Args = []string{oldArgs[0], fmt.Sprintf("--%s=%s", invalidConsentPhasesFlagID, testPhases)} + actualConsentPhases := parseOwnerConsentPhasesFlag() + if len(actualConsentPhases) != 0 { + t.Errorf("\"incorrect value: %v , expecting: empty \"", actualConsentPhases) + } + }) + + t.Run("test_parse_consent_phases_flag_err_1", func(t *testing.T) { + os.Args = []string{oldArgs[0], fmt.Sprintf("-%s", ownerConsentPhasesFlagID)} + actualConsentPhases := parseOwnerConsentPhasesFlag() + if len(actualConsentPhases) != 0 { + t.Errorf("\"incorrect value: %v , expecting: empty \"", actualConsentPhases) + } + }) + + t.Run("test_parse_consent_phases_flag_err_2", func(t *testing.T) { + os.Args = []string{oldArgs[0], fmt.Sprintf("--%s", ownerConsentPhasesFlagID)} + actualConsentPhases := parseOwnerConsentPhasesFlag() + if len(actualConsentPhases) != 0 { + t.Errorf("\"incorrect value: %v , expecting: empty \"", actualConsentPhases) + } + }) +} func TestParseFlags(t *testing.T) { testVersion := "testVersion" @@ -372,4 +440,37 @@ func TestParseFlags(t *testing.T) { parseFlags(cfg, testVersion) assert.Equal(t, expectedAgents, cfg.Agents) }) + t.Run("test_owner_consent_phases", func(t *testing.T) { + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + testConfigPath := "../config/testdata/config.json" + expectedPhases := []string{"download"} + + os.Args = []string{oldArgs[0], fmt.Sprintf("--%s=%s", configFileFlagID, testConfigPath)} + cfg := newDefaultConfig() + configFilePath := ParseConfigFilePath() + if configFilePath != "" { + assert.NoError(t, LoadConfigFromFile(configFilePath, cfg)) + } + parseFlags(cfg, testVersion) + assert.Equal(t, expectedPhases, cfg.OwnerConsentPhases) + }) + t.Run("test_overwrite_owner_consent_phases", func(t *testing.T) { + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + testConfigPath := "../config/testdata/config.json" + expectedPhases := []string{"update", "activation"} + + os.Args = []string{oldArgs[0], fmt.Sprintf("--%s=%s", configFileFlagID, testConfigPath), + fmt.Sprintf("--%s=%s", ownerConsentPhasesFlagID, strings.Join(expectedPhases, ","))} + cfg := newDefaultConfig() + configFilePath := ParseConfigFilePath() + if configFilePath != "" { + assert.NoError(t, LoadConfigFromFile(configFilePath, cfg)) + } + parseFlags(cfg, testVersion) + assert.Equal(t, expectedPhases, cfg.OwnerConsentPhases) + }) } diff --git a/config/testdata/config.json b/config/testdata/config.json index e093e56..0e8a624 100644 --- a/config/testdata/config.json +++ b/config/testdata/config.json @@ -24,6 +24,7 @@ "reportFeedbackInterval": "2m", "currentStateDelay": "1m", "phaseTimeout": "2m", + "ownerConsentPhases": ["download"], "agents": { "self-update": { "rebootRequired": false, diff --git a/go.mod b/go.mod index b2471b6..2b37a96 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/eclipse-kanto/update-manager -go 1.17 +go 1.18 require ( github.com/eclipse/ditto-clients-golang v0.0.0-20230504175246-3e6e17510ac4 diff --git a/mqtt/desired_state_client.go b/mqtt/desired_state_client.go index d5e508c..b36b1c8 100755 --- a/mqtt/desired_state_client.go +++ b/mqtt/desired_state_client.go @@ -33,14 +33,9 @@ type desiredStateClient struct { // NewDesiredStateClient instantiates a new client for triggering MQTT requests. func NewDesiredStateClient(domain string, updateAgent api.UpdateAgentClient) (api.DesiredStateClient, error) { - var mqttClient *mqttClient - switch v := updateAgent.(type) { - case *updateAgentClient: - mqttClient = updateAgent.(*updateAgentClient).mqttClient - case *updateAgentThingsClient: - mqttClient = updateAgent.(*updateAgentThingsClient).mqttClient - default: - return nil, fmt.Errorf("Unexpected type: %T", v) + mqttClient, err := getMQTTClient(updateAgent) + if err != nil { + return nil, err } return &desiredStateClient{ mqttClient: newInternalClient(domain, mqttClient.mqttConfig, mqttClient.pahoClient), diff --git a/mqtt/desired_state_client_test.go b/mqtt/desired_state_client_test.go index fedc148..b545495 100644 --- a/mqtt/desired_state_client_test.go +++ b/mqtt/desired_state_client_test.go @@ -66,14 +66,14 @@ func TestNewDesiredStateClient(t *testing.T) { }, "test_error": { client: mockClient, - err: fmt.Sprintf("Unexpected type: %T", mockClient), + err: fmt.Sprintf("unexpected type: %T", mockClient), }, } for name, test := range tests { t.Run(name, func(t *testing.T) { client, err := NewDesiredStateClient("testDomain", test.client) if test.err != "" { - assert.EqualError(t, err, fmt.Sprintf("Unexpected type: %T", test.client)) + assert.EqualError(t, err, fmt.Sprintf("unexpected type: %T", test.client)) } else { assert.NoError(t, err) assert.NotNil(t, client) diff --git a/mqtt/owner_consent_agent_client .go b/mqtt/owner_consent_agent_client .go new file mode 100755 index 0000000..11afa99 --- /dev/null +++ b/mqtt/owner_consent_agent_client .go @@ -0,0 +1,124 @@ +// Copyright (c) 2024 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package mqtt + +import ( + "fmt" + + "github.com/eclipse-kanto/update-manager/api" + "github.com/eclipse-kanto/update-manager/api/types" + "github.com/eclipse-kanto/update-manager/logger" + + pahomqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/pkg/errors" +) + +type ownerConsentAgentClient struct { + *mqttClient + domain string + handler api.OwnerConsentAgentHandler +} + +// NewOwnerConsentAgentClient instantiates a new client for triggering MQTT requests. +func NewOwnerConsentAgentClient(domain string, config *ConnectionConfig) (api.OwnerConsentAgentClient, error) { + client := &ownerConsentAgentClient{ + mqttClient: newInternalClient(domain, newInternalConnectionConfig(config), nil), + domain: domain, + } + pahoClient, err := newClient(client.mqttConfig, client.onConnect) + if err == nil { + client.pahoClient = pahoClient + } + return client, err +} + +func (client *ownerConsentAgentClient) onConnect(_ pahomqtt.Client) { + if err := client.subscribe(); err != nil { + logger.ErrorErr(err, "[%s] error subscribing for OwnerConsentGet requests", client.Domain()) + } else { + logger.Debug("[%s] subscribed for OwnerConsentGet requests", client.Domain()) + } +} + +// Start connects the client to the MQTT broker. +func (client *ownerConsentAgentClient) Start(handler api.OwnerConsentAgentHandler) error { + client.handler = handler + token := client.pahoClient.Connect() + if !token.WaitTimeout(client.mqttConfig.ConnectTimeout) { + return fmt.Errorf("[%s] connect timed out", client.Domain()) + } + return token.Error() +} + +func (client *ownerConsentAgentClient) Domain() string { + return client.domain +} + +// Stop removes the client subscription to the MQTT broker for the MQTT topics for getting owner consent. +func (client *ownerConsentAgentClient) Stop() error { + if err := client.unsubscribe(); err != nil { + logger.WarnErr(err, "[%s] error unsubscribing for OwnerConsentGet requests", client.Domain()) + } else { + logger.Debug("[%s] unsubscribed for OwnerConsentGet messages", client.Domain()) + } + client.pahoClient.Disconnect(disconnectQuiesce) + client.handler = nil + return nil +} + +func (client *ownerConsentAgentClient) subscribe() error { + logger.Debug("subscribing for '%v' topic", client.topicOwnerConsentGet) + token := client.pahoClient.Subscribe(client.topicOwnerConsentGet, 1, client.handleMessage) + if !token.WaitTimeout(client.mqttConfig.SubscribeTimeout) { + return fmt.Errorf("cannot subscribe for topic '%s' in '%v'", client.topicOwnerConsentGet, client.mqttConfig.SubscribeTimeout) + } + return token.Error() +} + +func (client *ownerConsentAgentClient) unsubscribe() error { + logger.Debug("unsubscribing from '%s' topic", client.topicOwnerConsentGet) + token := client.pahoClient.Unsubscribe(client.topicOwnerConsentGet) + if !token.WaitTimeout(client.mqttConfig.UnsubscribeTimeout) { + return fmt.Errorf("cannot unsubscribe from topic '%s' in '%v'", client.topicOwnerConsentGet, client.mqttConfig.UnsubscribeTimeout) + } + return token.Error() +} + +func (client *ownerConsentAgentClient) handleMessage(mqttClient pahomqtt.Client, message pahomqtt.Message) { + topic := message.Topic() + logger.Debug("[%s] received %s message", client.Domain(), topic) + if topic == client.topicOwnerConsentGet { + consent := &types.OwnerConsent{} + envelope, err := types.FromEnvelope(message.Payload(), consent) + if err != nil { + logger.ErrorErr(err, "[%s] cannot parse owner conset get message", client.Domain()) + return + } + if err := client.handler.HandleOwnerConsentGet(envelope.ActivityID, envelope.Timestamp, consent); err != nil { + logger.ErrorErr(err, "[%s] error processing owner consent get message", client.Domain()) + } + } +} + +func (client *ownerConsentAgentClient) SendOwnerConsent(activityID string, consent *types.OwnerConsent) error { + logger.Debug("publishing to topic '%s'", client.topicOwnerConsent) + desiredStateBytes, err := types.ToEnvelope(activityID, consent) + if err != nil { + return errors.Wrapf(err, "cannot marshal owner consent message for activity-id %s", activityID) + } + token := client.pahoClient.Publish(client.topicOwnerConsent, 1, false, desiredStateBytes) + if !token.WaitTimeout(client.mqttConfig.AcknowledgeTimeout) { + return fmt.Errorf("cannot publish to topic '%s' in '%v'", client.topicOwnerConsent, client.mqttConfig.AcknowledgeTimeout) + } + return token.Error() +} diff --git a/mqtt/owner_consent_agent_client_test.go b/mqtt/owner_consent_agent_client_test.go new file mode 100644 index 0000000..0d41f8c --- /dev/null +++ b/mqtt/owner_consent_agent_client_test.go @@ -0,0 +1,173 @@ +// Copyright (c) 2024 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package mqtt + +import ( + "errors" + "testing" + + "github.com/eclipse-kanto/update-manager/api/types" + mqttmocks "github.com/eclipse-kanto/update-manager/mqtt/mocks" + "github.com/eclipse-kanto/update-manager/test/mocks" + + pahomqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestOwnerConsentAgentClientStart(t *testing.T) { + tests := map[string]testCaseOutgoing{ + "test_connect_ok": {domain: "testdomain", isTimedOut: false}, + "test_connect_timeout": {domain: "mydomain", isTimedOut: true}, + } + + mockCtrl, mockPaho, mockToken := setupCommonMocks(t) + defer mockCtrl.Finish() + + mockHandler := mocks.NewMockOwnerConsentAgentHandler(mockCtrl) + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + client := &ownerConsentAgentClient{ + domain: test.domain, + mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), + } + + mockPaho.EXPECT().Connect().Return(mockToken) + setupMockToken(mockToken, mqttTestConfig.ConnectTimeout, test.isTimedOut) + + assertOutgoingResult(t, test.isTimedOut, client.Start(mockHandler)) + assert.Equal(t, mockHandler, client.handler) + }) + } +} + +func TestOwnerConsentAgentClientStop(t *testing.T) { + tests := map[string]testCaseOutgoing{ + //"test_disconnect_ok": {domain: "testdomain", isTimedOut: false}, + "test_disconnect_timeout": {domain: "mydomain", isTimedOut: true}, + } + + mockCtrl, mockPaho, mockToken := setupCommonMocks(t) + defer mockCtrl.Finish() + + mockHandler := mocks.NewMockOwnerConsentAgentHandler(mockCtrl) + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + client := &ownerConsentAgentClient{ + domain: test.domain, + mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), + handler: mockHandler, + } + + mockPaho.EXPECT().Unsubscribe(test.domain + "update/ownerconsent/get").Return(mockToken) + mockPaho.EXPECT().Disconnect(disconnectQuiesce) + setupMockToken(mockToken, mqttTestConfig.UnsubscribeTimeout, test.isTimedOut) + + assert.NoError(t, client.Stop()) + assert.Nil(t, client.handler) + }) + } +} + +func TestSendOwnerConsent(t *testing.T) { + tests := map[string]testCaseOutgoing{ + "test_send_owner_consent_ok": {domain: "testdomain", isTimedOut: false}, + "test_send_owner_consent_error": {domain: "mydomain", isTimedOut: true}, + } + + mockCtrl, mockPaho, mockToken := setupCommonMocks(t) + defer mockCtrl.Finish() + + testConsent := &types.OwnerConsent{ + Status: types.StatusApproved, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + client := &ownerConsentAgentClient{ + domain: test.domain, + mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), + } + mockPaho.EXPECT().Publish(test.domain+"update/ownerconsent", uint8(1), false, gomock.Any()).DoAndReturn( + func(topic string, qos byte, retained bool, payload interface{}) pahomqtt.Token { + consent := &types.OwnerConsent{} + envelope, err := types.FromEnvelope(payload.([]byte), consent) + assert.NoError(t, err) + assert.Equal(t, name, envelope.ActivityID) + assert.True(t, envelope.Timestamp > 0) + assert.Equal(t, testConsent, consent) + return mockToken + }) + setupMockToken(mockToken, mqttTestConfig.AcknowledgeTimeout, false) + + assert.NoError(t, client.SendOwnerConsent(name, testConsent)) + }) + } +} + +func TestOwnerConsentOnConnect(t *testing.T) { + mockCtrl, mockPaho, mockToken := setupCommonMocks(t) + defer mockCtrl.Finish() + + client := newInternalClient("test", mqttTestConfig, mockPaho) + + t.Run("test_onConnect", func(t *testing.T) { + mockHandler := mocks.NewMockOwnerConsentAgentHandler(mockCtrl) + client := &ownerConsentAgentClient{ + mqttClient: client, + domain: "test", + handler: mockHandler, + } + mockPaho.EXPECT().Subscribe("testupdate/ownerconsent/get", uint8(1), gomock.Any()).Return(mockToken) + setupMockToken(mockToken, mqttTestConfig.SubscribeTimeout, false) + + client.onConnect(nil) + }) +} + +func TestHandleOwnerConsentGetMessage(t *testing.T) { + tests := map[string]testCaseIncoming{ + "test_handle_owner_conset_get_ok": {domain: "testdomain", handlerError: nil, expectedJSONErr: false}, + "test_handle_owner_conset_get_error": {domain: "mydomain", handlerError: errors.New("handler error"), expectedJSONErr: false}, + "test_handle_owner_conset_get_json_error": {domain: "testdomain", handlerError: nil, expectedJSONErr: true}, + } + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockMessage := mqttmocks.NewMockMessage(mockCtrl) + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + testConsent := &types.OwnerConsent{ + Status: types.StatusApproved, + } + testBytes, expectedCalls := testBytesToEnvelope(t, name, testConsent, test.expectedJSONErr) + + mockHandler := mocks.NewMockOwnerConsentAgentHandler(mockCtrl) + mockHandler.EXPECT().HandleOwnerConsentGet(name, gomock.Any(), testConsent).Times(expectedCalls).Return(test.handlerError) + + client := &ownerConsentAgentClient{ + mqttClient: newInternalClient(test.domain, mqttTestConfig, nil), + domain: test.domain, + handler: mockHandler, + } + mockMessage.EXPECT().Topic().Return(test.domain + "update/ownerconsent/get") + mockMessage.EXPECT().Payload().Return(testBytes) + + client.handleMessage(nil, mockMessage) + }) + } +} diff --git a/mqtt/owner_consent_client _test.go b/mqtt/owner_consent_client _test.go new file mode 100755 index 0000000..9bfdf0f --- /dev/null +++ b/mqtt/owner_consent_client _test.go @@ -0,0 +1,201 @@ +// Copyright (c) 2024 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package mqtt + +import ( + "fmt" + "testing" + + "github.com/eclipse-kanto/update-manager/api" + "github.com/eclipse-kanto/update-manager/api/types" + clientsmocks "github.com/eclipse-kanto/update-manager/mqtt/mocks" + "github.com/eclipse-kanto/update-manager/test/mocks" + pahomqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestNewOwnerConsentClient(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockPaho := clientsmocks.NewMockClient(mockCtrl) + mockClient := mocks.NewMockUpdateAgentClient(mockCtrl) + + tests := map[string]struct { + client api.UpdateAgentClient + err string + }{ + "test_update_agent_client": { + client: &updateAgentClient{ + mqttClient: newInternalClient("testDomain", &internalConnectionConfig{}, mockPaho), + }, + }, + "test_update_agent_things_client": { + client: &updateAgentThingsClient{ + updateAgentClient: &updateAgentClient{ + mqttClient: newInternalClient("testDomain", &internalConnectionConfig{}, mockPaho), + }, + }, + }, + "test_error": { + client: mockClient, + err: fmt.Sprintf("unexpected type: %T", mockClient), + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + client, err := NewDesiredStateClient("testDomain", test.client) + if test.err != "" { + assert.EqualError(t, err, fmt.Sprintf("unexpected type: %T", test.client)) + } else { + assert.NoError(t, err) + assert.NotNil(t, client) + } + }) + } +} + +func TestOwnerConsentClientStart(t *testing.T) { + tests := map[string]testCaseOutgoing{ + "test_subscribe_ok": {domain: "testdomain", isTimedOut: false}, + "test_subscribe_timeout": {domain: "mydomain", isTimedOut: true}, + } + + mockCtrl, mockPaho, mockToken := setupCommonMocks(t) + defer mockCtrl.Finish() + + mockHandler := mocks.NewMockOwnerConsentHandler(mockCtrl) + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + client := &ownerConsentClient{ + mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), + domain: test.domain, + } + mockPaho.EXPECT().Subscribe(test.domain+"update/ownerconsent", uint8(1), gomock.Any()).Return(mockToken) + setupMockToken(mockToken, mqttTestConfig.SubscribeTimeout, test.isTimedOut) + + assertOutgoingResult(t, test.isTimedOut, client.Start(mockHandler)) + if test.isTimedOut { + assert.Nil(t, client.handler) + } else { + assert.Equal(t, mockHandler, client.handler) + } + }) + } +} + +func TestOwnerConsentClientStop(t *testing.T) { + tests := map[string]testCaseOutgoing{ + "test_unsubscribe_ok": {domain: "testdomain", isTimedOut: false}, + "test_unsubscribe_timeout": {domain: "mydomain", isTimedOut: true}, + } + + mockCtrl, mockPaho, mockToken := setupCommonMocks(t) + defer mockCtrl.Finish() + + mockHandler := mocks.NewMockOwnerConsentHandler(mockCtrl) + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + client := &ownerConsentClient{ + mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), + domain: test.domain, + handler: mockHandler, + } + mockPaho.EXPECT().Unsubscribe(test.domain + "update/ownerconsent").Return(mockToken) + setupMockToken(mockToken, mqttTestConfig.UnsubscribeTimeout, test.isTimedOut) + + assertOutgoingResult(t, test.isTimedOut, client.Stop()) + if test.isTimedOut { + assert.Equal(t, mockHandler, client.handler) + } else { + assert.Nil(t, client.handler) + } + }) + } +} + +func TestSendOwnerConsentGet(t *testing.T) { + tests := map[string]testCaseOutgoing{ + "test_send_owner_consent_get_ok": {domain: "testdomain", isTimedOut: false}, + "test_send_owner_consent_get_error": {domain: "mydomain", isTimedOut: true}, + } + + mockCtrl, mockPaho, mockToken := setupCommonMocks(t) + defer mockCtrl.Finish() + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + testDesiredState := &types.DesiredState{ + Domains: []*types.Domain{ + {ID: test.domain}, + }, + } + client, _ := NewOwnerConsentClient(test.domain, &updateAgentClient{ + mqttClient: newInternalClient("testDomain", mqttTestConfig, mockPaho), + }) + mockPaho.EXPECT().Publish(test.domain+"update/ownerconsent/get", uint8(1), false, gomock.Any()).DoAndReturn( + func(topic string, qos byte, retained bool, payload interface{}) pahomqtt.Token { + desiresState := &types.DesiredState{} + envelope, err := types.FromEnvelope(payload.([]byte), desiresState) + assert.NoError(t, err) + assert.Equal(t, name, envelope.ActivityID) + assert.True(t, envelope.Timestamp > 0) + assert.Equal(t, testDesiredState, desiresState) + return mockToken + }) + setupMockToken(mockToken, mqttTestConfig.AcknowledgeTimeout, test.isTimedOut) + + assertOutgoingResult(t, test.isTimedOut, client.SendOwnerConsentGet(name, testDesiredState)) + }) + } +} + +func TestHandleOwnerConsentMessage(t *testing.T) { + tests := map[string]testCaseIncoming{ + "test_handle_owner_consent_ok": {domain: "testdomain", handlerError: nil, expectedJSONErr: false}, + "test_handle_owner_consent_error": {domain: "mydomain", handlerError: errors.New("handler error"), expectedJSONErr: false}, + "test_handle_owner_consent_json_error": {domain: "testdomain", handlerError: nil, expectedJSONErr: true}, + } + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockMessage := clientsmocks.NewMockMessage(mockCtrl) + + testConsent := &types.OwnerConsent{ + Status: types.StatusApproved, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + testBytes, expectedCalls := testBytesToEnvelope(t, name, testConsent, test.expectedJSONErr) + + handler := mocks.NewMockOwnerConsentHandler(mockCtrl) + handler.EXPECT().HandleOwnerConsent(name, gomock.Any(), testConsent).Times(expectedCalls).Return(test.handlerError) + + client := &ownerConsentClient{ + mqttClient: newInternalClient(test.domain, &internalConnectionConfig{}, nil), + domain: test.domain, + handler: handler, + } + mockMessage.EXPECT().Topic().Return(test.domain + "update/ownerconsent") + mockMessage.EXPECT().Payload().Return(testBytes) + + client.handleMessage(nil, mockMessage) + }) + } +} diff --git a/mqtt/owner_consent_client.go b/mqtt/owner_consent_client.go new file mode 100755 index 0000000..27238cc --- /dev/null +++ b/mqtt/owner_consent_client.go @@ -0,0 +1,114 @@ +// Copyright (c) 2024 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package mqtt + +import ( + "fmt" + + "github.com/eclipse-kanto/update-manager/api" + "github.com/eclipse-kanto/update-manager/api/types" + "github.com/eclipse-kanto/update-manager/logger" + + pahomqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/pkg/errors" +) + +type ownerConsentClient struct { + *mqttClient + domain string + handler api.OwnerConsentHandler +} + +// NewOwnerConsentClient instantiates a new client for triggering MQTT requests. +func NewOwnerConsentClient(domain string, updateAgent api.UpdateAgentClient) (api.OwnerConsentClient, error) { + mqttClient, err := getMQTTClient(updateAgent) + if err != nil { + return nil, err + } + return &ownerConsentClient{ + mqttClient: newInternalClient(domain, mqttClient.mqttConfig, mqttClient.pahoClient), + domain: domain, + }, nil +} + +func (client *ownerConsentClient) Domain() string { + return client.domain +} + +// Start makes a client subscription to the MQTT broker for the MQTT topics for consent. +func (client *ownerConsentClient) Start(consentHandler api.OwnerConsentHandler) error { + client.handler = consentHandler + if err := client.subscribe(); err != nil { + client.handler = nil + return fmt.Errorf("[%s] error subscribing for OwnerConsent messages: %w", client.Domain(), err) + } + logger.Debug("[%s] subscribed for OwnerConsent messages", client.Domain()) + return nil +} + +// Stop removes the client subscription to the MQTT broker for the MQTT topics for owner consent. +func (client *ownerConsentClient) Stop() error { + if err := client.unsubscribe(); err != nil { + return fmt.Errorf("[%s] error unsubscribing for OwnerConsent messages: %w", client.Domain(), err) + } + logger.Debug("[%s] unsubscribed for OwnerConsent messages", client.Domain()) + client.handler = nil + return nil +} + +func (client *ownerConsentClient) subscribe() error { + logger.Debug("subscribing for '%v' topic", client.topicOwnerConsent) + token := client.pahoClient.Subscribe(client.topicOwnerConsent, 1, client.handleMessage) + if !token.WaitTimeout(client.mqttConfig.SubscribeTimeout) { + return fmt.Errorf("cannot subscribe for topic '%s' in '%v'", client.topicOwnerConsent, client.mqttConfig.SubscribeTimeout) + } + return token.Error() +} + +func (client *ownerConsentClient) unsubscribe() error { + logger.Debug("unsubscribing from '%s' topic", client.topicOwnerConsent) + token := client.pahoClient.Unsubscribe(client.topicOwnerConsent) + if !token.WaitTimeout(client.mqttConfig.UnsubscribeTimeout) { + return fmt.Errorf("cannot unsubscribe from topic '%s' in '%v'", client.topicOwnerConsent, client.mqttConfig.UnsubscribeTimeout) + } + return token.Error() +} + +func (client *ownerConsentClient) handleMessage(mqttClient pahomqtt.Client, message pahomqtt.Message) { + topic := message.Topic() + logger.Debug("[%s] received %s message", client.Domain(), topic) + if topic == client.topicOwnerConsent { + ownerConsent := &types.OwnerConsent{} + envelope, err := types.FromEnvelope(message.Payload(), ownerConsent) + if err != nil { + logger.ErrorErr(err, "[%s] cannot parse owner consent message", client.Domain()) + return + } + if err := client.handler.HandleOwnerConsent(envelope.ActivityID, envelope.Timestamp, ownerConsent); err != nil { + logger.ErrorErr(err, "[%s] error processing owner consent message", client.Domain()) + } + } +} + +func (client *ownerConsentClient) SendOwnerConsentGet(activityID string, desiredState *types.DesiredState) error { + logger.Debug("publishing to topic '%s'", client.topicOwnerConsentGet) + desiredStateBytes, err := types.ToEnvelope(activityID, desiredState) + if err != nil { + return errors.Wrapf(err, "cannot marshal owner consent get message for activity-id %s", activityID) + } + token := client.pahoClient.Publish(client.topicOwnerConsentGet, 1, false, desiredStateBytes) + if !token.WaitTimeout(client.mqttConfig.AcknowledgeTimeout) { + return fmt.Errorf("cannot publish to topic '%s' in '%v'", client.topicOwnerConsentGet, client.mqttConfig.AcknowledgeTimeout) + } + return token.Error() +} diff --git a/mqtt/update_agent_client.go b/mqtt/update_agent_client.go index b33b2f4..1b1575b 100755 --- a/mqtt/update_agent_client.go +++ b/mqtt/update_agent_client.go @@ -37,6 +37,8 @@ const ( suffixCurrentState = "/currentstate" suffixCurrentStateGet = "/currentstate/get" suffixDesiredStateFeedback = "/desiredstatefeedback" + suffixOwnerConsentGet = "/ownerconsent/get" + suffixOwnerConsent = "/ownerconsent" disconnectQuiesce uint = 10000 ) @@ -77,13 +79,15 @@ type mqttClient struct { mqttConfig *internalConnectionConfig pahoClient pahomqtt.Client - // incoming topics + // UM incoming topics topicCurrentState string topicDesiredStateFeedback string - // outgoing topics + topicOwnerConsent string + // UM outgoing topics topicDesiredState string topicDesiredStateCommand string topicCurrentStateGet string + topicOwnerConsentGet string } func newInternalClient(domain string, config *internalConnectionConfig, pahoClient pahomqtt.Client) *mqttClient { @@ -97,6 +101,8 @@ func newInternalClient(domain string, config *internalConnectionConfig, pahoClie topicDesiredState: mqttPrefix + suffixDesiredState, topicDesiredStateCommand: mqttPrefix + suffixDesiredStateCommand, topicDesiredStateFeedback: mqttPrefix + suffixDesiredStateFeedback, + topicOwnerConsent: mqttPrefix + suffixOwnerConsent, + topicOwnerConsentGet: mqttPrefix + suffixOwnerConsentGet, } } diff --git a/mqtt/util.go b/mqtt/util.go new file mode 100644 index 0000000..aee649e --- /dev/null +++ b/mqtt/util.go @@ -0,0 +1,30 @@ +// Copyright (c) 2024 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package mqtt + +import ( + "fmt" + + "github.com/eclipse-kanto/update-manager/api" +) + +func getMQTTClient(client api.UpdateAgentClient) (*mqttClient, error) { + switch v := client.(type) { + case *updateAgentClient: + return client.(*updateAgentClient).mqttClient, nil + case *updateAgentThingsClient: + return client.(*updateAgentThingsClient).mqttClient, nil + default: + return nil, fmt.Errorf("unexpected type: %T", v) + } +} diff --git a/test/mocks/client_mock.go b/test/mocks/client_mock.go index b36a298..f136a53 100644 --- a/test/mocks/client_mock.go +++ b/test/mocks/client_mock.go @@ -390,3 +390,235 @@ func (mr *MockDesiredStateClientMockRecorder) Stop() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockDesiredStateClient)(nil).Stop)) } + +// MockOwnerConsentAgentHandler is a mock of OwnerConsentAgentHandler interface. +type MockOwnerConsentAgentHandler struct { + ctrl *gomock.Controller + recorder *MockOwnerConsentAgentHandlerMockRecorder +} + +// MockOwnerConsentAgentHandlerMockRecorder is the mock recorder for MockOwnerConsentAgentHandler. +type MockOwnerConsentAgentHandlerMockRecorder struct { + mock *MockOwnerConsentAgentHandler +} + +// NewMockOwnerConsentAgentHandler creates a new mock instance. +func NewMockOwnerConsentAgentHandler(ctrl *gomock.Controller) *MockOwnerConsentAgentHandler { + mock := &MockOwnerConsentAgentHandler{ctrl: ctrl} + mock.recorder = &MockOwnerConsentAgentHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOwnerConsentAgentHandler) EXPECT() *MockOwnerConsentAgentHandlerMockRecorder { + return m.recorder +} + +// HandleOwnerConsentGet mocks base method. +func (m *MockOwnerConsentAgentHandler) HandleOwnerConsentGet(arg0 string, arg1 int64, arg2 *types.OwnerConsent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleOwnerConsentGet", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleOwnerConsentGet indicates an expected call of HandleOwnerConsentGet. +func (mr *MockOwnerConsentAgentHandlerMockRecorder) HandleOwnerConsentGet(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleOwnerConsentGet", reflect.TypeOf((*MockOwnerConsentAgentHandler)(nil).HandleOwnerConsentGet), arg0, arg1, arg2) +} + +// MockOwnerConsentAgentClient is a mock of OwnerConsentAgentClient interface. +type MockOwnerConsentAgentClient struct { + ctrl *gomock.Controller + recorder *MockOwnerConsentAgentClientMockRecorder +} + +// MockOwnerConsentAgentClientMockRecorder is the mock recorder for MockOwnerConsentAgentClient. +type MockOwnerConsentAgentClientMockRecorder struct { + mock *MockOwnerConsentAgentClient +} + +// NewMockOwnerConsentAgentClient creates a new mock instance. +func NewMockOwnerConsentAgentClient(ctrl *gomock.Controller) *MockOwnerConsentAgentClient { + mock := &MockOwnerConsentAgentClient{ctrl: ctrl} + mock.recorder = &MockOwnerConsentAgentClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOwnerConsentAgentClient) EXPECT() *MockOwnerConsentAgentClientMockRecorder { + return m.recorder +} + +// Domain mocks base method. +func (m *MockOwnerConsentAgentClient) Domain() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Domain") + ret0, _ := ret[0].(string) + return ret0 +} + +// Domain indicates an expected call of Domain. +func (mr *MockOwnerConsentAgentClientMockRecorder) Domain() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Domain", reflect.TypeOf((*MockOwnerConsentAgentClient)(nil).Domain)) +} + +// SendOwnerConsent mocks base method. +func (m *MockOwnerConsentAgentClient) SendOwnerConsent(arg0 string, arg1 *types.OwnerConsent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendOwnerConsent", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendOwnerConsent indicates an expected call of SendOwnerConsent. +func (mr *MockOwnerConsentAgentClientMockRecorder) SendOwnerConsent(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendOwnerConsent", reflect.TypeOf((*MockOwnerConsentAgentClient)(nil).SendOwnerConsent), arg0, arg1) +} + +// Start mocks base method. +func (m *MockOwnerConsentAgentClient) Start(arg0 api.OwnerConsentAgentHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockOwnerConsentAgentClientMockRecorder) Start(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockOwnerConsentAgentClient)(nil).Start), arg0) +} + +// Stop mocks base method. +func (m *MockOwnerConsentAgentClient) Stop() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stop") + ret0, _ := ret[0].(error) + return ret0 +} + +// Stop indicates an expected call of Stop. +func (mr *MockOwnerConsentAgentClientMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockOwnerConsentAgentClient)(nil).Stop)) +} + +// MockOwnerConsentHandler is a mock of OwnerConsentHandler interface. +type MockOwnerConsentHandler struct { + ctrl *gomock.Controller + recorder *MockOwnerConsentHandlerMockRecorder +} + +// MockOwnerConsentHandlerMockRecorder is the mock recorder for MockOwnerConsentHandler. +type MockOwnerConsentHandlerMockRecorder struct { + mock *MockOwnerConsentHandler +} + +// NewMockOwnerConsentHandler creates a new mock instance. +func NewMockOwnerConsentHandler(ctrl *gomock.Controller) *MockOwnerConsentHandler { + mock := &MockOwnerConsentHandler{ctrl: ctrl} + mock.recorder = &MockOwnerConsentHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOwnerConsentHandler) EXPECT() *MockOwnerConsentHandlerMockRecorder { + return m.recorder +} + +// HandleOwnerConsent mocks base method. +func (m *MockOwnerConsentHandler) HandleOwnerConsent(arg0 string, arg1 int64, arg2 *types.OwnerConsent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleOwnerConsent", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleOwnerConsent indicates an expected call of HandleOwnerConsent. +func (mr *MockOwnerConsentHandlerMockRecorder) HandleOwnerConsent(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleOwnerConsent", reflect.TypeOf((*MockOwnerConsentHandler)(nil).HandleOwnerConsent), arg0, arg1, arg2) +} + +// MockOwnerConsentClient is a mock of OwnerConsentClient interface. +type MockOwnerConsentClient struct { + ctrl *gomock.Controller + recorder *MockOwnerConsentClientMockRecorder +} + +// MockOwnerConsentClientMockRecorder is the mock recorder for MockOwnerConsentClient. +type MockOwnerConsentClientMockRecorder struct { + mock *MockOwnerConsentClient +} + +// NewMockOwnerConsentClient creates a new mock instance. +func NewMockOwnerConsentClient(ctrl *gomock.Controller) *MockOwnerConsentClient { + mock := &MockOwnerConsentClient{ctrl: ctrl} + mock.recorder = &MockOwnerConsentClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOwnerConsentClient) EXPECT() *MockOwnerConsentClientMockRecorder { + return m.recorder +} + +// Domain mocks base method. +func (m *MockOwnerConsentClient) Domain() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Domain") + ret0, _ := ret[0].(string) + return ret0 +} + +// Domain indicates an expected call of Domain. +func (mr *MockOwnerConsentClientMockRecorder) Domain() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Domain", reflect.TypeOf((*MockOwnerConsentClient)(nil).Domain)) +} + +// SendOwnerConsentGet mocks base method. +func (m *MockOwnerConsentClient) SendOwnerConsentGet(arg0 string, arg1 *types.DesiredState) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendOwnerConsentGet", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendOwnerConsentGet indicates an expected call of SendOwnerConsentGet. +func (mr *MockOwnerConsentClientMockRecorder) SendOwnerConsentGet(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendOwnerConsentGet", reflect.TypeOf((*MockOwnerConsentClient)(nil).SendOwnerConsentGet), arg0, arg1) +} + +// Start mocks base method. +func (m *MockOwnerConsentClient) Start(arg0 api.OwnerConsentHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockOwnerConsentClientMockRecorder) Start(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockOwnerConsentClient)(nil).Start), arg0) +} + +// Stop mocks base method. +func (m *MockOwnerConsentClient) Stop() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stop") + ret0, _ := ret[0].(error) + return ret0 +} + +// Stop indicates an expected call of Stop. +func (mr *MockOwnerConsentClientMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockOwnerConsentClient)(nil).Stop)) +} \ No newline at end of file diff --git a/test/mocks/update_orchestrator_mock.go b/test/mocks/update_orchestrator_mock.go index b8ddd6e..c32654f 100644 --- a/test/mocks/update_orchestrator_mock.go +++ b/test/mocks/update_orchestrator_mock.go @@ -73,3 +73,17 @@ func (mr *MockUpdateOrchestratorMockRecorder) HandleDesiredStateFeedbackEvent(do mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleDesiredStateFeedbackEvent", reflect.TypeOf((*MockUpdateOrchestrator)(nil).HandleDesiredStateFeedbackEvent), domain, activityID, baseline, status, message, actions) } + +// HandleOwnerConsent mocks base method. +func (m *MockUpdateOrchestrator) HandleOwnerConsent(arg0 string, arg1 int64, arg2 *types.OwnerConsent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleOwnerConsent", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleOwnerConsent indicates an expected call of HandleOwnerConsent. +func (mr *MockUpdateOrchestratorMockRecorder) HandleOwnerConsent(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleOwnerConsent", reflect.TypeOf((*MockUpdateOrchestrator)(nil).HandleOwnerConsent), arg0, arg1, arg2) +} diff --git a/updatem/orchestration/update_operation.go b/updatem/orchestration/update_operation.go index c5e7445..3b515e0 100644 --- a/updatem/orchestration/update_operation.go +++ b/updatem/orchestration/update_operation.go @@ -39,6 +39,8 @@ type updateOperation struct { errChan chan bool errMsg string + ownerConsented chan bool + rebootRequired bool desiredStateCallback api.DesiredStateFeedbackHandler @@ -74,7 +76,8 @@ func newUpdateOperation(domainAgents map[string]api.UpdateManager, activityID st desiredState: desiredState, phaseChannels: generatePhaseChannels(), - errChan: make(chan bool, 1), + errChan: make(chan bool, 1), + ownerConsented: make(chan bool), desiredStateCallback: desiredStateCallback, }, nil diff --git a/updatem/orchestration/update_orchestrator.go b/updatem/orchestration/update_orchestrator.go index 54a9f90..138e2b9 100644 --- a/updatem/orchestration/update_orchestrator.go +++ b/updatem/orchestration/update_orchestrator.go @@ -28,8 +28,9 @@ type updateOrchestrator struct { operationLock sync.Mutex actionsLock sync.Mutex - cfg *config.Config - phaseTimeout time.Duration + cfg *config.Config + phaseTimeout time.Duration + ownerConsentClient api.OwnerConsentClient operation *updateOperation } @@ -39,11 +40,13 @@ func (orchestrator *updateOrchestrator) Name() string { } // NewUpdateOrchestrator creates a new update orchestrator that does not handle cross-domain dependencies -func NewUpdateOrchestrator(cfg *config.Config) api.UpdateOrchestrator { - return &updateOrchestrator{ - cfg: cfg, - phaseTimeout: util.ParseDuration("phase-timeout", cfg.PhaseTimeout, 10*time.Minute, 10*time.Minute), +func NewUpdateOrchestrator(cfg *config.Config, ownerApprovalClient api.OwnerConsentClient) api.UpdateOrchestrator { + ua := &updateOrchestrator{ + cfg: cfg, + phaseTimeout: util.ParseDuration("phase-timeout", cfg.PhaseTimeout, 10*time.Minute, 10*time.Minute), + ownerConsentClient: ownerApprovalClient, } + return ua } // Apply is called by the update manager. @@ -77,3 +80,11 @@ func (orchestrator *updateOrchestrator) Apply(ctx context.Context, domainAgents rebootRequired, applyErr := orchestrator.apply(ctx) return rebootRequired } + +func (orchestrator *updateOrchestrator) HandleOwnerConsent(activityID string, timestamp int64, consent *types.OwnerConsent) error { + if orchestrator.operation != nil && activityID == orchestrator.operation.activityID { + logger.Info("owner consent received with status: %v, timestamp: %d", consent.Status, timestamp) + orchestrator.operation.ownerConsented <- consent.Status == types.StatusApproved + } + return nil +} diff --git a/updatem/orchestration/update_orchestrator_apply.go b/updatem/orchestration/update_orchestrator_apply.go index bc06393..81e859c 100644 --- a/updatem/orchestration/update_orchestrator_apply.go +++ b/updatem/orchestration/update_orchestrator_apply.go @@ -15,6 +15,7 @@ package orchestration import ( "context" "fmt" + "slices" "time" "github.com/eclipse-kanto/update-manager/api" @@ -75,6 +76,13 @@ func handlePhaseCompletion(ctx context.Context, completedPhase phase, orchestrat return } + if err := orchestrator.getOwnerConsent(ctx, completedPhase); err != nil { + // should a rollback be performed at this point? + orchestrator.operation.errChan <- true + orchestrator.operation.errMsg = err.Error() + return + } + executeCommand := func(status types.StatusType, command types.CommandType) { for domain, domainStatus := range orchestrator.operation.domains { if domainStatus == status { @@ -82,6 +90,7 @@ func handlePhaseCompletion(ctx context.Context, completedPhase phase, orchestrat } } } + switch completedPhase { case phaseIdentification: executeCommand(types.StatusIdentified, types.CommandDownload) @@ -91,11 +100,53 @@ func handlePhaseCompletion(ctx context.Context, completedPhase phase, orchestrat executeCommand(types.BaselineStatusUpdateSuccess, types.CommandActivate) case phaseActivation: executeCommand(types.BaselineStatusActivationSuccess, types.CommandCleanup) + case phaseCleanup: + // nothing to do default: logger.Error("unknown phase %s", completedPhase) } } +func (orchestrator *updateOrchestrator) getOwnerConsent(ctx context.Context, completedPhase phase) error { + nextPhase := completedPhase.next() + if nextPhase == "" || !slices.Contains(orchestrator.cfg.OwnerConsentPhases, string(nextPhase)) { + return nil + } + if nextPhase == phaseCleanup || nextPhase == phaseIdentification { + // no need for owner consent + return nil + } + + if orchestrator.ownerConsentClient == nil { + return fmt.Errorf("owner consent client not available") + } + + if err := orchestrator.ownerConsentClient.Start(orchestrator); err != nil { + return err + } + defer func() { + if err := orchestrator.ownerConsentClient.Stop(); err != nil { + logger.Error("failed to stop owner consent client: %v", err) + } + }() + + if err := orchestrator.ownerConsentClient.SendOwnerConsentGet(orchestrator.operation.activityID, orchestrator.operation.desiredState); err != nil { + return err + } + + select { + case approved := <-orchestrator.operation.ownerConsented: + if !approved { + return fmt.Errorf("owner approval not granted") + } + return nil + case <-time.After(orchestrator.phaseTimeout): + return fmt.Errorf("owner consent not granted in %v", orchestrator.phaseTimeout) + case <-ctx.Done(): + return fmt.Errorf("the update manager instance is terminated") + } +} + func (orchestrator *updateOrchestrator) command(ctx context.Context, activityID, domain string, commandName types.CommandType) { domainAgent := orchestrator.getDomainAgent(domain) if domainAgent == nil { diff --git a/updatem/orchestration/update_orchestrator_apply_test.go b/updatem/orchestration/update_orchestrator_apply_test.go index 3a2869f..00b544e 100644 --- a/updatem/orchestration/update_orchestrator_apply_test.go +++ b/updatem/orchestration/update_orchestrator_apply_test.go @@ -21,6 +21,7 @@ import ( "github.com/eclipse-kanto/update-manager/api" "github.com/eclipse-kanto/update-manager/api/types" + "github.com/eclipse-kanto/update-manager/config" "github.com/eclipse-kanto/update-manager/test" "github.com/eclipse-kanto/update-manager/test/mocks" "github.com/golang/mock/gomock" @@ -260,11 +261,12 @@ func TestHandlePhaseCompletion(t *testing.T) { testDomain1: types.StatusIdentifying, testDomain2: types.StatusIdentifying, }, + errChan: make(chan bool, 1), } operation.statesPerDomain = map[api.UpdateManager]*types.DesiredState{ mockUpdateManager: {}, } - orchestrator := &updateOrchestrator{} + orchestrator := &updateOrchestrator{cfg: &config.Config{}} mockCommand := func(mockUpdateManager *mocks.MockUpdateManager, command types.CommandType, domains ...string) func() { return func() { @@ -276,11 +278,12 @@ func TestHandlePhaseCompletion(t *testing.T) { } testCases := map[string]struct { - noOperation bool - domainStatus1 types.StatusType - domainStatus2 types.StatusType - phase phase - expectedCalls func() + noOperation bool + noConsentClient bool + domainStatus1 types.StatusType + domainStatus2 types.StatusType + phase phase + expectedCalls func() }{ "test_handle_phase_completion_identify": { domainStatus1: types.StatusIdentified, @@ -327,9 +330,21 @@ func TestHandlePhaseCompletion(t *testing.T) { phase: phase("unknown"), expectedCalls: func() {}, }, + "test_handle_phase_completion_consent_error": { + noConsentClient: true, + domainStatus1: types.StatusIdentified, + phase: phaseIdentification, + expectedCalls: func() {}, + }, } for testName, testCase := range testCases { t.Run(testName, func(t *testing.T) { + if testCase.noConsentClient { + orchestrator.cfg.OwnerConsentPhases = []string{"download"} + go func() { + <-orchestrator.operation.errChan + }() + } if testCase.noOperation { orchestrator.operation = nil } else { @@ -400,9 +415,11 @@ func TestSetupUpdateOperation(t *testing.T) { assert.NotNil(t, orchestrator.operation.phaseChannels) assert.NotNil(t, orchestrator.operation.errChan) + assert.NotNil(t, orchestrator.operation.ownerConsented) orchestrator.operation.errChan = nil orchestrator.operation.phaseChannels = nil + orchestrator.operation.ownerConsented = nil assert.Equal(t, expectedOp, orchestrator.operation) assert.Nil(t, err) @@ -448,3 +465,111 @@ func TestDisposeUpdateOperation(t *testing.T) { assert.Nil(t, orchestrator.operation) }) } + +func TestGetOwnerConsent(t *testing.T) { + tests := map[string]struct { + updateOrchestrator *updateOrchestrator + currentPhase phase + expectedErr error + mock func(*gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) + }{ + "test_no_next_phase": { + updateOrchestrator: &updateOrchestrator{}, + currentPhase: phaseCleanup, + }, + "test_consent_not_needed": { + updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, + currentPhase: phaseUpdate, + }, + "test_no_consent_for_cleanup": { + updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"cleanup"}}}, + currentPhase: phaseActivation, + }, + "test_no_owner_consent_client": { + updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, + currentPhase: phaseIdentification, + expectedErr: fmt.Errorf("owner consent client not available"), + }, + "test_owner_consent_client_start_err": { + updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, + currentPhase: phaseIdentification, + expectedErr: fmt.Errorf("start error"), + mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { + mockClient := mocks.NewMockOwnerConsentClient(ctrl) + mockClient.EXPECT().Start(gomock.Any()).Return(fmt.Errorf("start error")) + return mockClient, nil + }, + }, + "test_owner_consent_client_send_err": { + updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, + currentPhase: phaseIdentification, + expectedErr: fmt.Errorf("send error"), + mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { + mockClient := mocks.NewMockOwnerConsentClient(ctrl) + mockClient.EXPECT().Start(gomock.Any()).Return(nil) + mockClient.EXPECT().Stop().Return(nil) + mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID, gomock.Any()).Return(fmt.Errorf("send error")) + return mockClient, nil + }, + }, + "test_owner_consent_approved": { + updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, + currentPhase: phaseIdentification, + mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { + mockClient := mocks.NewMockOwnerConsentClient(ctrl) + mockClient.EXPECT().Start(gomock.Any()).Return(nil) + mockClient.EXPECT().Stop().Return(nil) + mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID, gomock.Any()).Return(nil) + ch := make(chan bool) + go func() { + ch <- true + }() + return mockClient, ch + }, + }, + "test_owner_consent_denied": { + updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, + currentPhase: phaseIdentification, + expectedErr: fmt.Errorf("owner approval not granted"), + mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { + mockClient := mocks.NewMockOwnerConsentClient(ctrl) + mockClient.EXPECT().Start(gomock.Any()).Return(nil) + mockClient.EXPECT().Stop().Return(nil) + mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID, gomock.Any()).Return(nil) + ch := make(chan bool) + go func() { + ch <- false + }() + return mockClient, ch + }, + }, + "test_owner_consent_timeout": { + updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, + currentPhase: phaseIdentification, + expectedErr: fmt.Errorf("owner consent not granted in %v", test.Interval), + mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { + mockClient := mocks.NewMockOwnerConsentClient(ctrl) + mockClient.EXPECT().Start(gomock.Any()).Return(nil) + mockClient.EXPECT().Stop().Return(nil) + mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID, gomock.Any()).Return(nil) + return mockClient, make(chan bool) + }, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + orch := testCase.updateOrchestrator + orch.operation = &updateOperation{activityID: test.ActivityID} + orch.phaseTimeout = test.Interval + if testCase.mock != nil { + orch.ownerConsentClient, orch.operation.ownerConsented = testCase.mock(mockCtrl) + } + err := orch.getOwnerConsent(context.Background(), testCase.currentPhase) + assert.Equal(t, testCase.expectedErr, err) + }) + } +} diff --git a/updatem/orchestration/update_orchestrator_test.go b/updatem/orchestration/update_orchestrator_test.go index 4742f98..2e6856b 100644 --- a/updatem/orchestration/update_orchestrator_test.go +++ b/updatem/orchestration/update_orchestrator_test.go @@ -33,7 +33,7 @@ func TestNewUpdateOrchestrator(t *testing.T) { }, phaseTimeout: 10 * time.Minute, } - assert.Equal(t, expectedOrchestrator, NewUpdateOrchestrator(&config.Config{RebootEnabled: true})) + assert.Equal(t, expectedOrchestrator, NewUpdateOrchestrator(&config.Config{RebootEnabled: true}, nil)) } func TestUpdOrchApply(t *testing.T) { @@ -86,6 +86,42 @@ func TestUpdOrchApply(t *testing.T) { }) } +func TestHandleOwnerConsent(t *testing.T) { + updateOrchestrator := &updateOrchestrator{ + operation: &updateOperation{ + activityID: test.ActivityID, + ownerConsented: make(chan bool), + }, + } + t.Run("test_handle_owner_approved", func(t *testing.T) { + go updateOrchestrator.HandleOwnerConsent(test.ActivityID, 0, &types.OwnerConsent{Status: types.StatusApproved}) + select { + case consented := <-updateOrchestrator.operation.ownerConsented: + assert.True(t, consented) + case <-time.After(1 * time.Second): + t.Fatal("owner consent not received") + } + }) + t.Run("test_handle_owner_denied", func(t *testing.T) { + go updateOrchestrator.HandleOwnerConsent(test.ActivityID, 0, &types.OwnerConsent{Status: types.StatusDenied}) + select { + case consented := <-updateOrchestrator.operation.ownerConsented: + assert.False(t, consented) + case <-time.After(1 * time.Second): + t.Fatal("owner consent not received") + } + }) + t.Run("test_handle_owner_approved_another_activity", func(t *testing.T) { + go updateOrchestrator.HandleOwnerConsent("anotherActivity", 0, &types.OwnerConsent{Status: types.StatusApproved}) + select { + case <-updateOrchestrator.operation.ownerConsented: + t.Fatal("unexpected owner consent") + case <-time.After(1 * time.Second): + // do nothing + } + }) +} + func applyDesiredState(ctx context.Context, updOrch *updateOrchestrator, done chan bool, domainAgents map[string]api.UpdateManager, activityID string, desiredState *types.DesiredState, apiDesState api.DesiredStateFeedbackHandler) { updOrch.Apply(ctx, domainAgents, activityID, desiredState, apiDesState) done <- true diff --git a/updatem/orchestration/update_phase.go b/updatem/orchestration/update_phase.go index dce6bd6..66b1b92 100644 --- a/updatem/orchestration/update_phase.go +++ b/updatem/orchestration/update_phase.go @@ -23,3 +23,12 @@ const ( ) var orderedPhases = []phase{phaseIdentification, phaseDownload, phaseUpdate, phaseActivation, phaseCleanup} + +func (p phase) next() phase { + for i := 0; i < len(orderedPhases)-1; i++ { + if orderedPhases[i] == p { + return orderedPhases[i+1] + } + } + return "" +} From 1d3be432b30beaeba8cf296b3fd5b20c00c63729 Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Fri, 19 Apr 2024 12:01:17 +0300 Subject: [PATCH 2/6] Fix formatting Signed-off-by: Dimitar Dimitrov --- test/mocks/client_mock.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mocks/client_mock.go b/test/mocks/client_mock.go index f136a53..6a5c251 100644 --- a/test/mocks/client_mock.go +++ b/test/mocks/client_mock.go @@ -621,4 +621,4 @@ func (m *MockOwnerConsentClient) Stop() error { func (mr *MockOwnerConsentClientMockRecorder) Stop() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockOwnerConsentClient)(nil).Stop)) -} \ No newline at end of file +} From cdc4c6fc7bae508b84b4c65f093900d746cc7dcb Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Mon, 22 Apr 2024 10:18:50 +0300 Subject: [PATCH 3/6] Remove desired state payload from owner consent get message Signed-off-by: Dimitar Dimitrov --- api/client.go | 2 +- mqtt/owner_consent_client _test.go | 12 +++--------- mqtt/owner_consent_client.go | 6 +++--- test/mocks/client_mock.go | 8 ++++---- updatem/orchestration/update_orchestrator_apply.go | 2 +- .../orchestration/update_orchestrator_apply_test.go | 8 ++++---- 6 files changed, 16 insertions(+), 22 deletions(-) diff --git a/api/client.go b/api/client.go index b5782ba..bcfd4ef 100755 --- a/api/client.go +++ b/api/client.go @@ -78,5 +78,5 @@ type OwnerConsentClient interface { BaseClient Start(OwnerConsentHandler) error - SendOwnerConsentGet(string, *types.DesiredState) error + SendOwnerConsentGet(string) error } diff --git a/mqtt/owner_consent_client _test.go b/mqtt/owner_consent_client _test.go index 9bfdf0f..d20976f 100755 --- a/mqtt/owner_consent_client _test.go +++ b/mqtt/owner_consent_client _test.go @@ -139,27 +139,21 @@ func TestSendOwnerConsentGet(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - testDesiredState := &types.DesiredState{ - Domains: []*types.Domain{ - {ID: test.domain}, - }, - } client, _ := NewOwnerConsentClient(test.domain, &updateAgentClient{ mqttClient: newInternalClient("testDomain", mqttTestConfig, mockPaho), }) mockPaho.EXPECT().Publish(test.domain+"update/ownerconsent/get", uint8(1), false, gomock.Any()).DoAndReturn( func(topic string, qos byte, retained bool, payload interface{}) pahomqtt.Token { - desiresState := &types.DesiredState{} - envelope, err := types.FromEnvelope(payload.([]byte), desiresState) + envelope, err := types.FromEnvelope(payload.([]byte), nil) assert.NoError(t, err) assert.Equal(t, name, envelope.ActivityID) assert.True(t, envelope.Timestamp > 0) - assert.Equal(t, testDesiredState, desiresState) + assert.Nil(t, envelope.Payload) return mockToken }) setupMockToken(mockToken, mqttTestConfig.AcknowledgeTimeout, test.isTimedOut) - assertOutgoingResult(t, test.isTimedOut, client.SendOwnerConsentGet(name, testDesiredState)) + assertOutgoingResult(t, test.isTimedOut, client.SendOwnerConsentGet(name)) }) } } diff --git a/mqtt/owner_consent_client.go b/mqtt/owner_consent_client.go index 27238cc..2a510fd 100755 --- a/mqtt/owner_consent_client.go +++ b/mqtt/owner_consent_client.go @@ -100,13 +100,13 @@ func (client *ownerConsentClient) handleMessage(mqttClient pahomqtt.Client, mess } } -func (client *ownerConsentClient) SendOwnerConsentGet(activityID string, desiredState *types.DesiredState) error { +func (client *ownerConsentClient) SendOwnerConsentGet(activityID string) error { logger.Debug("publishing to topic '%s'", client.topicOwnerConsentGet) - desiredStateBytes, err := types.ToEnvelope(activityID, desiredState) + consentGetBytes, err := types.ToEnvelope(activityID, nil) if err != nil { return errors.Wrapf(err, "cannot marshal owner consent get message for activity-id %s", activityID) } - token := client.pahoClient.Publish(client.topicOwnerConsentGet, 1, false, desiredStateBytes) + token := client.pahoClient.Publish(client.topicOwnerConsentGet, 1, false, consentGetBytes) if !token.WaitTimeout(client.mqttConfig.AcknowledgeTimeout) { return fmt.Errorf("cannot publish to topic '%s' in '%v'", client.topicOwnerConsentGet, client.mqttConfig.AcknowledgeTimeout) } diff --git a/test/mocks/client_mock.go b/test/mocks/client_mock.go index 6a5c251..1952171 100644 --- a/test/mocks/client_mock.go +++ b/test/mocks/client_mock.go @@ -582,17 +582,17 @@ func (mr *MockOwnerConsentClientMockRecorder) Domain() *gomock.Call { } // SendOwnerConsentGet mocks base method. -func (m *MockOwnerConsentClient) SendOwnerConsentGet(arg0 string, arg1 *types.DesiredState) error { +func (m *MockOwnerConsentClient) SendOwnerConsentGet(arg0 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendOwnerConsentGet", arg0, arg1) + ret := m.ctrl.Call(m, "SendOwnerConsentGet", arg0) ret0, _ := ret[0].(error) return ret0 } // SendOwnerConsentGet indicates an expected call of SendOwnerConsentGet. -func (mr *MockOwnerConsentClientMockRecorder) SendOwnerConsentGet(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockOwnerConsentClientMockRecorder) SendOwnerConsentGet(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendOwnerConsentGet", reflect.TypeOf((*MockOwnerConsentClient)(nil).SendOwnerConsentGet), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendOwnerConsentGet", reflect.TypeOf((*MockOwnerConsentClient)(nil).SendOwnerConsentGet), arg0) } // Start mocks base method. diff --git a/updatem/orchestration/update_orchestrator_apply.go b/updatem/orchestration/update_orchestrator_apply.go index 81e859c..ee64e8a 100644 --- a/updatem/orchestration/update_orchestrator_apply.go +++ b/updatem/orchestration/update_orchestrator_apply.go @@ -130,7 +130,7 @@ func (orchestrator *updateOrchestrator) getOwnerConsent(ctx context.Context, com } }() - if err := orchestrator.ownerConsentClient.SendOwnerConsentGet(orchestrator.operation.activityID, orchestrator.operation.desiredState); err != nil { + if err := orchestrator.ownerConsentClient.SendOwnerConsentGet(orchestrator.operation.activityID); err != nil { return err } diff --git a/updatem/orchestration/update_orchestrator_apply_test.go b/updatem/orchestration/update_orchestrator_apply_test.go index 00b544e..c691c55 100644 --- a/updatem/orchestration/update_orchestrator_apply_test.go +++ b/updatem/orchestration/update_orchestrator_apply_test.go @@ -508,7 +508,7 @@ func TestGetOwnerConsent(t *testing.T) { mockClient := mocks.NewMockOwnerConsentClient(ctrl) mockClient.EXPECT().Start(gomock.Any()).Return(nil) mockClient.EXPECT().Stop().Return(nil) - mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID, gomock.Any()).Return(fmt.Errorf("send error")) + mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID).Return(fmt.Errorf("send error")) return mockClient, nil }, }, @@ -519,7 +519,7 @@ func TestGetOwnerConsent(t *testing.T) { mockClient := mocks.NewMockOwnerConsentClient(ctrl) mockClient.EXPECT().Start(gomock.Any()).Return(nil) mockClient.EXPECT().Stop().Return(nil) - mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID, gomock.Any()).Return(nil) + mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID).Return(nil) ch := make(chan bool) go func() { ch <- true @@ -535,7 +535,7 @@ func TestGetOwnerConsent(t *testing.T) { mockClient := mocks.NewMockOwnerConsentClient(ctrl) mockClient.EXPECT().Start(gomock.Any()).Return(nil) mockClient.EXPECT().Stop().Return(nil) - mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID, gomock.Any()).Return(nil) + mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID).Return(nil) ch := make(chan bool) go func() { ch <- false @@ -551,7 +551,7 @@ func TestGetOwnerConsent(t *testing.T) { mockClient := mocks.NewMockOwnerConsentClient(ctrl) mockClient.EXPECT().Start(gomock.Any()).Return(nil) mockClient.EXPECT().Stop().Return(nil) - mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID, gomock.Any()).Return(nil) + mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID).Return(nil) return mockClient, make(chan bool) }, }, From 2451dc0f1f5441d9917bde25036a2f4fb9acdb4e Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Wed, 24 Apr 2024 18:40:13 +0300 Subject: [PATCH 4/6] Rework request owner's approval - add command to the consent request - remove phases and re-use desired state command values - remove unit test will be added in separate PR Signed-off-by: Dimitar Dimitrov --- api/client.go | 10 +- api/types/owner_consent.go | 9 +- cmd/update-manager/main.go | 2 +- config/config_internal.go | 3 +- config/config_test.go | 3 +- config/flags_internal.go | 35 +-- config/flags_test.go | 101 -------- config/testdata/config.json | 2 +- mqtt/owner_consent_agent_client .go | 42 +-- mqtt/owner_consent_agent_client_test.go | 173 ------------- mqtt/owner_consent_client _test.go | 195 -------------- mqtt/owner_consent_client.go | 38 +-- mqtt/update_agent_client.go | 8 +- test/mocks/client_mock.go | 48 ++-- test/mocks/update_orchestrator_mock.go | 10 +- updatem/orchestration/update_operation.go | 17 +- .../orchestration/update_operation_test.go | 2 +- updatem/orchestration/update_orchestrator.go | 5 +- .../update_orchestrator_apply.go | 96 ++++--- .../update_orchestrator_apply_test.go | 243 +++++------------- .../update_orchestrator_feedback.go | 12 +- .../update_orchestrator_feedback_test.go | 45 ++-- .../orchestration/update_orchestrator_test.go | 36 --- updatem/orchestration/update_phase.go | 34 --- 24 files changed, 268 insertions(+), 901 deletions(-) delete mode 100644 mqtt/owner_consent_agent_client_test.go delete mode 100755 mqtt/owner_consent_client _test.go delete mode 100644 updatem/orchestration/update_phase.go diff --git a/api/client.go b/api/client.go index bcfd4ef..cbf9c17 100755 --- a/api/client.go +++ b/api/client.go @@ -57,7 +57,7 @@ type DesiredStateClient interface { // OwnerConsentAgentHandler defines functions for handling the owner consent requests type OwnerConsentAgentHandler interface { - HandleOwnerConsentGet(string, int64, *types.OwnerConsent) error + HandleOwnerConsent(string, int64, *types.OwnerConsent) error } // OwnerConsentAgentClient defines an interface for handling for owner consent requests @@ -65,12 +65,12 @@ type OwnerConsentAgentClient interface { BaseClient Start(OwnerConsentAgentHandler) error - SendOwnerConsent(string, *types.OwnerConsent) error + SendOwnerConsentFeedback(string, *types.OwnerConsentFeedback) error } -// OwnerConsentHandler defines functions for handling the owner consent +// OwnerConsentHandler defines functions for handling the owner consent feedback type OwnerConsentHandler interface { - HandleOwnerConsent(string, int64, *types.OwnerConsent) error + HandleOwnerConsentFeedback(string, int64, *types.OwnerConsentFeedback) error } // OwnerConsentClient defines an interface for triggering requests for owner consent @@ -78,5 +78,5 @@ type OwnerConsentClient interface { BaseClient Start(OwnerConsentHandler) error - SendOwnerConsentGet(string) error + SendOwnerConsent(string, *types.OwnerConsent) error } diff --git a/api/types/owner_consent.go b/api/types/owner_consent.go index c05497d..78fea62 100644 --- a/api/types/owner_consent.go +++ b/api/types/owner_consent.go @@ -22,8 +22,13 @@ const ( StatusDenied ConsentStatusType = "DENIED" ) -// OwnerConsent defines the payload for Owner Consent response. -type OwnerConsent struct { +// OwnerConsentFeedback defines the payload for Owner Consent Feedback. +type OwnerConsentFeedback struct { Status ConsentStatusType `json:"status,omitempty"` // time field for scheduling could be added here } + +// OwnerConsent defines the payload for Owner Consent. +type OwnerConsent struct { + Command CommandType `json:"command,omitempty"` +} diff --git a/cmd/update-manager/main.go b/cmd/update-manager/main.go index 6006abb..2ec218e 100755 --- a/cmd/update-manager/main.go +++ b/cmd/update-manager/main.go @@ -70,7 +70,7 @@ func initUpdateManager(cfg *config.Config) (api.UpdateAgentClient, api.UpdateMan return nil, nil, err } - if len(cfg.OwnerConsentPhases) != 0 { + if len(cfg.OwnerConsentCommands) != 0 { if occ, err = mqtt.NewOwnerConsentClient(cfg.Domain, uac); err != nil { return nil, nil, err } diff --git a/config/config_internal.go b/config/config_internal.go index 32c44ad..0e8aee6 100755 --- a/config/config_internal.go +++ b/config/config_internal.go @@ -14,6 +14,7 @@ package config import ( "github.com/eclipse-kanto/update-manager/api" + "github.com/eclipse-kanto/update-manager/api/types" ) const ( @@ -44,7 +45,7 @@ type Config struct { ReportFeedbackInterval string `json:"reportFeedbackInterval"` CurrentStateDelay string `json:"currentStateDelay"` PhaseTimeout string `json:"phaseTimeout"` - OwnerConsentPhases []string `json:"ownerConsentPhases"` + OwnerConsentCommands []types.CommandType `json:"ownerConsentCommands"` } func newDefaultConfig() *Config { diff --git a/config/config_test.go b/config/config_test.go index 6f907e4..7ccd0a7 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -17,6 +17,7 @@ import ( "testing" "github.com/eclipse-kanto/update-manager/api" + "github.com/eclipse-kanto/update-manager/api/types" "github.com/eclipse-kanto/update-manager/logger" "github.com/eclipse-kanto/update-manager/mqtt" @@ -136,7 +137,7 @@ func TestLoadConfigFromFile(t *testing.T) { ReportFeedbackInterval: "2m", CurrentStateDelay: "1m", PhaseTimeout: "2m", - OwnerConsentPhases: []string{"download"}, + OwnerConsentCommands: []types.CommandType{types.CommandDownload}, } assert.True(t, reflect.DeepEqual(*cfg, expectedConfigValues)) }) diff --git a/config/flags_internal.go b/config/flags_internal.go index d737fd9..897f148 100755 --- a/config/flags_internal.go +++ b/config/flags_internal.go @@ -19,15 +19,16 @@ import ( "os" "strings" + "github.com/eclipse-kanto/update-manager/api/types" "github.com/eclipse-kanto/update-manager/logger" ) const ( // domains flag - domainsFlagID = "domains" - domainsDesc = "Specify a comma-separated list of domains handled by the update manager" - ownerConsentPhasesFlagID = "owner-consent-phases" - ownerConsentPhasesDesc = "Specify a comma-separated list of update phase, before which an owner consent should be granted. Possible values are: 'download', 'update', 'activation'" + domainsFlagID = "domains" + domainsDesc = "Specify a comma-separated list of domains handled by the update manager" + ownerConsentCommandsFlagID = "owner-consent-commands" + ownerConsentCommandsDesc = "Specify a comma-separated list of commands, before which an owner consent should be granted. Possible values are: 'download', 'update', 'activate'" ) // SetupAllUpdateManagerFlags adds all flags for the configuration of the update manager @@ -42,15 +43,15 @@ func SetupAllUpdateManagerFlags(flagSet *flag.FlagSet, cfg *Config) { flagSet.StringVar(&cfg.PhaseTimeout, "phase-timeout", EnvToString("PHASE_TIMEOUT", cfg.PhaseTimeout), "Specify the timeout for completing an Update Orchestration phase. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") flagSet.StringVar(&cfg.ReportFeedbackInterval, "report-feedback-interval", EnvToString("REPORT_FEEDBACK_INTERVAL", cfg.ReportFeedbackInterval), "Specify the time interval for reporting intermediate desired state feedback messages during an active update operation. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") flagSet.StringVar(&cfg.CurrentStateDelay, "current-state-delay", EnvToString("CURRENT_STATE_DELAY", cfg.CurrentStateDelay), "Specify the time delay for reporting current state messages. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") - flagSet.String(ownerConsentPhasesFlagID, "", "Specify a comma-separated list of update phase, before which an owner consent should be granted. Possible values are: 'download', 'update', 'activation'") + flagSet.String(ownerConsentCommandsFlagID, "", ownerConsentCommandsDesc) setupAgentsConfigFlags(flagSet, cfg) } func parseFlags(cfg *Config, version string) { domains := parseDomainsFlag() prepareAgentsConfig(cfg, domains) - if ownerConsentPhases := parseOwnerConsentPhasesFlag(); len(ownerConsentPhases) > 0 { - cfg.OwnerConsentPhases = ownerConsentPhases + if ownerConsentPhases := parseOwnerConsentCommandsFlag(); len(ownerConsentPhases) > 0 { + cfg.OwnerConsentCommands = ownerConsentPhases } flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) @@ -69,20 +70,20 @@ func parseFlags(cfg *Config, version string) { } } -func parseOwnerConsentPhasesFlag() []string { - var listPhases string +func parseOwnerConsentCommandsFlag() []types.CommandType { + var listCommands string flagSet := flag.NewFlagSet("", flag.ContinueOnError) flagSet.SetOutput(io.Discard) - flagSet.StringVar(&listPhases, ownerConsentPhasesFlagID, EnvToString("OWNER_CONSENT_PHASES", ""), ownerConsentPhasesDesc) - if err := flagSet.Parse(getFlagArgs(ownerConsentPhasesFlagID)); err != nil { - logger.ErrorErr(err, "Cannot parse %s flag", ownerConsentPhasesFlagID) + flagSet.StringVar(&listCommands, ownerConsentCommandsFlagID, EnvToString("OWNER_CONSENT_COMMANDS", ""), ownerConsentCommandsDesc) + if err := flagSet.Parse(getFlagArgs(ownerConsentCommandsFlagID)); err != nil { + logger.ErrorErr(err, "Cannot parse %s flag", ownerConsentCommandsFlagID) } - var result []string - for _, phase := range strings.Split(listPhases, ",") { - p := strings.TrimSpace(phase) - if len(p) > 0 { - result = append(result, p) + var result []types.CommandType + for _, command := range strings.Split(listCommands, ",") { + c := strings.TrimSpace(command) + if len(c) > 0 { + result = append(result, types.CommandType(strings.ToUpper(c))) } } return result diff --git a/config/flags_test.go b/config/flags_test.go index e668c41..ea91cc6 100644 --- a/config/flags_test.go +++ b/config/flags_test.go @@ -17,8 +17,6 @@ import ( "fmt" "os" "reflect" - "slices" - "strings" "testing" "github.com/eclipse-kanto/update-manager/api" @@ -131,10 +129,6 @@ func TestSetupFlags(t *testing.T) { flag: "phase-timeout", expectedType: reflect.String.String(), }, - "test_flags_owner_consent_phases": { - flag: "owner-consent-phases", - expectedType: reflect.String.String(), - }, } for testName, testCase := range tests { t.Run(testName, func(t *testing.T) { @@ -236,68 +230,6 @@ func TestParseDomainsFlag(t *testing.T) { } }) } -func TestParseOwnerConsentPhasesFlag(t *testing.T) { - oldArgs := os.Args - defer func() { os.Args = oldArgs }() - testPhases := "download" - - t.Run("test_parse_consent_phases_flag_1", func(t *testing.T) { - os.Args = []string{oldArgs[0], fmt.Sprintf("-%s=%s", ownerConsentPhasesFlagID, testPhases)} - actualConsentPhases := parseOwnerConsentPhasesFlag() - if len(actualConsentPhases) != 1 && !slices.Contains(actualConsentPhases, testPhases) { - t.Error("consent phase not set") - } - }) - - t.Run("test_parse_consent_phases_flag_2", func(t *testing.T) { - os.Args = []string{oldArgs[0], fmt.Sprintf("--%s=%s", ownerConsentPhasesFlagID, testPhases)} - actualConsentPhases := parseOwnerConsentPhasesFlag() - if len(actualConsentPhases) != 1 && !slices.Contains(actualConsentPhases, testPhases) { - t.Error("consent phase not set") - } - }) - - t.Run("test_parse_consent_phases_flag_3", func(t *testing.T) { - os.Args = []string{oldArgs[0], fmt.Sprintf("-%s", ownerConsentPhasesFlagID), testPhases} - actualConsentPhases := parseOwnerConsentPhasesFlag() - if len(actualConsentPhases) != 1 && !slices.Contains(actualConsentPhases, testPhases) { - t.Error("consent phase not set") - } - }) - - t.Run("test_parse_consent_phases_flag_4", func(t *testing.T) { - os.Args = []string{oldArgs[0], fmt.Sprintf("-%s", ownerConsentPhasesFlagID), testPhases} - actualConsentPhases := parseOwnerConsentPhasesFlag() - if len(actualConsentPhases) != 1 && !slices.Contains(actualConsentPhases, testPhases) { - t.Error("consent phase not set") - } - }) - - t.Run("test_parse_consent_phase_flag_err", func(t *testing.T) { - invalidConsentPhasesFlagID := "invalid" - os.Args = []string{oldArgs[0], fmt.Sprintf("--%s=%s", invalidConsentPhasesFlagID, testPhases)} - actualConsentPhases := parseOwnerConsentPhasesFlag() - if len(actualConsentPhases) != 0 { - t.Errorf("\"incorrect value: %v , expecting: empty \"", actualConsentPhases) - } - }) - - t.Run("test_parse_consent_phases_flag_err_1", func(t *testing.T) { - os.Args = []string{oldArgs[0], fmt.Sprintf("-%s", ownerConsentPhasesFlagID)} - actualConsentPhases := parseOwnerConsentPhasesFlag() - if len(actualConsentPhases) != 0 { - t.Errorf("\"incorrect value: %v , expecting: empty \"", actualConsentPhases) - } - }) - - t.Run("test_parse_consent_phases_flag_err_2", func(t *testing.T) { - os.Args = []string{oldArgs[0], fmt.Sprintf("--%s", ownerConsentPhasesFlagID)} - actualConsentPhases := parseOwnerConsentPhasesFlag() - if len(actualConsentPhases) != 0 { - t.Errorf("\"incorrect value: %v , expecting: empty \"", actualConsentPhases) - } - }) -} func TestParseFlags(t *testing.T) { testVersion := "testVersion" @@ -440,37 +372,4 @@ func TestParseFlags(t *testing.T) { parseFlags(cfg, testVersion) assert.Equal(t, expectedAgents, cfg.Agents) }) - t.Run("test_owner_consent_phases", func(t *testing.T) { - oldArgs := os.Args - defer func() { os.Args = oldArgs }() - - testConfigPath := "../config/testdata/config.json" - expectedPhases := []string{"download"} - - os.Args = []string{oldArgs[0], fmt.Sprintf("--%s=%s", configFileFlagID, testConfigPath)} - cfg := newDefaultConfig() - configFilePath := ParseConfigFilePath() - if configFilePath != "" { - assert.NoError(t, LoadConfigFromFile(configFilePath, cfg)) - } - parseFlags(cfg, testVersion) - assert.Equal(t, expectedPhases, cfg.OwnerConsentPhases) - }) - t.Run("test_overwrite_owner_consent_phases", func(t *testing.T) { - oldArgs := os.Args - defer func() { os.Args = oldArgs }() - - testConfigPath := "../config/testdata/config.json" - expectedPhases := []string{"update", "activation"} - - os.Args = []string{oldArgs[0], fmt.Sprintf("--%s=%s", configFileFlagID, testConfigPath), - fmt.Sprintf("--%s=%s", ownerConsentPhasesFlagID, strings.Join(expectedPhases, ","))} - cfg := newDefaultConfig() - configFilePath := ParseConfigFilePath() - if configFilePath != "" { - assert.NoError(t, LoadConfigFromFile(configFilePath, cfg)) - } - parseFlags(cfg, testVersion) - assert.Equal(t, expectedPhases, cfg.OwnerConsentPhases) - }) } diff --git a/config/testdata/config.json b/config/testdata/config.json index 0e8a624..174bfba 100644 --- a/config/testdata/config.json +++ b/config/testdata/config.json @@ -24,7 +24,7 @@ "reportFeedbackInterval": "2m", "currentStateDelay": "1m", "phaseTimeout": "2m", - "ownerConsentPhases": ["download"], + "ownerConsentCommands": ["DOWNLOAD"], "agents": { "self-update": { "rebootRequired": false, diff --git a/mqtt/owner_consent_agent_client .go b/mqtt/owner_consent_agent_client .go index 11afa99..f75cc7b 100755 --- a/mqtt/owner_consent_agent_client .go +++ b/mqtt/owner_consent_agent_client .go @@ -44,9 +44,9 @@ func NewOwnerConsentAgentClient(domain string, config *ConnectionConfig) (api.Ow func (client *ownerConsentAgentClient) onConnect(_ pahomqtt.Client) { if err := client.subscribe(); err != nil { - logger.ErrorErr(err, "[%s] error subscribing for OwnerConsentGet requests", client.Domain()) + logger.ErrorErr(err, "[%s] error subscribing for OwnerConsent requests", client.Domain()) } else { - logger.Debug("[%s] subscribed for OwnerConsentGet requests", client.Domain()) + logger.Debug("[%s] subscribed for OwnerConsent requests", client.Domain()) } } @@ -64,12 +64,12 @@ func (client *ownerConsentAgentClient) Domain() string { return client.domain } -// Stop removes the client subscription to the MQTT broker for the MQTT topics for getting owner consent. +// Stop removes the client subscription to the MQTT broker for the MQTT topics for requesting owner consent. func (client *ownerConsentAgentClient) Stop() error { if err := client.unsubscribe(); err != nil { - logger.WarnErr(err, "[%s] error unsubscribing for OwnerConsentGet requests", client.Domain()) + logger.WarnErr(err, "[%s] error unsubscribing for OwnerConsent requests", client.Domain()) } else { - logger.Debug("[%s] unsubscribed for OwnerConsentGet messages", client.Domain()) + logger.Debug("[%s] unsubscribed for OwnerConsent messages", client.Domain()) } client.pahoClient.Disconnect(disconnectQuiesce) client.handler = nil @@ -77,19 +77,19 @@ func (client *ownerConsentAgentClient) Stop() error { } func (client *ownerConsentAgentClient) subscribe() error { - logger.Debug("subscribing for '%v' topic", client.topicOwnerConsentGet) - token := client.pahoClient.Subscribe(client.topicOwnerConsentGet, 1, client.handleMessage) + logger.Debug("subscribing for '%v' topic", client.topicOwnerConsent) + token := client.pahoClient.Subscribe(client.topicOwnerConsent, 1, client.handleMessage) if !token.WaitTimeout(client.mqttConfig.SubscribeTimeout) { - return fmt.Errorf("cannot subscribe for topic '%s' in '%v'", client.topicOwnerConsentGet, client.mqttConfig.SubscribeTimeout) + return fmt.Errorf("cannot subscribe for topic '%s' in '%v'", client.topicOwnerConsent, client.mqttConfig.SubscribeTimeout) } return token.Error() } func (client *ownerConsentAgentClient) unsubscribe() error { - logger.Debug("unsubscribing from '%s' topic", client.topicOwnerConsentGet) - token := client.pahoClient.Unsubscribe(client.topicOwnerConsentGet) + logger.Debug("unsubscribing from '%s' topic", client.topicOwnerConsent) + token := client.pahoClient.Unsubscribe(client.topicOwnerConsent) if !token.WaitTimeout(client.mqttConfig.UnsubscribeTimeout) { - return fmt.Errorf("cannot unsubscribe from topic '%s' in '%v'", client.topicOwnerConsentGet, client.mqttConfig.UnsubscribeTimeout) + return fmt.Errorf("cannot unsubscribe from topic '%s' in '%v'", client.topicOwnerConsent, client.mqttConfig.UnsubscribeTimeout) } return token.Error() } @@ -97,28 +97,28 @@ func (client *ownerConsentAgentClient) unsubscribe() error { func (client *ownerConsentAgentClient) handleMessage(mqttClient pahomqtt.Client, message pahomqtt.Message) { topic := message.Topic() logger.Debug("[%s] received %s message", client.Domain(), topic) - if topic == client.topicOwnerConsentGet { + if topic == client.topicOwnerConsent { consent := &types.OwnerConsent{} envelope, err := types.FromEnvelope(message.Payload(), consent) if err != nil { - logger.ErrorErr(err, "[%s] cannot parse owner conset get message", client.Domain()) + logger.ErrorErr(err, "[%s] cannot parse owner consent message", client.Domain()) return } - if err := client.handler.HandleOwnerConsentGet(envelope.ActivityID, envelope.Timestamp, consent); err != nil { - logger.ErrorErr(err, "[%s] error processing owner consent get message", client.Domain()) + if err := client.handler.HandleOwnerConsent(envelope.ActivityID, envelope.Timestamp, consent); err != nil { + logger.ErrorErr(err, "[%s] error processing owner consent message", client.Domain()) } } } -func (client *ownerConsentAgentClient) SendOwnerConsent(activityID string, consent *types.OwnerConsent) error { - logger.Debug("publishing to topic '%s'", client.topicOwnerConsent) - desiredStateBytes, err := types.ToEnvelope(activityID, consent) +func (client *ownerConsentAgentClient) SendOwnerConsentFeedback(activityID string, consentFeedback *types.OwnerConsentFeedback) error { + logger.Debug("publishing to topic '%s'", client.topicOwnerConsentFeedback) + desiredStateBytes, err := types.ToEnvelope(activityID, consentFeedback) if err != nil { - return errors.Wrapf(err, "cannot marshal owner consent message for activity-id %s", activityID) + return errors.Wrapf(err, "cannot marshal owner consent feedback message for activity-id %s", activityID) } - token := client.pahoClient.Publish(client.topicOwnerConsent, 1, false, desiredStateBytes) + token := client.pahoClient.Publish(client.topicOwnerConsentFeedback, 1, false, desiredStateBytes) if !token.WaitTimeout(client.mqttConfig.AcknowledgeTimeout) { - return fmt.Errorf("cannot publish to topic '%s' in '%v'", client.topicOwnerConsent, client.mqttConfig.AcknowledgeTimeout) + return fmt.Errorf("cannot publish to topic '%s' in '%v'", client.topicOwnerConsentFeedback, client.mqttConfig.AcknowledgeTimeout) } return token.Error() } diff --git a/mqtt/owner_consent_agent_client_test.go b/mqtt/owner_consent_agent_client_test.go deleted file mode 100644 index 0d41f8c..0000000 --- a/mqtt/owner_consent_agent_client_test.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) 2024 Contributors to the Eclipse Foundation -// -// See the NOTICE file(s) distributed with this work for additional -// information regarding copyright ownership. -// -// This program and the accompanying materials are made available under the -// terms of the Eclipse Public License 2.0 which is available at -// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 -// which is available at https://www.apache.org/licenses/LICENSE-2.0. -// -// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 - -package mqtt - -import ( - "errors" - "testing" - - "github.com/eclipse-kanto/update-manager/api/types" - mqttmocks "github.com/eclipse-kanto/update-manager/mqtt/mocks" - "github.com/eclipse-kanto/update-manager/test/mocks" - - pahomqtt "github.com/eclipse/paho.mqtt.golang" - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) - -func TestOwnerConsentAgentClientStart(t *testing.T) { - tests := map[string]testCaseOutgoing{ - "test_connect_ok": {domain: "testdomain", isTimedOut: false}, - "test_connect_timeout": {domain: "mydomain", isTimedOut: true}, - } - - mockCtrl, mockPaho, mockToken := setupCommonMocks(t) - defer mockCtrl.Finish() - - mockHandler := mocks.NewMockOwnerConsentAgentHandler(mockCtrl) - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - client := &ownerConsentAgentClient{ - domain: test.domain, - mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), - } - - mockPaho.EXPECT().Connect().Return(mockToken) - setupMockToken(mockToken, mqttTestConfig.ConnectTimeout, test.isTimedOut) - - assertOutgoingResult(t, test.isTimedOut, client.Start(mockHandler)) - assert.Equal(t, mockHandler, client.handler) - }) - } -} - -func TestOwnerConsentAgentClientStop(t *testing.T) { - tests := map[string]testCaseOutgoing{ - //"test_disconnect_ok": {domain: "testdomain", isTimedOut: false}, - "test_disconnect_timeout": {domain: "mydomain", isTimedOut: true}, - } - - mockCtrl, mockPaho, mockToken := setupCommonMocks(t) - defer mockCtrl.Finish() - - mockHandler := mocks.NewMockOwnerConsentAgentHandler(mockCtrl) - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - client := &ownerConsentAgentClient{ - domain: test.domain, - mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), - handler: mockHandler, - } - - mockPaho.EXPECT().Unsubscribe(test.domain + "update/ownerconsent/get").Return(mockToken) - mockPaho.EXPECT().Disconnect(disconnectQuiesce) - setupMockToken(mockToken, mqttTestConfig.UnsubscribeTimeout, test.isTimedOut) - - assert.NoError(t, client.Stop()) - assert.Nil(t, client.handler) - }) - } -} - -func TestSendOwnerConsent(t *testing.T) { - tests := map[string]testCaseOutgoing{ - "test_send_owner_consent_ok": {domain: "testdomain", isTimedOut: false}, - "test_send_owner_consent_error": {domain: "mydomain", isTimedOut: true}, - } - - mockCtrl, mockPaho, mockToken := setupCommonMocks(t) - defer mockCtrl.Finish() - - testConsent := &types.OwnerConsent{ - Status: types.StatusApproved, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - client := &ownerConsentAgentClient{ - domain: test.domain, - mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), - } - mockPaho.EXPECT().Publish(test.domain+"update/ownerconsent", uint8(1), false, gomock.Any()).DoAndReturn( - func(topic string, qos byte, retained bool, payload interface{}) pahomqtt.Token { - consent := &types.OwnerConsent{} - envelope, err := types.FromEnvelope(payload.([]byte), consent) - assert.NoError(t, err) - assert.Equal(t, name, envelope.ActivityID) - assert.True(t, envelope.Timestamp > 0) - assert.Equal(t, testConsent, consent) - return mockToken - }) - setupMockToken(mockToken, mqttTestConfig.AcknowledgeTimeout, false) - - assert.NoError(t, client.SendOwnerConsent(name, testConsent)) - }) - } -} - -func TestOwnerConsentOnConnect(t *testing.T) { - mockCtrl, mockPaho, mockToken := setupCommonMocks(t) - defer mockCtrl.Finish() - - client := newInternalClient("test", mqttTestConfig, mockPaho) - - t.Run("test_onConnect", func(t *testing.T) { - mockHandler := mocks.NewMockOwnerConsentAgentHandler(mockCtrl) - client := &ownerConsentAgentClient{ - mqttClient: client, - domain: "test", - handler: mockHandler, - } - mockPaho.EXPECT().Subscribe("testupdate/ownerconsent/get", uint8(1), gomock.Any()).Return(mockToken) - setupMockToken(mockToken, mqttTestConfig.SubscribeTimeout, false) - - client.onConnect(nil) - }) -} - -func TestHandleOwnerConsentGetMessage(t *testing.T) { - tests := map[string]testCaseIncoming{ - "test_handle_owner_conset_get_ok": {domain: "testdomain", handlerError: nil, expectedJSONErr: false}, - "test_handle_owner_conset_get_error": {domain: "mydomain", handlerError: errors.New("handler error"), expectedJSONErr: false}, - "test_handle_owner_conset_get_json_error": {domain: "testdomain", handlerError: nil, expectedJSONErr: true}, - } - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockMessage := mqttmocks.NewMockMessage(mockCtrl) - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - testConsent := &types.OwnerConsent{ - Status: types.StatusApproved, - } - testBytes, expectedCalls := testBytesToEnvelope(t, name, testConsent, test.expectedJSONErr) - - mockHandler := mocks.NewMockOwnerConsentAgentHandler(mockCtrl) - mockHandler.EXPECT().HandleOwnerConsentGet(name, gomock.Any(), testConsent).Times(expectedCalls).Return(test.handlerError) - - client := &ownerConsentAgentClient{ - mqttClient: newInternalClient(test.domain, mqttTestConfig, nil), - domain: test.domain, - handler: mockHandler, - } - mockMessage.EXPECT().Topic().Return(test.domain + "update/ownerconsent/get") - mockMessage.EXPECT().Payload().Return(testBytes) - - client.handleMessage(nil, mockMessage) - }) - } -} diff --git a/mqtt/owner_consent_client _test.go b/mqtt/owner_consent_client _test.go deleted file mode 100755 index d20976f..0000000 --- a/mqtt/owner_consent_client _test.go +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright (c) 2024 Contributors to the Eclipse Foundation -// -// See the NOTICE file(s) distributed with this work for additional -// information regarding copyright ownership. -// -// This program and the accompanying materials are made available under the -// terms of the Eclipse Public License 2.0 which is available at -// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 -// which is available at https://www.apache.org/licenses/LICENSE-2.0. -// -// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 - -package mqtt - -import ( - "fmt" - "testing" - - "github.com/eclipse-kanto/update-manager/api" - "github.com/eclipse-kanto/update-manager/api/types" - clientsmocks "github.com/eclipse-kanto/update-manager/mqtt/mocks" - "github.com/eclipse-kanto/update-manager/test/mocks" - pahomqtt "github.com/eclipse/paho.mqtt.golang" - "github.com/golang/mock/gomock" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" -) - -func TestNewOwnerConsentClient(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockPaho := clientsmocks.NewMockClient(mockCtrl) - mockClient := mocks.NewMockUpdateAgentClient(mockCtrl) - - tests := map[string]struct { - client api.UpdateAgentClient - err string - }{ - "test_update_agent_client": { - client: &updateAgentClient{ - mqttClient: newInternalClient("testDomain", &internalConnectionConfig{}, mockPaho), - }, - }, - "test_update_agent_things_client": { - client: &updateAgentThingsClient{ - updateAgentClient: &updateAgentClient{ - mqttClient: newInternalClient("testDomain", &internalConnectionConfig{}, mockPaho), - }, - }, - }, - "test_error": { - client: mockClient, - err: fmt.Sprintf("unexpected type: %T", mockClient), - }, - } - for name, test := range tests { - t.Run(name, func(t *testing.T) { - client, err := NewDesiredStateClient("testDomain", test.client) - if test.err != "" { - assert.EqualError(t, err, fmt.Sprintf("unexpected type: %T", test.client)) - } else { - assert.NoError(t, err) - assert.NotNil(t, client) - } - }) - } -} - -func TestOwnerConsentClientStart(t *testing.T) { - tests := map[string]testCaseOutgoing{ - "test_subscribe_ok": {domain: "testdomain", isTimedOut: false}, - "test_subscribe_timeout": {domain: "mydomain", isTimedOut: true}, - } - - mockCtrl, mockPaho, mockToken := setupCommonMocks(t) - defer mockCtrl.Finish() - - mockHandler := mocks.NewMockOwnerConsentHandler(mockCtrl) - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - client := &ownerConsentClient{ - mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), - domain: test.domain, - } - mockPaho.EXPECT().Subscribe(test.domain+"update/ownerconsent", uint8(1), gomock.Any()).Return(mockToken) - setupMockToken(mockToken, mqttTestConfig.SubscribeTimeout, test.isTimedOut) - - assertOutgoingResult(t, test.isTimedOut, client.Start(mockHandler)) - if test.isTimedOut { - assert.Nil(t, client.handler) - } else { - assert.Equal(t, mockHandler, client.handler) - } - }) - } -} - -func TestOwnerConsentClientStop(t *testing.T) { - tests := map[string]testCaseOutgoing{ - "test_unsubscribe_ok": {domain: "testdomain", isTimedOut: false}, - "test_unsubscribe_timeout": {domain: "mydomain", isTimedOut: true}, - } - - mockCtrl, mockPaho, mockToken := setupCommonMocks(t) - defer mockCtrl.Finish() - - mockHandler := mocks.NewMockOwnerConsentHandler(mockCtrl) - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - client := &ownerConsentClient{ - mqttClient: newInternalClient(test.domain, mqttTestConfig, mockPaho), - domain: test.domain, - handler: mockHandler, - } - mockPaho.EXPECT().Unsubscribe(test.domain + "update/ownerconsent").Return(mockToken) - setupMockToken(mockToken, mqttTestConfig.UnsubscribeTimeout, test.isTimedOut) - - assertOutgoingResult(t, test.isTimedOut, client.Stop()) - if test.isTimedOut { - assert.Equal(t, mockHandler, client.handler) - } else { - assert.Nil(t, client.handler) - } - }) - } -} - -func TestSendOwnerConsentGet(t *testing.T) { - tests := map[string]testCaseOutgoing{ - "test_send_owner_consent_get_ok": {domain: "testdomain", isTimedOut: false}, - "test_send_owner_consent_get_error": {domain: "mydomain", isTimedOut: true}, - } - - mockCtrl, mockPaho, mockToken := setupCommonMocks(t) - defer mockCtrl.Finish() - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - client, _ := NewOwnerConsentClient(test.domain, &updateAgentClient{ - mqttClient: newInternalClient("testDomain", mqttTestConfig, mockPaho), - }) - mockPaho.EXPECT().Publish(test.domain+"update/ownerconsent/get", uint8(1), false, gomock.Any()).DoAndReturn( - func(topic string, qos byte, retained bool, payload interface{}) pahomqtt.Token { - envelope, err := types.FromEnvelope(payload.([]byte), nil) - assert.NoError(t, err) - assert.Equal(t, name, envelope.ActivityID) - assert.True(t, envelope.Timestamp > 0) - assert.Nil(t, envelope.Payload) - return mockToken - }) - setupMockToken(mockToken, mqttTestConfig.AcknowledgeTimeout, test.isTimedOut) - - assertOutgoingResult(t, test.isTimedOut, client.SendOwnerConsentGet(name)) - }) - } -} - -func TestHandleOwnerConsentMessage(t *testing.T) { - tests := map[string]testCaseIncoming{ - "test_handle_owner_consent_ok": {domain: "testdomain", handlerError: nil, expectedJSONErr: false}, - "test_handle_owner_consent_error": {domain: "mydomain", handlerError: errors.New("handler error"), expectedJSONErr: false}, - "test_handle_owner_consent_json_error": {domain: "testdomain", handlerError: nil, expectedJSONErr: true}, - } - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockMessage := clientsmocks.NewMockMessage(mockCtrl) - - testConsent := &types.OwnerConsent{ - Status: types.StatusApproved, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - testBytes, expectedCalls := testBytesToEnvelope(t, name, testConsent, test.expectedJSONErr) - - handler := mocks.NewMockOwnerConsentHandler(mockCtrl) - handler.EXPECT().HandleOwnerConsent(name, gomock.Any(), testConsent).Times(expectedCalls).Return(test.handlerError) - - client := &ownerConsentClient{ - mqttClient: newInternalClient(test.domain, &internalConnectionConfig{}, nil), - domain: test.domain, - handler: handler, - } - mockMessage.EXPECT().Topic().Return(test.domain + "update/ownerconsent") - mockMessage.EXPECT().Payload().Return(testBytes) - - client.handleMessage(nil, mockMessage) - }) - } -} diff --git a/mqtt/owner_consent_client.go b/mqtt/owner_consent_client.go index 2a510fd..c42dcb0 100755 --- a/mqtt/owner_consent_client.go +++ b/mqtt/owner_consent_client.go @@ -50,36 +50,36 @@ func (client *ownerConsentClient) Start(consentHandler api.OwnerConsentHandler) client.handler = consentHandler if err := client.subscribe(); err != nil { client.handler = nil - return fmt.Errorf("[%s] error subscribing for OwnerConsent messages: %w", client.Domain(), err) + return fmt.Errorf("[%s] error subscribing for OwnerConsentFeedback messages: %w", client.Domain(), err) } - logger.Debug("[%s] subscribed for OwnerConsent messages", client.Domain()) + logger.Debug("[%s] subscribed for OwnerConsentFeedback messages", client.Domain()) return nil } // Stop removes the client subscription to the MQTT broker for the MQTT topics for owner consent. func (client *ownerConsentClient) Stop() error { if err := client.unsubscribe(); err != nil { - return fmt.Errorf("[%s] error unsubscribing for OwnerConsent messages: %w", client.Domain(), err) + return fmt.Errorf("[%s] error unsubscribing for OwnerConsentFeedback messages: %w", client.Domain(), err) } - logger.Debug("[%s] unsubscribed for OwnerConsent messages", client.Domain()) + logger.Debug("[%s] unsubscribed for OwnerConsentFeedback messages", client.Domain()) client.handler = nil return nil } func (client *ownerConsentClient) subscribe() error { - logger.Debug("subscribing for '%v' topic", client.topicOwnerConsent) - token := client.pahoClient.Subscribe(client.topicOwnerConsent, 1, client.handleMessage) + logger.Debug("subscribing for '%v' topic", client.topicOwnerConsentFeedback) + token := client.pahoClient.Subscribe(client.topicOwnerConsentFeedback, 1, client.handleMessage) if !token.WaitTimeout(client.mqttConfig.SubscribeTimeout) { - return fmt.Errorf("cannot subscribe for topic '%s' in '%v'", client.topicOwnerConsent, client.mqttConfig.SubscribeTimeout) + return fmt.Errorf("cannot subscribe for topic '%s' in '%v'", client.topicOwnerConsentFeedback, client.mqttConfig.SubscribeTimeout) } return token.Error() } func (client *ownerConsentClient) unsubscribe() error { - logger.Debug("unsubscribing from '%s' topic", client.topicOwnerConsent) - token := client.pahoClient.Unsubscribe(client.topicOwnerConsent) + logger.Debug("unsubscribing from '%s' topic", client.topicOwnerConsentFeedback) + token := client.pahoClient.Unsubscribe(client.topicOwnerConsentFeedback) if !token.WaitTimeout(client.mqttConfig.UnsubscribeTimeout) { - return fmt.Errorf("cannot unsubscribe from topic '%s' in '%v'", client.topicOwnerConsent, client.mqttConfig.UnsubscribeTimeout) + return fmt.Errorf("cannot unsubscribe from topic '%s' in '%v'", client.topicOwnerConsentFeedback, client.mqttConfig.UnsubscribeTimeout) } return token.Error() } @@ -87,28 +87,28 @@ func (client *ownerConsentClient) unsubscribe() error { func (client *ownerConsentClient) handleMessage(mqttClient pahomqtt.Client, message pahomqtt.Message) { topic := message.Topic() logger.Debug("[%s] received %s message", client.Domain(), topic) - if topic == client.topicOwnerConsent { - ownerConsent := &types.OwnerConsent{} + if topic == client.topicOwnerConsentFeedback { + ownerConsent := &types.OwnerConsentFeedback{} envelope, err := types.FromEnvelope(message.Payload(), ownerConsent) if err != nil { logger.ErrorErr(err, "[%s] cannot parse owner consent message", client.Domain()) return } - if err := client.handler.HandleOwnerConsent(envelope.ActivityID, envelope.Timestamp, ownerConsent); err != nil { + if err := client.handler.HandleOwnerConsentFeedback(envelope.ActivityID, envelope.Timestamp, ownerConsent); err != nil { logger.ErrorErr(err, "[%s] error processing owner consent message", client.Domain()) } } } -func (client *ownerConsentClient) SendOwnerConsentGet(activityID string) error { - logger.Debug("publishing to topic '%s'", client.topicOwnerConsentGet) - consentGetBytes, err := types.ToEnvelope(activityID, nil) +func (client *ownerConsentClient) SendOwnerConsent(activityID string, consent *types.OwnerConsent) error { + logger.Debug("publishing to topic '%s'", client.topicOwnerConsent) + consentGetBytes, err := types.ToEnvelope(activityID, consent) if err != nil { - return errors.Wrapf(err, "cannot marshal owner consent get message for activity-id %s", activityID) + return errors.Wrapf(err, "cannot marshal owner consent message for activity-id %s", activityID) } - token := client.pahoClient.Publish(client.topicOwnerConsentGet, 1, false, consentGetBytes) + token := client.pahoClient.Publish(client.topicOwnerConsent, 1, false, consentGetBytes) if !token.WaitTimeout(client.mqttConfig.AcknowledgeTimeout) { - return fmt.Errorf("cannot publish to topic '%s' in '%v'", client.topicOwnerConsentGet, client.mqttConfig.AcknowledgeTimeout) + return fmt.Errorf("cannot publish to topic '%s' in '%v'", client.topicOwnerConsent, client.mqttConfig.AcknowledgeTimeout) } return token.Error() } diff --git a/mqtt/update_agent_client.go b/mqtt/update_agent_client.go index 1b1575b..bcbe937 100755 --- a/mqtt/update_agent_client.go +++ b/mqtt/update_agent_client.go @@ -37,8 +37,8 @@ const ( suffixCurrentState = "/currentstate" suffixCurrentStateGet = "/currentstate/get" suffixDesiredStateFeedback = "/desiredstatefeedback" - suffixOwnerConsentGet = "/ownerconsent/get" suffixOwnerConsent = "/ownerconsent" + suffixOwnerConsentFeedback = "/ownerconsentfeedback" disconnectQuiesce uint = 10000 ) @@ -82,12 +82,12 @@ type mqttClient struct { // UM incoming topics topicCurrentState string topicDesiredStateFeedback string - topicOwnerConsent string + topicOwnerConsentFeedback string // UM outgoing topics topicDesiredState string topicDesiredStateCommand string topicCurrentStateGet string - topicOwnerConsentGet string + topicOwnerConsent string } func newInternalClient(domain string, config *internalConnectionConfig, pahoClient pahomqtt.Client) *mqttClient { @@ -102,7 +102,7 @@ func newInternalClient(domain string, config *internalConnectionConfig, pahoClie topicDesiredStateCommand: mqttPrefix + suffixDesiredStateCommand, topicDesiredStateFeedback: mqttPrefix + suffixDesiredStateFeedback, topicOwnerConsent: mqttPrefix + suffixOwnerConsent, - topicOwnerConsentGet: mqttPrefix + suffixOwnerConsentGet, + topicOwnerConsentFeedback: mqttPrefix + suffixOwnerConsentFeedback, } } diff --git a/test/mocks/client_mock.go b/test/mocks/client_mock.go index 1952171..422e867 100644 --- a/test/mocks/client_mock.go +++ b/test/mocks/client_mock.go @@ -414,18 +414,18 @@ func (m *MockOwnerConsentAgentHandler) EXPECT() *MockOwnerConsentAgentHandlerMoc return m.recorder } -// HandleOwnerConsentGet mocks base method. -func (m *MockOwnerConsentAgentHandler) HandleOwnerConsentGet(arg0 string, arg1 int64, arg2 *types.OwnerConsent) error { +// HandleOwnerConsent mocks base method. +func (m *MockOwnerConsentAgentHandler) HandleOwnerConsent(arg0 string, arg1 int64, arg2 *types.OwnerConsent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandleOwnerConsentGet", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "HandleOwnerConsent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } -// HandleOwnerConsentGet indicates an expected call of HandleOwnerConsentGet. -func (mr *MockOwnerConsentAgentHandlerMockRecorder) HandleOwnerConsentGet(arg0, arg1, arg2 interface{}) *gomock.Call { +// HandleOwnerConsent indicates an expected call of HandleOwnerConsent. +func (mr *MockOwnerConsentAgentHandlerMockRecorder) HandleOwnerConsent(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleOwnerConsentGet", reflect.TypeOf((*MockOwnerConsentAgentHandler)(nil).HandleOwnerConsentGet), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleOwnerConsent", reflect.TypeOf((*MockOwnerConsentAgentHandler)(nil).HandleOwnerConsent), arg0, arg1, arg2) } // MockOwnerConsentAgentClient is a mock of OwnerConsentAgentClient interface. @@ -465,18 +465,18 @@ func (mr *MockOwnerConsentAgentClientMockRecorder) Domain() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Domain", reflect.TypeOf((*MockOwnerConsentAgentClient)(nil).Domain)) } -// SendOwnerConsent mocks base method. -func (m *MockOwnerConsentAgentClient) SendOwnerConsent(arg0 string, arg1 *types.OwnerConsent) error { +// SendOwnerConsentFeedback mocks base method. +func (m *MockOwnerConsentAgentClient) SendOwnerConsentFeedback(arg0 string, arg1 *types.OwnerConsentFeedback) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendOwnerConsent", arg0, arg1) + ret := m.ctrl.Call(m, "SendOwnerConsentFeedback", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// SendOwnerConsent indicates an expected call of SendOwnerConsent. -func (mr *MockOwnerConsentAgentClientMockRecorder) SendOwnerConsent(arg0, arg1 interface{}) *gomock.Call { +// SendOwnerConsentFeedback indicates an expected call of SendOwnerConsentFeedback. +func (mr *MockOwnerConsentAgentClientMockRecorder) SendOwnerConsentFeedback(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendOwnerConsent", reflect.TypeOf((*MockOwnerConsentAgentClient)(nil).SendOwnerConsent), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendOwnerConsentFeedback", reflect.TypeOf((*MockOwnerConsentAgentClient)(nil).SendOwnerConsentFeedback), arg0, arg1) } // Start mocks base method. @@ -530,18 +530,18 @@ func (m *MockOwnerConsentHandler) EXPECT() *MockOwnerConsentHandlerMockRecorder return m.recorder } -// HandleOwnerConsent mocks base method. -func (m *MockOwnerConsentHandler) HandleOwnerConsent(arg0 string, arg1 int64, arg2 *types.OwnerConsent) error { +// HandleOwnerConsentFeedback mocks base method. +func (m *MockOwnerConsentHandler) HandleOwnerConsentFeedback(arg0 string, arg1 int64, arg2 *types.OwnerConsentFeedback) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandleOwnerConsent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "HandleOwnerConsentFeedback", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } -// HandleOwnerConsent indicates an expected call of HandleOwnerConsent. -func (mr *MockOwnerConsentHandlerMockRecorder) HandleOwnerConsent(arg0, arg1, arg2 interface{}) *gomock.Call { +// HandleOwnerConsentFeedback indicates an expected call of HandleOwnerConsentFeedback. +func (mr *MockOwnerConsentHandlerMockRecorder) HandleOwnerConsentFeedback(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleOwnerConsent", reflect.TypeOf((*MockOwnerConsentHandler)(nil).HandleOwnerConsent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleOwnerConsentFeedback", reflect.TypeOf((*MockOwnerConsentHandler)(nil).HandleOwnerConsentFeedback), arg0, arg1, arg2) } // MockOwnerConsentClient is a mock of OwnerConsentClient interface. @@ -581,18 +581,18 @@ func (mr *MockOwnerConsentClientMockRecorder) Domain() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Domain", reflect.TypeOf((*MockOwnerConsentClient)(nil).Domain)) } -// SendOwnerConsentGet mocks base method. -func (m *MockOwnerConsentClient) SendOwnerConsentGet(arg0 string) error { +// SendOwnerConsent mocks base method. +func (m *MockOwnerConsentClient) SendOwnerConsent(arg0 string, arg1 *types.OwnerConsent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendOwnerConsentGet", arg0) + ret := m.ctrl.Call(m, "SendOwnerConsent", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// SendOwnerConsentGet indicates an expected call of SendOwnerConsentGet. -func (mr *MockOwnerConsentClientMockRecorder) SendOwnerConsentGet(arg0 interface{}) *gomock.Call { +// SendOwnerConsent indicates an expected call of SendOwnerConsent. +func (mr *MockOwnerConsentClientMockRecorder) SendOwnerConsent(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendOwnerConsentGet", reflect.TypeOf((*MockOwnerConsentClient)(nil).SendOwnerConsentGet), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendOwnerConsent", reflect.TypeOf((*MockOwnerConsentClient)(nil).SendOwnerConsent), arg0, arg1) } // Start mocks base method. diff --git a/test/mocks/update_orchestrator_mock.go b/test/mocks/update_orchestrator_mock.go index c32654f..5b87efe 100644 --- a/test/mocks/update_orchestrator_mock.go +++ b/test/mocks/update_orchestrator_mock.go @@ -74,16 +74,16 @@ func (mr *MockUpdateOrchestratorMockRecorder) HandleDesiredStateFeedbackEvent(do return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleDesiredStateFeedbackEvent", reflect.TypeOf((*MockUpdateOrchestrator)(nil).HandleDesiredStateFeedbackEvent), domain, activityID, baseline, status, message, actions) } -// HandleOwnerConsent mocks base method. -func (m *MockUpdateOrchestrator) HandleOwnerConsent(arg0 string, arg1 int64, arg2 *types.OwnerConsent) error { +// HandleOwnerConsentFeedback mocks base method. +func (m *MockUpdateOrchestrator) HandleOwnerConsentFeedback(arg0 string, arg1 int64, arg2 *types.OwnerConsentFeedback) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HandleOwnerConsent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } -// HandleOwnerConsent indicates an expected call of HandleOwnerConsent. -func (mr *MockUpdateOrchestratorMockRecorder) HandleOwnerConsent(arg0, arg1, arg2 interface{}) *gomock.Call { +// HandleOwnerConsentFeedback indicates an expected call of HandleOwnerConsentFeedback. +func (mr *MockUpdateOrchestratorMockRecorder) HandleOwnerConsentFeedback(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleOwnerConsent", reflect.TypeOf((*MockUpdateOrchestrator)(nil).HandleOwnerConsent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleOwnerConsentFeedback", reflect.TypeOf((*MockUpdateOrchestrator)(nil).HandleOwnerConsentFeedback), arg0, arg1, arg2) } diff --git a/updatem/orchestration/update_operation.go b/updatem/orchestration/update_operation.go index 3b515e0..ed4f836 100644 --- a/updatem/orchestration/update_operation.go +++ b/updatem/orchestration/update_operation.go @@ -34,7 +34,8 @@ type updateOperation struct { desiredState *types.DesiredState statesPerDomain map[api.UpdateManager]*types.DesiredState - phaseChannels map[phase]chan bool + commandChannels map[types.CommandType]chan bool + done chan bool errChan chan bool errMsg string @@ -74,7 +75,9 @@ func newUpdateOperation(domainAgents map[string]api.UpdateManager, activityID st statesPerDomain: statesPerDomain, desiredState: desiredState, - phaseChannels: generatePhaseChannels(), + commandChannels: generateCommandChannels(), + + done: make(chan bool, 1), errChan: make(chan bool, 1), ownerConsented: make(chan bool), @@ -90,10 +93,10 @@ func (operation *updateOperation) updateStatus(status types.StatusType) { operation.status = status } -func generatePhaseChannels() map[phase]chan bool { - phaseChannels := make(map[phase]chan bool, len(orderedPhases)) - for _, phase := range orderedPhases { - phaseChannels[phase] = make(chan bool, 1) +func generateCommandChannels() map[types.CommandType]chan bool { + commandChannels := make(map[types.CommandType]chan bool, len(orderedCommands)) + for _, command := range orderedCommands { + commandChannels[command] = make(chan bool, 1) } - return phaseChannels + return commandChannels } diff --git a/updatem/orchestration/update_operation_test.go b/updatem/orchestration/update_operation_test.go index cd844d7..adb9318 100644 --- a/updatem/orchestration/update_operation_test.go +++ b/updatem/orchestration/update_operation_test.go @@ -59,7 +59,7 @@ func TestNewUpdateOperation(t *testing.T) { assert.Equal(t, types.StatusIdentifying, testOp.domains["domain2"]) assert.NotNil(t, testOp.actions) assert.NotNil(t, testOp.statesPerDomain) - assert.NotNil(t, testOp.phaseChannels) + assert.NotNil(t, testOp.commandChannels) assert.False(t, testOp.rebootRequired) assert.Equal(t, handler, testOp.desiredStateCallback) }) diff --git a/updatem/orchestration/update_orchestrator.go b/updatem/orchestration/update_orchestrator.go index 138e2b9..67feee8 100644 --- a/updatem/orchestration/update_orchestrator.go +++ b/updatem/orchestration/update_orchestrator.go @@ -78,10 +78,13 @@ func (orchestrator *updateOrchestrator) Apply(ctx context.Context, domainAgents } rebootRequired, applyErr := orchestrator.apply(ctx) + if applyErr != nil { + logger.Error("failed to apply '%s' desired state: %v", activityID, applyErr) + } return rebootRequired } -func (orchestrator *updateOrchestrator) HandleOwnerConsent(activityID string, timestamp int64, consent *types.OwnerConsent) error { +func (orchestrator *updateOrchestrator) HandleOwnerConsentFeedback(activityID string, timestamp int64, consent *types.OwnerConsentFeedback) error { if orchestrator.operation != nil && activityID == orchestrator.operation.activityID { logger.Info("owner consent received with status: %v, timestamp: %d", consent.Status, timestamp) orchestrator.operation.ownerConsented <- consent.Status == types.StatusApproved diff --git a/updatem/orchestration/update_orchestrator_apply.go b/updatem/orchestration/update_orchestrator_apply.go index ee64e8a..b6b2f4e 100644 --- a/updatem/orchestration/update_orchestrator_apply.go +++ b/updatem/orchestration/update_orchestrator_apply.go @@ -23,6 +23,8 @@ import ( "github.com/eclipse-kanto/update-manager/logger" ) +var orderedCommands = []types.CommandType{types.CommandDownload, types.CommandUpdate, types.CommandActivate, types.CommandCleanup} + func (orchestrator *updateOrchestrator) apply(ctx context.Context) (bool, error) { orchestrator.notifyFeedback(types.StatusIdentifying, "") for updateManagerForDomain, statePerDomain := range orchestrator.operation.statesPerDomain { @@ -31,44 +33,57 @@ func (orchestrator *updateOrchestrator) apply(ctx context.Context) (bool, error) }(updateManagerForDomain, statePerDomain) } - wait, err := orchestrator.waitPhase(ctx, phaseIdentification, handlePhaseCompletion) + // send DOWNLOAD command when identification is done + running, err := orchestrator.waitCommandSignal(ctx, types.CommandDownload, handleCommandSignal) if err != nil { return false, err } - - for i := 1; i < len(orderedPhases) && wait; i++ { - wait, err = orchestrator.waitPhase(ctx, orderedPhases[i], handlePhaseCompletion) + // send the rest of the commands in order + for i := 1; i < len(orderedCommands) && running; i++ { + running, err = orchestrator.waitCommandSignal(ctx, orderedCommands[i], handleCommandSignal) + } + // wait for the last command(CLEANUP) to finish + if running { + _, _, err = orchestrator.waitSignal(ctx, orchestrator.operation.done) } return orchestrator.operation.rebootRequired && orchestrator.operation.status == types.StatusCompleted, err } -type phaseHandler func(ctx context.Context, phase phase, orchestrator *updateOrchestrator) +type commandSignalHandler func(ctx context.Context, command types.CommandType, orchestrator *updateOrchestrator) + +func (orchestrator *updateOrchestrator) waitCommandSignal(ctx context.Context, command types.CommandType, handle commandSignalHandler) (bool, error) { + signalValue, timeout, err := orchestrator.waitSignal(ctx, orchestrator.operation.commandChannels[command]) + if err != nil { + if timeout { + if command == types.CommandDownload { + orchestrator.operation.updateStatus(types.StatusIdentificationFailed) + } else { + orchestrator.operation.updateStatus(types.StatusIncomplete) + } + } + return false, fmt.Errorf("failed to wait for command '%s' signal: %v", command, err) + } + if signalValue { + go handle(ctx, command, orchestrator) + } + return signalValue, nil +} -func (orchestrator *updateOrchestrator) waitPhase(ctx context.Context, currentPhase phase, handle phaseHandler) (bool, error) { +func (orchestrator *updateOrchestrator) waitSignal(ctx context.Context, signal chan bool) (bool, bool, error) { select { case <-time.After(orchestrator.phaseTimeout): - if currentPhase == phaseIdentification { - orchestrator.operation.updateStatus(types.StatusIdentificationFailed) - } else { - orchestrator.operation.updateStatus(types.StatusIncomplete) - } - return false, fmt.Errorf("%s phase not done in %v", currentPhase, orchestrator.phaseTimeout) + return false, true, fmt.Errorf("not received in %v", orchestrator.phaseTimeout) case <-orchestrator.operation.errChan: - return false, fmt.Errorf(orchestrator.operation.errMsg) - case running := <-orchestrator.operation.phaseChannels[currentPhase]: - logger.Info("the %s phase is done", currentPhase) - if running { - go handle(ctx, currentPhase, orchestrator) - return true, nil - } - return false, nil + return false, false, fmt.Errorf(orchestrator.operation.errMsg) + case value := <-signal: + return value, false, nil case <-ctx.Done(): orchestrator.operation.updateStatus(types.StatusIncomplete) - return false, fmt.Errorf("the update manager instance is terminated") + return false, false, fmt.Errorf("the update manager instance is terminated") } } -func handlePhaseCompletion(ctx context.Context, completedPhase phase, orchestrator *updateOrchestrator) { +func handleCommandSignal(ctx context.Context, command types.CommandType, orchestrator *updateOrchestrator) { orchestrator.operationLock.Lock() defer orchestrator.operationLock.Unlock() @@ -76,14 +91,14 @@ func handlePhaseCompletion(ctx context.Context, completedPhase phase, orchestrat return } - if err := orchestrator.getOwnerConsent(ctx, completedPhase); err != nil { + if err := orchestrator.getOwnerConsent(ctx, command); err != nil { // should a rollback be performed at this point? - orchestrator.operation.errChan <- true orchestrator.operation.errMsg = err.Error() + orchestrator.operation.errChan <- true return } - executeCommand := func(status types.StatusType, command types.CommandType) { + executeCommand := func(status types.StatusType) { for domain, domainStatus := range orchestrator.operation.domains { if domainStatus == status { orchestrator.command(ctx, orchestrator.operation.activityID, domain, command) @@ -91,28 +106,27 @@ func handlePhaseCompletion(ctx context.Context, completedPhase phase, orchestrat } } - switch completedPhase { - case phaseIdentification: - executeCommand(types.StatusIdentified, types.CommandDownload) - case phaseDownload: - executeCommand(types.BaselineStatusDownloadSuccess, types.CommandUpdate) - case phaseUpdate: - executeCommand(types.BaselineStatusUpdateSuccess, types.CommandActivate) - case phaseActivation: - executeCommand(types.BaselineStatusActivationSuccess, types.CommandCleanup) - case phaseCleanup: + switch command { + case types.CommandDownload: + executeCommand(types.StatusIdentified) + case types.CommandUpdate: + executeCommand(types.BaselineStatusDownloadSuccess) + case types.CommandActivate: + executeCommand(types.BaselineStatusUpdateSuccess) + case types.CommandCleanup: + executeCommand(types.BaselineStatusActivationSuccess) + case types.CommandRollback: // nothing to do default: - logger.Error("unknown phase %s", completedPhase) + logger.Error("unknown command %s", command) } } -func (orchestrator *updateOrchestrator) getOwnerConsent(ctx context.Context, completedPhase phase) error { - nextPhase := completedPhase.next() - if nextPhase == "" || !slices.Contains(orchestrator.cfg.OwnerConsentPhases, string(nextPhase)) { +func (orchestrator *updateOrchestrator) getOwnerConsent(ctx context.Context, command types.CommandType) error { + if command == "" || !slices.Contains(orchestrator.cfg.OwnerConsentCommands, command) { return nil } - if nextPhase == phaseCleanup || nextPhase == phaseIdentification { + if command == types.CommandRollback || command == types.CommandCleanup { // no need for owner consent return nil } @@ -130,7 +144,7 @@ func (orchestrator *updateOrchestrator) getOwnerConsent(ctx context.Context, com } }() - if err := orchestrator.ownerConsentClient.SendOwnerConsentGet(orchestrator.operation.activityID); err != nil { + if err := orchestrator.ownerConsentClient.SendOwnerConsent(orchestrator.operation.activityID, &types.OwnerConsent{Command: command}); err != nil { return err } diff --git a/updatem/orchestration/update_orchestrator_apply_test.go b/updatem/orchestration/update_orchestrator_apply_test.go index c691c55..8c9308a 100644 --- a/updatem/orchestration/update_orchestrator_apply_test.go +++ b/updatem/orchestration/update_orchestrator_apply_test.go @@ -47,7 +47,7 @@ func TestApply(t *testing.T) { operation: &updateOperation{ desiredState: &types.DesiredState{}, desiredStateCallback: eventCallback, - phaseChannels: generatePhaseChannels(), + commandChannels: generateCommandChannels(), activityID: test.ActivityID, actions: map[string]map[string]*types.Action{ "action1": { @@ -64,8 +64,8 @@ func TestApply(t *testing.T) { mockUpdateManager: statePerDomain, } - orchestrator.operation.phaseChannels[phaseIdentification] <- true - orchestrator.operation.phaseChannels[phaseDownload] <- false + orchestrator.operation.commandChannels[types.CommandDownload] <- true + orchestrator.operation.commandChannels[types.CommandUpdate] <- false expectedActions := []*types.Action{ { Message: "testMsg", @@ -130,7 +130,7 @@ func TestApply(t *testing.T) { go applyCall(ctx, orchestrator, doneChan, successReturnChan, errReturnChan) <-applyChan - assert.Equal(t, fmt.Errorf("testErrMsg"), <-errReturnChan) + assert.Equal(t, fmt.Errorf("failed to wait for command 'DOWNLOAD' signal: testErrMsg"), <-errReturnChan) assert.False(t, orchestrator.operation.rebootRequired) assert.False(t, <-successReturnChan) <-doneChan @@ -144,16 +144,16 @@ func applyCall(ctx context.Context, orchestrator *updateOrchestrator, done chan done <- true } -func TestWaitPhase(t *testing.T) { +func TestWaitCommandSignal(t *testing.T) { const ( - phaseDone = "1" - errChan = "2" - none = "4" + commandDone = "1" + errChan = "2" + none = "3" ) testCases := map[string]struct { ctx context.Context testChan string - phase phase + command types.CommandType phaseDone bool terminateContext bool expectedWait bool @@ -163,41 +163,43 @@ func TestWaitPhase(t *testing.T) { "test_case_errChan": { ctx: context.Background(), testChan: errChan, - expectedErr: fmt.Errorf("testErrMsg"), + expectedErr: fmt.Errorf("failed to wait for command 'CLEANUP' signal: testErrMsg"), + command: types.CommandCleanup, expectedStatus: types.StatusIdentifying, }, - "test_case_phaseDone_identification": { + "test_case_command_download": { ctx: context.Background(), - testChan: phaseDone, - phase: phaseIdentification, + testChan: commandDone, + command: types.CommandDownload, expectedWait: true, expectedStatus: types.StatusIdentifying, }, - "test_case_phaseDone_cleanup": { + "test_case_command_": { ctx: context.Background(), - testChan: phaseDone, - phase: phaseCleanup, + testChan: commandDone, + command: types.CommandActivate, expectedStatus: types.StatusIdentifying, }, "test_case_terminateContext": { ctx: context.Background(), testChan: none, - expectedErr: fmt.Errorf("the update manager instance is terminated"), + command: types.CommandDownload, + expectedErr: fmt.Errorf("failed to wait for command 'DOWNLOAD' signal: the update manager instance is terminated"), expectedStatus: types.StatusIncomplete, terminateContext: true, }, - "test_case_timeout_identification": { + "test_case_timeout_download": { ctx: context.Background(), testChan: none, - phase: phaseIdentification, - expectedErr: fmt.Errorf("identification phase not done in 1s"), + command: types.CommandDownload, + expectedErr: fmt.Errorf("failed to wait for command 'DOWNLOAD' signal: not received in 1s"), expectedStatus: types.StatusIdentificationFailed, }, - "test_case_timeout_download": { + "test_case_timeout_update": { ctx: context.Background(), testChan: none, - phase: phaseDownload, - expectedErr: fmt.Errorf("download phase not done in 1s"), + command: types.CommandUpdate, + expectedErr: fmt.Errorf("failed to wait for command 'UPDATE' signal: not received in 1s"), expectedStatus: types.StatusIncomplete, }, } @@ -206,27 +208,27 @@ func TestWaitPhase(t *testing.T) { t.Run(testName, func(t *testing.T) { orchestrator := &updateOrchestrator{ operation: &updateOperation{ - errChan: make(chan bool, 1), - phaseChannels: generatePhaseChannels(), - errMsg: "testErrMsg", - status: types.StatusIdentifying, + errChan: make(chan bool, 1), + commandChannels: generateCommandChannels(), + errMsg: "testErrMsg", + status: types.StatusIdentifying, }, phaseTimeout: time.Second, } wg := sync.WaitGroup{} - var phaseHandler phaseHandler + var commandHandler commandSignalHandler if testCase.testChan == errChan { orchestrator.operation.errChan <- true - } else if testCase.testChan == phaseDone { - if testCase.phase == phaseIdentification { + } else if testCase.testChan == commandDone { + if testCase.command == types.CommandDownload { wg.Add(1) - phaseHandler = func(ctx context.Context, phase phase, orchestrator *updateOrchestrator) { + commandHandler = func(ctx context.Context, command types.CommandType, orchestrator *updateOrchestrator) { wg.Done() } } - orchestrator.operation.phaseChannels[testCase.phase] <- testCase.expectedWait + orchestrator.operation.commandChannels[testCase.command] <- testCase.expectedWait } var actualErr error @@ -234,9 +236,9 @@ func TestWaitPhase(t *testing.T) { if testCase.terminateContext { newContext, cancel := context.WithTimeout(testCase.ctx, time.Second) cancel() - actualWait, actualErr = orchestrator.waitPhase(newContext, testCase.phase, phaseHandler) + actualWait, actualErr = orchestrator.waitCommandSignal(newContext, testCase.command, commandHandler) } else { - actualWait, actualErr = orchestrator.waitPhase(testCase.ctx, testCase.phase, phaseHandler) + actualWait, actualErr = orchestrator.waitCommandSignal(testCase.ctx, testCase.command, commandHandler) } assert.Equal(t, testCase.expectedErr, actualErr) @@ -248,7 +250,7 @@ func TestWaitPhase(t *testing.T) { } } -func TestHandlePhaseCompletion(t *testing.T) { +func TestHandleCommandSignal(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() mockUpdateManager := mocks.NewMockUpdateManager(mockCtrl) @@ -278,73 +280,54 @@ func TestHandlePhaseCompletion(t *testing.T) { } testCases := map[string]struct { - noOperation bool - noConsentClient bool - domainStatus1 types.StatusType - domainStatus2 types.StatusType - phase phase - expectedCalls func() + noOperation bool + domainStatus1 types.StatusType + domainStatus2 types.StatusType + command types.CommandType + expectedCalls func() }{ - "test_handle_phase_completion_identify": { + "test_handle_command_signal_download": { domainStatus1: types.StatusIdentified, domainStatus2: types.StatusIdentified, - phase: phaseIdentification, + command: types.CommandDownload, expectedCalls: mockCommand(mockUpdateManager, types.CommandDownload, testDomain1, testDomain2), }, - "test_handle_phase_completion_download": { + "test_handle_command_signal_update": { domainStatus1: types.BaselineStatusDownloadSuccess, domainStatus2: types.BaselineStatusDownloadFailure, - phase: phaseDownload, + command: types.CommandUpdate, expectedCalls: mockCommand(mockUpdateManager, types.CommandUpdate, testDomain1), }, - "test_handle_phase_completion_update": { + "test_handle_command_signal_activate": { domainStatus1: types.BaselineStatusUpdateSuccess, domainStatus2: types.BaselineStatusUpdateFailure, - phase: phaseUpdate, + command: types.CommandActivate, expectedCalls: mockCommand(mockUpdateManager, types.CommandActivate, testDomain1), }, - "test_handle_phase_completion_activate": { + "test_handle_command_signal_cleanup": { domainStatus1: types.BaselineStatusActivationSuccess, domainStatus2: types.BaselineStatusActivationFailure, - phase: phaseActivation, + command: types.CommandCleanup, expectedCalls: mockCommand(mockUpdateManager, types.CommandCleanup, testDomain1), }, - "test_handle_phase_completion_cleanup": { - domainStatus1: types.BaselineStatusCleanupSuccess, - domainStatus2: types.BaselineStatusCleanupFailure, - phase: phaseCleanup, - expectedCalls: func() {}, - }, - "test_handle_phase_completion_update_failure": { + "test_handle_command_signal_activate_failure": { domainStatus1: types.BaselineStatusUpdateFailure, domainStatus2: types.BaselineStatusUpdateFailure, - phase: phaseUpdate, + command: types.CommandActivate, expectedCalls: mockCommand(mockUpdateManager, types.CommandActivate), }, - "test_handle_phase_completion_no_operation": { + "test_handle_command_signal_no_operation": { noOperation: true, }, - "test_handle_phase_completion_unknown_phase": { + "test_handle_command_signal_unknown_command": { domainStatus1: types.BaselineStatusCleanupSuccess, domainStatus2: types.BaselineStatusCleanupFailure, - phase: phase("unknown"), + command: types.CommandType("unknown"), expectedCalls: func() {}, }, - "test_handle_phase_completion_consent_error": { - noConsentClient: true, - domainStatus1: types.StatusIdentified, - phase: phaseIdentification, - expectedCalls: func() {}, - }, } for testName, testCase := range testCases { t.Run(testName, func(t *testing.T) { - if testCase.noConsentClient { - orchestrator.cfg.OwnerConsentPhases = []string{"download"} - go func() { - <-orchestrator.operation.errChan - }() - } if testCase.noOperation { orchestrator.operation = nil } else { @@ -353,7 +336,7 @@ func TestHandlePhaseCompletion(t *testing.T) { orchestrator.operation.domains[testDomain2] = testCase.domainStatus2 testCase.expectedCalls() } - handlePhaseCompletion(context.Background(), testCase.phase, orchestrator) + handleCommandSignal(context.Background(), testCase.command, orchestrator) }) } } @@ -413,12 +396,14 @@ func TestSetupUpdateOperation(t *testing.T) { err := orchestrator.setupUpdateOperation(domainAgents, test.ActivityID, test.DesiredState, handler) - assert.NotNil(t, orchestrator.operation.phaseChannels) + assert.NotNil(t, orchestrator.operation.commandChannels) assert.NotNil(t, orchestrator.operation.errChan) + assert.NotNil(t, orchestrator.operation.done) assert.NotNil(t, orchestrator.operation.ownerConsented) orchestrator.operation.errChan = nil - orchestrator.operation.phaseChannels = nil + orchestrator.operation.done = nil + orchestrator.operation.commandChannels = nil orchestrator.operation.ownerConsented = nil assert.Equal(t, expectedOp, orchestrator.operation) @@ -465,111 +450,3 @@ func TestDisposeUpdateOperation(t *testing.T) { assert.Nil(t, orchestrator.operation) }) } - -func TestGetOwnerConsent(t *testing.T) { - tests := map[string]struct { - updateOrchestrator *updateOrchestrator - currentPhase phase - expectedErr error - mock func(*gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) - }{ - "test_no_next_phase": { - updateOrchestrator: &updateOrchestrator{}, - currentPhase: phaseCleanup, - }, - "test_consent_not_needed": { - updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, - currentPhase: phaseUpdate, - }, - "test_no_consent_for_cleanup": { - updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"cleanup"}}}, - currentPhase: phaseActivation, - }, - "test_no_owner_consent_client": { - updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, - currentPhase: phaseIdentification, - expectedErr: fmt.Errorf("owner consent client not available"), - }, - "test_owner_consent_client_start_err": { - updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, - currentPhase: phaseIdentification, - expectedErr: fmt.Errorf("start error"), - mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { - mockClient := mocks.NewMockOwnerConsentClient(ctrl) - mockClient.EXPECT().Start(gomock.Any()).Return(fmt.Errorf("start error")) - return mockClient, nil - }, - }, - "test_owner_consent_client_send_err": { - updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, - currentPhase: phaseIdentification, - expectedErr: fmt.Errorf("send error"), - mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { - mockClient := mocks.NewMockOwnerConsentClient(ctrl) - mockClient.EXPECT().Start(gomock.Any()).Return(nil) - mockClient.EXPECT().Stop().Return(nil) - mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID).Return(fmt.Errorf("send error")) - return mockClient, nil - }, - }, - "test_owner_consent_approved": { - updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, - currentPhase: phaseIdentification, - mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { - mockClient := mocks.NewMockOwnerConsentClient(ctrl) - mockClient.EXPECT().Start(gomock.Any()).Return(nil) - mockClient.EXPECT().Stop().Return(nil) - mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID).Return(nil) - ch := make(chan bool) - go func() { - ch <- true - }() - return mockClient, ch - }, - }, - "test_owner_consent_denied": { - updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, - currentPhase: phaseIdentification, - expectedErr: fmt.Errorf("owner approval not granted"), - mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { - mockClient := mocks.NewMockOwnerConsentClient(ctrl) - mockClient.EXPECT().Start(gomock.Any()).Return(nil) - mockClient.EXPECT().Stop().Return(nil) - mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID).Return(nil) - ch := make(chan bool) - go func() { - ch <- false - }() - return mockClient, ch - }, - }, - "test_owner_consent_timeout": { - updateOrchestrator: &updateOrchestrator{cfg: &config.Config{OwnerConsentPhases: []string{"download"}}}, - currentPhase: phaseIdentification, - expectedErr: fmt.Errorf("owner consent not granted in %v", test.Interval), - mock: func(ctrl *gomock.Controller) (*mocks.MockOwnerConsentClient, chan bool) { - mockClient := mocks.NewMockOwnerConsentClient(ctrl) - mockClient.EXPECT().Start(gomock.Any()).Return(nil) - mockClient.EXPECT().Stop().Return(nil) - mockClient.EXPECT().SendOwnerConsentGet(test.ActivityID).Return(nil) - return mockClient, make(chan bool) - }, - }, - } - - for testName, testCase := range tests { - t.Run(testName, func(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - orch := testCase.updateOrchestrator - orch.operation = &updateOperation{activityID: test.ActivityID} - orch.phaseTimeout = test.Interval - if testCase.mock != nil { - orch.ownerConsentClient, orch.operation.ownerConsented = testCase.mock(mockCtrl) - } - err := orch.getOwnerConsent(context.Background(), testCase.currentPhase) - assert.Equal(t, testCase.expectedErr, err) - }) - } -} diff --git a/updatem/orchestration/update_orchestrator_feedback.go b/updatem/orchestration/update_orchestrator_feedback.go index a992d6d..1e728a5 100644 --- a/updatem/orchestration/update_orchestrator_feedback.go +++ b/updatem/orchestration/update_orchestrator_feedback.go @@ -109,11 +109,11 @@ func handleDomainIdentified(orchestrator *updateOrchestrator, domain, message st if isIdentified { orchestrator.domainUpdateRunning() orchestrator.operation.updateStatus(types.StatusRunning) - orchestrator.operation.phaseChannels[phaseIdentification] <- true + orchestrator.operation.commandChannels[types.CommandDownload] <- true } else { // no actions(status CleanupSuccess for all domains), operation is done orchestrator.operation.updateStatus(types.StatusCompleted) - orchestrator.operation.phaseChannels[phaseIdentification] <- false + orchestrator.operation.commandChannels[types.CommandDownload] <- false } } @@ -165,7 +165,7 @@ func handleDomainDownloadSuccess(orchestrator *updateOrchestrator, domain, messa return } } - orchestrator.operation.phaseChannels[phaseDownload] <- true + orchestrator.operation.commandChannels[types.CommandUpdate] <- true orchestrator.domainUpdateRunning() } @@ -207,7 +207,7 @@ func handleDomainUpdateSuccess(orchestrator *updateOrchestrator, domain, message return } } - orchestrator.operation.phaseChannels[phaseUpdate] <- true + orchestrator.operation.commandChannels[types.CommandActivate] <- true orchestrator.domainUpdateRunning() } @@ -250,7 +250,7 @@ func handleDomainActivationSuccess(orchestrator *updateOrchestrator, domain, mes return } } - orchestrator.operation.phaseChannels[phaseActivation] <- true + orchestrator.operation.commandChannels[types.CommandCleanup] <- true orchestrator.domainUpdateRunning() } @@ -331,7 +331,7 @@ func (orchestrator *updateOrchestrator) domainUpdateCompleted() { return } orchestrator.operation.updateStatus(types.StatusCompleted) - orchestrator.operation.phaseChannels[phaseCleanup] <- false + orchestrator.operation.done <- true } func (orchestrator *updateOrchestrator) domainUpdateRunning() { diff --git a/updatem/orchestration/update_orchestrator_feedback_test.go b/updatem/orchestration/update_orchestrator_feedback_test.go index 51853c8..85368b7 100644 --- a/updatem/orchestration/update_orchestrator_feedback_test.go +++ b/updatem/orchestration/update_orchestrator_feedback_test.go @@ -43,7 +43,7 @@ var testUpdateActions = []*types.Action{ }, } -type testWaitChannel func(map[phase]chan bool) error +type testWaitChannel func(map[types.CommandType]chan bool) error func TestFeedbackHandleDesiredStateFeedbackEvent(t *testing.T) { mockCtrl := gomock.NewController(t) @@ -51,14 +51,14 @@ func TestFeedbackHandleDesiredStateFeedbackEvent(t *testing.T) { eventCallback := mocks.NewMockUpdateManagerCallback(mockCtrl) mockUpdateManager := mocks.NewMockUpdateManager(mockCtrl) - waitPhaseChannel := func(currentPhase phase) testWaitChannel { - return func(phases map[phase]chan bool) error { + waitCommandChannel := func(command types.CommandType) testWaitChannel { + return func(commands map[types.CommandType]chan bool) error { timeout := time.Second select { - case <-phases[currentPhase]: + case <-commands[command]: return nil case <-time.After(timeout): - return fmt.Errorf("%s phase not done in %v", currentPhase, timeout) + return fmt.Errorf("%s command signal not received in %v", command, timeout) } } } @@ -87,7 +87,7 @@ func TestFeedbackHandleDesiredStateFeedbackEvent(t *testing.T) { eventCallback.EXPECT().HandleDesiredStateFeedbackEvent("device", test.ActivityID, "", types.StatusIdentified, "", gomock.Any()) eventCallback.EXPECT().HandleDesiredStateFeedbackEvent("device", test.ActivityID, "", types.StatusRunning, "", gomock.Any()) }, - waitChannel: waitPhaseChannel(phaseIdentification), + waitChannel: waitCommandChannel(types.CommandDownload), }, "test_handleDomainIdentified_domainUpdateStatus_StatusIdentificationFailed": { @@ -294,7 +294,7 @@ func TestFeedbackHandleDesiredStateFeedbackEvent(t *testing.T) { testCode: func() { eventCallback.EXPECT().HandleDesiredStateFeedbackEvent("device", test.ActivityID, "", types.StatusRunning, "", gomock.Any()) }, - waitChannel: waitPhaseChannel(phaseDownload), + waitChannel: waitCommandChannel(types.CommandUpdate), }, "test_handleDomainDownloadSuccess_orchestration_status_not_running": { handleStatus: types.BaselineStatusDownloadSuccess, @@ -394,7 +394,7 @@ func TestFeedbackHandleDesiredStateFeedbackEvent(t *testing.T) { testCode: func() { eventCallback.EXPECT().HandleDesiredStateFeedbackEvent("device", test.ActivityID, "", types.StatusRunning, "", gomock.Any()) }, - waitChannel: waitPhaseChannel(phaseUpdate), + waitChannel: waitCommandChannel(types.CommandActivate), }, "test_handleDomainUpdateSuccess_orchestration_status_not_running": { handleStatus: types.BaselineStatusUpdateSuccess, @@ -493,7 +493,7 @@ func TestFeedbackHandleDesiredStateFeedbackEvent(t *testing.T) { testCode: func() { eventCallback.EXPECT().HandleDesiredStateFeedbackEvent("device", test.ActivityID, "", types.StatusRunning, "", gomock.Any()) }, - waitChannel: waitPhaseChannel(phaseActivation), + waitChannel: waitCommandChannel(types.CommandCleanup), }, "test_handleDomainActivationSuccess_orchestration_status_not_running": { handleStatus: types.BaselineStatusActivationSuccess, @@ -590,7 +590,6 @@ func TestFeedbackHandleDesiredStateFeedbackEvent(t *testing.T) { expectedStatus: types.StatusCompleted, expectedDomainStatus: types.BaselineStatusCleanupSuccess, testCode: func() {}, - waitChannel: waitPhaseChannel(phaseCleanup), }, "test_handleDomainCleanupSuccess_orchestration_status_not_running": { handleStatus: types.BaselineStatusCleanupSuccess, @@ -678,7 +677,8 @@ func TestFeedbackHandleDesiredStateFeedbackEvent(t *testing.T) { desiredStateCallback: eventCallback, actions: testOperationActions, status: testCase.updateOrchStatus, - phaseChannels: generatePhaseChannels(), + commandChannels: generateCommandChannels(), + done: make(chan bool, 1), errChan: make(chan bool, 1), statesPerDomain: map[api.UpdateManager]*types.DesiredState{ mockUpdateManager: {}, @@ -691,7 +691,7 @@ func TestFeedbackHandleDesiredStateFeedbackEvent(t *testing.T) { errorChan := make(chan error, 1) go func() { if testCase.waitChannel != nil { - errorChan <- testCase.waitChannel(testUpdOrch.operation.phaseChannels) + errorChan <- testCase.waitChannel(testUpdOrch.operation.commandChannels) } errorChan <- nil }() @@ -809,10 +809,10 @@ func TestDomainUpdateCompleted(t *testing.T) { assert.Equal(t, 0, len(orchestrator.operation.errChan)) } if testCase.doneChanLen { - assert.Equal(t, 1, len(orchestrator.operation.phaseChannels[phaseCleanup])) - assert.False(t, <-orchestrator.operation.phaseChannels[phaseCleanup]) + assert.Equal(t, 1, len(orchestrator.operation.done)) + assert.True(t, <-orchestrator.operation.done) } else { - assert.Equal(t, 0, len(orchestrator.operation.phaseChannels[phaseCleanup])) + assert.Equal(t, 0, len(orchestrator.operation.done)) } assert.Equal(t, testCase.expectedStatus, orchestrator.operation.status) @@ -825,13 +825,14 @@ func TestDomainUpdateCompleted(t *testing.T) { func generateUpdOrch(cfgRebootReqired bool, actions map[string]map[string]*types.Action, delayedStatus types.StatusType, domains map[string]types.StatusType) *updateOrchestrator { return &updateOrchestrator{ operation: &updateOperation{ - actions: actions, - delayedStatus: delayedStatus, - status: types.StatusIdentifying, - errMsg: "", - errChan: make(chan bool, 1), - phaseChannels: generatePhaseChannels(), - domains: domains, + actions: actions, + delayedStatus: delayedStatus, + status: types.StatusIdentifying, + errMsg: "", + errChan: make(chan bool, 1), + done: make(chan bool, 1), + commandChannels: generateCommandChannels(), + domains: domains, }, cfg: createTestConfig(cfgRebootReqired, true), phaseTimeout: 10 * time.Minute, diff --git a/updatem/orchestration/update_orchestrator_test.go b/updatem/orchestration/update_orchestrator_test.go index 2e6856b..6a45048 100644 --- a/updatem/orchestration/update_orchestrator_test.go +++ b/updatem/orchestration/update_orchestrator_test.go @@ -86,42 +86,6 @@ func TestUpdOrchApply(t *testing.T) { }) } -func TestHandleOwnerConsent(t *testing.T) { - updateOrchestrator := &updateOrchestrator{ - operation: &updateOperation{ - activityID: test.ActivityID, - ownerConsented: make(chan bool), - }, - } - t.Run("test_handle_owner_approved", func(t *testing.T) { - go updateOrchestrator.HandleOwnerConsent(test.ActivityID, 0, &types.OwnerConsent{Status: types.StatusApproved}) - select { - case consented := <-updateOrchestrator.operation.ownerConsented: - assert.True(t, consented) - case <-time.After(1 * time.Second): - t.Fatal("owner consent not received") - } - }) - t.Run("test_handle_owner_denied", func(t *testing.T) { - go updateOrchestrator.HandleOwnerConsent(test.ActivityID, 0, &types.OwnerConsent{Status: types.StatusDenied}) - select { - case consented := <-updateOrchestrator.operation.ownerConsented: - assert.False(t, consented) - case <-time.After(1 * time.Second): - t.Fatal("owner consent not received") - } - }) - t.Run("test_handle_owner_approved_another_activity", func(t *testing.T) { - go updateOrchestrator.HandleOwnerConsent("anotherActivity", 0, &types.OwnerConsent{Status: types.StatusApproved}) - select { - case <-updateOrchestrator.operation.ownerConsented: - t.Fatal("unexpected owner consent") - case <-time.After(1 * time.Second): - // do nothing - } - }) -} - func applyDesiredState(ctx context.Context, updOrch *updateOrchestrator, done chan bool, domainAgents map[string]api.UpdateManager, activityID string, desiredState *types.DesiredState, apiDesState api.DesiredStateFeedbackHandler) { updOrch.Apply(ctx, domainAgents, activityID, desiredState, apiDesState) done <- true diff --git a/updatem/orchestration/update_phase.go b/updatem/orchestration/update_phase.go deleted file mode 100644 index 66b1b92..0000000 --- a/updatem/orchestration/update_phase.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2024 Contributors to the Eclipse Foundation -// -// See the NOTICE file(s) distributed with this work for additional -// information regarding copyright ownership. -// -// This program and the accompanying materials are made available under the -// terms of the Eclipse Public License 2.0 which is available at -// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 -// which is available at https://www.apache.org/licenses/LICENSE-2.0. -// -// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 - -package orchestration - -type phase string - -const ( - phaseIdentification phase = "identification" - phaseDownload phase = "download" - phaseUpdate phase = "update" - phaseActivation phase = "activation" - phaseCleanup phase = "cleanup" -) - -var orderedPhases = []phase{phaseIdentification, phaseDownload, phaseUpdate, phaseActivation, phaseCleanup} - -func (p phase) next() phase { - for i := 0; i < len(orderedPhases)-1; i++ { - if orderedPhases[i] == p { - return orderedPhases[i+1] - } - } - return "" -} From f2c85afd62640e976b51f9f29ef1ff83c949b9df Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Tue, 30 Apr 2024 10:58:48 +0300 Subject: [PATCH 5/6] Fixed PR comments and added rollback when owner denies the update Signed-off-by: Dimitar Dimitrov --- api/types/owner_consent.go | 4 +- config/flags_internal.go | 17 ++---- updatem/orchestration/update_operation.go | 5 +- .../update_orchestrator_apply.go | 53 +++++++++++-------- .../update_orchestrator_apply_test.go | 6 ++- .../update_orchestrator_feedback.go | 52 ++++++++++++++++++ 6 files changed, 99 insertions(+), 38 deletions(-) diff --git a/api/types/owner_consent.go b/api/types/owner_consent.go index 78fea62..bd62d7f 100644 --- a/api/types/owner_consent.go +++ b/api/types/owner_consent.go @@ -16,9 +16,9 @@ package types type ConsentStatusType string const ( - // StatusApproved denotes that the owner has consented. + // StatusApproved denotes that the owner approved the update operation. StatusApproved ConsentStatusType = "APPROVED" - // StatusDenied denotes that the owner has not consented. + // StatusDenied denotes that the owner denied the update operation. StatusDenied ConsentStatusType = "DENIED" ) diff --git a/config/flags_internal.go b/config/flags_internal.go index 897f148..2da5b0a 100755 --- a/config/flags_internal.go +++ b/config/flags_internal.go @@ -43,16 +43,12 @@ func SetupAllUpdateManagerFlags(flagSet *flag.FlagSet, cfg *Config) { flagSet.StringVar(&cfg.PhaseTimeout, "phase-timeout", EnvToString("PHASE_TIMEOUT", cfg.PhaseTimeout), "Specify the timeout for completing an Update Orchestration phase. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") flagSet.StringVar(&cfg.ReportFeedbackInterval, "report-feedback-interval", EnvToString("REPORT_FEEDBACK_INTERVAL", cfg.ReportFeedbackInterval), "Specify the time interval for reporting intermediate desired state feedback messages during an active update operation. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") flagSet.StringVar(&cfg.CurrentStateDelay, "current-state-delay", EnvToString("CURRENT_STATE_DELAY", cfg.CurrentStateDelay), "Specify the time delay for reporting current state messages. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") - flagSet.String(ownerConsentCommandsFlagID, "", ownerConsentCommandsDesc) setupAgentsConfigFlags(flagSet, cfg) } func parseFlags(cfg *Config, version string) { domains := parseDomainsFlag() prepareAgentsConfig(cfg, domains) - if ownerConsentPhases := parseOwnerConsentCommandsFlag(); len(ownerConsentPhases) > 0 { - cfg.OwnerConsentCommands = ownerConsentPhases - } flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) flagSet := flag.CommandLine @@ -60,6 +56,7 @@ func parseFlags(cfg *Config, version string) { SetupAllUpdateManagerFlags(flagSet, cfg) fVersion := flagSet.Bool("version", false, "Prints current version and exits") + listCommands := flagSet.String(ownerConsentCommandsFlagID, "", ownerConsentCommandsDesc) if err := flagSet.Parse(os.Args[1:]); err != nil { logger.ErrorErr(err, "Cannot parse command flags") } @@ -68,17 +65,13 @@ func parseFlags(cfg *Config, version string) { fmt.Println(version) os.Exit(0) } -} -func parseOwnerConsentCommandsFlag() []types.CommandType { - var listCommands string - flagSet := flag.NewFlagSet("", flag.ContinueOnError) - flagSet.SetOutput(io.Discard) - flagSet.StringVar(&listCommands, ownerConsentCommandsFlagID, EnvToString("OWNER_CONSENT_COMMANDS", ""), ownerConsentCommandsDesc) - if err := flagSet.Parse(getFlagArgs(ownerConsentCommandsFlagID)); err != nil { - logger.ErrorErr(err, "Cannot parse %s flag", ownerConsentCommandsFlagID) + if len(*listCommands) != 0 { + cfg.OwnerConsentCommands = parseOwnerConsentCommandsFlag(*listCommands) } +} +func parseOwnerConsentCommandsFlag(listCommands string) []types.CommandType { var result []types.CommandType for _, command := range strings.Split(listCommands, ",") { c := strings.TrimSpace(command) diff --git a/updatem/orchestration/update_operation.go b/updatem/orchestration/update_operation.go index ed4f836..2ae85ab 100644 --- a/updatem/orchestration/update_operation.go +++ b/updatem/orchestration/update_operation.go @@ -27,6 +27,7 @@ type updateOperation struct { statusLock sync.Mutex status types.StatusType delayedStatus types.StatusType + delayedErrMsg string domains map[string]types.StatusType actions map[string]map[string]*types.Action @@ -41,6 +42,7 @@ type updateOperation struct { errMsg string ownerConsented chan bool + rollbackChan chan bool rebootRequired bool @@ -80,7 +82,8 @@ func newUpdateOperation(domainAgents map[string]api.UpdateManager, activityID st done: make(chan bool, 1), errChan: make(chan bool, 1), - ownerConsented: make(chan bool), + ownerConsented: make(chan bool, 1), + rollbackChan: make(chan bool, 1), desiredStateCallback: desiredStateCallback, }, nil diff --git a/updatem/orchestration/update_orchestrator_apply.go b/updatem/orchestration/update_orchestrator_apply.go index b6b2f4e..d81eeba 100644 --- a/updatem/orchestration/update_orchestrator_apply.go +++ b/updatem/orchestration/update_orchestrator_apply.go @@ -34,25 +34,28 @@ func (orchestrator *updateOrchestrator) apply(ctx context.Context) (bool, error) } // send DOWNLOAD command when identification is done - running, err := orchestrator.waitCommandSignal(ctx, types.CommandDownload, handleCommandSignal) + running, rollback, err := orchestrator.waitCommandSignal(ctx, types.CommandDownload, handleCommandSignal) if err != nil { return false, err } // send the rest of the commands in order for i := 1; i < len(orderedCommands) && running; i++ { - running, err = orchestrator.waitCommandSignal(ctx, orderedCommands[i], handleCommandSignal) + if rollback && orderedCommands[i] != types.CommandCleanup { + continue + } + running, rollback, err = orchestrator.waitCommandSignal(ctx, orderedCommands[i], handleCommandSignal) } // wait for the last command(CLEANUP) to finish if running { - _, _, err = orchestrator.waitSignal(ctx, orchestrator.operation.done) + _, _, _, err = orchestrator.waitSignal(ctx, orchestrator.operation.done) } return orchestrator.operation.rebootRequired && orchestrator.operation.status == types.StatusCompleted, err } type commandSignalHandler func(ctx context.Context, command types.CommandType, orchestrator *updateOrchestrator) -func (orchestrator *updateOrchestrator) waitCommandSignal(ctx context.Context, command types.CommandType, handle commandSignalHandler) (bool, error) { - signalValue, timeout, err := orchestrator.waitSignal(ctx, orchestrator.operation.commandChannels[command]) +func (orchestrator *updateOrchestrator) waitCommandSignal(ctx context.Context, command types.CommandType, handle commandSignalHandler) (bool, bool, error) { + signalValue, rollback, timeout, err := orchestrator.waitSignal(ctx, orchestrator.operation.commandChannels[command]) if err != nil { if timeout { if command == types.CommandDownload { @@ -61,25 +64,27 @@ func (orchestrator *updateOrchestrator) waitCommandSignal(ctx context.Context, c orchestrator.operation.updateStatus(types.StatusIncomplete) } } - return false, fmt.Errorf("failed to wait for command '%s' signal: %v", command, err) + return false, false, fmt.Errorf("failed to wait for command '%s' signal: %v", command, err) } - if signalValue { + if signalValue && !rollback { go handle(ctx, command, orchestrator) } - return signalValue, nil + return signalValue, rollback, nil } -func (orchestrator *updateOrchestrator) waitSignal(ctx context.Context, signal chan bool) (bool, bool, error) { +func (orchestrator *updateOrchestrator) waitSignal(ctx context.Context, signal chan bool) (bool, bool, bool, error) { select { case <-time.After(orchestrator.phaseTimeout): - return false, true, fmt.Errorf("not received in %v", orchestrator.phaseTimeout) + return false, false, true, fmt.Errorf("not received in %v", orchestrator.phaseTimeout) case <-orchestrator.operation.errChan: - return false, false, fmt.Errorf(orchestrator.operation.errMsg) + return false, false, false, fmt.Errorf(orchestrator.operation.errMsg) + case <-orchestrator.operation.rollbackChan: + return true, true, false, nil case value := <-signal: - return value, false, nil + return value, false, false, nil case <-ctx.Done(): orchestrator.operation.updateStatus(types.StatusIncomplete) - return false, false, fmt.Errorf("the update manager instance is terminated") + return false, false, false, fmt.Errorf("the update manager instance is terminated") } } @@ -92,15 +97,21 @@ func handleCommandSignal(ctx context.Context, command types.CommandType, orchest } if err := orchestrator.getOwnerConsent(ctx, command); err != nil { - // should a rollback be performed at this point? - orchestrator.operation.errMsg = err.Error() - orchestrator.operation.errChan <- true - return + if command != types.CommandUpdate && command != types.CommandActivate { + orchestrator.operation.updateStatus(types.StatusIncomplete) + orchestrator.operation.errMsg = err.Error() + orchestrator.operation.errChan <- true + return + } + command = types.CommandRollback + orchestrator.operation.delayedStatus = types.StatusIncomplete + orchestrator.operation.delayedErrMsg = err.Error() + orchestrator.operation.rollbackChan <- true } - executeCommand := func(status types.StatusType) { + executeCommand := func(statuses ...types.StatusType) { for domain, domainStatus := range orchestrator.operation.domains { - if domainStatus == status { + if slices.Contains(statuses, domainStatus) { orchestrator.command(ctx, orchestrator.operation.activityID, domain, command) } } @@ -114,9 +125,9 @@ func handleCommandSignal(ctx context.Context, command types.CommandType, orchest case types.CommandActivate: executeCommand(types.BaselineStatusUpdateSuccess) case types.CommandCleanup: - executeCommand(types.BaselineStatusActivationSuccess) + executeCommand(types.BaselineStatusActivationSuccess, types.BaselineStatusRollbackSuccess) case types.CommandRollback: - // nothing to do + executeCommand(types.BaselineStatusDownloadSuccess, types.BaselineStatusUpdateSuccess) default: logger.Error("unknown command %s", command) } diff --git a/updatem/orchestration/update_orchestrator_apply_test.go b/updatem/orchestration/update_orchestrator_apply_test.go index 8c9308a..f55927d 100644 --- a/updatem/orchestration/update_orchestrator_apply_test.go +++ b/updatem/orchestration/update_orchestrator_apply_test.go @@ -236,9 +236,9 @@ func TestWaitCommandSignal(t *testing.T) { if testCase.terminateContext { newContext, cancel := context.WithTimeout(testCase.ctx, time.Second) cancel() - actualWait, actualErr = orchestrator.waitCommandSignal(newContext, testCase.command, commandHandler) + actualWait, _, actualErr = orchestrator.waitCommandSignal(newContext, testCase.command, commandHandler) } else { - actualWait, actualErr = orchestrator.waitCommandSignal(testCase.ctx, testCase.command, commandHandler) + actualWait, _, actualErr = orchestrator.waitCommandSignal(testCase.ctx, testCase.command, commandHandler) } assert.Equal(t, testCase.expectedErr, actualErr) @@ -400,11 +400,13 @@ func TestSetupUpdateOperation(t *testing.T) { assert.NotNil(t, orchestrator.operation.errChan) assert.NotNil(t, orchestrator.operation.done) assert.NotNil(t, orchestrator.operation.ownerConsented) + assert.NotNil(t, orchestrator.operation.rollbackChan) orchestrator.operation.errChan = nil orchestrator.operation.done = nil orchestrator.operation.commandChannels = nil orchestrator.operation.ownerConsented = nil + orchestrator.operation.rollbackChan = nil assert.Equal(t, expectedOp, orchestrator.operation) assert.Nil(t, err) diff --git a/updatem/orchestration/update_orchestrator_feedback.go b/updatem/orchestration/update_orchestrator_feedback.go index 1e728a5..9fda7da 100644 --- a/updatem/orchestration/update_orchestrator_feedback.go +++ b/updatem/orchestration/update_orchestrator_feedback.go @@ -40,6 +40,9 @@ var statusHandlers = map[types.StatusType]statusHandler{ types.BaselineStatusCleanup: handleDomainCleanup, types.BaselineStatusCleanupSuccess: handleDomainCleanupSuccess, types.BaselineStatusCleanupFailure: handleDomainCleanupFailure, + types.BaselineStatusRollback: handleDomainRollback, + types.BaselineStatusRollbackSuccess: handleDomainRollbackSuccess, + types.BaselineStatusRollbackFailure: handleDomainRollbackFailure, } func (orchestrator *updateOrchestrator) HandleDesiredStateFeedbackEvent(domain, activityID, baseline string, status types.StatusType, message string, actions []*types.Action) { @@ -278,6 +281,52 @@ func handleDomainActivating(orchestrator *updateOrchestrator, domain, message st orchestrator.domainUpdateRunning() } +func handleDomainRollbackSuccess(orchestrator *updateOrchestrator, domain, message string, actions []*types.Action) { + if orchestrator.operation.status != types.StatusRunning { + return + } + domainStatus := orchestrator.operation.domains[domain] + if domainStatus != types.BaselineStatusDownloadSuccess && domainStatus != types.BaselineStatusUpdateSuccess && + domainStatus != types.BaselineStatusRollback { + return + } + orchestrator.operation.domains[domain] = types.BaselineStatusRollbackSuccess + for _, status := range orchestrator.operation.domains { + if status == types.BaselineStatusDownloadSuccess || status == types.BaselineStatusUpdateSuccess || + status == types.BaselineStatusRollback { + return + } + } + orchestrator.operation.commandChannels[types.CommandCleanup] <- true + orchestrator.domainUpdateRunning() +} + +func handleDomainRollbackFailure(orchestrator *updateOrchestrator, domain, message string, actions []*types.Action) { + if orchestrator.operation.status != types.StatusRunning { + return + } + domainStatus := orchestrator.operation.domains[domain] + if domainStatus != types.BaselineStatusDownloadSuccess && domainStatus != types.BaselineStatusUpdateSuccess && + domainStatus != types.BaselineStatusRollback { + return + } + orchestrator.operation.delayedStatus = types.StatusIncomplete + orchestrator.operation.domains[domain] = types.BaselineStatusRollbackFailure + orchestrator.command(context.Background(), orchestrator.operation.activityID, domain, types.CommandCleanup) +} + +func handleDomainRollback(orchestrator *updateOrchestrator, domain, message string, actions []*types.Action) { + if orchestrator.operation.status != types.StatusRunning { + return + } + domainStatus := orchestrator.operation.domains[domain] + if domainStatus != types.BaselineStatusDownloadSuccess && domainStatus != types.BaselineStatusUpdateSuccess && + domainStatus != types.BaselineStatusRollback { + return + } + orchestrator.domainUpdateRunning() +} + func handleDomainCleanupSuccess(orchestrator *updateOrchestrator, domain, message string, actions []*types.Action) { domainStatus := orchestrator.operation.domains[domain] if domainStatus != types.BaselineStatusActivationSuccess && domainStatus != types.BaselineStatusDownloadFailure && @@ -327,6 +376,9 @@ func (orchestrator *updateOrchestrator) domainUpdateCompleted() { if orchestrator.operation.delayedStatus == types.StatusIncomplete { orchestrator.operation.updateStatus(types.StatusIncomplete) orchestrator.operation.errMsg = "the update process is incompleted" + if orchestrator.operation.delayedErrMsg != "" { + orchestrator.operation.errMsg = fmt.Sprintf("%s: %s", orchestrator.operation.errMsg, orchestrator.operation.delayedErrMsg) + } orchestrator.operation.errChan <- true return } From 11af7ce8c13eb58464ca48c4b3898aca4209e370 Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Mon, 13 May 2024 13:53:56 +0300 Subject: [PATCH 6/6] Add owner consent timeout Signed-off-by: Dimitar Dimitrov --- config/config_internal.go | 3 ++ config/config_test.go | 2 ++ config/flags_internal.go | 1 + config/testdata/config.json | 1 + updatem/orchestration/update_orchestrator.go | 14 +++++---- .../update_orchestrator_apply.go | 4 +-- .../update_orchestrator_apply_test.go | 30 ++++++++++++------- .../orchestration/update_orchestrator_test.go | 3 +- 8 files changed, 38 insertions(+), 20 deletions(-) diff --git a/config/config_internal.go b/config/config_internal.go index 0e8aee6..ef162bb 100755 --- a/config/config_internal.go +++ b/config/config_internal.go @@ -32,6 +32,7 @@ const ( currentStateDelayDefault = "30s" phaseTimeoutDefault = "10m" readTimeoutDefault = "1m" + ownerConsentTimeoutDefault = "30m" domainContainers = "containers" ) @@ -46,6 +47,7 @@ type Config struct { CurrentStateDelay string `json:"currentStateDelay"` PhaseTimeout string `json:"phaseTimeout"` OwnerConsentCommands []types.CommandType `json:"ownerConsentCommands"` + OwnerConsentTimeout string `json:"ownerConsentTimeout"` } func newDefaultConfig() *Config { @@ -57,6 +59,7 @@ func newDefaultConfig() *Config { ReportFeedbackInterval: reportFeedbackIntervalDefault, CurrentStateDelay: currentStateDelayDefault, PhaseTimeout: phaseTimeoutDefault, + OwnerConsentTimeout: ownerConsentTimeoutDefault, } } diff --git a/config/config_test.go b/config/config_test.go index 7ccd0a7..bc3a121 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -64,6 +64,7 @@ func TestNewDefaultConfig(t *testing.T) { ReportFeedbackInterval: "1m", CurrentStateDelay: "30s", PhaseTimeout: "10m", + OwnerConsentTimeout: "30m", } cfg := newDefaultConfig() @@ -137,6 +138,7 @@ func TestLoadConfigFromFile(t *testing.T) { ReportFeedbackInterval: "2m", CurrentStateDelay: "1m", PhaseTimeout: "2m", + OwnerConsentTimeout: "4m", OwnerConsentCommands: []types.CommandType{types.CommandDownload}, } assert.True(t, reflect.DeepEqual(*cfg, expectedConfigValues)) diff --git a/config/flags_internal.go b/config/flags_internal.go index 2da5b0a..a433811 100755 --- a/config/flags_internal.go +++ b/config/flags_internal.go @@ -43,6 +43,7 @@ func SetupAllUpdateManagerFlags(flagSet *flag.FlagSet, cfg *Config) { flagSet.StringVar(&cfg.PhaseTimeout, "phase-timeout", EnvToString("PHASE_TIMEOUT", cfg.PhaseTimeout), "Specify the timeout for completing an Update Orchestration phase. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") flagSet.StringVar(&cfg.ReportFeedbackInterval, "report-feedback-interval", EnvToString("REPORT_FEEDBACK_INTERVAL", cfg.ReportFeedbackInterval), "Specify the time interval for reporting intermediate desired state feedback messages during an active update operation. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") flagSet.StringVar(&cfg.CurrentStateDelay, "current-state-delay", EnvToString("CURRENT_STATE_DELAY", cfg.CurrentStateDelay), "Specify the time delay for reporting current state messages. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") + flagSet.StringVar(&cfg.OwnerConsentTimeout, "owner-consent-timeout", EnvToString("OWNER_CONSENT_TIMEOUT", cfg.OwnerConsentTimeout), "Specify the timeout to wait for owner consent. Value should be a positive integer number followed by a unit suffix, such as '60s', '10m', etc") setupAgentsConfigFlags(flagSet, cfg) } diff --git a/config/testdata/config.json b/config/testdata/config.json index 174bfba..3977169 100644 --- a/config/testdata/config.json +++ b/config/testdata/config.json @@ -25,6 +25,7 @@ "currentStateDelay": "1m", "phaseTimeout": "2m", "ownerConsentCommands": ["DOWNLOAD"], + "ownerConsentTimeout": "4m", "agents": { "self-update": { "rebootRequired": false, diff --git a/updatem/orchestration/update_orchestrator.go b/updatem/orchestration/update_orchestrator.go index 67feee8..155380a 100644 --- a/updatem/orchestration/update_orchestrator.go +++ b/updatem/orchestration/update_orchestrator.go @@ -28,9 +28,10 @@ type updateOrchestrator struct { operationLock sync.Mutex actionsLock sync.Mutex - cfg *config.Config - phaseTimeout time.Duration - ownerConsentClient api.OwnerConsentClient + cfg *config.Config + phaseTimeout time.Duration + ownerConsentTimeout time.Duration + ownerConsentClient api.OwnerConsentClient operation *updateOperation } @@ -42,9 +43,10 @@ func (orchestrator *updateOrchestrator) Name() string { // NewUpdateOrchestrator creates a new update orchestrator that does not handle cross-domain dependencies func NewUpdateOrchestrator(cfg *config.Config, ownerApprovalClient api.OwnerConsentClient) api.UpdateOrchestrator { ua := &updateOrchestrator{ - cfg: cfg, - phaseTimeout: util.ParseDuration("phase-timeout", cfg.PhaseTimeout, 10*time.Minute, 10*time.Minute), - ownerConsentClient: ownerApprovalClient, + cfg: cfg, + phaseTimeout: util.ParseDuration("phase-timeout", cfg.PhaseTimeout, 10*time.Minute, 10*time.Minute), + ownerConsentTimeout: util.ParseDuration("owner-consent-timeout", cfg.OwnerConsentTimeout, 30*time.Minute, 30*time.Minute), + ownerConsentClient: ownerApprovalClient, } return ua } diff --git a/updatem/orchestration/update_orchestrator_apply.go b/updatem/orchestration/update_orchestrator_apply.go index d81eeba..feb1e19 100644 --- a/updatem/orchestration/update_orchestrator_apply.go +++ b/updatem/orchestration/update_orchestrator_apply.go @@ -165,8 +165,8 @@ func (orchestrator *updateOrchestrator) getOwnerConsent(ctx context.Context, com return fmt.Errorf("owner approval not granted") } return nil - case <-time.After(orchestrator.phaseTimeout): - return fmt.Errorf("owner consent not granted in %v", orchestrator.phaseTimeout) + case <-time.After(orchestrator.ownerConsentTimeout): + return fmt.Errorf("owner consent not granted in %v", orchestrator.ownerConsentTimeout) case <-ctx.Done(): return fmt.Errorf("the update manager instance is terminated") } diff --git a/updatem/orchestration/update_orchestrator_apply_test.go b/updatem/orchestration/update_orchestrator_apply_test.go index f55927d..6ae709f 100644 --- a/updatem/orchestration/update_orchestrator_apply_test.go +++ b/updatem/orchestration/update_orchestrator_apply_test.go @@ -213,7 +213,7 @@ func TestWaitCommandSignal(t *testing.T) { errMsg: "testErrMsg", status: types.StatusIdentifying, }, - phaseTimeout: time.Second, + phaseTimeout: test.Interval, } wg := sync.WaitGroup{} @@ -253,7 +253,8 @@ func TestWaitCommandSignal(t *testing.T) { func TestHandleCommandSignal(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - mockUpdateManager := mocks.NewMockUpdateManager(mockCtrl) + mockUpdateManager1 := mocks.NewMockUpdateManager(mockCtrl) + mockUpdateManager2 := mocks.NewMockUpdateManager(mockCtrl) testDomain1 := "testName1" testDomain2 := "testName2" @@ -266,15 +267,22 @@ func TestHandleCommandSignal(t *testing.T) { errChan: make(chan bool, 1), } operation.statesPerDomain = map[api.UpdateManager]*types.DesiredState{ - mockUpdateManager: {}, + mockUpdateManager1: {}, + mockUpdateManager2: {}, } orchestrator := &updateOrchestrator{cfg: &config.Config{}} - mockCommand := func(mockUpdateManager *mocks.MockUpdateManager, command types.CommandType, domains ...string) func() { + mockUpdateManager1.EXPECT().Name().Return(testDomain1).AnyTimes() + mockUpdateManager2.EXPECT().Name().Return(testDomain2).AnyTimes() + + mockCommand := func(command types.CommandType, domains ...string) func() { return func() { for _, domain := range domains { - mockUpdateManager.EXPECT().Name().Return(domain).Times(1) - mockUpdateManager.EXPECT().Command(context.Background(), test.ActivityID, generateCommand(command)) + if testDomain1 == domain { + mockUpdateManager1.EXPECT().Command(context.Background(), test.ActivityID, generateCommand(command)) + } else if testDomain2 == domain { + mockUpdateManager2.EXPECT().Command(context.Background(), test.ActivityID, generateCommand(command)) + } } } } @@ -290,31 +298,31 @@ func TestHandleCommandSignal(t *testing.T) { domainStatus1: types.StatusIdentified, domainStatus2: types.StatusIdentified, command: types.CommandDownload, - expectedCalls: mockCommand(mockUpdateManager, types.CommandDownload, testDomain1, testDomain2), + expectedCalls: mockCommand(types.CommandDownload, testDomain2, testDomain1), }, "test_handle_command_signal_update": { domainStatus1: types.BaselineStatusDownloadSuccess, domainStatus2: types.BaselineStatusDownloadFailure, command: types.CommandUpdate, - expectedCalls: mockCommand(mockUpdateManager, types.CommandUpdate, testDomain1), + expectedCalls: mockCommand(types.CommandUpdate, testDomain1), }, "test_handle_command_signal_activate": { domainStatus1: types.BaselineStatusUpdateSuccess, domainStatus2: types.BaselineStatusUpdateFailure, command: types.CommandActivate, - expectedCalls: mockCommand(mockUpdateManager, types.CommandActivate, testDomain1), + expectedCalls: mockCommand(types.CommandActivate, testDomain1), }, "test_handle_command_signal_cleanup": { domainStatus1: types.BaselineStatusActivationSuccess, domainStatus2: types.BaselineStatusActivationFailure, command: types.CommandCleanup, - expectedCalls: mockCommand(mockUpdateManager, types.CommandCleanup, testDomain1), + expectedCalls: mockCommand(types.CommandCleanup, testDomain1), }, "test_handle_command_signal_activate_failure": { domainStatus1: types.BaselineStatusUpdateFailure, domainStatus2: types.BaselineStatusUpdateFailure, command: types.CommandActivate, - expectedCalls: mockCommand(mockUpdateManager, types.CommandActivate), + expectedCalls: func() {}, }, "test_handle_command_signal_no_operation": { noOperation: true, diff --git a/updatem/orchestration/update_orchestrator_test.go b/updatem/orchestration/update_orchestrator_test.go index 6a45048..0cec876 100644 --- a/updatem/orchestration/update_orchestrator_test.go +++ b/updatem/orchestration/update_orchestrator_test.go @@ -31,7 +31,8 @@ func TestNewUpdateOrchestrator(t *testing.T) { cfg: &config.Config{ RebootEnabled: true, }, - phaseTimeout: 10 * time.Minute, + phaseTimeout: 10 * time.Minute, + ownerConsentTimeout: 30 * time.Minute, } assert.Equal(t, expectedOrchestrator, NewUpdateOrchestrator(&config.Config{RebootEnabled: true}, nil)) }