diff --git a/cmd/apply.go b/cmd/apply.go deleted file mode 100644 index d8332c5..0000000 --- a/cmd/apply.go +++ /dev/null @@ -1,68 +0,0 @@ -package cmd - -import ( - "log" - "os" - "sync" - - "github.com/liweiyi88/onedump/dump" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" -) - -var file string - -var applyCmd = &cobra.Command{ - Use: "apply -f /path/to/jobs.yaml", - Args: cobra.ExactArgs(0), - Short: "Dump database content from different sources to different destinations with a yaml config file.", - Run: func(cmd *cobra.Command, args []string) { - content, err := os.ReadFile(file) - if err != nil { - log.Fatalf("failed to read job file from %s, error: %v", file, err) - } - - var oneDump dump.Dump - err = yaml.Unmarshal(content, &oneDump) - if err != nil { - log.Fatalf("failed to read job content from %s, error: %v", file, err) - } - - err = oneDump.Validate() - if err != nil { - log.Fatalf("invalid job configuration, error: %v", err) - } - - numberOfJobs := len(oneDump.Jobs) - if numberOfJobs == 0 { - log.Printf("no job is defined in the file %s", file) - return - } - - resultCh := make(chan *dump.JobResult) - - for _, job := range oneDump.Jobs { - go func(job *dump.Job, resultCh chan *dump.JobResult) { - resultCh <- job.Run() - }(job, resultCh) - } - - var wg sync.WaitGroup - wg.Add(numberOfJobs) - go func(resultCh chan *dump.JobResult) { - for result := range resultCh { - result.Print() - wg.Done() - } - }(resultCh) - - wg.Wait() - close(resultCh) - }, -} - -func init() { - rootCmd.AddCommand(applyCmd) - applyCmd.Flags().StringVarP(&file, "file", "f", "", "jobs yaml file path.") - applyCmd.MarkFlagRequired("file") -} diff --git a/cmd/root.go b/cmd/root.go index 60a3fda..24ced76 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -4,50 +4,61 @@ import ( "fmt" "log" "os" - "strings" + "sync" "github.com/liweiyi88/onedump/dump" "github.com/spf13/cobra" + "gopkg.in/yaml.v3" ) -var ( - sshHost, sshUser, sshPrivateKeyFile string - dbDsn string - jobName string - dumpOptions []string - gzip bool -) +var file string var rootCmd = &cobra.Command{ - Use: " ", - Short: "Dump database content from a source to a destination via cli command.", - Args: cobra.ExactArgs(2), + Use: "-f /path/to/jobs.yaml", + Short: "Dump database content from different sources to different destinations with a yaml config file.", + Args: cobra.ExactArgs(0), Run: func(cmd *cobra.Command, args []string) { - driver := strings.TrimSpace(args[0]) + content, err := os.ReadFile(file) + if err != nil { + log.Fatalf("failed to read job file from %s, error: %v", file, err) + } + + var oneDump dump.Dump + err = yaml.Unmarshal(content, &oneDump) + if err != nil { + log.Fatalf("failed to read job content from %s, error: %v", file, err) + } - dumpFile := strings.TrimSpace(args[1]) - if dumpFile == "" { - log.Fatal("you must specify the dump file path. e.g. /download/dump.sql") + err = oneDump.Validate() + if err != nil { + log.Fatalf("invalid job configuration, error: %v", err) } - name := "dump via cli" - if strings.TrimSpace(jobName) != "" { - name = jobName + numberOfJobs := len(oneDump.Jobs) + if numberOfJobs == 0 { + log.Printf("no job is defined in the file %s", file) + return + } + + resultCh := make(chan *dump.JobResult) + + for _, job := range oneDump.Jobs { + go func(job *dump.Job, resultCh chan *dump.JobResult) { + resultCh <- job.Run() + }(job, resultCh) } - job := dump.NewJob( - name, - driver, - dumpFile, - dbDsn, - dump.WithGzip(gzip), - dump.WithSshHost(sshHost), - dump.WithSshUser(sshUser), - dump.WithPrivateKeyFile(sshPrivateKeyFile), - dump.WithDumpOptions(dumpOptions), - ) + var wg sync.WaitGroup + wg.Add(numberOfJobs) + go func(resultCh chan *dump.JobResult) { + for result := range resultCh { + result.Print() + wg.Done() + } + }(resultCh) - job.Run().Print() + wg.Wait() + close(resultCh) }, } @@ -59,12 +70,6 @@ func Execute() { } func init() { - rootCmd.Flags().StringVarP(&sshHost, "ssh-host", "", "", "SSH host e.g. yourdomain.com (you can omit port as it uses 22 by default) or 56.09.139.09:33. (required) ") - rootCmd.Flags().StringVarP(&sshUser, "ssh-user", "", "root", "SSH username") - rootCmd.Flags().StringVarP(&sshPrivateKeyFile, "privatekey", "f", "", "private key file path for SSH connection") - rootCmd.Flags().StringArrayVarP(&dumpOptions, "dump-options", "", nil, "use options to overwrite or add new dump command options. e.g. for mysql: --dump-options \"--no-create-info\" --dump-options \"--skip-comments\"") - rootCmd.Flags().StringVarP(&dbDsn, "db-dsn", "d", "", "the database dsn for connection. e.g. :@tcp(:)/") - rootCmd.MarkFlagRequired("db-dsn") - rootCmd.Flags().BoolVarP(&gzip, "gzip", "g", true, "if need to gzip the file") - rootCmd.Flags().StringVarP(&jobName, "job-name", "", "", "The dump job name") + rootCmd.Flags().StringVarP(&file, "file", "f", "", "jobs yaml file path.") + rootCmd.MarkFlagRequired("file") } diff --git a/driver/mysql.go b/driver/mysql.go index ca43a85..b7c3ef4 100644 --- a/driver/mysql.go +++ b/driver/mysql.go @@ -2,6 +2,7 @@ package driver import ( "fmt" + "log" "net" "os" "os/exec" @@ -115,7 +116,12 @@ host = %s` return fileName, fmt.Errorf("failed to create temp folder: %w", err) } - defer file.Close() + defer func() { + err := file.Close() + if err != nil { + log.Printf("failed to close temp file for storing mysql credentials: %v", err) + } + }() _, err = file.WriteString(contents) if err != nil { diff --git a/driver/mysql_test.go b/driver/mysql_test.go index 161a944..f328f31 100644 --- a/driver/mysql_test.go +++ b/driver/mysql_test.go @@ -2,6 +2,8 @@ package driver import ( "os" + "os/exec" + "strings" "testing" "golang.org/x/exp/slices" @@ -113,37 +115,43 @@ host = 127.0.0.1` t.Log("removed temp credential file", fileName) } -// func TestDump(t *testing.T) { -// dsn := "root@tcp(127.0.0.1:3306)/test_local" -// mysql, err := NewMysqlDumper(dsn, nil, false) -// if err != nil { -// t.Fatal(err) -// } - -// dumpfile, err := os.CreateTemp("", "dbdump") -// if err != nil { -// t.Fatal(err) -// } -// defer dumpfile.Close() - -// err = mysql.Dump(dumpfile.Name(), false) -// if err != nil { -// t.Fatal(err) -// } - -// out, err := os.ReadFile(dumpfile.Name()) -// if err != nil { -// t.Fatal("failed to read the test dump file") -// } - -// if len(out) == 0 { -// t.Fatal("test dump file is empty") -// } - -// t.Log("test dump file content size", len(out)) - -// err = os.Remove(dumpfile.Name()) -// if err != nil { -// t.Fatal("can not cleanup the test dump file", err) -// } -// } +func TestGetSshDumpCommand(t *testing.T) { + mysql, err := NewMysqlDriver(testDBDsn, nil, false) + if err != nil { + t.Fatal(err) + } + + command, err := mysql.GetSshDumpCommand() + if err != nil { + t.Errorf("failed to get dump command %v", command) + } + + if !strings.Contains(command, "mysqldump --defaults-extra-file") || !strings.Contains(command, "--skip-comments --extended-insert dump_test") { + t.Errorf("unexpected command: %s", command) + } +} + +func TestGetDumpCommand(t *testing.T) { + mysql, err := NewMysqlDriver(testDBDsn, nil, false) + if err != nil { + t.Fatal(err) + } + + mysqldumpPath, err := exec.LookPath(mysql.MysqlDumpBinaryPath) + if err != nil { + t.Fatal(err) + } + + path, args, err := mysql.GetDumpCommand() + if err != nil { + t.Error("failed to get dump command") + } + + if mysqldumpPath != path { + t.Errorf("expected mysqldump path: %s, actual got: %s", mysqldumpPath, path) + } + + if len(args) != 4 { + t.Errorf("get unexpected args, expected %d args, but got: %d", 4, len(args)) + } +} diff --git a/dump/closer.go b/dump/closer.go new file mode 100644 index 0000000..a15408b --- /dev/null +++ b/dump/closer.go @@ -0,0 +1,27 @@ +package dump + +import ( + "io" + + "github.com/hashicorp/go-multierror" +) + +type MultiCloser struct { + closers []io.Closer +} + +func NewMultiCloser(closers []io.Closer) *MultiCloser { + return &MultiCloser{ + closers: closers, + } +} + +func (m *MultiCloser) Close() error { + var err error + for _, c := range m.closers { + if e := c.Close(); e != nil { + err = multierror.Append(err, e) + } + } + return err +} diff --git a/dump/closer_test.go b/dump/closer_test.go new file mode 100644 index 0000000..2e0fc37 --- /dev/null +++ b/dump/closer_test.go @@ -0,0 +1,46 @@ +package dump + +import ( + "bytes" + "fmt" + "io" + "os" + "testing" +) + +type mockCloser struct{} + +func (m mockCloser) Close() error { + fmt.Print("close") + + return nil +} + +func TestNewMultiCloser(t *testing.T) { + + closer1 := mockCloser{} + closer2 := mockCloser{} + + closers := make([]io.Closer, 0) + closers = append(closers, closer1, closer2) + + multiCloser := NewMultiCloser(closers) + + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + multiCloser.Close() + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + io.Copy(&buf, r) + + expected := buf.String() + actual := "closeclose" + if expected != actual { + t.Errorf("expected: %s, actual: %s", expected, actual) + } +} diff --git a/dump/dump.go b/dump/job.go similarity index 59% rename from dump/dump.go rename to dump/job.go index 712c255..426536c 100644 --- a/dump/dump.go +++ b/dump/job.go @@ -5,36 +5,44 @@ import ( "compress/gzip" "errors" "fmt" + "io" + "log" "net" "os" "os/exec" "strings" + "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/liweiyi88/onedump/driver" "github.com/liweiyi88/onedump/storage" + "github.com/liweiyi88/onedump/storage/local" + "github.com/liweiyi88/onedump/storage/s3" "golang.org/x/crypto/ssh" ) +var ( + ErrMissingJobName = errors.New("job name is required") + ErrMissingDBDsn = errors.New("databse dsn is required") + ErrMissingDBDriver = errors.New("databse driver is required") +) + type Dump struct { Jobs []*Job `yaml:"jobs"` } func (dump *Dump) Validate() error { - errorCollection := make([]string, 0) + var errs error for _, job := range dump.Jobs { err := job.validate() if err != nil { - errorCollection = append(errorCollection, err.Error()) + errs = multierror.Append(errs, err) } } - if len(errorCollection) == 0 { - return nil - } - - return errors.New(strings.Join(errorCollection, ",")) + return errs } type JobResult struct { @@ -52,16 +60,18 @@ func (result *JobResult) Print() { } type Job struct { - DumpFile string `yaml:"dumpfile"` - Name string `yaml:"name"` - DBDriver string `yaml:"dbdriver"` - DBDsn string `yaml:"dbdsn"` - Gzip bool `yaml:"gzip"` - SshHost string `yaml:"sshhost"` - SshUser string `yaml:"sshuser"` - PrivateKeyFile string `yaml:"privatekeyfile"` - DumpOptions []string `yaml:"options"` - S3 *storage.AWSCredentials `yaml:"s3"` + Name string `yaml:"name"` + DBDriver string `yaml:"dbdriver"` + DBDsn string `yaml:"dbdsn"` + Gzip bool `yaml:"gzip"` + SshHost string `yaml:"sshhost"` + SshUser string `yaml:"sshuser"` + PrivateKeyFile string `yaml:"privatekeyfile"` + DumpOptions []string `yaml:"options"` + Storage struct { + Local []*local.Local `yaml:"local"` + S3 []*s3.S3 `yaml:"s3"` + } `yaml:"storage"` } type Option func(job *Job) @@ -84,7 +94,7 @@ func WithGzip(gzip bool) Option { } } -func WithDumpOptions(dumpOptions []string) Option { +func WithDumpOptions(dumpOptions ...string) Option { return func(job *Job) { job.DumpOptions = dumpOptions } @@ -100,7 +110,6 @@ func NewJob(name, driver, dumpFile, dbDsn string, opts ...Option) *Job { job := &Job{ Name: name, DBDriver: driver, - DumpFile: dumpFile, DBDsn: dbDsn, } @@ -113,19 +122,15 @@ func NewJob(name, driver, dumpFile, dbDsn string, opts ...Option) *Job { func (job Job) validate() error { if strings.TrimSpace(job.Name) == "" { - return errors.New("job name is required") - } - - if strings.TrimSpace(job.DumpFile) == "" { - return errors.New("dump file path is required") + return ErrMissingJobName } if strings.TrimSpace(job.DBDsn) == "" { - return errors.New("databse dsn is required") + return ErrMissingDBDsn } if strings.TrimSpace(job.DBDriver) == "" { - return errors.New("databse driver is required") + return ErrMissingDBDriver } return nil @@ -186,15 +191,19 @@ func (job *Job) sshDump() error { return fmt.Errorf("failed to dial remote server via ssh: %w", err) } - defer client.Close() + defer func() { + // Do not need to call session.Close() here as it will only give EOF error. + err = client.Close() + if err != nil { + log.Printf("failed to close ssh client: %v", err) + } + }() session, err := client.NewSession() if err != nil { return fmt.Errorf("failed to start ssh session: %w", err) } - defer session.Close() - err = job.dump(session) if err != nil { return err @@ -241,25 +250,18 @@ func (job *Job) Run() *JobResult { return &result } -func (job *Job) dumpToFile(sshSession *ssh.Session, store storage.Storage) error { - file, err := store.CreateDumpFile() - if err != nil { - return fmt.Errorf("failed to create storage dump file: %w", err) - } - +func (job *Job) writeToFile(sshSession *ssh.Session, file io.Writer) error { var gzipWriter *gzip.Writer if job.Gzip { gzipWriter = gzip.NewWriter(file) + defer func() { + err := gzipWriter.Close() + if err != nil { + log.Printf("failed to close gzip writer: %v", err) + } + }() } - defer func() { - if gzipWriter != nil { - gzipWriter.Close() - } - - file.Close() - }() - driver, err := job.getDBDriver() if err != nil { return fmt.Errorf("failed to get db driver: %w", err) @@ -307,60 +309,122 @@ func (job *Job) dumpToFile(sshSession *ssh.Session, store storage.Storage) error return nil } +func (job *Job) dumpToCacheFile(sshSession *ssh.Session) (string, error) { + dumpFileName := storage.UploadCacheFilePath(job.Gzip) + + file, err := os.Create(dumpFileName) + if err != nil { + return "", fmt.Errorf("failed to create dump file: %w", err) + } + + defer func() { + err := file.Close() + if err != nil { + log.Printf("failed to close dump cache file: %v", err) + } + }() + + err = job.writeToFile(sshSession, file) + if err != nil { + return "", fmt.Errorf("failed to write dump content to file: %w,", err) + } + + return file.Name(), nil +} + // The core function that dump db content to a file (locally or remotely). // It checks the filename to determine if we need to upload the file to remote storage or keep it locally. // For uploading file to S3 bucket, the filename shold follow the pattern: s3:/// . // For any remote upload, we try to cache it in a local dir then upload it to the remote storage. func (job *Job) dump(sshSession *ssh.Session) error { - store, err := job.createStorage() + err := os.MkdirAll(storage.UploadCacheDir(), 0750) if err != nil { - return fmt.Errorf("failed to create storage: %w", err) + return fmt.Errorf("failed to create upload cache dir for remote upload. %w", err) } - err = job.dumpToFile(sshSession, store) + defer func() { + err = os.RemoveAll(storage.UploadCacheDir()) + if err != nil { + log.Println("failed to remove cache dir after dump", err) + } + }() + + cacheFile, err := job.dumpToCacheFile(sshSession) + + dumpFile, err := os.Open(cacheFile) if err != nil { - return err + return fmt.Errorf("failed to open the cached dump file %w", err) } - cloudStore, ok := store.(storage.CloudStorage) - - if ok { - err := cloudStore.Upload() + defer func() { + err := dumpFile.Close() if err != nil { - return fmt.Errorf("failed to upload file to cloud storage: %w", err) + log.Printf("failed to close dump cache file for saving to destination: %v", err) } - } + }() + + job.saveToDestinations(dumpFile) return nil } -// Factory method to create the storage struct based on filename. -func (job *Job) createStorage() (storage.Storage, error) { - filename := ensureFileSuffix(job.DumpFile, job.Gzip) - s3Storage, ok, err := storage.CreateS3Storage(filename, job.S3) +func (job *Job) saveToDestinations(cacheFile io.Reader) error { + storages := job.getStorages() + numberOfStorages := len(storages) + + if numberOfStorages > 0 { + readers, writer, closer := storageReadWriteCloser(numberOfStorages) + + go func() { + io.Copy(writer, cacheFile) + closer.Close() + }() + + var wg sync.WaitGroup + wg.Add(numberOfStorages) + for i, s := range storages { + storage := s + go func(i int) { + defer wg.Done() + storage.Save(readers[i], job.Gzip) + }(i) + } - if err != nil { - return nil, err + wg.Wait() + } + + return nil +} + +func (job *Job) getStorages() []storage.Storage { + var storages []storage.Storage + if len(job.Storage.Local) > 0 { + for _, v := range job.Storage.Local { + storages = append(storages, v) + } } - if ok { - return s3Storage, nil + if len(job.Storage.S3) > 0 { + for _, v := range job.Storage.S3 { + storages = append(storages, v) + } } - return &storage.LocalStorage{ - Filename: filename, - }, nil + return storages } -// Ensure a file has proper file extension. -func ensureFileSuffix(filename string, shouldGzip bool) string { - if !shouldGzip { - return filename - } +// Pipe readers, writers and closer for fanout the same os.file +func storageReadWriteCloser(count int) ([]io.Reader, io.Writer, io.Closer) { + var prs []io.Reader + var pws []io.Writer + var pcs []io.Closer + for i := 0; i < count; i++ { + pr, pw := io.Pipe() - if strings.HasSuffix(filename, ".gz") { - return filename + prs = append(prs, pr) + pws = append(pws, pw) + pcs = append(pcs, pw) } - return filename + ".gz" + return prs, io.MultiWriter(pws...), NewMultiCloser(pcs) } diff --git a/dump/dump_test.go b/dump/job_test.go similarity index 60% rename from dump/dump_test.go rename to dump/job_test.go index 5b360eb..0b8a296 100644 --- a/dump/dump_test.go +++ b/dump/job_test.go @@ -5,11 +5,15 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" "fmt" "os" + "sync" "testing" ) +var testDBDsn = "root@tcp(127.0.0.1:3306)/dump_test" + func generateTestRSAPrivatePEMFile() (string, error) { tempDir := os.TempDir() @@ -40,23 +44,113 @@ func TestEnsureSSHHostHavePort(t *testing.T) { if ensureHaveSSHPort(sshHost) != sshHost+":22" { t.Error("ssh host port is not ensured") } + + sshHost = "127.0.0.1:22" + actual := ensureHaveSSHPort(sshHost) + if actual != sshHost { + t.Errorf("expect ssh host: %s, actual: %s", sshHost, actual) + } } -// func TestNewSshDumper(t *testing.T) { -// sshDumper := NewSshDumper("127.0.0.1", "root", "~/.ssh/test.pem") +func TestGetDBDriver(t *testing.T) { + job := NewJob("job1", "mysql", "test.sql", testDBDsn) -// if sshDumper.Host != "127.0.0.1" { -// t.Errorf("ssh host is unexpected, exepct: %s, actual: %s", "127.0.0.1", sshDumper.Host) -// } + _, err := job.getDBDriver() + if err != nil { + t.Errorf("expect get mysql db driver, but get err: %v", err) + } -// if sshDumper.User != "root" { -// t.Errorf("ssh user is unexpected, exepct: %s, actual: %s", "root", sshDumper.User) -// } + job = NewJob("job1", "x", "test.sql", testDBDsn) + _, err = job.getDBDriver() + if err == nil { + t.Error("expect unsupport database driver err, but actual get nil") + } +} -// if sshDumper.PrivateKeyFile != "~/.ssh/test.pem" { -// t.Errorf("ssh private key file path is unexpected, exepct: %s, actual: %s", "~/.ssh/test.pem", sshDumper.PrivateKeyFile) -// } -// } +func TestDumpValidate(t *testing.T) { + jobs := make([]*Job, 0) + job1 := NewJob( + "job1", + "mysql", + "test.sql", + testDBDsn, + WithGzip(true), + WithDumpOptions("--skip-comments"), + WithPrivateKeyFile("/privatekey.pen"), + WithSshUser("root"), + WithSshHost("localhost"), + ) + jobs = append(jobs, job1) + + dump := Dump{Jobs: jobs} + + err := dump.Validate() + if err != nil { + t.Errorf("expected validate dump but got err :%v", err) + } + + job2 := NewJob("", "mysql", "dump.sql", "") + jobs = append(jobs, job2) + dump.Jobs = jobs + err = dump.Validate() + + if !errors.Is(err, ErrMissingJobName) { + t.Errorf("expected err: %v, actual got: %v", ErrMissingJobName, err) + } + + job3 := NewJob("job3", "mysql", "dump.sql", "") + jobs = append(jobs, job3) + dump.Jobs = jobs + err = dump.Validate() + + if !errors.Is(err, ErrMissingDBDsn) { + t.Errorf("expected err: %v, actual got: %v", ErrMissingJobName, err) + } + + job4 := NewJob("job3", "", "dump.sql", testDBDsn) + jobs = append(jobs, job4) + dump.Jobs = jobs + err = dump.Validate() + + if !errors.Is(err, ErrMissingDBDriver) { + t.Errorf("expected err: %v, actual got: %v", ErrMissingJobName, err) + } +} + +func TestRun(t *testing.T) { + tempDir := os.TempDir() + privateKeyFile, err := generateTestRSAPrivatePEMFile() + if err != nil { + t.Errorf("failed to generate test rsa key pairs %v", err) + } + + jobs := make([]*Job, 0, 2) + job1 := NewJob("exec-dump", "mysql", tempDir+"/test.sql", testDBDsn) + job2 := NewJob("ssh", "mysql", tempDir+"/test.sql", testDBDsn, WithSshHost("127.0.0.1:2022"), WithSshUser("root"), WithPrivateKeyFile(privateKeyFile)) + + jobs = append(jobs, job1) + jobs = append(jobs, job2) + dump := Dump{Jobs: jobs} + + resultCh := make(chan *JobResult) + for _, job := range dump.Jobs { + go func(job *Job, resultCh chan *JobResult) { + resultCh <- job.Run() + }(job, resultCh) + } + + var wg sync.WaitGroup + wg.Add(2) + go func(resultCh chan *JobResult) { + for result := range resultCh { + result.Print() + wg.Done() + } + }(resultCh) + + wg.Wait() + close(resultCh) +} // func TestSSHDump(t *testing.T) { // privateKeyFile, err := generateTestRSAPrivatePEMFile() diff --git a/go.mod b/go.mod index 4383688..b0f70b7 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,14 @@ go 1.19 require ( github.com/aws/aws-sdk-go v1.44.151 github.com/go-sql-driver/mysql v1.6.0 + github.com/hashicorp/go-multierror v1.1.1 github.com/spf13/cobra v1.5.0 golang.org/x/exp v0.0.0-20220921164117-439092de6870 gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/hashicorp/errwrap v1.0.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect golang.org/x/sys v0.1.0 // indirect ) diff --git a/go.sum b/go.sum index 6e889a4..58c8b1a 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,10 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc= github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= diff --git a/storage/local.go b/storage/local.go deleted file mode 100644 index dcb5869..0000000 --- a/storage/local.go +++ /dev/null @@ -1,19 +0,0 @@ -package storage - -import ( - "fmt" - "os" -) - -type LocalStorage struct { - Filename string -} - -func (local *LocalStorage) CreateDumpFile() (*os.File, error) { - file, err := os.Create(local.Filename) - if err != nil { - return nil, fmt.Errorf("failed to create dump file") - } - - return file, err -} diff --git a/storage/local/local.go b/storage/local/local.go new file mode 100644 index 0000000..6bcf82e --- /dev/null +++ b/storage/local/local.go @@ -0,0 +1,37 @@ +package local + +import ( + "fmt" + "io" + "log" + "os" + + "github.com/liweiyi88/onedump/storage" +) + +type Local struct { + Path string `yaml:"path"` +} + +func (local *Local) Save(reader io.Reader, gzip bool) error { + path := storage.EnsureFileSuffix(local.Path, gzip) + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to create local dump file: %w", err) + } + + defer func() { + err := file.Close() + if err != nil { + log.Printf("failed to close local dump file %v", err) + } + }() + + _, err = io.Copy(file, reader) + + if err != nil { + return fmt.Errorf("failed to copy cache file to the dest file: %w", err) + } + + return nil +} diff --git a/storage/local/local_test.go b/storage/local/local_test.go new file mode 100644 index 0000000..5f737b5 --- /dev/null +++ b/storage/local/local_test.go @@ -0,0 +1,31 @@ +package local + +import ( + "os" + "strings" + "testing" +) + +func TestSave(t *testing.T) { + filename := os.TempDir() + "/test.sql.gz" + local := &Local{Path: filename} + + expected := "hello" + reader := strings.NewReader(expected) + + err := local.Save(reader, true) + if err != nil { + t.Errorf("failed to save file: %v", err) + } + + data, err := os.ReadFile(filename) + if err != nil { + t.Errorf("can not read file %s", err) + } + + if string(data) != expected { + t.Errorf("expected string: %s but actual got: %s", expected, data) + } + + defer os.Remove(filename) +} diff --git a/storage/s3.go b/storage/s3.go deleted file mode 100644 index 32034ca..0000000 --- a/storage/s3.go +++ /dev/null @@ -1,127 +0,0 @@ -package storage - -import ( - "fmt" - "log" - "os" - "strings" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3/s3manager" -) - -const s3Prefix = "s3://" - -var ErrInvalidS3Path = fmt.Errorf("invalid s3 filename, it should follow the format %s/", s3Prefix) - -func CreateS3Storage(filename string, credentials *AWSCredentials) (*S3Storage, bool, error) { - name := strings.TrimSpace(filename) - - if !strings.HasPrefix(name, s3Prefix) { - return nil, false, nil - } - - path := strings.TrimPrefix(name, s3Prefix) - - pathChunks := strings.Split(path, "/") - - if len(pathChunks) < 2 { - return nil, false, ErrInvalidS3Path - } - - bucket := pathChunks[0] - s3Filename := pathChunks[len(pathChunks)-1] - key := strings.Join(pathChunks[1:], "/") - - cacheDir := uploadCacheDir() - - return &S3Storage{ - CacheDir: cacheDir, - CacheFile: s3Filename, - CacheFilePath: fmt.Sprintf("%s/%s", cacheDir, s3Filename), - Credentials: credentials, - Bucket: bucket, - Key: key, - }, true, nil -} - -type AWSCredentials struct { - Region string `yaml:"region"` - AccessKeyId string `yaml:"access-key-id"` - SecretAccessKey string `yaml:"secret-access-key"` -} - -type S3Storage struct { - Bucket string - Key string - CacheFile string - CacheDir string - CacheFilePath string - Credentials *AWSCredentials -} - -func (s3 *S3Storage) CreateDumpFile() (*os.File, error) { - err := os.MkdirAll(s3.CacheDir, 0750) - if err != nil { - return nil, fmt.Errorf("failed to create upload cache dir for remote upload. %w", err) - } - - file, err := os.Create(s3.CacheFilePath) - if err != nil { - return nil, fmt.Errorf("failed to create dump file in cache dir. %w", err) - } - - return file, err -} - -func (s3 *S3Storage) Upload() error { - uploadFile, err := os.Open(s3.CacheFilePath) - if err != nil { - return fmt.Errorf("failed to open dumped file %w", err) - } - - defer func() { - uploadFile.Close() - - // Remove local cache dir after uploading to s3 bucket. - log.Printf("removing cache dir %s ... ", s3.CacheDir) - err = os.RemoveAll(s3.CacheDir) - if err != nil { - log.Println("failed to remove cache dir after uploading to s3", err) - } - }() - - var awsConfig aws.Config - if s3.Credentials != nil { - if s3.Credentials.Region != "" { - awsConfig.Region = aws.String(s3.Credentials.Region) - } - - if s3.Credentials.AccessKeyId != "" && s3.Credentials.SecretAccessKey != "" { - awsConfig.Credentials = credentials.NewStaticCredentials(s3.Credentials.AccessKeyId, s3.Credentials.SecretAccessKey, "") - } - } - - session := session.Must(session.NewSession(&awsConfig)) - uploader := s3manager.NewUploader(session) - - log.Printf("uploading file %s to s3...", uploadFile.Name()) - // TODO: implement re-try - _, uploadErr := uploader.Upload(&s3manager.UploadInput{ - Bucket: aws.String(s3.Bucket), - Key: aws.String(s3.Key), - Body: uploadFile, - }) - - if uploadErr != nil { - return fmt.Errorf("failed to upload file to s3 bucket %w", uploadErr) - } - - return nil -} - -func (s3 *S3Storage) CloudFilePath() string { - return fmt.Sprintf("s3://%s/%s", s3.Bucket, s3.Key) -} diff --git a/storage/s3/s3.go b/storage/s3/s3.go new file mode 100644 index 0000000..937f111 --- /dev/null +++ b/storage/s3/s3.go @@ -0,0 +1,60 @@ +package s3 + +import ( + "fmt" + "io" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/liweiyi88/onedump/storage" +) + +func NewS3(bucket, key, region, accessKeyId, secretAccessKey string) *S3 { + return &S3{ + Bucket: bucket, + Key: key, + Region: region, + AccessKeyId: accessKeyId, + SecretAccessKey: secretAccessKey, + } +} + +type S3 struct { + Bucket string + Key string + Region string `yaml:"region"` + AccessKeyId string `yaml:"access-key-id"` + SecretAccessKey string `yaml:"secret-access-key"` +} + +func (s3 *S3) Save(reader io.Reader, gzip bool) error { + var awsConfig aws.Config + + if s3.Region != "" { + awsConfig.Region = aws.String(s3.Region) + } + + if s3.AccessKeyId != "" && s3.SecretAccessKey != "" { + awsConfig.Credentials = credentials.NewStaticCredentials(s3.AccessKeyId, s3.SecretAccessKey, "") + } + + session := session.Must(session.NewSession(&awsConfig)) + uploader := s3manager.NewUploader(session) + + key := storage.EnsureFileSuffix(s3.Key, gzip) + + // TODO: implement re-try + _, uploadErr := uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String(s3.Bucket), + Key: aws.String(key), + Body: reader, + }) + + if uploadErr != nil { + return fmt.Errorf("failed to upload file to s3 bucket %w", uploadErr) + } + + return nil +} diff --git a/storage/s3/s3_test.go b/storage/s3/s3_test.go new file mode 100644 index 0000000..ed3dec2 --- /dev/null +++ b/storage/s3/s3_test.go @@ -0,0 +1,55 @@ +package s3 + +import ( + "errors" + "strings" + "testing" +) + +func TestNewS3(t *testing.T) { + expectedBucket := "onedump" + expectedKey := "/backup/dump.sql" + expectedRegion := "ap-southeast-2" + expectedAccessKeyId := "accessKey" + expectedSecretKey := "secret" + + s3 := NewS3(expectedBucket, expectedKey, expectedRegion, expectedAccessKeyId, expectedSecretKey) + + if s3.Bucket != expectedBucket { + t.Errorf("expected: %s, actual: %s", expectedBucket, s3.Bucket) + } + + if s3.Key != expectedKey { + t.Errorf("expected: %s, actual: %s", expectedBucket, s3.Key) + } + + if s3.Region != expectedRegion { + t.Errorf("expected: %s, actual: %s", expectedRegion, s3.Region) + } + + if s3.AccessKeyId != expectedAccessKeyId { + t.Errorf("expected: %s, actual: %s", expectedAccessKeyId, s3.AccessKeyId) + } + + if s3.SecretAccessKey != expectedSecretKey { + t.Errorf("expected: %s, actual: %s", expectedSecretKey, s3.SecretAccessKey) + } +} + +func TestSave(t *testing.T) { + s3 := &S3{ + Bucket: "onedump", + Key: "/backup/dump.sql", + Region: "ap-southeast-2", + AccessKeyId: "none", + SecretAccessKey: "none", + } + + reader := strings.NewReader("hello s3") + err := s3.Save(reader, true) + + actual := errors.Unwrap(err).Error() + if !strings.HasPrefix(actual, "InvalidAccessKeyId") { + t.Errorf("expeceted invalid access key id but actual got: %s", actual) + } +} diff --git a/storage/storage.go b/storage/storage.go index 78834d1..c97656c 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -2,25 +2,28 @@ package storage import ( "fmt" + "io" "log" + "math/rand" "os" + "strings" + "time" ) type Storage interface { - CreateDumpFile() (*os.File, error) -} - -type CloudStorage interface { - Upload() error - CloudFilePath() string + Save(reader io.Reader, gzip bool) error } const uploadDumpCacheDir = ".onedump" +func init() { + rand.Seed(time.Now().UnixNano()) +} + // For uploading dump file to remote storage, we need to firstly dump the db content to a dir locally. -// We firstly try to get current dir, if not successful, then try to get home dir, if still not successful we finally try temp dir -// We need to be aware of the size limit of a temp dir in different OS. -func uploadCacheDir() string { +// We firstly try to get current work dir, if not successful, then try to get home dir and finally try temp dir. +// Be aware of the size limit of a temp dir in different OS. +func UploadCacheDir() string { dir, err := os.Getwd() if err != nil { log.Printf("Cannot get the current directory: %v, using $HOME directory!", err) @@ -33,3 +36,31 @@ func uploadCacheDir() string { return fmt.Sprintf("%s/%s", dir, uploadDumpCacheDir) } + +func UploadCacheFilePath(shouldGzip bool) string { + filename := fmt.Sprintf("%s/%s", UploadCacheDir(), generateCacheFileName(8)+".sql") + return EnsureFileSuffix(filename, shouldGzip) +} + +func generateCacheFileName(n int) string { + const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} + +// Ensure a file has proper file extension. +func EnsureFileSuffix(filename string, shouldGzip bool) string { + if !shouldGzip { + return filename + } + + if strings.HasSuffix(filename, ".gz") { + return filename + } + + return filename + ".gz" +} diff --git a/storage/storage_test.go b/storage/storage_test.go new file mode 100644 index 0000000..0c439aa --- /dev/null +++ b/storage/storage_test.go @@ -0,0 +1,62 @@ +package storage + +import ( + "fmt" + "os" + "strings" + "testing" +) + +func TestUploadCacheDir(t *testing.T) { + actual := UploadCacheDir() + + workDir, _ := os.Getwd() + expected := fmt.Sprintf("%s/%s", workDir, uploadDumpCacheDir) + + if actual != expected { + t.Errorf("get unexpected cache dir: expected: %s, actual: %s", expected, actual) + } +} + +func TestGenerateCacheFileName(t *testing.T) { + expectedLen := 5 + name := generateCacheFileName(expectedLen) + + actualLen := len([]rune(name)) + if actualLen != expectedLen { + t.Errorf("unexpected cache filename, expected length: %d, actual length: %d", 5, actualLen) + } +} + +func TestUploadCacheFilePath(t *testing.T) { + gziped := UploadCacheFilePath(true) + + if !strings.HasSuffix(gziped, ".gz") { + t.Errorf("expected filename has .gz extention, actual file name: %s", gziped) + } + + sql := UploadCacheFilePath(false) + + if !strings.HasSuffix(sql, ".sql") { + t.Errorf("expected filename has .sql extention, actual file name: %s", sql) + } + + sql2 := UploadCacheFilePath(false) + + if sql == sql2 { + t.Errorf("expected unique file name but got same filename %s", sql) + } +} + +func TestEnsureFileSuffix(t *testing.T) { + gzip := EnsureFileSuffix("test.sql", true) + if gzip != "test.sql.gz" { + t.Errorf("expected filename has .gz extention, actual file name: %s", gzip) + } + + sql := EnsureFileSuffix("test.sql.gz", true) + + if sql != "test.sql.gz" { + t.Errorf("expected: %s is not equals to actual: %s", sql, "test.sql.gz") + } +}