diff --git a/cmd/describe/cluster/cmd.go b/cmd/describe/cluster/cmd.go index 83cc8ca28b..4331ae9f99 100644 --- a/cmd/describe/cluster/cmd.go +++ b/cmd/describe/cluster/cmd.go @@ -497,9 +497,14 @@ func run(cmd *cobra.Command, argv []string) { strings.Join(cluster.AWS().AdditionalAllowedPrincipals(), ",")) } if cluster.RegistryConfig() != nil { + registryConfigOutput, err := getClusterRegistryConfig(cluster, r.OCMClient) + if err != nil { + r.Reporter.Errorf("Failed to get cluster registry config for cluster '%s': %v", clusterKey, err) + os.Exit(1) + } str = fmt.Sprintf("%s"+ "Registry Configuration:\n"+ - "%s\n", str, getClusterRegistryConfig(cluster)) + "%s\n", str, registryConfigOutput) } } @@ -843,7 +848,7 @@ func getAuditLogForwardingStatus(cluster *cmv1.Cluster) string { return auditLogForwardingStatus } -func getClusterRegistryConfig(cluster *cmv1.Cluster) string { +func getClusterRegistryConfig(cluster *cmv1.Cluster, client *ocm.Client) (string, error) { var output string if cluster.RegistryConfig().RegistrySources() != nil { registryResources := cluster.RegistryConfig().RegistrySources() @@ -876,12 +881,18 @@ func getClusterRegistryConfig(cluster *cmv1.Cluster) string { } } if cluster.RegistryConfig().PlatformAllowlist().ID() != "" { + allowlist, err := client.GetAllowlist(cluster.RegistryConfig().PlatformAllowlist().ID()) + if err != nil { + return output, err + } + output = fmt.Sprintf("%s"+ - " - Platform Allowlist: %s\n", output, - cluster.RegistryConfig().PlatformAllowlist().ID()) + " - Platform Allowlist: %s\n"+ + " - Registries: %s\n", output, + cluster.RegistryConfig().PlatformAllowlist().ID(), strings.Join(allowlist.Registries(), ",")) } - return output + return output, nil } func getExternalAuthConfigStatus(cluster *cmv1.Cluster) string { diff --git a/cmd/describe/cluster/cmd_test.go b/cmd/describe/cluster/cmd_test.go index f46ee0e119..c2eb3e40ea 100644 --- a/cmd/describe/cluster/cmd_test.go +++ b/cmd/describe/cluster/cmd_test.go @@ -9,6 +9,9 @@ import ( . "github.com/onsi/ginkgo/v2/dsl/table" . "github.com/onsi/gomega" cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" + + "github.com/openshift/rosa/pkg/logging" + "github.com/openshift/rosa/pkg/ocm" ) const ( @@ -137,6 +140,15 @@ var _ = Describe("Cluster description", Ordered, func() { }) var _ = Describe("getClusterRegistryConfig", func() { + var client *ocm.Client + BeforeEach(func() { + // todo this test expects and uses a real ocm client + // disabling the test until we can mock this to run in prow + Skip("disabling test until ocm client is mocked") + c, err := ocm.NewClient().Logger(logging.NewLogger()).Build() + Expect(err).NotTo(HaveOccurred()) + client = c + }) It("Should return expected output", func() { mockCluster, err := cmv1.NewCluster().RegistryConfig(cmv1.NewClusterRegistryConfig(). RegistrySources(cmv1.NewRegistrySources(). @@ -147,7 +159,8 @@ var _ = Describe("getClusterRegistryConfig", func() { DomainName("quay.io").Insecure(true)). PlatformAllowlist(cmv1.NewRegistryAllowlist().ID("test-id"))).Build() Expect(err).NotTo(HaveOccurred()) - output := getClusterRegistryConfig(mockCluster) + output, err := getClusterRegistryConfig(mockCluster, client) + Expect(err).NotTo(HaveOccurred()) expectedOutput := " - Allowed Registries: allow1.com,allow2.com\n" + " - Blocked Registries: block1.com,block2.com\n" + " - Insecure Registries: insecure1.com,insecure2.com\n" + diff --git a/cmd/edit/cluster/cmd.go b/cmd/edit/cluster/cmd.go index 9ede5f6ff3..04a509724a 100644 --- a/cmd/edit/cluster/cmd.go +++ b/cmd/edit/cluster/cmd.go @@ -639,6 +639,12 @@ func run(cmd *cobra.Command, _ []string) { os.Exit(1) } if clusterRegistryConfigArgs != nil { + prompt := "Changing any registry related parameter will trigger a rollout across all machinepools " + + "(all machinepool nodes will be recreated, following pod draining from each node). Do you want to proceed?" + if !confirm.ConfirmRaw(prompt) { + r.Reporter.Warnf("You have not changed any registry configuration -- exiting.") + os.Exit(0) + } allowedRegistries, blockedRegistries, insecureRegistries, additionalTrustedCa, allowedRegistriesForImport := clusterregistryconfig.GetClusterRegistryConfigArgs( clusterRegistryConfigArgs) diff --git a/pkg/clusterregistryconfig/flags.go b/pkg/clusterregistryconfig/flags.go index c3410944c9..ffbe38ff9e 100644 --- a/pkg/clusterregistryconfig/flags.go +++ b/pkg/clusterregistryconfig/flags.go @@ -105,6 +105,10 @@ func GetClusterRegistryConfigOptions(cmd *pflag.FlagSet, result := &ClusterRegistryConfigArgs{} + if args.allowedRegistries != nil && args.blockedRegistries != nil { + return nil, fmt.Errorf("Allowed registries and blocked registries are mutually exclusive fields") + } + result.allowedRegistries = args.allowedRegistries result.insecureRegistries = args.insecureRegistries result.blockedRegistries = args.blockedRegistries @@ -139,6 +143,9 @@ func GetClusterRegistryConfigOptions(cmd *pflag.FlagSet, } enableRegistriesConfig := IsClusterRegistryConfigSetViaCLI(cmd) + if cluster.RegistryConfig() != nil { + enableRegistriesConfig = true + } if !enableRegistriesConfig && interactive.Enabled() && isHostedCP { updateRegistriesConfigValue, err := interactive.GetBool(interactive.Input{ @@ -152,25 +159,30 @@ func GetClusterRegistryConfigOptions(cmd *pflag.FlagSet, } if enableRegistriesConfig && interactive.Enabled() { - allowedRegistriesInputs, err := interactive.GetString(interactive.Input{ - Question: "Allowed Registries", - Help: cmd.Lookup(allowedRegistriesFlag).Usage, - Default: strings.Join(defaultAllowedRegistries, ","), - }) - if err != nil { - return nil, fmt.Errorf("Expected a valid value for allowed registries: %s", err) + // Allowed registries and blocked registries are mutually exclusive + if result.blockedRegistries == nil { + allowedRegistriesInputs, err := interactive.GetString(interactive.Input{ + Question: "Allowed Registries", + Help: cmd.Lookup(allowedRegistriesFlag).Usage, + Default: strings.Join(defaultAllowedRegistries, ","), + }) + if err != nil { + return nil, fmt.Errorf("Expected a comma-separated list of allowed registries: %s", err) + } + result.allowedRegistries = helper.HandleEmptyStringOnSlice(strings.Split(allowedRegistriesInputs, ",")) } - result.allowedRegistries = helper.HandleEmptyStringOnSlice(strings.Split(allowedRegistriesInputs, ",")) - blockedRegistriesInputs, err := interactive.GetString(interactive.Input{ - Question: "Blocked Registries", - Help: cmd.Lookup(blockedRegistriesFlag).Usage, - Default: strings.Join(defaultBlockedRegistries, ","), - }) - if err != nil { - return nil, fmt.Errorf("Expected a valid value for blocked registries: %s", err) + if result.allowedRegistries == nil { + blockedRegistriesInputs, err := interactive.GetString(interactive.Input{ + Question: "Blocked Registries", + Help: cmd.Lookup(blockedRegistriesFlag).Usage, + Default: strings.Join(defaultBlockedRegistries, ","), + }) + if err != nil { + return nil, fmt.Errorf("Expected a comma-separated list of blocked registries: %s", err) + } + result.blockedRegistries = helper.HandleEmptyStringOnSlice(strings.Split(blockedRegistriesInputs, ",")) } - result.blockedRegistries = helper.HandleEmptyStringOnSlice(strings.Split(blockedRegistriesInputs, ",")) insecureRegistriesInputs, err := interactive.GetString(interactive.Input{ Question: "Insecure Registries", @@ -178,7 +190,7 @@ func GetClusterRegistryConfigOptions(cmd *pflag.FlagSet, Default: strings.Join(defaultInsecureRegistries, ","), }) if err != nil { - return nil, fmt.Errorf("Expected a valid value for insecure registries: %s", err) + return nil, fmt.Errorf("Expected a comma-separated list of insecure registries: %s", err) } result.insecureRegistries = helper.HandleEmptyStringOnSlice(strings.Split(insecureRegistriesInputs, ",")) @@ -191,7 +203,7 @@ func GetClusterRegistryConfigOptions(cmd *pflag.FlagSet, }, }) if err != nil { - return nil, fmt.Errorf("Expected a valid value for allowed registries for import: %s", err) + return nil, fmt.Errorf("Expected a comma-separated list of allowed registries for import: %s", err) } result.additionalTrustedCa, err = interactive.GetString(interactive.Input{ diff --git a/pkg/ocm/registry_config.go b/pkg/ocm/registry_config.go index 3caffe2ef0..7661f1c3de 100644 --- a/pkg/ocm/registry_config.go +++ b/pkg/ocm/registry_config.go @@ -98,3 +98,14 @@ func BuildAllowedRegistriesForImport(allowedRegistriesForImport string) (map[str } return obj, nil } + +func (c *Client) GetAllowlist(id string) (*cmv1.RegistryAllowlist, error) { + response, err := c.ocm.ClustersMgmt().V1().RegistryAllowlists(). + RegistryAllowlist(id).Get(). + Send() + if err != nil { + return nil, handleErr(response.Error(), err) + } + + return response.Body(), nil +} diff --git a/pkg/ocm/registry_config_test.go b/pkg/ocm/registry_config_test.go index 487594d65f..7ac4511ca2 100644 --- a/pkg/ocm/registry_config_test.go +++ b/pkg/ocm/registry_config_test.go @@ -1,11 +1,95 @@ package ocm import ( + "bytes" + "net/http" + "time" + . "github.com/onsi/ginkgo/v2/dsl/core" . "github.com/onsi/gomega" + "github.com/onsi/gomega/ghttp" + sdk "github.com/openshift-online/ocm-sdk-go" cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" + "github.com/openshift-online/ocm-sdk-go/logging" + . "github.com/openshift-online/ocm-sdk-go/testing" ) +var _ = Describe("Allowlist", func() { + + var ssoServer, apiServer *ghttp.Server + var ocmClient *Client + var body string + var allowlist *cmv1.RegistryAllowlist + + BeforeEach(func() { + // Create the servers: + ssoServer = MakeTCPServer() + apiServer = MakeTCPServer() + apiServer.SetAllowUnhandledRequests(true) + apiServer.SetUnhandledRequestStatusCode(http.StatusInternalServerError) + + // Create the token: + accessToken := MakeTokenString("Bearer", 15*time.Minute) + + // Prepare the server: + ssoServer.AppendHandlers( + RespondWithAccessToken(accessToken), + ) + // Prepare the logger: + logger, err := logging.NewGoLoggerBuilder(). + Debug(true). + Build() + Expect(err).To(BeNil()) + // Set up the connection with the fake config + connection, err := sdk.NewConnectionBuilder(). + Logger(logger). + Tokens(accessToken). + URL(apiServer.URL()). + Build() + // Initialize client object + Expect(err).To(BeNil()) + ocmClient = &Client{ocm: connection} + + allowlist, body, err = CreateAllowlist() + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + // Close the servers: + ssoServer.Close() + apiServer.Close() + Expect(ocmClient.Close()).To(Succeed()) + }) + + It("KO: fails to get allowlist if returns error", func() { + apiServer.AppendHandlers( + RespondWithJSON( + http.StatusBadRequest, + body, + ), + ) + + _, err := ocmClient.GetAllowlist("id") + Expect(err).To(HaveOccurred()) + }) + + It("OK: gets allowlist when it exists", func() { + + apiServer.AppendHandlers( + RespondWithJSON( + http.StatusOK, + body, + ), + ) + + output, err := ocmClient.GetAllowlist("allowlist-id") + + Expect(err).To(BeNil()) + Expect(output).To(Not(BeNil())) + Expect(output.ID()).To(Equal(allowlist.ID())) + }) +}) + var _ = Describe("Registry Config", func() { Context("BuildAllowedRegistriesForImport", func() { It("OK: should pass if the user passes a valid string", func() { @@ -61,3 +145,21 @@ var _ = Describe("Registry Config", func() { }) }) }) + +func CreateAllowlist() (*cmv1.RegistryAllowlist, string, error) { + builder := cmv1.NewRegistryAllowlist() + allowlist, err := builder.ID("allowlist-id"). + Registries([]string{"quay.io", "registry.redhat.io"}...).Build() + if err != nil { + return &cmv1.RegistryAllowlist{}, "", err + } + + var buf bytes.Buffer + err = cmv1.MarshalRegistryAllowlist(allowlist, &buf) + + if err != nil { + return &cmv1.RegistryAllowlist{}, "", err + } + + return allowlist, buf.String(), nil +}