diff --git a/client/client.go b/client/client.go index 5a95e541..d1b214ff 100644 --- a/client/client.go +++ b/client/client.go @@ -46,11 +46,11 @@ type Client interface { SetOption(Option) error Connected() bool DisconnectNotify() chan struct{} - Echo() error - Transact(...ovsdb.Operation) ([]ovsdb.OperationResult, error) - Monitor(...TableMonitor) (string, error) - MonitorAll() (string, error) - MonitorCancel(id string) error + Echo(context.Context) error + Transact(context.Context, ...ovsdb.Operation) ([]ovsdb.OperationResult, error) + Monitor(context.Context, ...TableMonitor) (string, error) + MonitorAll(context.Context) (string, error) + MonitorCancel(ctx context.Context, id string) error NewTableMonitor(m model.Model, fields ...interface{}) TableMonitor API } @@ -149,7 +149,7 @@ func (o *ovsdbClient) connect(ctx context.Context, reconnect bool) error { return err } - dbs, err := o.listDbs() + dbs, err := o.listDbs(ctx) if err != nil { o.rpcClient.Close() return err @@ -167,7 +167,7 @@ func (o *ovsdbClient) connect(ctx context.Context, reconnect bool) error { return fmt.Errorf("target database not found") } - schema, err := o.getSchema(o.dbModel.Name()) + schema, err := o.getSchema(ctx, o.dbModel.Name()) errors := o.dbModel.Validate(schema) if len(errors) > 0 { var combined []string @@ -205,7 +205,7 @@ func (o *ovsdbClient) connect(ctx context.Context, reconnect bool) error { o.monitorsMutex.Lock() defer o.monitorsMutex.Unlock() for id, request := range o.monitors { - err = o.monitor(id, reconnect, request...) + err = o.monitor(ctx, id, reconnect, request...) if err != nil { o.rpcClient.Close() return err @@ -312,10 +312,10 @@ func (o *ovsdbClient) update(args []json.RawMessage, reply *[]interface{}) error // getSchema returns the schema in use for the provided database name // RFC 7047 : get_schema // Should only be called when mutex is held -func (o *ovsdbClient) getSchema(dbName string) (*ovsdb.DatabaseSchema, error) { +func (o *ovsdbClient) getSchema(ctx context.Context, dbName string) (*ovsdb.DatabaseSchema, error) { args := ovsdb.NewGetSchemaArgs(dbName) var reply ovsdb.DatabaseSchema - err := o.rpcClient.Call("get_schema", args, &reply) + err := o.rpcClient.CallWithContext(ctx, "get_schema", args, &reply) if err != nil { if err == rpc2.ErrShutdown { return nil, ErrNotConnected @@ -328,9 +328,9 @@ func (o *ovsdbClient) getSchema(dbName string) (*ovsdb.DatabaseSchema, error) { // listDbs returns the list of databases on the server // RFC 7047 : list_dbs // Should only be called when mutex is held -func (o *ovsdbClient) listDbs() ([]string, error) { +func (o *ovsdbClient) listDbs(ctx context.Context) ([]string, error) { var dbs []string - err := o.rpcClient.Call("list_dbs", nil, &dbs) + err := o.rpcClient.CallWithContext(ctx, "list_dbs", nil, &dbs) if err != nil { if err == rpc2.ErrShutdown { return nil, ErrNotConnected @@ -342,7 +342,7 @@ func (o *ovsdbClient) listDbs() ([]string, error) { // Transact performs the provided Operations on the database // RFC 7047 : transact -func (o *ovsdbClient) Transact(operation ...ovsdb.Operation) ([]ovsdb.OperationResult, error) { +func (o *ovsdbClient) Transact(ctx context.Context, operation ...ovsdb.Operation) ([]ovsdb.OperationResult, error) { var reply []ovsdb.OperationResult if ok := o.Schema().ValidateOperations(operation...); !ok { return nil, fmt.Errorf("validation failed for the operation") @@ -354,7 +354,7 @@ func (o *ovsdbClient) Transact(operation ...ovsdb.Operation) ([]ovsdb.OperationR o.rpcMutex.Unlock() return nil, ErrNotConnected } - err := o.rpcClient.Call("transact", args, &reply) + err := o.rpcClient.CallWithContext(ctx, "transact", args, &reply) o.rpcMutex.Unlock() if err != nil { if err == rpc2.ErrShutdown { @@ -366,17 +366,17 @@ func (o *ovsdbClient) Transact(operation ...ovsdb.Operation) ([]ovsdb.OperationR } // MonitorAll is a convenience method to monitor every table/column -func (o *ovsdbClient) MonitorAll() (string, error) { +func (o *ovsdbClient) MonitorAll(ctx context.Context) (string, error) { var options []TableMonitor for name := range o.dbModel.Types() { options = append(options, TableMonitor{Table: name}) } - return o.Monitor(options...) + return o.Monitor(ctx, options...) } // MonitorCancel will request cancel a previously issued monitor request // RFC 7047 : monitor_cancel -func (o *ovsdbClient) MonitorCancel(id string) error { +func (o *ovsdbClient) MonitorCancel(ctx context.Context, id string) error { var reply ovsdb.OperationResult args := ovsdb.NewMonitorCancelArgs(id) o.rpcMutex.Lock() @@ -384,7 +384,7 @@ func (o *ovsdbClient) MonitorCancel(id string) error { if o.rpcClient == nil { return ErrNotConnected } - err := o.rpcClient.Call("monitor_cancel", args, &reply) + err := o.rpcClient.CallWithContext(ctx, "monitor_cancel", args, &reply) if err != nil { if err == rpc2.ErrShutdown { return ErrNotConnected @@ -428,12 +428,12 @@ func (o *ovsdbClient) NewTableMonitor(m model.Model, fields ...interface{}) Tabl // and populate the cache with them. Subsequent updates will be processed // by the Update Notifications // RFC 7047 : monitor -func (o *ovsdbClient) Monitor(options ...TableMonitor) (string, error) { +func (o *ovsdbClient) Monitor(ctx context.Context, options ...TableMonitor) (string, error) { id := uuid.NewString() - return id, o.monitor(id, false, options...) + return id, o.monitor(ctx, id, false, options...) } -func (o *ovsdbClient) monitor(id string, reconnect bool, options ...TableMonitor) error { +func (o *ovsdbClient) monitor(ctx context.Context, id string, reconnect bool, options ...TableMonitor) error { if len(options) == 0 { return fmt.Errorf("no monitor options provided") } @@ -463,7 +463,7 @@ func (o *ovsdbClient) monitor(id string, reconnect bool, options ...TableMonitor if o.rpcClient == nil { return ErrNotConnected } - err := o.rpcClient.Call("monitor", args, &reply) + err := o.rpcClient.CallWithContext(ctx, "monitor", args, &reply) if err != nil { if err == rpc2.ErrShutdown { return ErrNotConnected @@ -480,7 +480,7 @@ func (o *ovsdbClient) monitor(id string, reconnect bool, options ...TableMonitor } // Echo tests the liveness of the OVSDB connetion -func (o *ovsdbClient) Echo() error { +func (o *ovsdbClient) Echo(ctx context.Context) error { args := ovsdb.NewEchoArgs() var reply []interface{} o.rpcMutex.RLock() @@ -488,7 +488,7 @@ func (o *ovsdbClient) Echo() error { if o.rpcClient == nil { return ErrNotConnected } - err := o.rpcClient.Call("echo", args, &reply) + err := o.rpcClient.CallWithContext(ctx, "echo", args, &reply) if err != nil { if err == rpc2.ErrShutdown { return ErrNotConnected diff --git a/client/client_test.go b/client/client_test.go index ebd93fea..fde431aa 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "encoding/json" "reflect" "strings" @@ -744,28 +745,28 @@ func TestOperationWhenNotConnected(t *testing.T) { { "echo", func() error { - return ovs.Echo() + return ovs.Echo(context.TODO()) }, }, { "transact", func() error { comment := "this is only a test" - _, err := ovs.Transact(ovsdb.Operation{Op: ovsdb.OperationComment, Comment: &comment}) + _, err := ovs.Transact(context.TODO(), ovsdb.Operation{Op: ovsdb.OperationComment, Comment: &comment}) return err }, }, { "monitor/monitor all", func() error { - _, err := ovs.MonitorAll() + _, err := ovs.MonitorAll(context.TODO()) return err }, }, { "monitor cancel", func() error { - return ovs.MonitorCancel("") + return ovs.MonitorCancel(context.TODO(), "") }, }, } diff --git a/cmd/stress/stress.go b/cmd/stress/stress.go index e0094b7b..511bb724 100644 --- a/cmd/stress/stress.go +++ b/cmd/stress/stress.go @@ -58,13 +58,13 @@ func cleanup(ctx context.Context) { if err != nil { log.Fatal(err) } - err = ovs.Connect(context.Background()) + err = ovs.Connect(ctx) if err != nil { log.Fatal(err) } defer ovs.Disconnect() - if _, err := ovs.MonitorAll(); err != nil { + if _, err := ovs.MonitorAll(ctx); err != nil { log.Fatal(err) } @@ -100,7 +100,7 @@ func run(ctx context.Context, resultsChan chan result, wg *sync.WaitGroup) { if err != nil { log.Fatal(err) } - err = ovs.Connect(context.Background()) + err = ovs.Connect(ctx) if err != nil { log.Fatal(err) } @@ -136,7 +136,7 @@ func run(ctx context.Context, resultsChan chan result, wg *sync.WaitGroup) { }, ) - if _, err := ovs.MonitorAll(); err != nil { + if _, err := ovs.MonitorAll(ctx); err != nil { log.Fatal(err) } @@ -166,7 +166,7 @@ func run(ctx context.Context, resultsChan chan result, wg *sync.WaitGroup) { } func transact(ctx context.Context, ovs client.Client, operations []ovsdb.Operation) (bool, string) { - reply, err := ovs.Transact(operations...) + reply, err := ovs.Transact(ctx, operations...) if err != nil { return false, "" } diff --git a/example/ovsdb-server/main.go b/example/ovsdb-server/main.go index 2d374e1c..b510beb0 100644 --- a/example/ovsdb-server/main.go +++ b/example/ovsdb-server/main.go @@ -96,7 +96,7 @@ func main() { if err != nil { log.Fatal(err) } - reply, err := c.Transact(ovsOps...) + reply, err := c.Transact(context.Background(), ovsOps...) if err != nil { log.Fatal(err) } diff --git a/example/play_with_ovs/main..go b/example/play_with_ovs/main..go index 4b5d8b3a..76949188 100644 --- a/example/play_with_ovs/main..go +++ b/example/play_with_ovs/main..go @@ -70,7 +70,7 @@ func createBridge(ovs client.Client, bridgeName string) { } operations := append(insertOp, mutateOps...) - reply, err := ovs.Transact(operations...) + reply, err := ovs.Transact(context.TODO(), operations...) if err != nil { log.Fatal(err) } @@ -121,6 +121,7 @@ func main() { }, }) _, err = ovs.Monitor( + context.TODO(), ovs.NewTableMonitor(&vswitchd.OpenvSwitch{}), ovs.NewTableMonitor(&vswitchd.Bridge{}), ) diff --git a/server/server_integration_test.go b/server/server_integration_test.go index f8c4e7b9..703f1a34 100644 --- a/server/server_integration_test.go +++ b/server/server_integration_test.go @@ -87,7 +87,7 @@ func TestClientServerEcho(t *testing.T) { require.NoError(t, err) err = ovs.Connect(context.Background()) require.NoError(t, err) - err = ovs.Echo() + err = ovs.Echo(context.Background()) assert.Nil(t, err) } @@ -124,7 +124,7 @@ func TestClientServerInsert(t *testing.T) { require.NoError(t, err) err = ovs.Connect(context.Background()) require.NoError(t, err) - _, err = ovs.MonitorAll() + _, err = ovs.MonitorAll(context.Background()) require.NoError(t, err) bridgeRow := &bridgeType{ @@ -134,7 +134,7 @@ func TestClientServerInsert(t *testing.T) { ops, err := ovs.Create(bridgeRow) require.Nil(t, err) - reply, err := ovs.Transact(ops...) + reply, err := ovs.Transact(context.Background(), ops...) assert.Nil(t, err) opErr, err := ovsdb.CheckOperationResults(reply, ops) assert.NoErrorf(t, err, "%+v", opErr) @@ -229,14 +229,14 @@ func TestClientServerMonitor(t *testing.T) { var ops []ovsdb.Operation ovsOps, err := ovs.Create(ovsRow) require.Nil(t, err) - reply, err := ovs.Transact(ovsOps...) + reply, err := ovs.Transact(context.Background(), ovsOps...) require.Nil(t, err) _, err = ovsdb.CheckOperationResults(reply, ovsOps) require.Nil(t, err) require.NotEmpty(t, reply[0].UUID.GoUUID) ovsRow.UUID = reply[0].UUID.GoUUID - _, err = ovs.MonitorAll() + _, err = ovs.MonitorAll(context.Background()) require.Nil(t, err) require.Eventually(t, func() bool { seenMutex.RLock() @@ -256,7 +256,7 @@ func TestClientServerMonitor(t *testing.T) { require.Nil(t, err) ops = append(ops, mutateOps...) - reply, err = ovs.Transact(ops...) + reply, err = ovs.Transact(context.Background(), ops...) require.Nil(t, err) _, err = ovsdb.CheckOperationResults(reply, ops) @@ -308,7 +308,7 @@ func TestClientServerInsertAndDelete(t *testing.T) { require.NoError(t, err) err = ovs.Connect(context.Background()) require.NoError(t, err) - _, err = ovs.MonitorAll() + _, err = ovs.MonitorAll(context.Background()) require.NoError(t, err) bridgeRow := &bridgeType{ @@ -318,7 +318,7 @@ func TestClientServerInsertAndDelete(t *testing.T) { ops, err := ovs.Create(bridgeRow) require.Nil(t, err) - reply, err := ovs.Transact(ops...) + reply, err := ovs.Transact(context.Background(), ops...) require.Nil(t, err) _, err = ovsdb.CheckOperationResults(reply, ops) require.Nil(t, err) @@ -334,7 +334,7 @@ func TestClientServerInsertAndDelete(t *testing.T) { deleteOp, err := ovs.Where(bridgeRow).Delete() require.Nil(t, err) - reply, err = ovs.Transact(deleteOp...) + reply, err = ovs.Transact(context.Background(), deleteOp...) assert.Nil(t, err) _, err = ovsdb.CheckOperationResults(reply, ops) assert.Nil(t, err) @@ -382,13 +382,13 @@ func TestClientServerInsertDuplicate(t *testing.T) { ops, err := ovs.Create(bridgeRow) require.Nil(t, err) - reply, err := ovs.Transact(ops...) + reply, err := ovs.Transact(context.Background(), ops...) require.Nil(t, err) _, err = ovsdb.CheckOperationResults(reply, ops) require.Nil(t, err) // duplicate - reply, err = ovs.Transact(ops...) + reply, err = ovs.Transact(context.Background(), ops...) require.Nil(t, err) opErrs, err := ovsdb.CheckOperationResults(reply, ops) require.Error(t, err) @@ -431,7 +431,7 @@ func TestClientServerInsertAndUpdate(t *testing.T) { require.NoError(t, err) defer ovs.Disconnect() - _, err = ovs.MonitorAll() + _, err = ovs.MonitorAll(context.Background()) require.NoError(t, err) bridgeRow := &bridgeType{ @@ -441,7 +441,7 @@ func TestClientServerInsertAndUpdate(t *testing.T) { ops, err := ovs.Create(bridgeRow) require.NoError(t, err) - reply, err := ovs.Transact(ops...) + reply, err := ovs.Transact(context.Background(), ops...) require.NoError(t, err) _, err = ovsdb.CheckOperationResults(reply, ops) require.NoError(t, err) @@ -458,7 +458,7 @@ func TestClientServerInsertAndUpdate(t *testing.T) { bridgeRow.Name = "br-update2" ops, err = ovs.Where(bridgeRow).Update(bridgeRow) require.NoError(t, err) - reply, err = ovs.Transact(ops...) + reply, err = ovs.Transact(context.Background(), ops...) require.NoError(t, err) _, err = ovsdb.CheckOperationResults(reply, ops) require.Error(t, err) @@ -472,7 +472,7 @@ func TestClientServerInsertAndUpdate(t *testing.T) { bridgeRow.OtherConfig = map[string]string{"foo": "bar"} ops, err = ovs.Where(bridgeRow).Update(bridgeRow) require.NoError(t, err) - reply, err = ovs.Transact(ops...) + reply, err = ovs.Transact(context.Background(), ops...) require.NoError(t, err) opErrs, err := ovsdb.CheckOperationResults(reply, ops) require.NoErrorf(t, err, "%+v", opErrs) @@ -491,7 +491,7 @@ func TestClientServerInsertAndUpdate(t *testing.T) { bridgeRow.ExternalIds = newExternalIds ops, err = ovs.Where(bridgeRow).Update(bridgeRow, &bridgeRow.ExternalIds) require.NoError(t, err) - reply, err = ovs.Transact(ops...) + reply, err = ovs.Transact(context.Background(), ops...) require.NoError(t, err) opErr, err := ovsdb.CheckOperationResults(reply, ops) require.NoErrorf(t, err, "%+v", opErr) diff --git a/test/ovs/ovs_integration_test.go b/test/ovs/ovs_integration_test.go index 1a76bc5e..06179ad5 100644 --- a/test/ovs/ovs_integration_test.go +++ b/test/ovs/ovs_integration_test.go @@ -85,7 +85,7 @@ func (suite *OVSIntegrationSuite) SetupSuite() { // give ovsdb-server some time to start up - _, err = suite.client.MonitorAll() + _, err = suite.client.MonitorAll(context.TODO()) require.NoError(suite.T(), err) } @@ -127,7 +127,7 @@ var defDB, _ = model.NewDBModel("Open_vSwitch", map[string]model.Model{ func (suite *OVSIntegrationSuite) TestConnectReconnect() { assert.True(suite.T(), suite.client.Connected()) - err := suite.client.Echo() + err := suite.client.Echo(context.TODO()) require.NoError(suite.T(), err) bridgeName := "br-discoreco" @@ -176,7 +176,7 @@ func (suite *OVSIntegrationSuite) TestConnectReconnect() { assert.Equal(suite.T(), false, suite.client.Connected()) - err = suite.client.Echo() + err = suite.client.Echo(context.TODO()) require.EqualError(suite.T(), err, client.ErrNotConnected.Error()) ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) @@ -193,10 +193,10 @@ func (suite *OVSIntegrationSuite) TestConnectReconnect() { err = suite.client.Get(br) require.Error(suite.T(), err, client.ErrNotFound) - err = suite.client.Echo() + err = suite.client.Echo(context.TODO()) assert.NoError(suite.T(), err) - _, err = suite.client.MonitorAll() + _, err = suite.client.MonitorAll(context.TODO()) require.NoError(suite.T(), err) // assert cache has been re-populated @@ -209,7 +209,7 @@ func (suite *OVSIntegrationSuite) TestConnectReconnect() { func (suite *OVSIntegrationSuite) TestWithReconnect() { assert.Equal(suite.T(), true, suite.client.Connected()) - err := suite.client.Echo() + err := suite.client.Echo(context.TODO()) require.NoError(suite.T(), err) // Disconnect client @@ -238,14 +238,14 @@ func (suite *OVSIntegrationSuite) TestWithReconnect() { require.NoError(suite.T(), err) // check the connection is working - err = suite.client.Echo() + err = suite.client.Echo(context.TODO()) require.NoError(suite.T(), err) // check the cache is purged require.True(suite.T(), suite.client.Cache().Table("Bridge").Len() == 0) // set up the monitor again - _, err = suite.client.MonitorAll() + _, err = suite.client.MonitorAll(context.TODO()) require.NoError(suite.T(), err) // add a bridge and verify our handler gets called @@ -276,7 +276,7 @@ func (suite *OVSIntegrationSuite) TestWithReconnect() { return suite.client.Connected() }, 2*time.Second, 500*time.Millisecond) - err = suite.client.Echo() + err = suite.client.Echo(context.TODO()) require.NoError(suite.T(), err) // check our original bridge is in the cache @@ -324,7 +324,7 @@ LOOP: assert.Equal(suite.T(), false, suite.client.Connected()) - err = suite.client.Echo() + err = suite.client.Echo(context.TODO()) require.EqualError(suite.T(), err, client.ErrNotConnected.Error()) ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) @@ -332,10 +332,10 @@ LOOP: err = suite.client.Connect(ctx) require.NoError(suite.T(), err) - err = suite.client.Echo() + err = suite.client.Echo(context.TODO()) assert.NoError(suite.T(), err) - _, err = suite.client.MonitorAll() + _, err = suite.client.MonitorAll(context.TODO()) require.NoError(suite.T(), err) } @@ -375,7 +375,7 @@ func (suite *OVSIntegrationSuite) TestInsertAndDeleteTransactIntegration() { require.NoError(suite.T(), err) delOperations := append(deleteOp, delMutateOp...) - delReply, err := suite.client.Transact(delOperations...) + delReply, err := suite.client.Transact(context.TODO(), delOperations...) require.NoError(suite.T(), err) delOperationErrs, err := ovsdb.CheckOperationResults(delReply, delOperations) @@ -393,7 +393,7 @@ func (suite *OVSIntegrationSuite) TestTableSchemaValidationIntegration() { Table: "InvalidTable", Row: ovsdb.Row(map[string]interface{}{"name": "docker-ovs"}), } - _, err := suite.client.Transact(operation) + _, err := suite.client.Transact(context.TODO(), operation) assert.Error(suite.T(), err) } @@ -404,7 +404,7 @@ func (suite *OVSIntegrationSuite) TestColumnSchemaInRowValidationIntegration() { Row: ovsdb.Row(map[string]interface{}{"name": "docker-ovs", "invalid_column": "invalid_column"}), } - _, err := suite.client.Transact(operation) + _, err := suite.client.Transact(context.TODO(), operation) assert.Error(suite.T(), err) } @@ -418,7 +418,7 @@ func (suite *OVSIntegrationSuite) TestColumnSchemaInMultipleRowsValidationIntegr Table: "Bridge", Rows: rows, } - _, err := suite.client.Transact(operation) + _, err := suite.client.Transact(context.TODO(), operation) assert.Error(suite.T(), err) } @@ -428,7 +428,7 @@ func (suite *OVSIntegrationSuite) TestColumnSchemaValidationIntegration() { Table: "Bridge", Columns: []string{"name", "invalidColumn"}, } - _, err := suite.client.Transact(operation) + _, err := suite.client.Transact(context.TODO(), operation) assert.Error(suite.T(), err) } @@ -440,12 +440,13 @@ func (suite *OVSIntegrationSuite) TestMonitorCancelIntegration() { } monitorID, err := suite.client.Monitor( + context.TODO(), suite.client.NewTableMonitor(&ovsType{}), suite.client.NewTableMonitor(&bridgeType{}), ) require.NoError(suite.T(), err) - err = suite.client.MonitorCancel(monitorID) + err = suite.client.MonitorCancel(context.TODO(), monitorID) assert.NoError(suite.T(), err) uuid, err := suite.createBridge("br-monitor") @@ -515,7 +516,7 @@ func (suite *OVSIntegrationSuite) TestUpdate() { bridgeRow.Name = "br-update2" ops, err := suite.client.Where(bridgeRow).Update(bridgeRow) require.NoError(suite.T(), err) - reply, err := suite.client.Transact(ops...) + reply, err := suite.client.Transact(context.TODO(), ops...) require.NoError(suite.T(), err) _, err = ovsdb.CheckOperationResults(reply, ops) require.Error(suite.T(), err) @@ -526,7 +527,7 @@ func (suite *OVSIntegrationSuite) TestUpdate() { bridgeRow.ExternalIds = newExternalIds ops, err = suite.client.Where(bridgeRow).Update(bridgeRow, &bridgeRow.ExternalIds) require.NoError(suite.T(), err) - reply, err = suite.client.Transact(ops...) + reply, err = suite.client.Transact(context.TODO(), ops...) require.NoError(suite.T(), err) _, err = ovsdb.CheckOperationResults(reply, ops) require.NoError(suite.T(), err) @@ -567,7 +568,7 @@ func (suite *OVSIntegrationSuite) createBridge(bridgeName string) (string, error require.NoError(suite.T(), err) operations := append(insertOp, mutateOp...) - reply, err := suite.client.Transact(operations...) + reply, err := suite.client.Transact(context.TODO(), operations...) require.NoError(suite.T(), err) _, err = ovsdb.CheckOperationResults(reply, operations)