diff --git a/updater/updater_test.go b/updater/updater_test.go index f0969f045..0d42b487c 100644 --- a/updater/updater_test.go +++ b/updater/updater_test.go @@ -26,9 +26,9 @@ type testOmahaHandler struct { handler *omaha.Handler } -func newTestHandler(a *api.API) *testOmahaHandler { +func newTestHandler(api *api.API) *testOmahaHandler { return &testOmahaHandler{ - handler: omaha.NewHandler(a), + handler: omaha.NewHandler(api), } } @@ -59,16 +59,15 @@ func newForTest(t *testing.T) *api.API { log.Printf("NEBRASKA_DB_URL not set, setting to default %q\n", defaultTestDbURL) _ = os.Setenv("NEBRASKA_DB_URL", defaultTestDbURL) } - a, err := api.New(api.OptionInitDB) + api, err := api.New(api.OptionInitDB) require.NoError(t, err) - require.NotNil(t, a) + require.NotNil(t, api) - return a + return api } func TestNewUpdater(t *testing.T) { - conf := updater.Config{ OmahaURL: "http://localhost:8000", AppID: "io.phony.App", @@ -76,16 +75,25 @@ func TestNewUpdater(t *testing.T) { InstanceID: "instance001", InstanceVersion: "0.1.0", } - + // Valid Config _, err := updater.New(conf) assert.NoError(t, err) + + // Invalid Config + conf.OmahaURL = "http://invalidurl.test\\" + updater, err := updater.New(conf) + assert.Nil(t, updater) + assert.Error(t, err) } func TestCheckForUpdates(t *testing.T) { - a := newForTest(t) - defer a.Close() + apiInstance := newForTest(t) - appID, track, tChannel := setup(t, a, "0.1.0", true, 2) + t.Cleanup(func() { + apiInstance.Close() + }) + + appID, track, tChannel := setup(&config{t: t, api: apiInstance, pkgVersion: "0.1.0", policySafeMode: true, policyMaxUpdatesPerPeriod: 2}) u, err := updater.New(updater.Config{ OmahaURL: "http://localhost:8000", @@ -93,35 +101,36 @@ func TestCheckForUpdates(t *testing.T) { Channel: track, InstanceID: "instance001", InstanceVersion: "0.2.0", - OmahaReqHandler: newTestHandler(a), + OmahaReqHandler: newTestHandler(apiInstance), }) require.NoError(t, err) info, err := u.CheckForUpdates(context.TODO()) - assert.NoError(t, err) + require.NoError(t, err) assert.False(t, info.HasUpdate) assert.Equal(t, "", info.GetVersion()) - newPkg, _ := a.AddPackage(&api.Package{Type: api.PkgTypeOther, URL: "http://sample.url/pkg", Version: "0.3.0", ApplicationID: appID, Arch: api.ArchAMD64, Filename: null.StringFrom("updatefile.txt")}) + newPkg, err := apiInstance.AddPackage(&api.Package{Type: api.PkgTypeOther, URL: "http://sample.url/pkg", Version: "0.3.0", ApplicationID: appID, Arch: api.ArchAMD64, Filename: null.StringFrom("updatefile.txt")}) + require.NoError(t, err) tChannel.PackageID = null.StringFrom(newPkg.ID) - err = a.UpdateChannel(tChannel) - assert.NoError(t, err) + err = apiInstance.UpdateChannel(tChannel) + require.NoError(t, err) info, err = u.CheckForUpdates(context.TODO()) - assert.NoError(t, err) + require.NoError(t, err) assert.True(t, info.HasUpdate) version := info.GetVersion() assert.Equal(t, "0.3.0", version) urls := info.GetURLs() - assert.NotNil(t, urls) + require.NotNil(t, urls) assert.Equal(t, 1, len(urls)) assert.Equal(t, urls[len(urls)-1], info.GetURL()) assert.Equal(t, "http://sample.url/pkg", info.GetURL()) pkg := info.GetPackage() - assert.NotNil(t, pkg) + require.NotNil(t, pkg) assert.Equal(t, "updatefile.txt", pkg.Name) } @@ -138,50 +147,60 @@ func (u updateTestHandler) ApplyUpdate(ctx context.Context, info *updater.Update return u.applyUpdateResult } -func setup(t *testing.T, a *api.API, version string, policySafeMode bool, policyMaxUpdatesPerPeriod int) (string, string, *api.Channel) { - t.Helper() - tTeam, err := a.AddTeam(&api.Team{Name: "test_team"}) - require.NoError(t, err) - tApp, err := a.AddApp(&api.Application{Name: "io.phony.App", TeamID: tTeam.ID}) - require.NoError(t, err) - tPkg, err := a.AddPackage(&api.Package{Type: api.PkgTypeOther, URL: "http://sample.url/pkg", Version: version, ApplicationID: tApp.ID, Arch: api.ArchAMD64}) - require.NoError(t, err) - tChannel, err := a.AddChannel(&api.Channel{Name: "channel1", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID), Arch: api.ArchAMD64}) - require.NoError(t, err) - tGroup, err := a.AddGroup(&api.Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: policySafeMode, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: policyMaxUpdatesPerPeriod, PolicyUpdateTimeout: "60 minutes", Track: "stable"}) - require.NoError(t, err) +type config struct { + t *testing.T + api *api.API + pkgVersion string + policySafeMode bool + policyMaxUpdatesPerPeriod int +} + +func setup(cnf *config) (string, string, *api.Channel) { + cnf.t.Helper() + tTeam, err := cnf.api.AddTeam(&api.Team{Name: "test_team"}) + require.NoError(cnf.t, err) + tApp, err := cnf.api.AddApp(&api.Application{Name: "io.phony.App", TeamID: tTeam.ID}) + require.NoError(cnf.t, err) + tPkg, err := cnf.api.AddPackage(&api.Package{Type: api.PkgTypeOther, URL: "http://sample.url/pkg", Version: cnf.pkgVersion, ApplicationID: tApp.ID, Arch: api.ArchAMD64}) + require.NoError(cnf.t, err) + tChannel, err := cnf.api.AddChannel(&api.Channel{Name: "channel1", Color: "blue", ApplicationID: tApp.ID, PackageID: null.StringFrom(tPkg.ID), Arch: api.ArchAMD64}) + require.NoError(cnf.t, err) + tGroup, err := cnf.api.AddGroup(&api.Group{Name: "group1", ApplicationID: tApp.ID, ChannelID: null.StringFrom(tChannel.ID), PolicyUpdatesEnabled: true, PolicySafeMode: cnf.policySafeMode, PolicyPeriodInterval: "15 minutes", PolicyMaxUpdatesPerPeriod: cnf.policyMaxUpdatesPerPeriod, PolicyUpdateTimeout: "60 minutes", Track: "stable"}) + require.NoError(cnf.t, err) return tApp.ID, tGroup.Track, tChannel } func TestTryUpdate(t *testing.T) { - a := newForTest(t) - defer a.Close() + api := newForTest(t) + + t.Cleanup(func() { + api.Close() + }) oldVersion := "0.2.0" pkgVersion := "0.4.0" - appID, track, _ := setup(t, a, pkgVersion, false, 10) + appID, track, _ := setup(&config{t: t, api: api, pkgVersion: pkgVersion, policySafeMode: false, policyMaxUpdatesPerPeriod: 10}) - type test struct { + tests := []struct { name string fetchUpdateResult error applyUpdateResult error isErr bool - } - - tests := []test{ + }{ { - name: "error fetching update", + name: "error_fetching_update", fetchUpdateResult: errors.New("something went wrong fetching the update"), applyUpdateResult: nil, isErr: true, }, { - name: "error applying update", + name: "error_applying_update", fetchUpdateResult: nil, applyUpdateResult: errors.New("something went wrong fetching the update"), isErr: true, - }, { - name: "success try update", + }, + { + name: "success_try_update", fetchUpdateResult: nil, applyUpdateResult: nil, isErr: false, @@ -191,14 +210,13 @@ func TestTryUpdate(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - u, err := updater.New(updater.Config{ OmahaURL: "http://localhost:8000", AppID: appID, Channel: track, InstanceID: "instance001", InstanceVersion: "0.2.0", - OmahaReqHandler: newTestHandler(a), + OmahaReqHandler: newTestHandler(api), Debug: true, }) require.NoError(t, err)