diff --git a/pkg/controlsvc/connect.go b/pkg/controlsvc/connect.go index db3a9912f..bd92f43dc 100644 --- a/pkg/controlsvc/connect.go +++ b/pkg/controlsvc/connect.go @@ -9,15 +9,15 @@ import ( ) type ( - connectCommandType struct{} - connectCommand struct { + ConnectCommandType struct{} + ConnectCommand struct { targetNode string targetService string tlsConfigName string } ) -func (t *connectCommandType) InitFromString(params string) (ControlCommand, error) { +func (t *ConnectCommandType) InitFromString(params string) (ControlCommand, error) { tokens := strings.Split(params, " ") if len(tokens) < 2 { return nil, fmt.Errorf("no connect target") @@ -29,7 +29,7 @@ func (t *connectCommandType) InitFromString(params string) (ControlCommand, erro if len(tokens) == 3 { tlsConfigName = tokens[2] } - c := &connectCommand{ + c := &ConnectCommand{ targetNode: tokens[0], targetService: tokens[1], tlsConfigName: tlsConfigName, @@ -38,7 +38,7 @@ func (t *connectCommandType) InitFromString(params string) (ControlCommand, erro return c, nil } -func (t *connectCommandType) InitFromJSON(config map[string]interface{}) (ControlCommand, error) { +func (t *ConnectCommandType) InitFromJSON(config map[string]interface{}) (ControlCommand, error) { targetNode, ok := config["node"] if !ok { return nil, fmt.Errorf("no connect target node") @@ -65,7 +65,7 @@ func (t *connectCommandType) InitFromJSON(config map[string]interface{}) (Contro } else { tlsConfigStr = "" } - c := &connectCommand{ + c := &ConnectCommand{ targetNode: targetNodeStr, targetService: targetServiceStr, tlsConfigName: tlsConfigStr, @@ -74,7 +74,7 @@ func (t *connectCommandType) InitFromJSON(config map[string]interface{}) (Contro return c, nil } -func (c *connectCommand) ControlFunc(_ context.Context, nc NetceptorForControlCommand, cfo ControlFuncOperations) (map[string]interface{}, error) { +func (c *ConnectCommand) ControlFunc(_ context.Context, nc NetceptorForControlCommand, cfo ControlFuncOperations) (map[string]interface{}, error) { tlscfg, err := nc.GetClientTLSConfig(c.tlsConfigName, c.targetNode, netceptor.ExpectedHostnameTypeReceptor) if err != nil { return nil, err diff --git a/pkg/controlsvc/connect_test.go b/pkg/controlsvc/connect_test.go new file mode 100644 index 000000000..a402864f0 --- /dev/null +++ b/pkg/controlsvc/connect_test.go @@ -0,0 +1,150 @@ +package controlsvc_test + +import ( + "context" + "errors" + "testing" + + "github.com/ansible/receptor/pkg/controlsvc" + "github.com/ansible/receptor/pkg/controlsvc/mock_controlsvc" + "github.com/ansible/receptor/pkg/logger" + "github.com/golang/mock/gomock" +) + +func CheckExpectedError(expectedError bool, errorMessage string, t *testing.T, err error) { + if expectedError && errorMessage != err.Error() { + t.Errorf("expected: %s , received: %s", errorMessage, err) + } + + if !expectedError && err != nil { + t.Error(err) + } +} + +func TestConnectInitFromString(t *testing.T) { + connectCommandType := controlsvc.ConnectCommandType{} + + initFromStringTestCases := []struct { + name string + expectedError bool + errorMessage string + input string + }{ + { + name: "no connect target", + expectedError: true, + errorMessage: "no connect target", + input: "", + }, + { + name: "too many parameters", + expectedError: true, + errorMessage: "too many parameters", + input: "one two three four", + }, + { + name: "three params - pass", + expectedError: false, + errorMessage: "", + input: "one two three", + }, + } + + for _, testCase := range initFromStringTestCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := connectCommandType.InitFromString(testCase.input) + + CheckExpectedError(testCase.expectedError, testCase.errorMessage, t, err) + }) + } +} + +func TestConnectInitFromJSON(t *testing.T) { + connectCommandType := controlsvc.ConnectCommandType{} + + initFromJSONTestCases := []struct { + name string + expectedError bool + errorMessage string + input map[string]interface{} + }{ + BuildInitFromJSONTestCases("no connect target node", true, "no connect target node", map[string]interface{}{}), + BuildInitFromJSONTestCases("connect target node must be string 1", true, "connect target node must be string", map[string]interface{}{"node": 7}), + BuildInitFromJSONTestCases("no connect target service", true, "no connect target service", map[string]interface{}{"node": "node1"}), + BuildInitFromJSONTestCases("connect target service must be string1", true, "connect target service must be string", map[string]interface{}{"node": "node2", "service": 7}), + BuildInitFromJSONTestCases("connect tls name be string", true, "connect tls name must be string", map[string]interface{}{"node": "node3", "service": "service1", "tls": 7}), + BuildInitFromJSONTestCases("pass with empty tls config", false, "connect target service must be string", map[string]interface{}{"node": "node4", "service": "service2"}), + BuildInitFromJSONTestCases("pass with all targets and tls config", false, "", map[string]interface{}{"node": "node4", "service": "service3", "tls": "tls1"}), + } + + for _, testCase := range initFromJSONTestCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := connectCommandType.InitFromJSON(testCase.input) + CheckExpectedError(testCase.expectedError, testCase.errorMessage, t, err) + }) + } +} + +func TestConnectControlFunc(t *testing.T) { + connectCommand := controlsvc.ConnectCommand{} + ctrl := gomock.NewController(t) + mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl) + mockControlFunc := mock_controlsvc.NewMockControlFuncOperations(ctrl) + logger := logger.NewReceptorLogger("") + + controlFuncTestCases := []struct { + name string + expectedError bool + errorMessage string + expectedCalls func() + }{ + { + name: "tls config error", + expectedError: true, + errorMessage: "terminated tls", + expectedCalls: func() { + mockNetceptor.EXPECT().GetClientTLSConfig(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("terminated tls")) + }, + }, + { + name: "dial error", + errorMessage: "terminated dial", + expectedError: true, + expectedCalls: func() { + mockNetceptor.EXPECT().GetClientTLSConfig(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockNetceptor.EXPECT().Dial(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("terminated dial")) + }, + }, + { + name: "bridge conn error", + errorMessage: "terminated bridge conn", + expectedError: true, + expectedCalls: func() { + mockNetceptor.EXPECT().GetClientTLSConfig(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockNetceptor.EXPECT().Dial(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockControlFunc.EXPECT().BridgeConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("terminated bridge conn")) + mockNetceptor.EXPECT().GetLogger().Return(logger) + }, + }, + { + name: "control func pass", + errorMessage: "", + expectedError: false, + expectedCalls: func() { + mockNetceptor.EXPECT().GetClientTLSConfig(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockNetceptor.EXPECT().Dial(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + mockControlFunc.EXPECT().BridgeConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockNetceptor.EXPECT().GetLogger().Return(logger) + }, + }, + } + + for _, testCase := range controlFuncTestCases { + t.Run(testCase.name, func(t *testing.T) { + testCase.expectedCalls() + _, err := connectCommand.ControlFunc(context.Background(), mockNetceptor, mockControlFunc) + + CheckExpectedError(testCase.expectedError, testCase.errorMessage, t, err) + }) + } +} diff --git a/pkg/controlsvc/controlsvc.go b/pkg/controlsvc/controlsvc.go index fdcce8a23..b219719cc 100644 --- a/pkg/controlsvc/controlsvc.go +++ b/pkg/controlsvc/controlsvc.go @@ -168,11 +168,11 @@ func New(stdServices bool, nc NetceptorForControlsvc) *Server { serverTLS: &TLS{}, } if stdServices { - s.controlTypes["ping"] = &pingCommandType{} - s.controlTypes["status"] = &statusCommandType{} - s.controlTypes["connect"] = &connectCommandType{} - s.controlTypes["traceroute"] = &tracerouteCommandType{} - s.controlTypes["reload"] = &reloadCommandType{} + s.controlTypes["ping"] = &PingCommandType{} + s.controlTypes["status"] = &StatusCommandType{} + s.controlTypes["connect"] = &ConnectCommandType{} + s.controlTypes["traceroute"] = &TracerouteCommandType{} + s.controlTypes["reload"] = &ReloadCommandType{} } return s diff --git a/pkg/controlsvc/mock_controlsvc/interfaces.go b/pkg/controlsvc/mock_controlsvc/interfaces.go index 70f75eb48..c63fded29 100644 --- a/pkg/controlsvc/mock_controlsvc/interfaces.go +++ b/pkg/controlsvc/mock_controlsvc/interfaces.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: pkg/controlsvc/interfaces.go +// Source: github.com/ansible/receptor/pkg/controlsvc (interfaces: ControlCommandType,NetceptorForControlCommand,ControlCommand,ControlFuncOperations) // Package mock_controlsvc is a generated GoMock package. package mock_controlsvc @@ -107,33 +107,33 @@ func (mr *MockNetceptorForControlCommandMockRecorder) CancelBackends() *gomock.C } // Dial mocks base method. -func (m *MockNetceptorForControlCommand) Dial(node, service string, tlscfg *tls.Config) (*netceptor.Conn, error) { +func (m *MockNetceptorForControlCommand) Dial(arg0, arg1 string, arg2 *tls.Config) (*netceptor.Conn, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Dial", node, service, tlscfg) + ret := m.ctrl.Call(m, "Dial", arg0, arg1, arg2) ret0, _ := ret[0].(*netceptor.Conn) ret1, _ := ret[1].(error) return ret0, ret1 } // Dial indicates an expected call of Dial. -func (mr *MockNetceptorForControlCommandMockRecorder) Dial(node, service, tlscfg interface{}) *gomock.Call { +func (mr *MockNetceptorForControlCommandMockRecorder) Dial(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockNetceptorForControlCommand)(nil).Dial), node, service, tlscfg) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockNetceptorForControlCommand)(nil).Dial), arg0, arg1, arg2) } // GetClientTLSConfig mocks base method. -func (m *MockNetceptorForControlCommand) GetClientTLSConfig(name, expectedHostName string, expectedHostNameType netceptor.ExpectedHostnameType) (*tls.Config, error) { +func (m *MockNetceptorForControlCommand) GetClientTLSConfig(arg0, arg1 string, arg2 netceptor.ExpectedHostnameType) (*tls.Config, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetClientTLSConfig", name, expectedHostName, expectedHostNameType) + ret := m.ctrl.Call(m, "GetClientTLSConfig", arg0, arg1, arg2) ret0, _ := ret[0].(*tls.Config) ret1, _ := ret[1].(error) return ret0, ret1 } // GetClientTLSConfig indicates an expected call of GetClientTLSConfig. -func (mr *MockNetceptorForControlCommandMockRecorder) GetClientTLSConfig(name, expectedHostName, expectedHostNameType interface{}) *gomock.Call { +func (mr *MockNetceptorForControlCommandMockRecorder) GetClientTLSConfig(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientTLSConfig", reflect.TypeOf((*MockNetceptorForControlCommand)(nil).GetClientTLSConfig), name, expectedHostName, expectedHostNameType) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientTLSConfig", reflect.TypeOf((*MockNetceptorForControlCommand)(nil).GetClientTLSConfig), arg0, arg1, arg2) } // GetLogger mocks base method. @@ -179,9 +179,9 @@ func (mr *MockNetceptorForControlCommandMockRecorder) NodeID() *gomock.Call { } // Ping mocks base method. -func (m *MockNetceptorForControlCommand) Ping(ctx context.Context, target string, hopsToLive byte) (time.Duration, string, error) { +func (m *MockNetceptorForControlCommand) Ping(arg0 context.Context, arg1 string, arg2 byte) (time.Duration, string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Ping", ctx, target, hopsToLive) + ret := m.ctrl.Call(m, "Ping", arg0, arg1, arg2) ret0, _ := ret[0].(time.Duration) ret1, _ := ret[1].(string) ret2, _ := ret[2].(error) @@ -189,9 +189,9 @@ func (m *MockNetceptorForControlCommand) Ping(ctx context.Context, target string } // Ping indicates an expected call of Ping. -func (mr *MockNetceptorForControlCommandMockRecorder) Ping(ctx, target, hopsToLive interface{}) *gomock.Call { +func (mr *MockNetceptorForControlCommandMockRecorder) Ping(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockNetceptorForControlCommand)(nil).Ping), ctx, target, hopsToLive) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockNetceptorForControlCommand)(nil).Ping), arg0, arg1, arg2) } // Status mocks base method. @@ -209,17 +209,17 @@ func (mr *MockNetceptorForControlCommandMockRecorder) Status() *gomock.Call { } // Traceroute mocks base method. -func (m *MockNetceptorForControlCommand) Traceroute(ctx context.Context, target string) <-chan *netceptor.TracerouteResult { +func (m *MockNetceptorForControlCommand) Traceroute(arg0 context.Context, arg1 string) <-chan *netceptor.TracerouteResult { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Traceroute", ctx, target) + ret := m.ctrl.Call(m, "Traceroute", arg0, arg1) ret0, _ := ret[0].(<-chan *netceptor.TracerouteResult) return ret0 } // Traceroute indicates an expected call of Traceroute. -func (mr *MockNetceptorForControlCommandMockRecorder) Traceroute(ctx, target interface{}) *gomock.Call { +func (mr *MockNetceptorForControlCommandMockRecorder) Traceroute(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Traceroute", reflect.TypeOf((*MockNetceptorForControlCommand)(nil).Traceroute), ctx, target) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Traceroute", reflect.TypeOf((*MockNetceptorForControlCommand)(nil).Traceroute), arg0, arg1) } // MockControlCommand is a mock of ControlCommand interface. @@ -284,17 +284,17 @@ func (m *MockControlFuncOperations) EXPECT() *MockControlFuncOperationsMockRecor } // BridgeConn mocks base method. -func (m *MockControlFuncOperations) BridgeConn(message string, bc io.ReadWriteCloser, bcName string, logger *logger.ReceptorLogger) error { +func (m *MockControlFuncOperations) BridgeConn(arg0 string, arg1 io.ReadWriteCloser, arg2 string, arg3 *logger.ReceptorLogger, arg4 controlsvc.Utiler) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BridgeConn", message, bc, bcName, logger) + ret := m.ctrl.Call(m, "BridgeConn", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) return ret0 } // BridgeConn indicates an expected call of BridgeConn. -func (mr *MockControlFuncOperationsMockRecorder) BridgeConn(message, bc, bcName, logger interface{}) *gomock.Call { +func (mr *MockControlFuncOperationsMockRecorder) BridgeConn(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BridgeConn", reflect.TypeOf((*MockControlFuncOperations)(nil).BridgeConn), message, bc, bcName, logger) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BridgeConn", reflect.TypeOf((*MockControlFuncOperations)(nil).BridgeConn), arg0, arg1, arg2, arg3, arg4) } // Close mocks base method. @@ -312,17 +312,17 @@ func (mr *MockControlFuncOperationsMockRecorder) Close() *gomock.Call { } // ReadFromConn mocks base method. -func (m *MockControlFuncOperations) ReadFromConn(message string, out io.Writer) error { +func (m *MockControlFuncOperations) ReadFromConn(arg0 string, arg1 io.Writer, arg2 controlsvc.Copier) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadFromConn", message, out) + ret := m.ctrl.Call(m, "ReadFromConn", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // ReadFromConn indicates an expected call of ReadFromConn. -func (mr *MockControlFuncOperationsMockRecorder) ReadFromConn(message, out interface{}) *gomock.Call { +func (mr *MockControlFuncOperationsMockRecorder) ReadFromConn(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFromConn", reflect.TypeOf((*MockControlFuncOperations)(nil).ReadFromConn), message, out) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFromConn", reflect.TypeOf((*MockControlFuncOperations)(nil).ReadFromConn), arg0, arg1, arg2) } // RemoteAddr mocks base method. @@ -340,15 +340,15 @@ func (mr *MockControlFuncOperationsMockRecorder) RemoteAddr() *gomock.Call { } // WriteToConn mocks base method. -func (m *MockControlFuncOperations) WriteToConn(message string, in chan []byte) error { +func (m *MockControlFuncOperations) WriteToConn(arg0 string, arg1 chan []byte) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WriteToConn", message, in) + ret := m.ctrl.Call(m, "WriteToConn", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // WriteToConn indicates an expected call of WriteToConn. -func (mr *MockControlFuncOperationsMockRecorder) WriteToConn(message, in interface{}) *gomock.Call { +func (mr *MockControlFuncOperationsMockRecorder) WriteToConn(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteToConn", reflect.TypeOf((*MockControlFuncOperations)(nil).WriteToConn), message, in) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteToConn", reflect.TypeOf((*MockControlFuncOperations)(nil).WriteToConn), arg0, arg1) } diff --git a/pkg/controlsvc/ping.go b/pkg/controlsvc/ping.go index aa58dc3bb..4c3f3644f 100644 --- a/pkg/controlsvc/ping.go +++ b/pkg/controlsvc/ping.go @@ -6,24 +6,24 @@ import ( ) type ( - pingCommandType struct{} - pingCommand struct { + PingCommandType struct{} + PingCommand struct { target string } ) -func (t *pingCommandType) InitFromString(params string) (ControlCommand, error) { +func (t *PingCommandType) InitFromString(params string) (ControlCommand, error) { if params == "" { return nil, fmt.Errorf("no ping target") } - c := &pingCommand{ + c := &PingCommand{ target: params, } return c, nil } -func (t *pingCommandType) InitFromJSON(config map[string]interface{}) (ControlCommand, error) { +func (t *PingCommandType) InitFromJSON(config map[string]interface{}) (ControlCommand, error) { target, ok := config["target"] if !ok { return nil, fmt.Errorf("no ping target") @@ -32,14 +32,14 @@ func (t *pingCommandType) InitFromJSON(config map[string]interface{}) (ControlCo if !ok { return nil, fmt.Errorf("ping target must be string") } - c := &pingCommand{ + c := &PingCommand{ target: targetStr, } return c, nil } -func (c *pingCommand) ControlFunc(ctx context.Context, nc NetceptorForControlCommand, _ ControlFuncOperations) (map[string]interface{}, error) { +func (c *PingCommand) ControlFunc(ctx context.Context, nc NetceptorForControlCommand, _ ControlFuncOperations) (map[string]interface{}, error) { pingTime, pingRemote, err := nc.Ping(ctx, c.target, nc.MaxForwardingHops()) cfr := make(map[string]interface{}) if err == nil { diff --git a/pkg/controlsvc/ping_test.go b/pkg/controlsvc/ping_test.go new file mode 100644 index 000000000..6b606a658 --- /dev/null +++ b/pkg/controlsvc/ping_test.go @@ -0,0 +1,128 @@ +package controlsvc_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/ansible/receptor/pkg/controlsvc" + "github.com/ansible/receptor/pkg/controlsvc/mock_controlsvc" + "github.com/golang/mock/gomock" +) + +func TestPingInitFromString(t *testing.T) { + pingCommandType := controlsvc.PingCommandType{} + + initFromStringTestCases := []struct { + name string + expectedError bool + errorMessage string + input string + }{ + { + name: "no ping target 1", + expectedError: true, + errorMessage: "no ping target", + input: "", + }, + { + name: "single param - pass", + expectedError: false, + errorMessage: "", + input: "one", + }, + } + + for _, testCase := range initFromStringTestCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := pingCommandType.InitFromString(testCase.input) + + CheckExpectedError(testCase.expectedError, testCase.errorMessage, t, err) + }) + } +} + +type InitFromJSONTestCase struct { + name string + expectedError bool + errorMessage string + input map[string]interface{} +} + +func BuildInitFromJSONTestCases(name string, expectedError bool, errorMessage string, input map[string]interface{}) InitFromJSONTestCase { + return InitFromJSONTestCase{ + name: name, + expectedError: expectedError, + errorMessage: errorMessage, + input: input, + } +} + +func TestPingInitFromJSON(t *testing.T) { + pingCommandType := controlsvc.PingCommandType{} + + initFromJSONTestCases := []InitFromJSONTestCase{ + BuildInitFromJSONTestCases("no ping target 2", true, "no ping target", map[string]interface{}{}), + BuildInitFromJSONTestCases("ping target must be string", true, "ping target must be string", map[string]interface{}{"target": 7}), + BuildInitFromJSONTestCases("three params - pass", false, "", map[string]interface{}{"target": "some target"}), + } + + for _, testCase := range initFromJSONTestCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := pingCommandType.InitFromJSON(testCase.input) + + CheckExpectedError(testCase.expectedError, testCase.errorMessage, t, err) + }) + } +} + +func TestPingControlFunc(t *testing.T) { + pingCommand := controlsvc.PingCommand{} + ctrl := gomock.NewController(t) + mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl) + mockControlFunc := mock_controlsvc.NewMockControlFuncOperations(ctrl) + + controlFuncTestCases := []struct { + name string + expectedError bool + errorMessage string + expectedCalls func() + }{ + { + name: "ping error", + expectedError: true, + errorMessage: "terminated ping", + expectedCalls: func() { + mockNetceptor.EXPECT().Ping(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.Second, "", errors.New("terminated ping")) + mockNetceptor.EXPECT().MaxForwardingHops() + }, + }, + { + name: "control func pass", + errorMessage: "", + expectedError: false, + expectedCalls: func() { + mockNetceptor.EXPECT().Ping(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.Second, "", nil) + mockNetceptor.EXPECT().MaxForwardingHops() + }, + }, + } + + for _, testCase := range controlFuncTestCases { + t.Run(testCase.name, func(t *testing.T) { + testCase.expectedCalls() + + cfr, _ := pingCommand.ControlFunc(context.Background(), mockNetceptor, mockControlFunc) + err, ok := cfr["Error"] + + if testCase.expectedError && testCase.errorMessage != err { + t.Errorf("expected: %s , received: %s", testCase.errorMessage, err) + } + + if !testCase.expectedError && ok { + t.Error(cfr["Error"]) + } + }) + } +} diff --git a/pkg/controlsvc/reload.go b/pkg/controlsvc/reload.go index 2a04167a2..498158e5e 100644 --- a/pkg/controlsvc/reload.go +++ b/pkg/controlsvc/reload.go @@ -11,8 +11,8 @@ import ( ) type ( - reloadCommandType struct{} - reloadCommand struct{} + ReloadCommandType struct{} + ReloadCommand struct{} ) var configPath = "" @@ -136,14 +136,14 @@ func checkReload() error { return parseConfigForReload(configPath, true) } -func (t *reloadCommandType) InitFromString(_ string) (ControlCommand, error) { - c := &reloadCommand{} +func (t *ReloadCommandType) InitFromString(_ string) (ControlCommand, error) { + c := &ReloadCommand{} return c, nil } -func (t *reloadCommandType) InitFromJSON(_ map[string]interface{}) (ControlCommand, error) { - c := &reloadCommand{} +func (t *ReloadCommandType) InitFromJSON(_ map[string]interface{}) (ControlCommand, error) { + c := &ReloadCommand{} return c, nil } @@ -157,7 +157,7 @@ func handleError(err error, errorcode int, logger *logger.ReceptorLogger) (map[s return cfr, nil } -func (c *reloadCommand) ControlFunc(_ context.Context, nc NetceptorForControlCommand, _ ControlFuncOperations) (map[string]interface{}, error) { +func (c *ReloadCommand) ControlFunc(_ context.Context, nc NetceptorForControlCommand, _ ControlFuncOperations) (map[string]interface{}, error) { // Reload command stops all backends, and re-runs the ParseAndRun() on the // initial config file nc.GetLogger().Debug("Reloading") diff --git a/pkg/controlsvc/status.go b/pkg/controlsvc/status.go index 74230f2c5..62c21362a 100644 --- a/pkg/controlsvc/status.go +++ b/pkg/controlsvc/status.go @@ -9,22 +9,22 @@ import ( ) type ( - statusCommandType struct{} - statusCommand struct { + StatusCommandType struct{} + StatusCommand struct { requestedFields []string } ) -func (t *statusCommandType) InitFromString(params string) (ControlCommand, error) { +func (t *StatusCommandType) InitFromString(params string) (ControlCommand, error) { if params != "" { return nil, fmt.Errorf("status command does not take parameters") } - c := &statusCommand{} + c := &StatusCommand{} return c, nil } -func (t *statusCommandType) InitFromJSON(config map[string]interface{}) (ControlCommand, error) { +func (t *StatusCommandType) InitFromJSON(config map[string]interface{}) (ControlCommand, error) { requestedFields, ok := config["requested_fields"] var requestedFieldsStr []string if ok { @@ -39,14 +39,14 @@ func (t *statusCommandType) InitFromJSON(config map[string]interface{}) (Control } else { requestedFieldsStr = nil } - c := &statusCommand{ + c := &StatusCommand{ requestedFields: requestedFieldsStr, } return c, nil } -func (c *statusCommand) ControlFunc(_ context.Context, nc NetceptorForControlCommand, _ ControlFuncOperations) (map[string]interface{}, error) { +func (c *StatusCommand) ControlFunc(_ context.Context, nc NetceptorForControlCommand, _ ControlFuncOperations) (map[string]interface{}, error) { status := nc.Status() statusGetters := make(map[string]func() interface{}) statusGetters["Version"] = func() interface{} { return version.Version } diff --git a/pkg/controlsvc/status_test.go b/pkg/controlsvc/status_test.go new file mode 100644 index 000000000..12a69b567 --- /dev/null +++ b/pkg/controlsvc/status_test.go @@ -0,0 +1,124 @@ +package controlsvc_test + +import ( + "context" + "testing" + + "github.com/ansible/receptor/pkg/controlsvc" + "github.com/ansible/receptor/pkg/controlsvc/mock_controlsvc" + "github.com/ansible/receptor/pkg/netceptor" + "github.com/golang/mock/gomock" +) + +func TestStatusInitFromString(t *testing.T) { + statusCommandType := controlsvc.StatusCommandType{} + + initFromStringTestCases := []struct { + name string + expectedError bool + errorMessage string + input string + }{ + { + name: "status command does not take parameters", + expectedError: true, + errorMessage: "status command does not take parameters", + input: "one", + }, + { + name: "pass without params", + expectedError: false, + errorMessage: "", + input: "", + }, + } + + for _, testCase := range initFromStringTestCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := statusCommandType.InitFromString(testCase.input) + + CheckExpectedError(testCase.expectedError, testCase.errorMessage, t, err) + }) + } +} + +func TestStatusInitFromJSON(t *testing.T) { + statusCommandType := controlsvc.StatusCommandType{} + + initFromJSONTestCases := []struct { + name string + expectedError bool + errorMessage string + input map[string]interface{} + }{ + { + name: "each element of requested_fields must be a string", + expectedError: true, + errorMessage: "each element of requested_fields must be a string", + input: map[string]interface{}{ + "requested_fields": []interface{}{ + 0: 7, + }, + }, + }, + { + name: "pass with no requested fields", + expectedError: false, + errorMessage: "", + input: map[string]interface{}{}, + }, + { + name: "pass with requested fields", + expectedError: false, + errorMessage: "", + input: map[string]interface{}{ + "requested_fields": []interface{}{ + 0: "request", + }, + }, + }, + } + + for _, testCase := range initFromJSONTestCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := statusCommandType.InitFromJSON(testCase.input) + + CheckExpectedError(testCase.expectedError, testCase.errorMessage, t, err) + }) + } +} + +func TestStatusControlFunc(t *testing.T) { + statusCommand := controlsvc.StatusCommand{} + ctrl := gomock.NewController(t) + mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl) + mockControlFunc := mock_controlsvc.NewMockControlFuncOperations(ctrl) + + controlFuncTestCases := []struct { + name string + expectedError bool + errorMessage string + expectedCalls func() + }{ + { + name: "control func pass", + errorMessage: "", + expectedError: false, + expectedCalls: func() { + mockNetceptor.EXPECT().Status().Return(netceptor.Status{}) + }, + }, + } + + for _, testCase := range controlFuncTestCases { + t.Run(testCase.name, func(t *testing.T) { + testCase.expectedCalls() + + _, err := statusCommand.ControlFunc(context.Background(), mockNetceptor, mockControlFunc) + + if !testCase.expectedError && err != nil { + t.Error(err) + } + }) + } +} diff --git a/pkg/controlsvc/traceroute.go b/pkg/controlsvc/traceroute.go index b7f69d801..6081b2f33 100644 --- a/pkg/controlsvc/traceroute.go +++ b/pkg/controlsvc/traceroute.go @@ -7,24 +7,24 @@ import ( ) type ( - tracerouteCommandType struct{} - tracerouteCommand struct { + TracerouteCommandType struct{} + TracerouteCommand struct { target string } ) -func (t *tracerouteCommandType) InitFromString(params string) (ControlCommand, error) { +func (t *TracerouteCommandType) InitFromString(params string) (ControlCommand, error) { if params == "" { return nil, fmt.Errorf("no traceroute target") } - c := &tracerouteCommand{ + c := &TracerouteCommand{ target: params, } return c, nil } -func (t *tracerouteCommandType) InitFromJSON(config map[string]interface{}) (ControlCommand, error) { +func (t *TracerouteCommandType) InitFromJSON(config map[string]interface{}) (ControlCommand, error) { target, ok := config["target"] if !ok { return nil, fmt.Errorf("no traceroute target") @@ -33,14 +33,14 @@ func (t *tracerouteCommandType) InitFromJSON(config map[string]interface{}) (Con if !ok { return nil, fmt.Errorf("traceroute target must be string") } - c := &tracerouteCommand{ + c := &TracerouteCommand{ target: targetStr, } return c, nil } -func (c *tracerouteCommand) ControlFunc(ctx context.Context, nc NetceptorForControlCommand, _ ControlFuncOperations) (map[string]interface{}, error) { +func (c *TracerouteCommand) ControlFunc(ctx context.Context, nc NetceptorForControlCommand, _ ControlFuncOperations) (map[string]interface{}, error) { cfr := make(map[string]interface{}) results := nc.Traceroute(ctx, c.target) i := 0 diff --git a/pkg/controlsvc/traceroute_test.go b/pkg/controlsvc/traceroute_test.go new file mode 100644 index 000000000..165a035a0 --- /dev/null +++ b/pkg/controlsvc/traceroute_test.go @@ -0,0 +1,124 @@ +package controlsvc_test + +import ( + "context" + "errors" + "testing" + + "github.com/ansible/receptor/pkg/controlsvc" + "github.com/ansible/receptor/pkg/controlsvc/mock_controlsvc" + "github.com/ansible/receptor/pkg/netceptor" + "github.com/golang/mock/gomock" +) + +func TestTracerouteInitFromString(t *testing.T) { + tracerouteCommandType := controlsvc.TracerouteCommandType{} + + initFromStringTestCases := []struct { + name string + expectedError bool + errorMessage string + input string + }{ + { + name: "no traceroute target 1", + expectedError: true, + errorMessage: "no traceroute target", + input: "", + }, + { + name: "pass with params", + expectedError: false, + errorMessage: "", + input: "one", + }, + } + + for _, testCase := range initFromStringTestCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := tracerouteCommandType.InitFromString(testCase.input) + + CheckExpectedError(testCase.expectedError, testCase.errorMessage, t, err) + }) + } +} + +func TestTracerouteInitFromJSON(t *testing.T) { + tracerouteCommandType := controlsvc.TracerouteCommandType{} + + initFromJSONTestCases := []InitFromJSONTestCase{ + BuildInitFromJSONTestCases("no traceroute target 2", true, "no traceroute target", map[string]interface{}{}), + BuildInitFromJSONTestCases("traceroute target must be string", true, "traceroute target must be string", map[string]interface{}{"target": 7}), + BuildInitFromJSONTestCases("pass with target", false, "", map[string]interface{}{"target": "some target"}), + } + + for _, testCase := range initFromJSONTestCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := tracerouteCommandType.InitFromJSON(testCase.input) + + CheckExpectedError(testCase.expectedError, testCase.errorMessage, t, err) + }) + } +} + +func TestTracerouteControlFunc(t *testing.T) { + tracerouteCommand := controlsvc.TracerouteCommand{} + ctrl := gomock.NewController(t) + mockNetceptor := mock_controlsvc.NewMockNetceptorForControlsvc(ctrl) + mockControlFunc := mock_controlsvc.NewMockControlFuncOperations(ctrl) + + controlFuncTestCases := []struct { + name string + expectedError bool + errorMessage string + expectedCalls func() + }{ + { + name: "control func pass with result error", + errorMessage: "terminated", + expectedError: false, + expectedCalls: func() { + c := make(chan *netceptor.TracerouteResult) + + go func() { + c <- &netceptor.TracerouteResult{ + Err: errors.New("terminated"), + } + close(c) + }() + mockNetceptor.EXPECT().Traceroute(gomock.Any(), gomock.Any()).Return(c) + }, + }, + { + name: "control func pass", + errorMessage: "", + expectedError: false, + expectedCalls: func() { + c := make(chan *netceptor.TracerouteResult) + + go func() { + c <- &netceptor.TracerouteResult{} + close(c) + }() + mockNetceptor.EXPECT().Traceroute(gomock.Any(), gomock.Any()).Return(c) + }, + }, + } + + for _, testCase := range controlFuncTestCases { + t.Run(testCase.name, func(t *testing.T) { + testCase.expectedCalls() + + cfr, _ := tracerouteCommand.ControlFunc(context.Background(), mockNetceptor, mockControlFunc) + err, ok := cfr["Error"] + + if testCase.expectedError && testCase.errorMessage != err { + t.Errorf("expected: %s , received: %s", testCase.errorMessage, err) + } + + if !testCase.expectedError && ok { + t.Error(cfr["Error"]) + } + }) + } +}