diff --git a/cmd/ocm-backplane/accessrequest/createAccessRequest.go b/cmd/ocm-backplane/accessrequest/createAccessRequest.go index fdff1a25..25c8533b 100644 --- a/cmd/ocm-backplane/accessrequest/createAccessRequest.go +++ b/cmd/ocm-backplane/accessrequest/createAccessRequest.go @@ -8,11 +8,12 @@ import ( "github.com/openshift/backplane-cli/pkg/accessrequest" - ocmcli "github.com/openshift-online/ocm-cli/pkg/ocm" - "github.com/openshift/backplane-cli/pkg/login" - "github.com/openshift/backplane-cli/pkg/utils" logger "github.com/sirupsen/logrus" "github.com/spf13/cobra" + + "github.com/openshift/backplane-cli/pkg/login" + "github.com/openshift/backplane-cli/pkg/ocm" + "github.com/openshift/backplane-cli/pkg/utils" ) var ( @@ -96,7 +97,7 @@ func runCreateAccessRequest(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to compute cluster ID: %v", err) } - ocmConnection, err := ocmcli.NewConnection().Build() + ocmConnection, err := ocm.DefaultOCMInterface.SetupOCMConnection() if err != nil { return fmt.Errorf("failed to create OCM connection: %v", err) } diff --git a/cmd/ocm-backplane/accessrequest/expireAccessRequest.go b/cmd/ocm-backplane/accessrequest/expireAccessRequest.go index 673bd999..e5e011b7 100644 --- a/cmd/ocm-backplane/accessrequest/expireAccessRequest.go +++ b/cmd/ocm-backplane/accessrequest/expireAccessRequest.go @@ -4,8 +4,8 @@ import ( "fmt" "github.com/openshift/backplane-cli/pkg/accessrequest" + "github.com/openshift/backplane-cli/pkg/ocm" - ocmcli "github.com/openshift-online/ocm-cli/pkg/ocm" "github.com/spf13/cobra" ) @@ -30,7 +30,7 @@ func runExpireAccessRequest(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to compute cluster ID: %v", err) } - ocmConnection, err := ocmcli.NewConnection().Build() + ocmConnection, err := ocm.DefaultOCMInterface.SetupOCMConnection() if err != nil { return fmt.Errorf("failed to create OCM connection: %v", err) } diff --git a/cmd/ocm-backplane/accessrequest/getAccessRequest.go b/cmd/ocm-backplane/accessrequest/getAccessRequest.go index 82f3edd5..08e73e2c 100644 --- a/cmd/ocm-backplane/accessrequest/getAccessRequest.go +++ b/cmd/ocm-backplane/accessrequest/getAccessRequest.go @@ -4,8 +4,8 @@ import ( "fmt" "github.com/openshift/backplane-cli/pkg/accessrequest" + "github.com/openshift/backplane-cli/pkg/ocm" - ocmcli "github.com/openshift-online/ocm-cli/pkg/ocm" logger "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -31,7 +31,7 @@ func runGetAccessRequest(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to compute cluster ID: %v", err) } - ocmConnection, err := ocmcli.NewConnection().Build() + ocmConnection, err := ocm.DefaultOCMInterface.SetupOCMConnection() if err != nil { return fmt.Errorf("failed to create OCM connection: %v", err) } diff --git a/cmd/ocm-backplane/cloud/console.go b/cmd/ocm-backplane/cloud/console.go index 68a14664..86379fe2 100644 --- a/cmd/ocm-backplane/cloud/console.go +++ b/cmd/ocm-backplane/cloud/console.go @@ -6,8 +6,6 @@ import ( "os" "strconv" - ocmsdk "github.com/openshift-online/ocm-cli/pkg/ocm" - "github.com/openshift/backplane-cli/pkg/ocm" "github.com/pkg/browser" @@ -130,9 +128,9 @@ func runConsole(cmd *cobra.Command, argv []string) (err error) { logger.Infof("Using backplane URL: %s\n", backplaneConfiguration.URL) // Initialize OCM connection - ocmConnection, err := ocmsdk.NewConnection().Build() + ocmConnection, err := ocm.DefaultOCMInterface.SetupOCMConnection() if err != nil { - return fmt.Errorf("unable to build ocm sdk: %w", err) + return fmt.Errorf("failed to create OCM connection: %w", err) } defer ocmConnection.Close() diff --git a/cmd/ocm-backplane/cloud/credentials.go b/cmd/ocm-backplane/cloud/credentials.go index ea69509e..ff320177 100644 --- a/cmd/ocm-backplane/cloud/credentials.go +++ b/cmd/ocm-backplane/cloud/credentials.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" - ocmsdk "github.com/openshift-online/ocm-cli/pkg/ocm" logger "github.com/sirupsen/logrus" "github.com/spf13/cobra" "sigs.k8s.io/yaml" @@ -98,9 +97,9 @@ func runCredentials(cmd *cobra.Command, argv []string) error { logger.Infof("Using backplane URL: %s\n", backplaneConfiguration.URL) // Initialize OCM connection - ocmConnection, err := ocmsdk.NewConnection().Build() + ocmConnection, err := ocm.DefaultOCMInterface.SetupOCMConnection() if err != nil { - return fmt.Errorf("unable to build ocm sdk: %w", err) + return fmt.Errorf("failed to create OCM connection: %w", err) } defer ocmConnection.Close() diff --git a/go.mod b/go.mod index d305793d..b32b0d6a 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.19.0 + github.com/trivago/tgo v1.0.7 golang.org/x/term v0.22.0 gopkg.in/AlecAivazis/survey.v1 v1.8.8 k8s.io/api v0.28.3 @@ -123,7 +124,6 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect - github.com/trivago/tgo v1.0.7 // indirect github.com/xlab/treeprint v1.2.0 // indirect github.com/zalando/go-keyring v0.2.3 // indirect go.starlark.net v0.0.0-20230525235612-a134d8f9ddca // indirect diff --git a/pkg/ocm/mocks/ocmWrapperMock.go b/pkg/ocm/mocks/ocmWrapperMock.go index 37848449..eb33977a 100644 --- a/pkg/ocm/mocks/ocmWrapperMock.go +++ b/pkg/ocm/mocks/ocmWrapperMock.go @@ -279,3 +279,18 @@ func (mr *MockOCMInterfaceMockRecorder) IsProduction() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProduction", reflect.TypeOf((*MockOCMInterface)(nil).IsProduction)) } + +// SetupOCMConnection mocks base method. +func (m *MockOCMInterface) SetupOCMConnection() (*sdk.Connection, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetupOCMConnection") + ret0, _ := ret[0].(*sdk.Connection) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SetupOCMConnection indicates an expected call of SetupOCMConnection. +func (mr *MockOCMInterfaceMockRecorder) SetupOCMConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetupOCMConnection", reflect.TypeOf((*MockOCMInterface)(nil).SetupOCMConnection)) +} diff --git a/pkg/ocm/ocmWrapper.go b/pkg/ocm/ocmWrapper.go index a69b8e79..bedbf31c 100644 --- a/pkg/ocm/ocmWrapper.go +++ b/pkg/ocm/ocmWrapper.go @@ -35,6 +35,7 @@ type OCMInterface interface { GetClusterActiveAccessRequest(ocmConnection *ocmsdk.Connection, clusterID string) (*acctrspv1.AccessRequest, error) CreateClusterAccessRequest(ocmConnection *ocmsdk.Connection, clusterID, reason, jiraIssueID, approvalDuration string) (*acctrspv1.AccessRequest, error) CreateAccessRequestDecision(ocmConnection *ocmsdk.Connection, accessRequest *acctrspv1.AccessRequest, decision acctrspv1.DecisionDecision, justification string) (*acctrspv1.Decision, error) + SetupOCMConnection() (*ocmsdk.Connection, error) } const ( @@ -45,9 +46,26 @@ type DefaultOCMInterfaceImpl struct{} var DefaultOCMInterface OCMInterface = &DefaultOCMInterfaceImpl{} +func (o *DefaultOCMInterfaceImpl) SetupOCMConnection() (*ocmsdk.Connection, error) { + ocmNotLoggedInMessage := "Not logged in" + + // Setup connection at the first try + connection, err := ocm.NewConnection().Build() + if err != nil { + if strings.Contains(err.Error(), ocmNotLoggedInMessage) { + return nil, fmt.Errorf("please make sure you have logged into OCM, " + + "use \"ocm login --use-auth-code --url $ENV\" to login ") + } else { + return nil, err + } + } + + return connection, nil +} + // IsClusterHibernating returns a boolean to indicate whether the cluster is hibernating func (o *DefaultOCMInterfaceImpl) IsClusterHibernating(clusterID string) (bool, error) { - connection, err := ocm.NewConnection().Build() + connection, err := o.SetupOCMConnection() if err != nil { return false, fmt.Errorf("failed to create OCM connection: %v", err) } @@ -64,7 +82,7 @@ func (o *DefaultOCMInterfaceImpl) IsClusterHibernating(clusterID string) (bool, // GetTargetCluster returns one single cluster based on the search key and survery. func (o *DefaultOCMInterfaceImpl) GetTargetCluster(clusterKey string) (clusterID, clusterName string, err error) { // Create the client for the OCM API: - connection, err := ocm.NewConnection().Build() + connection, err := o.SetupOCMConnection() if err != nil { return "", "", fmt.Errorf("failed to create OCM connection: %v", err) } @@ -96,7 +114,7 @@ func (o *DefaultOCMInterfaceImpl) GetTargetCluster(clusterKey string) (clusterID // for the given clusterID func (o *DefaultOCMInterfaceImpl) GetManagingCluster(targetClusterID string) (clusterID, clusterName string, isHostedControlPlane bool, err error) { // Create the client for the OCM API: - connection, err := ocm.NewConnection().Build() + connection, err := o.SetupOCMConnection() if err != nil { return "", "", false, fmt.Errorf("failed to create OCM connection: %v", err) } @@ -153,7 +171,7 @@ func (o *DefaultOCMInterfaceImpl) GetManagingCluster(targetClusterID string) (cl // GetServiceCluster gets the service cluster for a given hpyershift hosted cluster func (o *DefaultOCMInterfaceImpl) GetServiceCluster(targetClusterID string) (clusterID, clusterName string, err error) { // Create the client for the OCM API - connection, err := ocm.NewConnection().Build() + connection, err := o.SetupOCMConnection() if err != nil { return "", "", fmt.Errorf("failed to create OCM connection: %v", err) } @@ -207,7 +225,7 @@ func (o *DefaultOCMInterfaceImpl) GetServiceCluster(targetClusterID string) (clu // GetOCMAccessToken initiates the OCM connection and returns the access token func (o *DefaultOCMInterfaceImpl) GetOCMAccessToken() (*string, error) { // Get ocm access token - connection, err := ocm.NewConnection().Build() + connection, err := o.SetupOCMConnection() if err != nil { return nil, fmt.Errorf("failed to create OCM connection: %v", err) } @@ -236,7 +254,7 @@ func (o *DefaultOCMInterfaceImpl) GetPullSecret() (string, error) { // Get ocm access token logger.Debugln("Finding ocm token") - connection, err := ocm.NewConnection().Build() + connection, err := o.SetupOCMConnection() if err != nil { return "", fmt.Errorf("failed to create OCM connection: %v", err) } @@ -256,7 +274,7 @@ func (o *DefaultOCMInterfaceImpl) GetPullSecret() (string, error) { // for a given internal cluster id. func (o *DefaultOCMInterfaceImpl) GetClusterInfoByID(clusterID string) (*cmv1.Cluster, error) { // Create the client for the OCM API: - connection, err := ocm.NewConnection().Build() + connection, err := o.SetupOCMConnection() if err != nil { return nil, fmt.Errorf("failed to create OCM connection: %v", err) } @@ -282,7 +300,7 @@ func (o *DefaultOCMInterfaceImpl) GetClusterInfoByIDWithConn(ocmConnection *ocms // IsProduction checks if OCM is currently in production env func (o *DefaultOCMInterfaceImpl) IsProduction() (bool, error) { // Create the client for the OCM API: - connection, err := ocm.NewConnection().Build() + connection, err := o.SetupOCMConnection() if err != nil { return false, fmt.Errorf("failed to create OCM connection: %v", err) } @@ -302,7 +320,7 @@ func (o *DefaultOCMInterfaceImpl) GetStsSupportJumpRoleARN(ocmConnection *ocmsdk // GetBackplaneURL returns the Backplane API URL based on the OCM env func (o *DefaultOCMInterfaceImpl) GetOCMEnvironment() (*cmv1.Environment, error) { // Create the client for the OCM API - connection, err := ocm.NewConnection().Build() + connection, err := o.SetupOCMConnection() if err != nil { return nil, fmt.Errorf("failed to create OCM connection: %v", err) }