Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanyu66 committed Sep 5, 2024
1 parent ecddff0 commit f12e11c
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 103 deletions.
29 changes: 10 additions & 19 deletions internal/cli/serverless/dataimport/start/azblob.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (

"tidbcloud-cli/internal"
"tidbcloud-cli/internal/flag"
"tidbcloud-cli/internal/telemetry"
"tidbcloud-cli/internal/ui"
"tidbcloud-cli/internal/util"

Expand Down Expand Up @@ -49,7 +48,6 @@ func (o AzBlobOpts) SupportedFileTypes() []string {
func (o AzBlobOpts) Run(cmd *cobra.Command) error {
ctx := cmd.Context()
var fileType, uri, sasToken string
clusterId := o.clusterId
var authType imp.ImportAzureBlobAuthTypeEnum
var format *imp.CSVFormat
d, err := o.h.Client()
Expand All @@ -58,11 +56,6 @@ func (o AzBlobOpts) Run(cmd *cobra.Command) error {
}

if o.interactive {
cmd.Annotations[telemetry.InteractiveMode] = "true"
if !o.h.IOStreams.CanPrompt {
return errors.New("The terminal doesn't support interactive mode, please use non-interactive mode")
}

// interactive mode
authTypes := []interface{}{imp.IMPORTAZUREBLOBAUTHTYPEENUM_SAS_TOKEN}
model, err := ui.InitialSelectModel(authTypes, "Choose the auth type:")
Expand Down Expand Up @@ -123,10 +116,6 @@ func (o AzBlobOpts) Run(cmd *cobra.Command) error {
}
} else {
// non-interactive mode
clusterId, err = cmd.Flags().GetString(flag.ClusterID)
if err != nil {
return errors.Trace(err)
}
fileType, err = cmd.Flags().GetString(flag.FileType)
if err != nil {
return errors.Trace(err)
Expand All @@ -153,29 +142,31 @@ func (o AzBlobOpts) Run(cmd *cobra.Command) error {
authType = imp.IMPORTAZUREBLOBAUTHTYPEENUM_SAS_TOKEN

// optional flags
format, err = getCSVFlagValue(cmd)
if err != nil {
return errors.Trace(err)
if fileType == string(imp.IMPORTFILETYPEENUM_CSV) {
format, err = getCSVFlagValue(cmd)
if err != nil {
return errors.Trace(err)
}
}
}

cmd.Annotations[telemetry.ClusterID] = clusterId

source := imp.NewImportSource(imp.IMPORTSOURCETYPEENUM_AZURE_BLOB)
source.AzureBlob = imp.NewAzureBlobSource(authType, uri)
source.AzureBlob.AuthType = authType
source.AzureBlob.SasToken = &sasToken
options := imp.NewImportOptions(imp.ImportFileTypeEnum(fileType))
options.CsvFormat = format
if fileType == string(imp.IMPORTFILETYPEENUM_CSV) {
options.CsvFormat = format
}
body := imp.NewImportServiceCreateImportBody(*options, *source)

if o.h.IOStreams.CanPrompt {
err := spinnerWaitStartOp(ctx, o.h, d, clusterId, body)
err := spinnerWaitStartOp(ctx, o.h, d, o.clusterId, body)
if err != nil {
return err
}
} else {
err := waitStartOp(ctx, o.h, d, clusterId, body)
err := waitStartOp(ctx, o.h, d, o.clusterId, body)
if err != nil {
return err
}
Expand Down
29 changes: 10 additions & 19 deletions internal/cli/serverless/dataimport/start/gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (

"tidbcloud-cli/internal"
"tidbcloud-cli/internal/flag"
"tidbcloud-cli/internal/telemetry"
"tidbcloud-cli/internal/ui"
"tidbcloud-cli/internal/util"

Expand Down Expand Up @@ -50,7 +49,6 @@ func (o GCSOpts) SupportedFileTypes() []string {
func (o GCSOpts) Run(cmd *cobra.Command) error {
ctx := cmd.Context()
var fileType, gcsUri, accountKey string
clusterId := o.clusterId
var authType imp.ImportGcsAuthTypeEnum
var format *imp.CSVFormat
d, err := o.h.Client()
Expand All @@ -59,11 +57,6 @@ func (o GCSOpts) Run(cmd *cobra.Command) error {
}

if o.interactive {
cmd.Annotations[telemetry.InteractiveMode] = "true"
if !o.h.IOStreams.CanPrompt {
return errors.New("The terminal doesn't support interactive mode, please use non-interactive mode")
}

// interactive mode
authTypes := []interface{}{imp.IMPORTGCSAUTHTYPEENUM_SERVICE_ACCOUNT_KEY}
model, err := ui.InitialSelectModel(authTypes, "Choose the auth type:")
Expand Down Expand Up @@ -124,10 +117,6 @@ func (o GCSOpts) Run(cmd *cobra.Command) error {
}
} else {
// non-interactive mode
clusterId, err = cmd.Flags().GetString(flag.ClusterID)
if err != nil {
return errors.Trace(err)
}
fileType, err = cmd.Flags().GetString(flag.FileType)
if err != nil {
return errors.Trace(err)
Expand All @@ -153,28 +142,30 @@ func (o GCSOpts) Run(cmd *cobra.Command) error {
authType = imp.IMPORTGCSAUTHTYPEENUM_SERVICE_ACCOUNT_KEY

// optional flags
format, err = getCSVFlagValue(cmd)
if err != nil {
return errors.Trace(err)
if fileType == string(imp.IMPORTFILETYPEENUM_CSV) {
format, err = getCSVFlagValue(cmd)
if err != nil {
return errors.Trace(err)
}
}
}

cmd.Annotations[telemetry.ClusterID] = clusterId

source := imp.NewImportSource(imp.IMPORTSOURCETYPEENUM_GCS)
source.Gcs = imp.NewGCSSource(gcsUri, authType)
source.Gcs.ServiceAccountKey = aws.String(accountKey)
options := imp.NewImportOptions(imp.ImportFileTypeEnum(fileType))
options.CsvFormat = format
if fileType == string(imp.IMPORTFILETYPEENUM_CSV) {
options.CsvFormat = format
}
body := imp.NewImportServiceCreateImportBody(*options, *source)

if o.h.IOStreams.CanPrompt {
err := spinnerWaitStartOp(ctx, o.h, d, clusterId, body)
err := spinnerWaitStartOp(ctx, o.h, d, o.clusterId, body)
if err != nil {
return err
}
} else {
err := waitStartOp(ctx, o.h, d, clusterId, body)
err := waitStartOp(ctx, o.h, d, o.clusterId, body)
if err != nil {
return err
}
Expand Down
36 changes: 16 additions & 20 deletions internal/cli/serverless/dataimport/start/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"tidbcloud-cli/internal/config"
"tidbcloud-cli/internal/flag"
"tidbcloud-cli/internal/service/aws/s3"
"tidbcloud-cli/internal/telemetry"
"tidbcloud-cli/internal/ui"
"tidbcloud-cli/internal/util"

Expand Down Expand Up @@ -64,7 +63,6 @@ func (o LocalOpts) SupportedFileTypes() []string {
func (o LocalOpts) Run(cmd *cobra.Command) error {
ctx := cmd.Context()
var fileType, targetDatabase, targetTable, filePath string
clusterId := o.clusterId
var format *imp.CSVFormat
d, err := o.h.Client()
if err != nil {
Expand All @@ -77,11 +75,6 @@ func (o LocalOpts) Run(cmd *cobra.Command) error {
}

if o.interactive {
cmd.Annotations[telemetry.InteractiveMode] = "true"
if !o.h.IOStreams.CanPrompt {
return errors.New("The terminal doesn't support interactive mode, please use non-interactive mode")
}

// interactive mode
var fileTypes []interface{}
for _, f := range o.SupportedFileTypes() {
Expand Down Expand Up @@ -132,27 +125,28 @@ func (o LocalOpts) Run(cmd *cobra.Command) error {
}
} else {
// non-interactive mode
clusterId = cmd.Flag(flag.ClusterID).Value.String()
fileType = cmd.Flag(flag.FileType).Value.String()
if !slices.Contains(o.SupportedFileTypes(), fileType) {
return fmt.Errorf("file type \"%s\" is not supported, please use one of %q", fileType, o.SupportedFileTypes())
}
targetDatabase = cmd.Flag(flag.LocalTargetDatabase).Value.String()
targetTable = cmd.Flag(flag.LocalTargetTable).Value.String()
f := cmd.Flags().Lookup(flag.LocalFilePath)
if !f.Changed {
filePath, err = cmd.Flags().GetString(flag.LocalFilePath)
if err != nil {
return errors.Trace(err)
}
if filePath == "" {
return errors.New("required flag(s) \"local.file-path\" not set")
}
filePath = f.Value.String()

format, err = getCSVFlagValue(cmd)
if err != nil {
return errors.Trace(err)
if fileType == string(imp.IMPORTFILETYPEENUM_CSV) {
format, err = getCSVFlagValue(cmd)
if err != nil {
return errors.Trace(err)
}
}
}

cmd.Annotations[telemetry.ClusterID] = clusterId

uploadFile, err := os.Open(filePath)
if err != nil {
return err
Expand All @@ -172,7 +166,7 @@ func (o LocalOpts) Run(cmd *cobra.Command) error {
DatabaseName: aws.String(targetDatabase),
TableName: aws.String(targetTable),
ContentLength: aws.Int64(stat.Size()),
ClusterID: clusterId,
ClusterID: o.clusterId,
Body: uploadFile,
}
if o.h.IOStreams.CanPrompt {
Expand All @@ -190,16 +184,18 @@ func (o LocalOpts) Run(cmd *cobra.Command) error {
source := imp.NewImportSource(imp.IMPORTSOURCETYPEENUM_LOCAL)
source.Local = imp.NewLocalSource(uploadID, targetDatabase, targetTable)
options := imp.NewImportOptions(imp.ImportFileTypeEnum(fileType))
options.CsvFormat = format
if fileType == string(imp.IMPORTFILETYPEENUM_CSV) {
options.CsvFormat = format
}
body := imp.NewImportServiceCreateImportBody(*options, *source)

if o.h.IOStreams.CanPrompt {
err := spinnerWaitStartOp(ctx, o.h, d, clusterId, body)
err := spinnerWaitStartOp(ctx, o.h, d, o.clusterId, body)
if err != nil {
return err
}
} else {
err := waitStartOp(ctx, o.h, d, clusterId, body)
err := waitStartOp(ctx, o.h, d, o.clusterId, body)
if err != nil {
return err
}
Expand Down
29 changes: 10 additions & 19 deletions internal/cli/serverless/dataimport/start/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (

"tidbcloud-cli/internal"
"tidbcloud-cli/internal/flag"
"tidbcloud-cli/internal/telemetry"
"tidbcloud-cli/internal/ui"
"tidbcloud-cli/internal/util"
imp "tidbcloud-cli/pkg/tidbcloud/v1beta1/serverless/import"
Expand Down Expand Up @@ -48,7 +47,6 @@ func (o S3Opts) SupportedFileTypes() []string {
func (o S3Opts) Run(cmd *cobra.Command) error {
ctx := cmd.Context()
var fileType, s3Uri, s3Arn, accessKeyID, secretAccessKey string
clusterId := o.clusterId
var authType imp.ImportS3AuthTypeEnum
var format *imp.CSVFormat
d, err := o.h.Client()
Expand All @@ -57,11 +55,6 @@ func (o S3Opts) Run(cmd *cobra.Command) error {
}

if o.interactive {
cmd.Annotations[telemetry.InteractiveMode] = "true"
if !o.h.IOStreams.CanPrompt {
return errors.New("The terminal doesn't support interactive mode, please use non-interactive mode")
}

// interactive mode
authTypes := []interface{}{imp.IMPORTS3AUTHTYPEENUM_ROLE_ARN, imp.IMPORTS3AUTHTYPEENUM_ACCESS_KEY}
model, err := ui.InitialSelectModel(authTypes, "Choose the auth type:")
Expand Down Expand Up @@ -140,10 +133,6 @@ func (o S3Opts) Run(cmd *cobra.Command) error {
}
} else {
// non-interactive mode
clusterId, err = cmd.Flags().GetString(flag.ClusterID)
if err != nil {
return errors.Trace(err)
}
fileType, err = cmd.Flags().GetString(flag.FileType)
if err != nil {
return errors.Trace(err)
Expand All @@ -160,9 +149,11 @@ func (o S3Opts) Run(cmd *cobra.Command) error {
}

// optional flags
format, err = getCSVFlagValue(cmd)
if err != nil {
return errors.Trace(err)
if fileType == string(imp.IMPORTFILETYPEENUM_CSV) {
format, err = getCSVFlagValue(cmd)
if err != nil {
return errors.Trace(err)
}
}
s3Arn, err = cmd.Flags().GetString(flag.S3RoleArn)
if err != nil {
Expand All @@ -185,8 +176,6 @@ func (o S3Opts) Run(cmd *cobra.Command) error {
}
}

cmd.Annotations[telemetry.ClusterID] = clusterId

source := imp.NewImportSource(imp.IMPORTSOURCETYPEENUM_S3)
source.S3 = imp.NewS3Source(s3Uri, authType)
if authType == imp.IMPORTS3AUTHTYPEENUM_ROLE_ARN {
Expand All @@ -200,16 +189,18 @@ func (o S3Opts) Run(cmd *cobra.Command) error {
}
}
options := imp.NewImportOptions(imp.ImportFileTypeEnum(fileType))
options.CsvFormat = format
if fileType == string(imp.IMPORTFILETYPEENUM_CSV) {
options.CsvFormat = format
}
body := imp.NewImportServiceCreateImportBody(*options, *source)

if o.h.IOStreams.CanPrompt {
err := spinnerWaitStartOp(ctx, o.h, d, clusterId, body)
err := spinnerWaitStartOp(ctx, o.h, d, o.clusterId, body)
if err != nil {
return err
}
} else {
err := waitStartOp(ctx, o.h, d, clusterId, body)
err := waitStartOp(ctx, o.h, d, o.clusterId, body)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit f12e11c

Please sign in to comment.