diff --git a/daemon/api_registry.go b/daemon/api_registry.go index 11073d0f747..c934f5e8827 100644 --- a/daemon/api_registry.go +++ b/daemon/api_registry.go @@ -60,6 +60,7 @@ func getView(c *Command, r *http.Request, _ *auth.UserState) Response { fields = strutil.CommaSeparatedList(fieldStr) } + // TODO: replace access w/ GetTransaction results, err := registrystateGetViaView(st, account, registryName, view, fields) if err != nil { return toAPIError(err) @@ -86,6 +87,7 @@ func setView(c *Command, r *http.Request, _ *auth.UserState) Response { return BadRequest("cannot decode registry request body: %v", err) } + // TODO: replace w/ GetTransaction + call ctx.Done() then return changeID err := registrystateSetViaView(st, account, registryName, view, values) if err != nil { return toAPIError(err) diff --git a/overlord/hookstate/ctlcmd/export_test.go b/overlord/hookstate/ctlcmd/export_test.go index f4fd11e5c31..6c2d4b964e1 100644 --- a/overlord/hookstate/ctlcmd/export_test.go +++ b/overlord/hookstate/ctlcmd/export_test.go @@ -28,9 +28,11 @@ import ( "github.com/snapcore/snapd/client/clientutil" "github.com/snapcore/snapd/overlord/devicestate" "github.com/snapcore/snapd/overlord/hookstate" + "github.com/snapcore/snapd/overlord/registrystate" "github.com/snapcore/snapd/overlord/servicestate" "github.com/snapcore/snapd/overlord/snapstate" "github.com/snapcore/snapd/overlord/state" + "github.com/snapcore/snapd/registry" "github.com/snapcore/snapd/snap" "github.com/snapcore/snapd/testutil" ) @@ -177,3 +179,11 @@ func MockNewStatusDecorator(f func(ctx context.Context, isGlobal bool, uid strin newStatusDecorator = f return restore } + +func MockRegistrystateGetTransaction(f func(*registrystate.Context, *state.State, *registry.View) (*registrystate.Transaction, error)) (restore func()) { + old := registrystateGetTransaction + registrystateGetTransaction = f + return func() { + registrystateGetTransaction = old + } +} diff --git a/overlord/hookstate/ctlcmd/get.go b/overlord/hookstate/ctlcmd/get.go index b94388fad2a..1b5ce6bafbf 100644 --- a/overlord/hookstate/ctlcmd/get.go +++ b/overlord/hookstate/ctlcmd/get.go @@ -39,6 +39,8 @@ import ( "github.com/snapcore/snapd/snap" ) +var registrystateGetTransaction = registrystate.GetTransaction + type getCommand struct { baseCommand @@ -369,7 +371,8 @@ func (c *getCommand) getRegistryValues(ctx *hookstate.Context, plugName string, return fmt.Errorf("cannot get registry: %v", err) } - tx, err := registrystate.RegistryTransaction(ctx, view.Registry()) + regCtx := registrystate.NewContext(ctx) + tx, err := registrystateGetTransaction(regCtx, ctx.State(), view) if err != nil { return err } diff --git a/overlord/hookstate/ctlcmd/get_test.go b/overlord/hookstate/ctlcmd/get_test.go index 60665df801a..9b186803db5 100644 --- a/overlord/hookstate/ctlcmd/get_test.go +++ b/overlord/hookstate/ctlcmd/get_test.go @@ -641,57 +641,6 @@ func (s *registrySuite) TestRegistryGetNoRequest(c *C) { c.Check(stderr, IsNil) } -func (s *registrySuite) TestRegistryGetHappensTransactionally(c *C) { - s.state.Lock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ - "ssid": "my-ssid", - }) - s.state.Unlock() - c.Assert(err, IsNil) - - // registry transaction is created when snapctl runs for the first time - stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) - c.Assert(err, IsNil) - c.Check(string(stdout), Equals, `{ - "ssid": "my-ssid" -} -`) - c.Check(stderr, IsNil) - - s.state.Lock() - err = registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ - "ssid": "other-ssid", - }) - s.state.Unlock() - c.Assert(err, IsNil) - - // the new write wasn't reflected because it didn't run in the same transaction - stdout, stderr, err = ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) - c.Assert(err, IsNil) - c.Check(string(stdout), Equals, `{ - "ssid": "my-ssid" -} -`) - c.Check(stderr, IsNil) - - // make a new context so we get a new transaction - s.state.Lock() - task := s.state.NewTask("test-task", "my test task") - setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "test-hook"} - s.mockContext, err = hookstate.NewContext(task, s.state, setup, s.mockHandler, "") - s.state.Unlock() - c.Assert(err, IsNil) - - // now we get the new data - stdout, stderr, err = ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) - c.Assert(err, IsNil) - c.Check(string(stdout), Equals, `{ - "ssid": "other-ssid" -} -`) - c.Check(stderr, IsNil) -} - func (s *registrySuite) TestRegistryGetInvalid(c *C) { type testcase struct { args []string @@ -804,7 +753,6 @@ func (s *registrySuite) TestRegistryGetAndSetAssertionNotFound(c *C) { c.Assert(err, ErrorMatches, fmt.Sprintf("cannot set registry: registry assertion %s/network not found", s.devAccID)) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - } func (s *registrySuite) TestRegistryGetAndSetViewNotFound(c *C) { diff --git a/overlord/hookstate/ctlcmd/set.go b/overlord/hookstate/ctlcmd/set.go index 47b7b69beb4..47f38b8ab20 100644 --- a/overlord/hookstate/ctlcmd/set.go +++ b/overlord/hookstate/ctlcmd/set.go @@ -235,13 +235,15 @@ func setRegistryValues(ctx *hookstate.Context, plugName string, requests map[str return fmt.Errorf("cannot set registry: %v", err) } - tx, err := registrystate.RegistryTransaction(ctx, view.Registry()) + if registrystate.IsRegistryHook(ctx) && !strings.HasPrefix(ctx.HookName(), "change-view-") { + return fmt.Errorf("cannot modify registry in %q hook", ctx.HookName()) + } + + regCtx := registrystate.NewContext(ctx) + tx, err := registrystateGetTransaction(regCtx, ctx.State(), view) if err != nil { return err } - // TODO: once we have hooks, check that we don't set values in the wrong hooks - // (e.g., "registry-changed" hooks can only read data) - return registrystate.SetViaViewInTx(tx, view, requests) } diff --git a/overlord/hookstate/ctlcmd/set_test.go b/overlord/hookstate/ctlcmd/set_test.go index e12a7a16deb..d4c30532add 100644 --- a/overlord/hookstate/ctlcmd/set_test.go +++ b/overlord/hookstate/ctlcmd/set_test.go @@ -33,6 +33,7 @@ import ( "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/registrystate" "github.com/snapcore/snapd/overlord/state" + "github.com/snapcore/snapd/registry" "github.com/snapcore/snapd/snap" ) @@ -405,6 +406,16 @@ func (s *setAttrSuite) TestSetCommandFailsOutsideOfValidContext(c *C) { } func (s *registrySuite) TestRegistrySetSingleView(c *C) { + s.state.Lock() + tx, err := registrystate.NewTransaction(s.state, s.devAccID, "network") + s.state.Unlock() + c.Assert(err, IsNil) + + restore := ctlcmd.MockRegistrystateGetTransaction(func(*registrystate.Context, *state.State, *registry.View) (*registrystate.Transaction, error) { + return tx, nil + }) + defer restore() + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"set", "--view", ":write-wifi", "ssid=other-ssid"}, 0) c.Assert(err, IsNil) c.Check(stdout, IsNil) @@ -413,56 +424,34 @@ func (s *registrySuite) TestRegistrySetSingleView(c *C) { c.Assert(s.mockContext.Done(), IsNil) s.mockContext.Unlock() - s.state.Lock() - val, err := registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) - s.state.Unlock() + val, err := tx.Get("wifi.ssid") c.Assert(err, IsNil) - c.Assert(val, DeepEquals, map[string]interface{}{"ssid": "other-ssid"}) + c.Assert(val, DeepEquals, "other-ssid") } func (s *registrySuite) TestRegistrySetManyViews(c *C) { - stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"set", "--view", ":write-wifi", "ssid=other-ssid", "password=other-secret"}, 0) - c.Assert(err, IsNil) - c.Check(stdout, IsNil) - c.Check(stderr, IsNil) - s.mockContext.Lock() - c.Assert(s.mockContext.Done(), IsNil) - s.mockContext.Unlock() - s.state.Lock() - val, err := registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid", "password"}) + tx, err := registrystate.NewTransaction(s.state, s.devAccID, "network") s.state.Unlock() c.Assert(err, IsNil) - c.Assert(val, DeepEquals, map[string]interface{}{ - "ssid": "other-ssid", - "password": "other-secret", + + restore := ctlcmd.MockRegistrystateGetTransaction(func(*registrystate.Context, *state.State, *registry.View) (*registrystate.Transaction, error) { + return tx, nil }) -} + defer restore() -func (s *registrySuite) TestRegistrySetHappensTransactionally(c *C) { - // sets values in a transaction - stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"set", "--view", ":write-wifi", "ssid=my-ssid"}, 0) + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"set", "--view", ":write-wifi", "ssid=other-ssid", "password=other-secret"}, 0) c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - s.state.Lock() - _, err = registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) - s.state.Unlock() - c.Assert(err, ErrorMatches, ".*matching rules don't map to any values") - - // commit transaction - s.mockContext.Lock() - c.Assert(s.mockContext.Done(), IsNil) - s.mockContext.Unlock() + val, err := tx.Get("wifi.ssid") + c.Assert(err, IsNil) + c.Assert(val, Equals, "other-ssid") - s.state.Lock() - val, err := registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) - s.state.Unlock() + val, err = tx.Get("wifi.psk") c.Assert(err, IsNil) - c.Assert(val, DeepEquals, map[string]interface{}{ - "ssid": "my-ssid", - }) + c.Assert(val, Equals, "other-secret") } func (s *registrySuite) TestRegistrySetInvalid(c *C) { @@ -491,12 +480,23 @@ func (s *registrySuite) TestRegistrySetInvalid(c *C) { } func (s *registrySuite) TestRegistrySetExclamationMark(c *C) { - stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"set", "--view", ":write-wifi", "ssid=other-ssid", "password=other-secret"}, 0) + s.state.Lock() + tx, err := registrystate.NewTransaction(s.state, s.devAccID, "network") + s.state.Unlock() c.Assert(err, IsNil) - c.Check(stdout, IsNil) - c.Check(stderr, IsNil) - stdout, stderr, err = ctlcmd.Run(s.mockContext, []string{"set", "--view", ":write-wifi", "password!"}, 0) + err = tx.Set("wifi.ssid", "foo") + c.Assert(err, IsNil) + + err = tx.Set("wifi.psk", "bar") + c.Assert(err, IsNil) + + restore := ctlcmd.MockRegistrystateGetTransaction(func(*registrystate.Context, *state.State, *registry.View) (*registrystate.Transaction, error) { + return tx, nil + }) + defer restore() + + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"set", "--view", ":write-wifi", "password!"}, 0) c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) @@ -504,8 +504,46 @@ func (s *registrySuite) TestRegistrySetExclamationMark(c *C) { stdout, stderr, err = ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) c.Assert(err, IsNil) c.Check(string(stdout), Equals, `{ - "ssid": "other-ssid" + "ssid": "foo" } `) c.Check(stderr, IsNil) + s.state.Lock() +} + +func (s *registrySuite) TestRegistryOnlyChangeViewCanSet(c *C) { + s.state.Lock() + defer s.state.Unlock() + task := s.state.NewTask("run-hook", "") + + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "save-view-plug"} + ctx, err := hookstate.NewContext(task, s.state, setup, s.mockHandler, "") + c.Assert(err, IsNil) + + tx, err := registrystate.NewTransaction(s.state, s.devAccID, "network") + c.Assert(err, IsNil) + + restore := ctlcmd.MockRegistrystateGetTransaction(func(*registrystate.Context, *state.State, *registry.View) (*registrystate.Transaction, error) { + return tx, nil + }) + defer restore() + + s.state.Unlock() + stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":write-wifi", "password=thing"}, 0) + s.state.Lock() + c.Assert(err, ErrorMatches, `cannot modify registry in "save-view-plug" hook`) + c.Check(stdout, IsNil) + c.Check(stderr, IsNil) + + setup.Hook = "change-view-plug" + ctx, err = hookstate.NewContext(task, s.state, setup, s.mockHandler, "") + c.Assert(err, IsNil) + + s.state.Unlock() + stdout, stderr, err = ctlcmd.Run(ctx, []string{"set", "--view", ":write-wifi", "password=thing"}, 0) + s.state.Lock() + c.Assert(err, IsNil) + c.Check(stdout, IsNil) + c.Check(stderr, IsNil) + } diff --git a/overlord/hookstate/ctlcmd/unset_test.go b/overlord/hookstate/ctlcmd/unset_test.go index 01e999bf2f7..e3932686fc9 100644 --- a/overlord/hookstate/ctlcmd/unset_test.go +++ b/overlord/hookstate/ctlcmd/unset_test.go @@ -30,6 +30,7 @@ import ( "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/registrystate" "github.com/snapcore/snapd/overlord/state" + "github.com/snapcore/snapd/registry" "github.com/snapcore/snapd/snap" ) @@ -164,52 +165,30 @@ func (s *unsetSuite) TestCommandWithoutContext(c *C) { func (s *registrySuite) TestRegistryUnsetManyViews(c *C) { s.state.Lock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{"ssid": "my-ssid", "password": "my-secret"}) + tx, err := registrystate.NewTransaction(s.state, s.devAccID, "network") s.state.Unlock() c.Assert(err, IsNil) - stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"unset", "--view", ":write-wifi", "ssid", "password"}, 0) + err = tx.Set("wifi.ssid", "foo") c.Assert(err, IsNil) - c.Check(stdout, IsNil) - c.Check(stderr, IsNil) - s.mockContext.Lock() - c.Assert(s.mockContext.Done(), IsNil) - s.mockContext.Unlock() - s.state.Lock() - _, err = registrystate.GetViaView(s.state, s.devAccID, "network", "write-wifi", []string{"ssid", "password"}) - s.state.Unlock() - c.Assert(err, ErrorMatches, `cannot get "ssid", "password" .*: matching rules don't map to any values`) -} - -func (s *registrySuite) TestRegistryUnsetHappensTransactionally(c *C) { - s.state.Lock() - err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{"ssid": "my-ssid"}) - s.state.Unlock() + err = tx.Set("wifi.psk", "bar") c.Assert(err, IsNil) - stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"unset", "--view", ":write-wifi", "ssid"}, 0) + ctlcmd.MockRegistrystateGetTransaction(func(*registrystate.Context, *state.State, *registry.View) (*registrystate.Transaction, error) { + return tx, nil + }) + + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"unset", "--view", ":write-wifi", "ssid", "password"}, 0) c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - s.state.Lock() - val, err := registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) - s.state.Unlock() - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, map[string]interface{}{ - "ssid": "my-ssid", - }) - - // commit transaction - s.mockContext.Lock() - c.Assert(s.mockContext.Done(), IsNil) - s.mockContext.Unlock() + _, err = tx.Get("wifi.ssid") + c.Assert(err, ErrorMatches, `no value was found under path "wifi.ssid"`) - s.state.Lock() - _, err = registrystate.GetViaView(s.state, s.devAccID, "network", "read-wifi", []string{"ssid"}) - s.state.Unlock() - c.Assert(err, ErrorMatches, `cannot get "ssid" .*: matching rules don't map to any values`) + _, err = tx.Get("wifi.psk") + c.Assert(err, ErrorMatches, `no value was found under path "wifi.psk"`) } func (s *registrySuite) TestRegistryUnsetInvalid(c *C) { diff --git a/overlord/registrystate/export_test.go b/overlord/registrystate/export_test.go index 5fdb4263127..bd4d486e9e1 100644 --- a/overlord/registrystate/export_test.go +++ b/overlord/registrystate/export_test.go @@ -34,6 +34,11 @@ var ( UnsetOngoingTransaction = unsetOngoingTransaction ) +const ( + CommitEdge = commitEdge + LastEdge = lastEdge +) + func ChangeViewHandlerGenerator(ctx *hookstate.Context) hookstate.Handler { return &changeViewHandler{ctx: ctx} } @@ -57,3 +62,11 @@ func MockWriteDatabag(f func(st *state.State, databag registry.JSONDataBag, acco writeDatabag = old } } + +func MockEnsureNow(f func(*state.State)) func() { + old := ensureNow + ensureNow = f + return func() { + ensureNow = old + } +} diff --git a/overlord/registrystate/registrymgr_test.go b/overlord/registrystate/registrymgr_test.go index e1cbef1cd44..bb60cee60da 100644 --- a/overlord/registrystate/registrymgr_test.go +++ b/overlord/registrystate/registrymgr_test.go @@ -65,7 +65,7 @@ slots: registry-slot: interface: registry ` - info := mockInstalledSnap(c, s.state, coreYaml, "") + info := mockInstalledSnap(c, s.state, coreYaml, nil) coreSet, err := interfaces.NewSnapAppSet(info, nil) c.Assert(err, IsNil) @@ -82,7 +82,7 @@ plugs: view: network/setup-wifi ` - info = mockInstalledSnap(c, s.state, snapYaml, "") + info = mockInstalledSnap(c, s.state, snapYaml, nil) appSet, err := interfaces.NewSnapAppSet(info, nil) c.Assert(err, IsNil) err = s.repo.AddAppSet(appSet) diff --git a/overlord/registrystate/registrystate.go b/overlord/registrystate/registrystate.go index bb0ee4c2bf8..7b1d1b8af2b 100644 --- a/overlord/registrystate/registrystate.go +++ b/overlord/registrystate/registrystate.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "sort" + "strings" "github.com/snapcore/snapd/overlord/assertstate" "github.com/snapcore/snapd/overlord/hookstate" @@ -197,45 +198,120 @@ var writeDatabag = func(st *state.State, databag registry.JSONDataBag, account, return nil } -type cachedRegistryTx struct { - account string - registry string -} +// GetTransaction retrieves or creates a transaction to access the view's +// registry. The state must be locked by the caller. +func GetTransaction(ctx *Context, st *state.State, view *registry.View) (*Transaction, error) { + account, registryName := view.Registry().Account, view.Registry().Name + + // check if we're already running in the context of a committing transaction + hookCtx := ctx.hookCtx + if hookCtx != nil && !hookCtx.IsEphemeral() && IsRegistryHook(hookCtx) { + // running in the context of a transaction, so if the referenced registry + // doesn't match that tx, we only allow the caller to read the other registry + t, _ := hookCtx.Task() + var err error + tx, commitTask, err := GetStoredTransaction(t) + if err != nil { + return nil, fmt.Errorf("cannot access registry view %s/%s/%s: cannot get transaction: %v", account, registryName, view.Name, err) + } + + if tx.RegistryAccount != account || tx.RegistryName != registryName { + // TODO: can we allow accessing a different registry just for reading? we'll + // need to create a change and run hooks so will require accounting in state + return nil, fmt.Errorf("cannot access registry %s/%s: ongoing transaction for %s/%s", account, registryName, tx.RegistryAccount, tx.RegistryName) + } + + ctx.OnDone(func() error { + setTransaction(commitTask, tx) + return nil + }) -// RegistryTransaction returns the registry.Transaction cached in the context -// or creates one and caches it, if none existed. The context must be locked by -// the caller. -func RegistryTransaction(ctx *hookstate.Context, reg *registry.Registry) (*Transaction, error) { - key := cachedRegistryTx{ - account: reg.Account, - registry: reg.Name, - } - tx, ok := ctx.Cached(key).(*Transaction) - if ok { return tx, nil } + // TODO: + // * add concurrency checks + // * distinguish from reads/write (reads will require different hooks) - tx, err := NewTransaction(ctx.State(), reg.Account, reg.Name) + // not running in an existing registry hook context, so create a transaction + // and a change to verify its changes and commit + tx, err := NewTransaction(st, account, registryName) if err != nil { - return nil, err + return nil, fmt.Errorf("cannot modify registry view %s/%s/%s: cannot create transaction: %v", account, registryName, view.Name, err) } ctx.OnDone(func() error { - return tx.Commit(ctx.State(), reg.Schema) + var chg *state.Change + if hookCtx == nil || hookCtx.IsEphemeral() { + chg = st.NewChange("modify-registry", fmt.Sprintf("Modify registry \"%s/%s\"", account, registryName)) + } else { + // we're running in the context of a non-registry hook, add the tasks to that change + task, _ := hookCtx.Task() + chg = task.Change() + } + + var callingSnap string + if hookCtx != nil { + callingSnap = hookCtx.InstanceName() + } + + ts, err := createChangeRegistryTasks(st, chg, tx, view, callingSnap) + if err != nil { + return err + } + + commitTask, err := ts.Edge(commitEdge) + if err != nil { + return err + } + + err = setOngoingTransaction(st, account, registryName, commitTask.ID()) + if err != nil { + return err + } + + lastTask, err := ts.Edge(lastEdge) + if err != nil { + return err + } + + // have a buffer size >1, in the unlikely case that there are more than 1 + // status changes for before the reader unblocks and removes the handler + taskReady := make(chan struct{}, 5) + id := st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) { + if t.ID() == lastTask.ID() && !old.Ready() && new.Ready() { + taskReady <- struct{}{} + return + } + }) + + ensureNow(st) + st.Unlock() + <-taskReady + st.Lock() + st.RemoveTaskStatusChangedHandler(id) + return nil }) - ctx.Cache(key, tx) return tx, nil } -func createChangeRegistryTasks(st *state.State, chg *state.Change, tx *Transaction, view *registry.View, callingSnap string) error { +var ensureNow = func(st *state.State) { + st.EnsureBefore(0) +} + +const ( + commitEdge = state.TaskSetEdge("commit-edge") + lastEdge = state.TaskSetEdge("last-edge") +) + +func createChangeRegistryTasks(st *state.State, chg *state.Change, tx *Transaction, view *registry.View, callingSnap string) (*state.TaskSet, error) { custodianPlugs, err := getCustodianPlugsForView(st, view) if err != nil { - return err + return nil, err } if len(custodianPlugs) == 0 { - return fmt.Errorf("cannot commit changes to registry %s/%s: no custodian snap installed", view.Registry().Account, view.Registry().Name) + return nil, fmt.Errorf("cannot commit changes to registry %s/%s: no custodian snap installed", view.Registry().Account, view.Registry().Name) } custodianNames := make([]string, 0, len(custodianPlugs)) @@ -247,13 +323,14 @@ func createChangeRegistryTasks(st *state.State, chg *state.Change, tx *Transacti // and potentially for the snaps themselves) sort.Strings(custodianNames) - var tasks []*state.Task + ts := state.NewTaskSet() linkTask := func(t *state.Task) { + tasks := ts.Tasks() if len(tasks) > 0 { t.WaitFor(tasks[len(tasks)-1]) } - tasks = append(tasks, t) chg.AddTask(t) + ts.AddTask(t) } // if the transaction errors, clear the tx from the state @@ -294,7 +371,7 @@ func createChangeRegistryTasks(st *state.State, chg *state.Change, tx *Transacti paths := tx.AlteredPaths() affectedPlugs, err := getPlugsAffectedByPaths(st, view.Registry(), paths) if err != nil { - return err + return nil, err } viewChangedSnaps := make([]string, 0, len(affectedPlugs)) @@ -321,17 +398,19 @@ func createChangeRegistryTasks(st *state.State, chg *state.Change, tx *Transacti commitTask := st.NewTask("commit-registry-tx", fmt.Sprintf("Commit changes to registry \"%s/%s\"", view.Registry().Account, view.Registry().Name)) commitTask.Set("registry-transaction", tx) // link all previous tasks to the commit task that carries the transaction - for _, t := range tasks { + for _, t := range ts.Tasks() { t.Set("commit-task", commitTask.ID()) } linkTask(commitTask) + ts.MarkEdge(commitTask, commitEdge) // clear the ongoing tx from the state and unblock other writers waiting for it clearTxTask := st.NewTask("clear-registry-tx", "Clears the ongoing registry transaction from state") linkTask(clearTxTask) clearTxTask.Set("commit-task", commitTask.ID()) + ts.MarkEdge(clearTxTask, lastEdge) - return nil + return ts, nil } func getCustodianPlugsForView(st *state.State, view *registry.View) (map[string]*snap.PlugInfo, error) { @@ -442,3 +521,45 @@ func GetStoredTransaction(t *state.Task) (*Transaction, *state.Task, error) { func setTransaction(t *state.Task, tx *Transaction) { t.Set("registry-transaction", tx) } + +func IsRegistryHook(ctx *hookstate.Context) bool { + return !ctx.IsEphemeral() && + (strings.HasPrefix(ctx.HookName(), "change-view-") || + strings.HasPrefix(ctx.HookName(), "save-view-") || + strings.HasSuffix(ctx.HookName(), "-view-changed")) +} + +type Context struct { + hookCtx *hookstate.Context + + onDone []func() error +} + +func NewContext(ctx *hookstate.Context) *Context { + regCtx := &Context{ + hookCtx: ctx, + } + + if ctx != nil { + ctx.OnDone(func() error { + return regCtx.Done() + }) + } + + return regCtx +} + +func (c *Context) OnDone(f func() error) { + c.onDone = append(c.onDone, f) +} + +func (c *Context) Done() error { + var firstErr error + for _, f := range c.onDone { + if err := f(); err != nil && firstErr == nil { + firstErr = err + } + } + + return firstErr +} diff --git a/overlord/registrystate/registrystate_test.go b/overlord/registrystate/registrystate_test.go index c56a68f662e..50d9276ddcb 100644 --- a/overlord/registrystate/registrystate_test.go +++ b/overlord/registrystate/registrystate_test.go @@ -20,9 +20,13 @@ package registrystate_test import ( "fmt" + "os" + "path/filepath" "testing" + "time" . "gopkg.in/check.v1" + "gopkg.in/tomb.v2" "github.com/snapcore/snapd/asserts" "github.com/snapcore/snapd/asserts/assertstest" @@ -33,6 +37,7 @@ import ( "github.com/snapcore/snapd/overlord/assertstate" "github.com/snapcore/snapd/overlord/assertstate/assertstatetest" "github.com/snapcore/snapd/overlord/hookstate" + "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/ifacestate/ifacerepo" "github.com/snapcore/snapd/overlord/registrystate" @@ -321,79 +326,7 @@ func (s *registryTestSuite) TestRegistrystateGetEntireView(c *C) { }) } -func (s *registryTestSuite) TestRegistryTransaction(c *C) { - mkRegistry := func(account, name string) *registry.Registry { - reg, err := registry.New(account, name, map[string]interface{}{ - "bar": map[string]interface{}{ - "rules": []interface{}{ - map[string]interface{}{"request": "foo", "storage": "foo"}, - }, - }, - }, registry.NewJSONSchema()) - c.Assert(err, IsNil) - return reg - } - - s.state.Lock() - task := s.state.NewTask("test-task", "my test task") - setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "test-hook"} - s.state.Unlock() - mockHandler := hooktest.NewMockHandler() - - type testcase struct { - acc1, acc2 string - reg1, reg2 string - equals bool - } - - tcs := []testcase{ - { - // same transaction - acc1: "acc-1", reg1: "reg-1", - acc2: "acc-1", reg2: "reg-1", - equals: true, - }, - { - // different registry name, different transaction - acc1: "acc-1", reg1: "reg-1", - acc2: "acc-1", reg2: "reg-2", - }, - { - // different account, different transaction - acc1: "acc-1", reg1: "reg-1", - acc2: "acc-2", reg2: "reg-1", - }, - { - // both different, different transaction - acc1: "acc-1", reg1: "reg-1", - acc2: "acc-2", reg2: "reg-2", - }, - } - - for _, tc := range tcs { - ctx, err := hookstate.NewContext(task, task.State(), setup, mockHandler, "") - c.Assert(err, IsNil) - ctx.Lock() - - reg1 := mkRegistry(tc.acc1, tc.reg1) - reg2 := mkRegistry(tc.acc2, tc.reg2) - - tx1, err := registrystate.RegistryTransaction(ctx, reg1) - c.Assert(err, IsNil) - - tx2, err := registrystate.RegistryTransaction(ctx, reg2) - c.Assert(err, IsNil) - - if tc.equals { - c.Assert(tx1, Equals, tx2) - } else { - c.Assert(tx1, Not(Equals), tx2) - } - ctx.Unlock() - } -} - -func mockInstalledSnap(c *C, st *state.State, snapYaml, cohortKey string) *snap.Info { +func mockInstalledSnap(c *C, st *state.State, snapYaml string, hooks []string) *snap.Info { info := snaptest.MockSnapCurrent(c, snapYaml, &snap.SideInfo{Revision: snap.R(1)}) snapstate.Set(st, info.InstanceName(), &snapstate.SnapState{ Active: true, @@ -406,8 +339,14 @@ func mockInstalledSnap(c *C, st *state.State, snapYaml, cohortKey string) *snap. }), Current: info.Revision, TrackingChannel: "stable", - CohortKey: cohortKey, }) + + for _, hook := range hooks { + c.Assert(os.MkdirAll(info.HooksDir(), 0775), IsNil) + err := os.WriteFile(filepath.Join(info.HooksDir(), hook), nil, 0755) + c.Assert(err, IsNil) + } + return info } @@ -470,7 +409,7 @@ plugs: account: %[1]s view: reg/view-4 `, s.devAccID) - info := mockInstalledSnap(c, s.state, snapYaml, "") + info := mockInstalledSnap(c, s.state, snapYaml, nil) appSet, err := interfaces.NewSnapAppSet(info, nil) c.Assert(err, IsNil) @@ -484,7 +423,7 @@ slots: registry-slot: interface: registry ` - info = mockInstalledSnap(c, s.state, coreYaml, "") + info = mockInstalledSnap(c, s.state, coreYaml, nil) coreSet, err := interfaces.NewSnapAppSet(info, nil) c.Assert(err, IsNil) @@ -529,8 +468,17 @@ func (s *registryTestSuite) TestRegistryTasksUserSetWithCustodianInstalled(c *C) chg := s.state.NewChange("modify-registry", "") // a user (not a snap) changes a registry - err = registrystate.CreateChangeRegistryTasks(s.state, chg, tx, view, "") + ts, err := registrystate.CreateChangeRegistryTasks(s.state, chg, tx, view, "") + c.Assert(err, IsNil) + + // there are two edges in the taskset + commitTask, err := ts.Edge(registrystate.CommitEdge) + c.Assert(err, IsNil) + c.Assert(commitTask.Kind(), Equals, "commit-registry-tx") + + cleanupTask, err := ts.Edge(registrystate.LastEdge) c.Assert(err, IsNil) + c.Assert(cleanupTask.Kind(), Equals, "clear-registry-tx") // the custodian snap's hooks are run tasks := []string{"clear-registry-tx-on-error", "run-hook", "run-hook", "run-hook", "commit-registry-tx", "clear-registry-tx"} @@ -575,7 +523,7 @@ func (s *registryTestSuite) TestRegistryTasksCustodianSnapSet(c *C) { chg := s.state.NewChange("modify-registry", "") // a user (not a snap) changes a registry - err = registrystate.CreateChangeRegistryTasks(s.state, chg, tx, view, "custodian-snap") + _, err = registrystate.CreateChangeRegistryTasks(s.state, chg, tx, view, "custodian-snap") c.Assert(err, IsNil) // the custodian snap's hooks are run @@ -615,7 +563,7 @@ func (s *registryTestSuite) TestRegistryTasksObserverSnapSetWithCustodianInstall chg := s.state.NewChange("modify-registry", "") // a non-custodian snap modifies a registry - err = registrystate.CreateChangeRegistryTasks(s.state, chg, tx, view, "test-snap-1") + _, err = registrystate.CreateChangeRegistryTasks(s.state, chg, tx, view, "test-snap-1") c.Assert(err, IsNil) // we trigger hooks for the custodian snap and for the -view-changed for the @@ -681,7 +629,7 @@ func (s *registryTestSuite) testRegistryTasksNoCustodian(c *C) { chg := s.state.NewChange("modify-registry", "") // a non-custodian snap modifies a registry - err = registrystate.CreateChangeRegistryTasks(s.state, chg, tx, view, "test-snap-1") + _, err = registrystate.CreateChangeRegistryTasks(s.state, chg, tx, view, "test-snap-1") c.Assert(err, ErrorMatches, fmt.Sprintf("cannot commit changes to registry %s/network: no custodian snap installed", s.devAccID)) } @@ -701,7 +649,7 @@ slots: registry-slot: interface: registry ` - info := mockInstalledSnap(c, s.state, coreYaml, "") + info := mockInstalledSnap(c, s.state, coreYaml, nil) coreSet, err := interfaces.NewSnapAppSet(info, nil) c.Assert(err, IsNil) @@ -709,7 +657,7 @@ slots: err = s.repo.AddAppSet(coreSet) c.Assert(err, IsNil) - mockSnap := func(snapName string, isCustodian bool) { + mockSnap := func(snapName string, isCustodian bool, hooks []string) { snapYaml := fmt.Sprintf(`name: %s version: 1 type: app @@ -725,12 +673,10 @@ plugs: ` role: custodian` } - info := mockInstalledSnap(c, s.state, snapYaml, "") - - // by default, mock all the hooks a custodians can have - for _, hookName := range []string{"change-view-setup", "save-view-setup", "setup-view-changed"} { - info.Hooks[hookName] = &snap.HookInfo{ - Name: hookName, + info := mockInstalledSnap(c, s.state, snapYaml, hooks) + for _, hook := range hooks { + info.Hooks[hook] = &snap.HookInfo{ + Name: hook, Snap: info, } } @@ -749,15 +695,17 @@ plugs: } // mock custodians + hooks := []string{"change-view-setup", "save-view-setup", "setup-view-changed"} for _, snap := range custodians { isCustodian := true - mockSnap(snap, isCustodian) + mockSnap(snap, isCustodian, hooks) } // mock non-custodians + hooks = []string{"change-view-setup", "save-view-setup", "setup-view-changed", "install"} for _, snap := range nonCustodians { isCustodian := false - mockSnap(snap, isCustodian) + mockSnap(snap, isCustodian, hooks) } } @@ -844,3 +792,320 @@ func (s *registryTestSuite) TestGetStoredTransaction(c *C) { c.Assert(carryingTask, Equals, commitTask) } } + +func (s *registryTestSuite) checkOngoingRegistryTransaction(c *C, account, registryName string) { + var commitTasks map[string]string + err := s.state.Get("registry-commit-tasks", &commitTasks) + c.Assert(err, IsNil) + + registryRef := account + "/" + registryName + taskID, ok := commitTasks[registryRef] + c.Assert(ok, Equals, true) + commitTask := s.state.Task(taskID) + c.Assert(commitTask.Kind(), Equals, "commit-registry-tx") + c.Assert(commitTask.Status(), Equals, state.DoStatus) +} + +func (s *registryTestSuite) TestGetTransactionFromUserCreatesNewChange(c *C) { + hooks, restore := s.mockRegistryHooks(c) + defer restore() + + restore = registrystate.MockEnsureNow(func(*state.State) { + s.checkOngoingRegistryTransaction(c, s.devAccID, "network") + + go s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) + }) + defer restore() + + s.state.Lock() + defer s.state.Unlock() + + // only one custodian snap is installed + s.setupRegistryModificationScenario(c, []string{"custodian-snap"}, nil) + + view := s.registry.View("setup-wifi") + + ctx := registrystate.NewContext(nil) + tx, err := registrystate.GetTransaction(ctx, s.state, view) + c.Assert(err, IsNil) + c.Assert(tx, NotNil) + + err = tx.Set("wifi.ssid", "foo") + c.Assert(err, IsNil) + + // mock the daemon calling Done() in api_registry + ctx.Done() + + c.Assert(s.state.Changes(), HasLen, 1) + chg := s.state.Changes()[0] + c.Assert(chg.Kind(), Equals, "modify-registry") + + s.checkModifyRegistryChange(c, chg, hooks) +} + +func (s *registryTestSuite) TestGetTransactionFromSnapCreatesNewChange(c *C) { + hooks, restore := s.mockRegistryHooks(c) + defer restore() + + restore = registrystate.MockEnsureNow(func(*state.State) { + s.checkOngoingRegistryTransaction(c, s.devAccID, "network") + + go s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) + }) + defer restore() + + s.state.Lock() + defer s.state.Unlock() + + // only one custodian snap is installed + s.setupRegistryModificationScenario(c, []string{"custodian-snap"}, []string{"test-snap"}) + + ctx, err := hookstate.NewContext(nil, s.state, &hookstate.HookSetup{Snap: "test-snap"}, nil, "") + c.Assert(err, IsNil) + + s.state.Unlock() + stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=foo"}, 0) + c.Assert(err, IsNil) + c.Check(stdout, IsNil) + c.Check(stderr, IsNil) + + // the daemon calls Done() in api_snapctl + ctx.Lock() + ctx.Done() + ctx.Unlock() + + s.state.Lock() + c.Assert(s.state.Changes(), HasLen, 1) + chg := s.state.Changes()[0] + c.Assert(chg.Kind(), Equals, "modify-registry") + + s.checkModifyRegistryChange(c, chg, hooks) +} + +func (s *registryTestSuite) TestGetTransactionFromNonRegistryHookAddsRegistryTx(c *C) { + var hooks []string + restore := hookstate.MockRunHook(func(ctx *hookstate.Context, _ *tomb.Tomb) ([]byte, error) { + t, _ := ctx.Task() + + ctx.State().Lock() + var hooksup *hookstate.HookSetup + err := t.Get("hook-setup", &hooksup) + ctx.State().Unlock() + if err != nil { + return nil, err + } + + if hooksup.Hook == "install" { + _, _, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=foo"}, 0) + c.Assert(err, IsNil) + return nil, nil + } + + hooks = append(hooks, hooksup.Hook) + return nil, nil + }) + defer restore() + + restore = registrystate.MockEnsureNow(func(st *state.State) { + // we actually want to call ensure here (since we use Loop) but check the + // transaction was added to the state as usual + s.checkOngoingRegistryTransaction(c, s.devAccID, "network") + st.EnsureBefore(0) + }) + defer restore() + + s.state.Lock() + defer s.state.Unlock() + // only one custodian snap is installed + s.setupRegistryModificationScenario(c, []string{"custodian-snap"}, []string{"test-snap"}) + + hookTask := s.state.NewTask("run-hook", "") + chg := s.state.NewChange("install", "") + chg.AddTask(hookTask) + + hooksup := &hookstate.HookSetup{ + Snap: "test-snap", + Hook: "install", + } + hookTask.Set("hook-setup", hooksup) + s.state.Unlock() + + c.Assert(s.o.StartUp(), IsNil) + s.state.EnsureBefore(0) + s.o.Loop() + defer s.o.Stop() + + select { + case <-chg.Ready(): + case <-time.After(5 * time.Second): + c.Fatalf("test timed out") + } + + s.state.Lock() + s.checkModifyRegistryChange(c, chg, &hooks) +} + +func (s *registryTestSuite) mockRegistryHooks(c *C) (*[]string, func()) { + var hooks []string + restore := hookstate.MockRunHook(func(ctx *hookstate.Context, _ *tomb.Tomb) ([]byte, error) { + t, _ := ctx.Task() + ctx.State().Lock() + defer ctx.State().Unlock() + + var hooksup *hookstate.HookSetup + err := t.Get("hook-setup", &hooksup) + if err != nil { + return nil, err + } + + hooks = append(hooks, hooksup.Hook) + return nil, nil + }) + + return &hooks, restore +} + +func (s *registryTestSuite) checkModifyRegistryChange(c *C, chg *state.Change, hooks *[]string) { + c.Assert(chg.Status(), Equals, state.DoneStatus) + c.Assert(*hooks, DeepEquals, []string{"change-view-setup", "save-view-setup", "setup-view-changed"}) + + commitTask := findTask(chg, "commit-registry-tx") + tx, _, err := registrystate.GetStoredTransaction(commitTask) + c.Assert(err, IsNil) + + // the state was cleared + var txCommits map[string]string + err = s.state.Get("registry-tx-commits", &txCommits) + c.Assert(err, testutil.ErrorIs, &state.NoStateError{}) + + err = tx.Clear(s.state) + c.Assert(err, IsNil) + + // was committed (otherwise would've been removed by Clear) + val, err := tx.Get("wifi.ssid") + c.Assert(err, IsNil) + c.Assert(val, Equals, "foo") +} + +func (s *registryTestSuite) TestGetTransactionDifferentFromOngoingOnlyForRead(c *C) { +} + +func (s *registryTestSuite) TestGetTransactionFromChangeViewHook(c *C) { + ctx := s.testGetReadableOngoingTransaction(c, "change-view-setup") + + // change-view hooks can also write to the transaction + stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=bar"}, 0) + c.Assert(err, IsNil) + // accessed an ongoing transaction + c.Assert(stdout, IsNil) + c.Assert(stderr, IsNil) + + // this save the changes that the hook performs + ctx.Lock() + ctx.Done() + ctx.Unlock() + + s.state.Lock() + defer s.state.Unlock() + t, _ := ctx.Task() + tx, _, err := registrystate.GetStoredTransaction(t) + c.Assert(err, IsNil) + + val, err := tx.Get("wifi.ssid") + c.Assert(err, IsNil) + c.Assert(val, Equals, "bar") +} + +func (s *registryTestSuite) TestGetTransactionFromSaveViewHook(c *C) { + ctx := s.testGetReadableOngoingTransaction(c, "save-view-setup") + + // non change-view hooks cannot modify the transaction + stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=bar"}, 0) + c.Assert(err, ErrorMatches, `cannot modify registry in "save-view-setup" hook`) + c.Assert(stdout, IsNil) + c.Assert(stderr, IsNil) +} + +func (s *registryTestSuite) TestGetTransactionFromViewChangedHook(c *C) { + ctx := s.testGetReadableOngoingTransaction(c, "setup-view-changed") + + // non change-view hooks cannot modify the transaction + stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=bar"}, 0) + c.Assert(err, ErrorMatches, `cannot modify registry in "setup-view-changed" hook`) + c.Assert(stdout, IsNil) + c.Assert(stderr, IsNil) +} + +func (s *registryTestSuite) testGetReadableOngoingTransaction(c *C, hook string) *hookstate.Context { + s.state.Lock() + defer s.state.Unlock() + + s.setupRegistryModificationScenario(c, []string{"custodian-snap"}, []string{"test-snap"}) + + originalTx, err := registrystate.NewTransaction(s.state, s.devAccID, "network") + c.Assert(err, IsNil) + + err = originalTx.Set("wifi.ssid", "foo") + c.Assert(err, IsNil) + + chg := s.state.NewChange("test", "") + commitTask := s.state.NewTask("commit-registry-tx", "") + commitTask.Set("registry-transaction", originalTx) + chg.AddTask(commitTask) + + hookTask := s.state.NewTask("run-hook", "") + chg.AddTask(hookTask) + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: hook} + mockHandler := hooktest.NewMockHandler() + hookTask.Set("commit-task", commitTask.ID()) + + ctx, err := hookstate.NewContext(hookTask, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + s.state.Unlock() + stdout, stderr, err := ctlcmd.Run(ctx, []string{"get", "--view", ":setup", "ssid"}, 0) + s.state.Lock() + c.Assert(err, IsNil) + // accessed an ongoing transaction + c.Assert(string(stdout), Equals, "foo\n") + c.Assert(stderr, IsNil) + + return ctx +} + +func (s *registryTestSuite) TestGetDifferentTransactionThanOngoing(c *C) { + s.state.Lock() + + tx, err := registrystate.NewTransaction(s.state, s.devAccID, "network") + c.Assert(err, IsNil) + + chg := s.state.NewChange("some-change", "") + commitTask := s.state.NewTask("commit", "") + chg.AddTask(commitTask) + commitTask.Set("registry-transaction", tx) + + refTask := s.state.NewTask("change-view-setup", "") + chg.AddTask(refTask) + refTask.Set("commit-task", commitTask.ID()) + + // make some other registry to access concurrently + reg, err := registry.New("foo", "bar", map[string]interface{}{ + "foo": map[string]interface{}{ + "rules": []interface{}{ + map[string]interface{}{"request": "foo", "storage": "foo"}, + }}}, registry.NewJSONSchema()) + c.Assert(err, IsNil) + s.state.Unlock() + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + hookCtx, err := hookstate.NewContext(refTask, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + hookCtx.Lock() + ctx := registrystate.NewContext(hookCtx) + tx, err = registrystate.GetTransaction(ctx, s.state, reg.View("foo")) + hookCtx.Unlock() + c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot access registry foo/bar: ongoing transaction for %s/network`, s.devAccID)) + c.Assert(tx, IsNil) +}