diff --git a/upgrade/change_v21.03.0.go b/upgrade/change_v21.03.0.go index e3e81341518..87ebba284bf 100644 --- a/upgrade/change_v21.03.0.go +++ b/upgrade/change_v21.03.0.go @@ -152,7 +152,7 @@ func dropDeprecated(dg *dgo.Dgraph) error { } func upgradePersitentQuery() error { - dg, cb := x.GetDgraphClient(Upgrade.Conf, true) + dg, cb := x.GetDgraphClient(Upgrade.Conf, hasAclCreds()) defer cb() jwt, err := getAccessJwt() @@ -205,7 +205,7 @@ func upgradePersitentQuery() error { } func upgradeCORS() error { - dg, cb := x.GetDgraphClient(Upgrade.Conf, true) + dg, cb := x.GetDgraphClient(Upgrade.Conf, hasAclCreds()) defer cb() jwt, err := getAccessJwt() @@ -215,7 +215,7 @@ func upgradeCORS() error { // Get CORS. corsData := make(map[string][]cors) - if err = getQueryResult(dg, queryCORS_v21_03_0, &corsData); err != nil { + if err := getQueryResult(dg, queryCORS_v21_03_0, &corsData); err != nil { return errors.Wrap(err, "error querying cors") } @@ -239,7 +239,7 @@ func upgradeCORS() error { // Get GraphQL schema. schemaData := make(map[string][]sch) - if err = getQueryResult(dg, querySchema_v21_03_0, &schemaData); err != nil { + if err := getQueryResult(dg, querySchema_v21_03_0, &schemaData); err != nil { return errors.Wrap(err, "error querying graphql schema") } diff --git a/upgrade/upgrade.go b/upgrade/upgrade.go index f20a1b9b3e6..fa53bdfad53 100644 --- a/upgrade/upgrade.go +++ b/upgrade/upgrade.go @@ -41,7 +41,6 @@ var ( type versionComparisonResult uint8 const ( - acl = "acl" dryRun = "dry-run" alpha = "alpha" slashGrpc = "slash_grpc_endpoint" @@ -151,9 +150,9 @@ func init() { }, Annotations: map[string]string{"group": "tool"}, } + Upgrade.EnvPrefix = "DGRAPH_UPGRADE" Upgrade.Cmd.SetHelpTemplate(x.NonRootTemplate) flag := Upgrade.Cmd.Flags() - flag.Bool(acl, false, "upgrade ACL from v1.2.2 to >=v20.03.0") flag.Bool(dryRun, false, "dry-run the upgrade") flag.StringP(alpha, "a", "127.0.0.1:9080", "Comma separated list of Dgraph Alpha gRPC server address") @@ -188,24 +187,11 @@ func run() { } func validateAndParseInput() (*commandInput, error) { - if !Upgrade.Conf.GetBool(acl) { - return nil, formatAsFlagParsingError(acl, - fmt.Errorf("we only support acl upgrade as of now")) - } - _, _, err := net.SplitHostPort(strings.TrimSpace(Upgrade.Conf.GetString(alpha))) if err != nil { return nil, formatAsFlagParsingError(alpha, err) } - if strings.TrimSpace(Upgrade.Conf.GetString(user)) == "" { - return nil, formatAsFlagRequiredError(user) - } - - if strings.TrimSpace(Upgrade.Conf.GetString(password)) == "" { - return nil, formatAsFlagRequiredError(password) - } - fromVersionParsed, err := parseVersionFromString(Upgrade.Conf.GetString(from)) if err != nil { return nil, formatAsFlagParsingError(from, err) diff --git a/upgrade/utils.go b/upgrade/utils.go index d576fa68ad5..5c2fa43ab32 100644 --- a/upgrade/utils.go +++ b/upgrade/utils.go @@ -33,8 +33,15 @@ import ( "github.com/pkg/errors" ) +func hasAclCreds() bool { + return len(Upgrade.Conf.GetString(user)) > 0 +} + // getAccessJwt gets the access jwt token from by logging into the cluster. func getAccessJwt() (*api.Jwt, error) { + if !hasAclCreds() { + return &api.Jwt{}, nil + } user := Upgrade.Conf.GetString(user) password := Upgrade.Conf.GetString(password) header := http.Header{}