Skip to content

Commit

Permalink
client: pass Context to client functions that call rpc2's Call()
Browse files Browse the repository at this point in the history
Let users of libovsdb pass a Context to functions that will block
on the OVS database so they can set a timeout.

Signed-off-by: Dan Williams <dcbw@redhat.com>
  • Loading branch information
dcbw committed Jul 28, 2021
1 parent 771d9ef commit a6498a7
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 72 deletions.
48 changes: 24 additions & 24 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ type Client interface {
SetOption(Option) error
Connected() bool
DisconnectNotify() chan struct{}
Echo() error
Transact(...ovsdb.Operation) ([]ovsdb.OperationResult, error)
Monitor(...TableMonitor) (string, error)
MonitorAll() (string, error)
MonitorCancel(id string) error
Echo(context.Context) error
Transact(context.Context, ...ovsdb.Operation) ([]ovsdb.OperationResult, error)
Monitor(context.Context, ...TableMonitor) (string, error)
MonitorAll(context.Context) (string, error)
MonitorCancel(ctx context.Context, id string) error
NewTableMonitor(m model.Model, fields ...interface{}) TableMonitor
API
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func (o *ovsdbClient) connect(ctx context.Context, reconnect bool) error {
return err
}

dbs, err := o.listDbs()
dbs, err := o.listDbs(ctx)
if err != nil {
o.rpcClient.Close()
return err
Expand All @@ -167,7 +167,7 @@ func (o *ovsdbClient) connect(ctx context.Context, reconnect bool) error {
return fmt.Errorf("target database not found")
}

schema, err := o.getSchema(o.dbModel.Name())
schema, err := o.getSchema(ctx, o.dbModel.Name())
errors := o.dbModel.Validate(schema)
if len(errors) > 0 {
var combined []string
Expand Down Expand Up @@ -205,7 +205,7 @@ func (o *ovsdbClient) connect(ctx context.Context, reconnect bool) error {
o.monitorsMutex.Lock()
defer o.monitorsMutex.Unlock()
for id, request := range o.monitors {
err = o.monitor(id, reconnect, request...)
err = o.monitor(ctx, id, reconnect, request...)
if err != nil {
o.rpcClient.Close()
return err
Expand Down Expand Up @@ -312,10 +312,10 @@ func (o *ovsdbClient) update(args []json.RawMessage, reply *[]interface{}) error
// getSchema returns the schema in use for the provided database name
// RFC 7047 : get_schema
// Should only be called when mutex is held
func (o *ovsdbClient) getSchema(dbName string) (*ovsdb.DatabaseSchema, error) {
func (o *ovsdbClient) getSchema(ctx context.Context, dbName string) (*ovsdb.DatabaseSchema, error) {
args := ovsdb.NewGetSchemaArgs(dbName)
var reply ovsdb.DatabaseSchema
err := o.rpcClient.Call("get_schema", args, &reply)
err := o.rpcClient.CallWithContext(ctx, "get_schema", args, &reply)
if err != nil {
if err == rpc2.ErrShutdown {
return nil, ErrNotConnected
Expand All @@ -328,9 +328,9 @@ func (o *ovsdbClient) getSchema(dbName string) (*ovsdb.DatabaseSchema, error) {
// listDbs returns the list of databases on the server
// RFC 7047 : list_dbs
// Should only be called when mutex is held
func (o *ovsdbClient) listDbs() ([]string, error) {
func (o *ovsdbClient) listDbs(ctx context.Context) ([]string, error) {
var dbs []string
err := o.rpcClient.Call("list_dbs", nil, &dbs)
err := o.rpcClient.CallWithContext(ctx, "list_dbs", nil, &dbs)
if err != nil {
if err == rpc2.ErrShutdown {
return nil, ErrNotConnected
Expand All @@ -342,7 +342,7 @@ func (o *ovsdbClient) listDbs() ([]string, error) {

// Transact performs the provided Operations on the database
// RFC 7047 : transact
func (o *ovsdbClient) Transact(operation ...ovsdb.Operation) ([]ovsdb.OperationResult, error) {
func (o *ovsdbClient) Transact(ctx context.Context, operation ...ovsdb.Operation) ([]ovsdb.OperationResult, error) {
var reply []ovsdb.OperationResult
if ok := o.Schema().ValidateOperations(operation...); !ok {
return nil, fmt.Errorf("validation failed for the operation")
Expand All @@ -354,7 +354,7 @@ func (o *ovsdbClient) Transact(operation ...ovsdb.Operation) ([]ovsdb.OperationR
o.rpcMutex.Unlock()
return nil, ErrNotConnected
}
err := o.rpcClient.Call("transact", args, &reply)
err := o.rpcClient.CallWithContext(ctx, "transact", args, &reply)
o.rpcMutex.Unlock()
if err != nil {
if err == rpc2.ErrShutdown {
Expand All @@ -366,25 +366,25 @@ func (o *ovsdbClient) Transact(operation ...ovsdb.Operation) ([]ovsdb.OperationR
}

// MonitorAll is a convenience method to monitor every table/column
func (o *ovsdbClient) MonitorAll() (string, error) {
func (o *ovsdbClient) MonitorAll(ctx context.Context) (string, error) {
var options []TableMonitor
for name := range o.dbModel.Types() {
options = append(options, TableMonitor{Table: name})
}
return o.Monitor(options...)
return o.Monitor(ctx, options...)
}

// MonitorCancel will request cancel a previously issued monitor request
// RFC 7047 : monitor_cancel
func (o *ovsdbClient) MonitorCancel(id string) error {
func (o *ovsdbClient) MonitorCancel(ctx context.Context, id string) error {
var reply ovsdb.OperationResult
args := ovsdb.NewMonitorCancelArgs(id)
o.rpcMutex.Lock()
defer o.rpcMutex.Unlock()
if o.rpcClient == nil {
return ErrNotConnected
}
err := o.rpcClient.Call("monitor_cancel", args, &reply)
err := o.rpcClient.CallWithContext(ctx, "monitor_cancel", args, &reply)
if err != nil {
if err == rpc2.ErrShutdown {
return ErrNotConnected
Expand Down Expand Up @@ -428,12 +428,12 @@ func (o *ovsdbClient) NewTableMonitor(m model.Model, fields ...interface{}) Tabl
// and populate the cache with them. Subsequent updates will be processed
// by the Update Notifications
// RFC 7047 : monitor
func (o *ovsdbClient) Monitor(options ...TableMonitor) (string, error) {
func (o *ovsdbClient) Monitor(ctx context.Context, options ...TableMonitor) (string, error) {
id := uuid.NewString()
return id, o.monitor(id, false, options...)
return id, o.monitor(ctx, id, false, options...)
}

func (o *ovsdbClient) monitor(id string, reconnect bool, options ...TableMonitor) error {
func (o *ovsdbClient) monitor(ctx context.Context, id string, reconnect bool, options ...TableMonitor) error {
if len(options) == 0 {
return fmt.Errorf("no monitor options provided")
}
Expand Down Expand Up @@ -463,7 +463,7 @@ func (o *ovsdbClient) monitor(id string, reconnect bool, options ...TableMonitor
if o.rpcClient == nil {
return ErrNotConnected
}
err := o.rpcClient.Call("monitor", args, &reply)
err := o.rpcClient.CallWithContext(ctx, "monitor", args, &reply)
if err != nil {
if err == rpc2.ErrShutdown {
return ErrNotConnected
Expand All @@ -480,15 +480,15 @@ func (o *ovsdbClient) monitor(id string, reconnect bool, options ...TableMonitor
}

// Echo tests the liveness of the OVSDB connetion
func (o *ovsdbClient) Echo() error {
func (o *ovsdbClient) Echo(ctx context.Context) error {
args := ovsdb.NewEchoArgs()
var reply []interface{}
o.rpcMutex.RLock()
defer o.rpcMutex.RUnlock()
if o.rpcClient == nil {
return ErrNotConnected
}
err := o.rpcClient.Call("echo", args, &reply)
err := o.rpcClient.CallWithContext(ctx, "echo", args, &reply)
if err != nil {
if err == rpc2.ErrShutdown {
return ErrNotConnected
Expand Down
9 changes: 5 additions & 4 deletions client/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/json"
"reflect"
"strings"
Expand Down Expand Up @@ -744,28 +745,28 @@ func TestOperationWhenNotConnected(t *testing.T) {
{
"echo",
func() error {
return ovs.Echo()
return ovs.Echo(context.TODO())
},
},
{
"transact",
func() error {
comment := "this is only a test"
_, err := ovs.Transact(ovsdb.Operation{Op: ovsdb.OperationComment, Comment: &comment})
_, err := ovs.Transact(context.TODO(), ovsdb.Operation{Op: ovsdb.OperationComment, Comment: &comment})
return err
},
},
{
"monitor/monitor all",
func() error {
_, err := ovs.MonitorAll()
_, err := ovs.MonitorAll(context.TODO())
return err
},
},
{
"monitor cancel",
func() error {
return ovs.MonitorCancel("")
return ovs.MonitorCancel(context.TODO(), "")
},
},
}
Expand Down
10 changes: 5 additions & 5 deletions cmd/stress/stress.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ func cleanup(ctx context.Context) {
if err != nil {
log.Fatal(err)
}
err = ovs.Connect(context.Background())
err = ovs.Connect(ctx)
if err != nil {
log.Fatal(err)
}
defer ovs.Disconnect()

if _, err := ovs.MonitorAll(); err != nil {
if _, err := ovs.MonitorAll(ctx); err != nil {
log.Fatal(err)
}

Expand Down Expand Up @@ -100,7 +100,7 @@ func run(ctx context.Context, resultsChan chan result, wg *sync.WaitGroup) {
if err != nil {
log.Fatal(err)
}
err = ovs.Connect(context.Background())
err = ovs.Connect(ctx)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -136,7 +136,7 @@ func run(ctx context.Context, resultsChan chan result, wg *sync.WaitGroup) {
},
)

if _, err := ovs.MonitorAll(); err != nil {
if _, err := ovs.MonitorAll(ctx); err != nil {
log.Fatal(err)
}

Expand Down Expand Up @@ -166,7 +166,7 @@ func run(ctx context.Context, resultsChan chan result, wg *sync.WaitGroup) {
}

func transact(ctx context.Context, ovs client.Client, operations []ovsdb.Operation) (bool, string) {
reply, err := ovs.Transact(operations...)
reply, err := ovs.Transact(ctx, operations...)
if err != nil {
return false, ""
}
Expand Down
2 changes: 1 addition & 1 deletion example/ovsdb-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
reply, err := c.Transact(ovsOps...)
reply, err := c.Transact(context.Background(), ovsOps...)
if err != nil {
log.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion example/play_with_ovs/main..go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func createBridge(ovs client.Client, bridgeName string) {
}

operations := append(insertOp, mutateOps...)
reply, err := ovs.Transact(operations...)
reply, err := ovs.Transact(context.TODO(), operations...)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -121,6 +121,7 @@ func main() {
},
})
_, err = ovs.Monitor(
context.TODO(),
ovs.NewTableMonitor(&vswitchd.OpenvSwitch{}),
ovs.NewTableMonitor(&vswitchd.Bridge{}),
)
Expand Down
Loading

0 comments on commit a6498a7

Please sign in to comment.