Skip to content

Commit

Permalink
OCM-11405 | fix: A few bug fixes for registry config
Browse files Browse the repository at this point in the history
fix fint and fmt

Signed-off-by: Maggie Chen <magchen@redhat.com>

change text

Signed-off-by: Maggie Chen <magchen@redhat.com>
  • Loading branch information
chenz4027 authored and openshift-cherrypick-robot committed Sep 26, 2024
1 parent 81075e6 commit 18a63d9
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 24 deletions.
21 changes: 16 additions & 5 deletions cmd/describe/cluster/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,14 @@ func run(cmd *cobra.Command, argv []string) {
strings.Join(cluster.AWS().AdditionalAllowedPrincipals(), ","))
}
if cluster.RegistryConfig() != nil {
registryConfigOutput, err := getClusterRegistryConfig(cluster, r.OCMClient)
if err != nil {
r.Reporter.Errorf("Failed to get cluster registry config for cluster '%s': %v", clusterKey, err)
os.Exit(1)
}
str = fmt.Sprintf("%s"+
"Registry Configuration:\n"+
"%s\n", str, getClusterRegistryConfig(cluster))
"%s\n", str, registryConfigOutput)
}
}

Expand Down Expand Up @@ -843,7 +848,7 @@ func getAuditLogForwardingStatus(cluster *cmv1.Cluster) string {
return auditLogForwardingStatus
}

func getClusterRegistryConfig(cluster *cmv1.Cluster) string {
func getClusterRegistryConfig(cluster *cmv1.Cluster, client *ocm.Client) (string, error) {
var output string
if cluster.RegistryConfig().RegistrySources() != nil {
registryResources := cluster.RegistryConfig().RegistrySources()
Expand Down Expand Up @@ -876,12 +881,18 @@ func getClusterRegistryConfig(cluster *cmv1.Cluster) string {
}
}
if cluster.RegistryConfig().PlatformAllowlist().ID() != "" {
allowlist, err := client.GetAllowlist(cluster.RegistryConfig().PlatformAllowlist().ID())
if err != nil {
return output, err
}

output = fmt.Sprintf("%s"+
" - Platform Allowlist: %s\n", output,
cluster.RegistryConfig().PlatformAllowlist().ID())
" - Platform Allowlist: %s\n"+
" - Registries: %s\n", output,
cluster.RegistryConfig().PlatformAllowlist().ID(), strings.Join(allowlist.Registries(), ","))
}

return output
return output, nil
}

func getExternalAuthConfigStatus(cluster *cmv1.Cluster) string {
Expand Down
15 changes: 14 additions & 1 deletion cmd/describe/cluster/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
. "github.com/onsi/ginkgo/v2/dsl/table"
. "github.com/onsi/gomega"
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"

"github.com/openshift/rosa/pkg/logging"
"github.com/openshift/rosa/pkg/ocm"
)

const (
Expand Down Expand Up @@ -137,6 +140,15 @@ var _ = Describe("Cluster description", Ordered, func() {
})

var _ = Describe("getClusterRegistryConfig", func() {
var client *ocm.Client
BeforeEach(func() {
// todo this test expects and uses a real ocm client
// disabling the test until we can mock this to run in prow
Skip("disabling test until ocm client is mocked")
c, err := ocm.NewClient().Logger(logging.NewLogger()).Build()
Expect(err).NotTo(HaveOccurred())
client = c
})
It("Should return expected output", func() {
mockCluster, err := cmv1.NewCluster().RegistryConfig(cmv1.NewClusterRegistryConfig().
RegistrySources(cmv1.NewRegistrySources().
Expand All @@ -147,7 +159,8 @@ var _ = Describe("getClusterRegistryConfig", func() {
DomainName("quay.io").Insecure(true)).
PlatformAllowlist(cmv1.NewRegistryAllowlist().ID("test-id"))).Build()
Expect(err).NotTo(HaveOccurred())
output := getClusterRegistryConfig(mockCluster)
output, err := getClusterRegistryConfig(mockCluster, client)
Expect(err).NotTo(HaveOccurred())
expectedOutput := " - Allowed Registries: allow1.com,allow2.com\n" +
" - Blocked Registries: block1.com,block2.com\n" +
" - Insecure Registries: insecure1.com,insecure2.com\n" +
Expand Down
6 changes: 6 additions & 0 deletions cmd/edit/cluster/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,12 @@ func run(cmd *cobra.Command, _ []string) {
os.Exit(1)
}
if clusterRegistryConfigArgs != nil {
prompt := "Changing any registry related parameter will trigger a rollout across all machinepools " +
"(all machinepool nodes will be recreated, following pod draining from each node). Do you want to proceed?"
if !confirm.ConfirmRaw(prompt) {
r.Reporter.Warnf("You have not changed any registry configuration -- exiting.")
os.Exit(0)
}
allowedRegistries, blockedRegistries, insecureRegistries,
additionalTrustedCa, allowedRegistriesForImport := clusterregistryconfig.GetClusterRegistryConfigArgs(
clusterRegistryConfigArgs)
Expand Down
48 changes: 30 additions & 18 deletions pkg/clusterregistryconfig/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ func GetClusterRegistryConfigOptions(cmd *pflag.FlagSet,

result := &ClusterRegistryConfigArgs{}

if args.allowedRegistries != nil && args.blockedRegistries != nil {
return nil, fmt.Errorf("Allowed registries and blocked registries are mutually exclusive fields")
}

result.allowedRegistries = args.allowedRegistries
result.insecureRegistries = args.insecureRegistries
result.blockedRegistries = args.blockedRegistries
Expand Down Expand Up @@ -139,6 +143,9 @@ func GetClusterRegistryConfigOptions(cmd *pflag.FlagSet,
}

enableRegistriesConfig := IsClusterRegistryConfigSetViaCLI(cmd)
if cluster.RegistryConfig() != nil {
enableRegistriesConfig = true
}

if !enableRegistriesConfig && interactive.Enabled() && isHostedCP {
updateRegistriesConfigValue, err := interactive.GetBool(interactive.Input{
Expand All @@ -152,33 +159,38 @@ func GetClusterRegistryConfigOptions(cmd *pflag.FlagSet,
}

if enableRegistriesConfig && interactive.Enabled() {
allowedRegistriesInputs, err := interactive.GetString(interactive.Input{
Question: "Allowed Registries",
Help: cmd.Lookup(allowedRegistriesFlag).Usage,
Default: strings.Join(defaultAllowedRegistries, ","),
})
if err != nil {
return nil, fmt.Errorf("Expected a valid value for allowed registries: %s", err)
// Allowed registries and blocked registries are mutually exclusive
if result.blockedRegistries == nil {
allowedRegistriesInputs, err := interactive.GetString(interactive.Input{
Question: "Allowed Registries",
Help: cmd.Lookup(allowedRegistriesFlag).Usage,
Default: strings.Join(defaultAllowedRegistries, ","),
})
if err != nil {
return nil, fmt.Errorf("Expected a comma-separated list of allowed registries: %s", err)
}
result.allowedRegistries = helper.HandleEmptyStringOnSlice(strings.Split(allowedRegistriesInputs, ","))
}
result.allowedRegistries = helper.HandleEmptyStringOnSlice(strings.Split(allowedRegistriesInputs, ","))

blockedRegistriesInputs, err := interactive.GetString(interactive.Input{
Question: "Blocked Registries",
Help: cmd.Lookup(blockedRegistriesFlag).Usage,
Default: strings.Join(defaultBlockedRegistries, ","),
})
if err != nil {
return nil, fmt.Errorf("Expected a valid value for blocked registries: %s", err)
if result.allowedRegistries == nil {
blockedRegistriesInputs, err := interactive.GetString(interactive.Input{
Question: "Blocked Registries",
Help: cmd.Lookup(blockedRegistriesFlag).Usage,
Default: strings.Join(defaultBlockedRegistries, ","),
})
if err != nil {
return nil, fmt.Errorf("Expected a comma-separated list of blocked registries: %s", err)
}
result.blockedRegistries = helper.HandleEmptyStringOnSlice(strings.Split(blockedRegistriesInputs, ","))
}
result.blockedRegistries = helper.HandleEmptyStringOnSlice(strings.Split(blockedRegistriesInputs, ","))

insecureRegistriesInputs, err := interactive.GetString(interactive.Input{
Question: "Insecure Registries",
Help: cmd.Lookup(insecureRegistriesFlag).Usage,
Default: strings.Join(defaultInsecureRegistries, ","),
})
if err != nil {
return nil, fmt.Errorf("Expected a valid value for insecure registries: %s", err)
return nil, fmt.Errorf("Expected a comma-separated list of insecure registries: %s", err)
}
result.insecureRegistries = helper.HandleEmptyStringOnSlice(strings.Split(insecureRegistriesInputs, ","))

Expand All @@ -191,7 +203,7 @@ func GetClusterRegistryConfigOptions(cmd *pflag.FlagSet,
},
})
if err != nil {
return nil, fmt.Errorf("Expected a valid value for allowed registries for import: %s", err)
return nil, fmt.Errorf("Expected a comma-separated list of allowed registries for import: %s", err)
}

result.additionalTrustedCa, err = interactive.GetString(interactive.Input{
Expand Down
11 changes: 11 additions & 0 deletions pkg/ocm/registry_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,14 @@ func BuildAllowedRegistriesForImport(allowedRegistriesForImport string) (map[str
}
return obj, nil
}

func (c *Client) GetAllowlist(id string) (*cmv1.RegistryAllowlist, error) {
response, err := c.ocm.ClustersMgmt().V1().RegistryAllowlists().
RegistryAllowlist(id).Get().
Send()
if err != nil {
return nil, handleErr(response.Error(), err)
}

return response.Body(), nil
}
102 changes: 102 additions & 0 deletions pkg/ocm/registry_config_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,95 @@
package ocm

import (
"bytes"
"net/http"
"time"

. "github.com/onsi/ginkgo/v2/dsl/core"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/ghttp"
sdk "github.com/openshift-online/ocm-sdk-go"
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
"github.com/openshift-online/ocm-sdk-go/logging"
. "github.com/openshift-online/ocm-sdk-go/testing"
)

var _ = Describe("Allowlist", func() {

var ssoServer, apiServer *ghttp.Server
var ocmClient *Client
var body string
var allowlist *cmv1.RegistryAllowlist

BeforeEach(func() {
// Create the servers:
ssoServer = MakeTCPServer()
apiServer = MakeTCPServer()
apiServer.SetAllowUnhandledRequests(true)
apiServer.SetUnhandledRequestStatusCode(http.StatusInternalServerError)

// Create the token:
accessToken := MakeTokenString("Bearer", 15*time.Minute)

// Prepare the server:
ssoServer.AppendHandlers(
RespondWithAccessToken(accessToken),
)
// Prepare the logger:
logger, err := logging.NewGoLoggerBuilder().
Debug(true).
Build()
Expect(err).To(BeNil())
// Set up the connection with the fake config
connection, err := sdk.NewConnectionBuilder().
Logger(logger).
Tokens(accessToken).
URL(apiServer.URL()).
Build()
// Initialize client object
Expect(err).To(BeNil())
ocmClient = &Client{ocm: connection}

allowlist, body, err = CreateAllowlist()
Expect(err).NotTo(HaveOccurred())
})

AfterEach(func() {
// Close the servers:
ssoServer.Close()
apiServer.Close()
Expect(ocmClient.Close()).To(Succeed())
})

It("KO: fails to get allowlist if returns error", func() {
apiServer.AppendHandlers(
RespondWithJSON(
http.StatusBadRequest,
body,
),
)

_, err := ocmClient.GetAllowlist("id")
Expect(err).To(HaveOccurred())
})

It("OK: gets allowlist when it exists", func() {

apiServer.AppendHandlers(
RespondWithJSON(
http.StatusOK,
body,
),
)

output, err := ocmClient.GetAllowlist("allowlist-id")

Expect(err).To(BeNil())
Expect(output).To(Not(BeNil()))
Expect(output.ID()).To(Equal(allowlist.ID()))
})
})

var _ = Describe("Registry Config", func() {
Context("BuildAllowedRegistriesForImport", func() {
It("OK: should pass if the user passes a valid string", func() {
Expand Down Expand Up @@ -61,3 +145,21 @@ var _ = Describe("Registry Config", func() {
})
})
})

func CreateAllowlist() (*cmv1.RegistryAllowlist, string, error) {
builder := cmv1.NewRegistryAllowlist()
allowlist, err := builder.ID("allowlist-id").
Registries([]string{"quay.io", "registry.redhat.io"}...).Build()
if err != nil {
return &cmv1.RegistryAllowlist{}, "", err
}

var buf bytes.Buffer
err = cmv1.MarshalRegistryAllowlist(allowlist, &buf)

if err != nil {
return &cmv1.RegistryAllowlist{}, "", err
}

return allowlist, buf.String(), nil
}

0 comments on commit 18a63d9

Please sign in to comment.