From 10dd79e390cc0d0c49d38350708e5b1e38a9b50b Mon Sep 17 00:00:00 2001 From: liweiyi88 Date: Fri, 16 Dec 2022 16:34:41 +1100 Subject: [PATCH 1/8] refactoring --- dump/dump.go | 366 --------------------------------------------- dump/dump_test.go | 165 -------------------- storage/local.go | 19 --- storage/s3.go | 83 ++-------- storage/storage.go | 33 +++- 5 files changed, 39 insertions(+), 627 deletions(-) delete mode 100644 dump/dump.go delete mode 100644 dump/dump_test.go delete mode 100644 storage/local.go diff --git a/dump/dump.go b/dump/dump.go deleted file mode 100644 index 712c255..0000000 --- a/dump/dump.go +++ /dev/null @@ -1,366 +0,0 @@ -package dump - -import ( - "bytes" - "compress/gzip" - "errors" - "fmt" - "net" - "os" - "os/exec" - "strings" - "time" - - "github.com/liweiyi88/onedump/driver" - "github.com/liweiyi88/onedump/storage" - "golang.org/x/crypto/ssh" -) - -type Dump struct { - Jobs []*Job `yaml:"jobs"` -} - -func (dump *Dump) Validate() error { - errorCollection := make([]string, 0) - - for _, job := range dump.Jobs { - err := job.validate() - if err != nil { - errorCollection = append(errorCollection, err.Error()) - } - } - - if len(errorCollection) == 0 { - return nil - } - - return errors.New(strings.Join(errorCollection, ",")) -} - -type JobResult struct { - Error error - JobName string - Elapsed time.Duration -} - -func (result *JobResult) Print() { - if result.Error != nil { - fmt.Printf("Job: %s failed, it took %s with error: %v \n", result.JobName, result.Elapsed, result.Error) - } else { - fmt.Printf("Job: %s succeeded, it took %v \n", result.JobName, result.Elapsed) - } -} - -type Job struct { - DumpFile string `yaml:"dumpfile"` - Name string `yaml:"name"` - DBDriver string `yaml:"dbdriver"` - DBDsn string `yaml:"dbdsn"` - Gzip bool `yaml:"gzip"` - SshHost string `yaml:"sshhost"` - SshUser string `yaml:"sshuser"` - PrivateKeyFile string `yaml:"privatekeyfile"` - DumpOptions []string `yaml:"options"` - S3 *storage.AWSCredentials `yaml:"s3"` -} - -type Option func(job *Job) - -func WithSshHost(sshHost string) Option { - return func(job *Job) { - job.SshHost = sshHost - } -} - -func WithSshUser(sshUser string) Option { - return func(job *Job) { - job.SshUser = sshUser - } -} - -func WithGzip(gzip bool) Option { - return func(job *Job) { - job.Gzip = gzip - } -} - -func WithDumpOptions(dumpOptions []string) Option { - return func(job *Job) { - job.DumpOptions = dumpOptions - } -} - -func WithPrivateKeyFile(privateKeyFile string) Option { - return func(job *Job) { - job.PrivateKeyFile = privateKeyFile - } -} - -func NewJob(name, driver, dumpFile, dbDsn string, opts ...Option) *Job { - job := &Job{ - Name: name, - DBDriver: driver, - DumpFile: dumpFile, - DBDsn: dbDsn, - } - - for _, opt := range opts { - opt(job) - } - - return job -} - -func (job Job) validate() error { - if strings.TrimSpace(job.Name) == "" { - return errors.New("job name is required") - } - - if strings.TrimSpace(job.DumpFile) == "" { - return errors.New("dump file path is required") - } - - if strings.TrimSpace(job.DBDsn) == "" { - return errors.New("databse dsn is required") - } - - if strings.TrimSpace(job.DBDriver) == "" { - return errors.New("databse driver is required") - } - - return nil -} - -func (job *Job) viaSsh() bool { - if strings.TrimSpace(job.SshHost) != "" && strings.TrimSpace(job.SshUser) != "" && strings.TrimSpace(job.PrivateKeyFile) != "" { - return true - } - - return false -} - -func (job *Job) getDBDriver() (driver.Driver, error) { - switch job.DBDriver { - case "mysql": - driver, err := driver.NewMysqlDriver(job.DBDsn, job.DumpOptions, job.viaSsh()) - if err != nil { - return nil, err - } - - return driver, nil - default: - return nil, fmt.Errorf("%s is not a supported database driver", job.DBDriver) - } -} - -func ensureHaveSSHPort(addr string) string { - if _, _, err := net.SplitHostPort(addr); err != nil { - return net.JoinHostPort(addr, "22") - } - return addr -} - -func (job *Job) sshDump() error { - host := ensureHaveSSHPort(job.SshHost) - - pKey, err := os.ReadFile(job.PrivateKeyFile) - if err != nil { - return fmt.Errorf("can not read the private key file :%w", err) - } - - signer, err := ssh.ParsePrivateKey(pKey) - if err != nil { - return fmt.Errorf("failed to create singer :%w", err) - } - - conf := &ssh.ClientConfig{ - User: job.SshUser, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(signer), - }, - } - - client, err := ssh.Dial("tcp", host, conf) - if err != nil { - return fmt.Errorf("failed to dial remote server via ssh: %w", err) - } - - defer client.Close() - - session, err := client.NewSession() - if err != nil { - return fmt.Errorf("failed to start ssh session: %w", err) - } - - defer session.Close() - - err = job.dump(session) - if err != nil { - return err - } - - return nil -} - -func (job *Job) execDump() error { - err := job.dump(nil) - if err != nil { - return fmt.Errorf("failed to exec command dump: %w", err) - } - - return nil -} - -func (job *Job) Run() *JobResult { - start := time.Now() - var result JobResult - - defer func() { - elapsed := time.Since(start) - result.Elapsed = elapsed - }() - - result.JobName = job.Name - - if job.viaSsh() { - err := job.sshDump() - if err != nil { - result.Error = fmt.Errorf("job %s, failed to run ssh dump command: %v", job.Name, err) - } - - return &result - } - - err := job.execDump() - if err != nil { - result.Error = fmt.Errorf("job %s, failed to run dump command: %v", job.Name, err) - - } - - return &result -} - -func (job *Job) dumpToFile(sshSession *ssh.Session, store storage.Storage) error { - file, err := store.CreateDumpFile() - if err != nil { - return fmt.Errorf("failed to create storage dump file: %w", err) - } - - var gzipWriter *gzip.Writer - if job.Gzip { - gzipWriter = gzip.NewWriter(file) - } - - defer func() { - if gzipWriter != nil { - gzipWriter.Close() - } - - file.Close() - }() - - driver, err := job.getDBDriver() - if err != nil { - return fmt.Errorf("failed to get db driver: %w", err) - } - - if sshSession != nil { - var remoteErr bytes.Buffer - sshSession.Stderr = &remoteErr - if gzipWriter != nil { - sshSession.Stdout = gzipWriter - } else { - sshSession.Stdout = file - } - - sshCommand, err := driver.GetSshDumpCommand() - if err != nil { - return fmt.Errorf("failed to get ssh dump command %w", err) - } - - if err := sshSession.Run(sshCommand); err != nil { - return fmt.Errorf("remote command error: %s, %v", remoteErr.String(), err) - } - - return nil - } - - command, args, err := driver.GetDumpCommand() - if err != nil { - return fmt.Errorf("job %s failed to get dump command: %v", job.Name, err) - } - - cmd := exec.Command(command, args...) - - cmd.Stderr = os.Stderr - if gzipWriter != nil { - cmd.Stdout = gzipWriter - } else { - cmd.Stdout = file - } - - if err := cmd.Run(); err != nil { - return fmt.Errorf("remote command error: %v", err) - } - - return nil -} - -// The core function that dump db content to a file (locally or remotely). -// It checks the filename to determine if we need to upload the file to remote storage or keep it locally. -// For uploading file to S3 bucket, the filename shold follow the pattern: s3:/// . -// For any remote upload, we try to cache it in a local dir then upload it to the remote storage. -func (job *Job) dump(sshSession *ssh.Session) error { - store, err := job.createStorage() - if err != nil { - return fmt.Errorf("failed to create storage: %w", err) - } - - err = job.dumpToFile(sshSession, store) - if err != nil { - return err - } - - cloudStore, ok := store.(storage.CloudStorage) - - if ok { - err := cloudStore.Upload() - if err != nil { - return fmt.Errorf("failed to upload file to cloud storage: %w", err) - } - } - - return nil -} - -// Factory method to create the storage struct based on filename. -func (job *Job) createStorage() (storage.Storage, error) { - filename := ensureFileSuffix(job.DumpFile, job.Gzip) - s3Storage, ok, err := storage.CreateS3Storage(filename, job.S3) - - if err != nil { - return nil, err - } - - if ok { - return s3Storage, nil - } - - return &storage.LocalStorage{ - Filename: filename, - }, nil -} - -// Ensure a file has proper file extension. -func ensureFileSuffix(filename string, shouldGzip bool) string { - if !shouldGzip { - return filename - } - - if strings.HasSuffix(filename, ".gz") { - return filename - } - - return filename + ".gz" -} diff --git a/dump/dump_test.go b/dump/dump_test.go deleted file mode 100644 index 5b360eb..0000000 --- a/dump/dump_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package dump - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "fmt" - "os" - "testing" -) - -func generateTestRSAPrivatePEMFile() (string, error) { - tempDir := os.TempDir() - - key, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return "", fmt.Errorf("could not genereate rsa key pair %w", err) - } - - keyPEM := pem.EncodeToMemory( - &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }, - ) - - privatePEMFile := fmt.Sprintf("%s/%s", tempDir, "sshdump_test.rsa") - - if err := os.WriteFile(privatePEMFile, keyPEM, 0700); err != nil { - return "", fmt.Errorf("failed to write private key to file %w", err) - } - - return privatePEMFile, nil -} - -func TestEnsureSSHHostHavePort(t *testing.T) { - sshHost := "127.0.0.1" - - if ensureHaveSSHPort(sshHost) != sshHost+":22" { - t.Error("ssh host port is not ensured") - } -} - -// func TestNewSshDumper(t *testing.T) { -// sshDumper := NewSshDumper("127.0.0.1", "root", "~/.ssh/test.pem") - -// if sshDumper.Host != "127.0.0.1" { -// t.Errorf("ssh host is unexpected, exepct: %s, actual: %s", "127.0.0.1", sshDumper.Host) -// } - -// if sshDumper.User != "root" { -// t.Errorf("ssh user is unexpected, exepct: %s, actual: %s", "root", sshDumper.User) -// } - -// if sshDumper.PrivateKeyFile != "~/.ssh/test.pem" { -// t.Errorf("ssh private key file path is unexpected, exepct: %s, actual: %s", "~/.ssh/test.pem", sshDumper.PrivateKeyFile) -// } -// } - -// func TestSSHDump(t *testing.T) { -// privateKeyFile, err := generateTestRSAPrivatePEMFile() - -// dumpFile := os.TempDir() + "sshdump.sql.gz" - -// if err != nil { -// t.Error("failed to generate test rsa key pairs", err) -// } - -// defer func() { -// err := os.Remove(privateKeyFile) -// if err != nil { -// t.Logf("failed to remove private key file %s", privateKeyFile) -// } -// }() - -// go func() { -// sshDumper := NewSshDumper("127.0.0.1:2022", "root", privateKeyFile) -// err := sshDumper.Dump(dumpFile, "echo hello", true) -// if err != nil { -// t.Error("failed to dump file", err) -// } -// }() - -// // An SSH server is represented by a ServerConfig, which holds -// // certificate details and handles authentication of ServerConns. -// config := &ssh.ServerConfig{ -// PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { -// return &ssh.Permissions{ -// // Record the public key used for authentication. -// Extensions: map[string]string{ -// "pubkey-fp": ssh.FingerprintSHA256(pubKey), -// }, -// }, nil -// }, -// } - -// privateBytes, err := os.ReadFile(privateKeyFile) -// if err != nil { -// t.Fatal("Failed to load private key: ", err) -// } - -// private, err := ssh.ParsePrivateKey(privateBytes) -// if err != nil { -// t.Fatal("Failed to parse private key: ", err) -// } - -// config.AddHostKey(private) - -// // Once a ServerConfig has been configured, connections can be -// // accepted. -// listener, err := net.Listen("tcp", "0.0.0.0:2022") -// if err != nil { -// t.Fatal("failed to listen for connection: ", err) -// } - -// nConn, err := listener.Accept() -// if err != nil { -// t.Fatal("failed to accept incoming connection: ", err) -// } - -// // Before use, a handshake must be performed on the incoming -// // net.Conn. -// conn, chans, reqs, err := ssh.NewServerConn(nConn, config) -// if err != nil { -// log.Fatal("failed to handshake: ", err) -// } -// t.Logf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"]) - -// // The incoming Request channel must be serviced. -// go ssh.DiscardRequests(reqs) - -// // Service the incoming Channel channel. -// newChannel := <-chans -// // Channels have a type, depending on the application level -// // protocol intended. In the case of a shell, the type is -// // "session" and ServerShell may be used to present a simple -// // terminal interface. -// if newChannel.ChannelType() != "session" { -// newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") -// t.Fatal("unknown channel type") -// } - -// channel, requests, err := newChannel.Accept() - -// if err != nil { -// log.Fatalf("Could not accept channel: %v", err) -// } - -// req := <-requests -// req.Reply(true, []byte("ssh dump")) -// channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) - -// channel.Close() -// conn.Close() - -// if _, err := os.Stat(dumpFile); errors.Is(err, os.ErrNotExist) { -// t.Error("dump file does not existed") -// } else { -// err := os.Remove(dumpFile) -// if err != nil { -// t.Fatal("failed to remove the test dump file", err) -// } -// } -// } diff --git a/storage/local.go b/storage/local.go deleted file mode 100644 index dcb5869..0000000 --- a/storage/local.go +++ /dev/null @@ -1,19 +0,0 @@ -package storage - -import ( - "fmt" - "os" -) - -type LocalStorage struct { - Filename string -} - -func (local *LocalStorage) CreateDumpFile() (*os.File, error) { - file, err := os.Create(local.Filename) - if err != nil { - return nil, fmt.Errorf("failed to create dump file") - } - - return file, err -} diff --git a/storage/s3.go b/storage/s3.go index 32034ca..67a8821 100644 --- a/storage/s3.go +++ b/storage/s3.go @@ -2,9 +2,7 @@ package storage import ( "fmt" - "log" - "os" - "strings" + "io" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -12,39 +10,16 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3manager" ) -const s3Prefix = "s3://" +const S3Prefix = "s3://" -var ErrInvalidS3Path = fmt.Errorf("invalid s3 filename, it should follow the format %s/", s3Prefix) - -func CreateS3Storage(filename string, credentials *AWSCredentials) (*S3Storage, bool, error) { - name := strings.TrimSpace(filename) - - if !strings.HasPrefix(name, s3Prefix) { - return nil, false, nil - } - - path := strings.TrimPrefix(name, s3Prefix) - - pathChunks := strings.Split(path, "/") - - if len(pathChunks) < 2 { - return nil, false, ErrInvalidS3Path - } - - bucket := pathChunks[0] - s3Filename := pathChunks[len(pathChunks)-1] - key := strings.Join(pathChunks[1:], "/") - - cacheDir := uploadCacheDir() +var ErrInvalidS3Path = fmt.Errorf("invalid s3 filename, it should follow the format %s/", S3Prefix) +func NewS3Storage(bucket, key string, credentials *AWSCredentials) *S3Storage { return &S3Storage{ - CacheDir: cacheDir, - CacheFile: s3Filename, - CacheFilePath: fmt.Sprintf("%s/%s", cacheDir, s3Filename), - Credentials: credentials, - Bucket: bucket, - Key: key, - }, true, nil + Credentials: credentials, + Bucket: bucket, + Key: key, + } } type AWSCredentials struct { @@ -54,43 +29,14 @@ type AWSCredentials struct { } type S3Storage struct { - Bucket string - Key string - CacheFile string - CacheDir string - CacheFilePath string - Credentials *AWSCredentials + Bucket string + Key string + Credentials *AWSCredentials } -func (s3 *S3Storage) CreateDumpFile() (*os.File, error) { - err := os.MkdirAll(s3.CacheDir, 0750) - if err != nil { - return nil, fmt.Errorf("failed to create upload cache dir for remote upload. %w", err) - } - - file, err := os.Create(s3.CacheFilePath) - if err != nil { - return nil, fmt.Errorf("failed to create dump file in cache dir. %w", err) - } - - return file, err -} - -func (s3 *S3Storage) Upload() error { - uploadFile, err := os.Open(s3.CacheFilePath) - if err != nil { - return fmt.Errorf("failed to open dumped file %w", err) - } - +func (s3 *S3Storage) Upload(reader io.ReadCloser) error { defer func() { - uploadFile.Close() - - // Remove local cache dir after uploading to s3 bucket. - log.Printf("removing cache dir %s ... ", s3.CacheDir) - err = os.RemoveAll(s3.CacheDir) - if err != nil { - log.Println("failed to remove cache dir after uploading to s3", err) - } + reader.Close() }() var awsConfig aws.Config @@ -107,12 +53,11 @@ func (s3 *S3Storage) Upload() error { session := session.Must(session.NewSession(&awsConfig)) uploader := s3manager.NewUploader(session) - log.Printf("uploading file %s to s3...", uploadFile.Name()) // TODO: implement re-try _, uploadErr := uploader.Upload(&s3manager.UploadInput{ Bucket: aws.String(s3.Bucket), Key: aws.String(s3.Key), - Body: uploadFile, + Body: reader, }) if uploadErr != nil { diff --git a/storage/storage.go b/storage/storage.go index 78834d1..2fe0b95 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -2,25 +2,28 @@ package storage import ( "fmt" + "io" "log" + "math/rand" "os" + "time" ) -type Storage interface { - CreateDumpFile() (*os.File, error) -} - type CloudStorage interface { - Upload() error + Upload(reader io.ReadCloser) error CloudFilePath() string } const uploadDumpCacheDir = ".onedump" +func init() { + rand.Seed(time.Now().UnixNano()) +} + // For uploading dump file to remote storage, we need to firstly dump the db content to a dir locally. -// We firstly try to get current dir, if not successful, then try to get home dir, if still not successful we finally try temp dir -// We need to be aware of the size limit of a temp dir in different OS. -func uploadCacheDir() string { +// We firstly try to get current work dir, if not successful, then try to get home dir and finally try temp dir. +// Be aware of the size limit of a temp dir in different OS. +func UploadCacheDir() string { dir, err := os.Getwd() if err != nil { log.Printf("Cannot get the current directory: %v, using $HOME directory!", err) @@ -33,3 +36,17 @@ func uploadCacheDir() string { return fmt.Sprintf("%s/%s", dir, uploadDumpCacheDir) } + +func UploadCacheFilePath() string { + return fmt.Sprintf("%s/%s", UploadCacheDir(), generateCacheFileName(8)+".sql") +} + +func generateCacheFileName(n int) string { + const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} From 3279a6424440f0aef165056179662be68c59a33e Mon Sep 17 00:00:00 2001 From: liweiyi88 Date: Fri, 16 Dec 2022 16:36:02 +1100 Subject: [PATCH 2/8] refactoring --- dump/job.go | 390 ++++++++++++++++++++++++++++++++++++++++ dump/job_test.go | 165 +++++++++++++++++ storage/storage_test.go | 18 ++ 3 files changed, 573 insertions(+) create mode 100644 dump/job.go create mode 100644 dump/job_test.go create mode 100644 storage/storage_test.go diff --git a/dump/job.go b/dump/job.go new file mode 100644 index 0000000..a65f7a3 --- /dev/null +++ b/dump/job.go @@ -0,0 +1,390 @@ +package dump + +import ( + "bytes" + "compress/gzip" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "os/exec" + "strings" + "time" + + "github.com/liweiyi88/onedump/driver" + "github.com/liweiyi88/onedump/storage" + "golang.org/x/crypto/ssh" +) + +type Dump struct { + Jobs []*Job `yaml:"jobs"` +} + +func (dump *Dump) Validate() error { + errorCollection := make([]string, 0) + + for _, job := range dump.Jobs { + err := job.validate() + if err != nil { + errorCollection = append(errorCollection, err.Error()) + } + } + + if len(errorCollection) == 0 { + return nil + } + + return errors.New(strings.Join(errorCollection, ",")) +} + +type JobResult struct { + Error error + JobName string + Elapsed time.Duration +} + +func (result *JobResult) Print() { + if result.Error != nil { + fmt.Printf("Job: %s failed, it took %s with error: %v \n", result.JobName, result.Elapsed, result.Error) + } else { + fmt.Printf("Job: %s succeeded, it took %v \n", result.JobName, result.Elapsed) + } +} + +type Job struct { + DumpFile string `yaml:"dumpfile"` + Name string `yaml:"name"` + DBDriver string `yaml:"dbdriver"` + DBDsn string `yaml:"dbdsn"` + Gzip bool `yaml:"gzip"` + SshHost string `yaml:"sshhost"` + SshUser string `yaml:"sshuser"` + PrivateKeyFile string `yaml:"privatekeyfile"` + DumpOptions []string `yaml:"options"` + S3 *storage.AWSCredentials `yaml:"s3"` +} + +type Option func(job *Job) + +func WithSshHost(sshHost string) Option { + return func(job *Job) { + job.SshHost = sshHost + } +} + +func WithSshUser(sshUser string) Option { + return func(job *Job) { + job.SshUser = sshUser + } +} + +func WithGzip(gzip bool) Option { + return func(job *Job) { + job.Gzip = gzip + } +} + +func WithDumpOptions(dumpOptions []string) Option { + return func(job *Job) { + job.DumpOptions = dumpOptions + } +} + +func WithPrivateKeyFile(privateKeyFile string) Option { + return func(job *Job) { + job.PrivateKeyFile = privateKeyFile + } +} + +func NewJob(name, driver, dumpFile, dbDsn string, opts ...Option) *Job { + job := &Job{ + Name: name, + DBDriver: driver, + DumpFile: dumpFile, + DBDsn: dbDsn, + } + + for _, opt := range opts { + opt(job) + } + + return job +} + +func (job Job) validate() error { + if strings.TrimSpace(job.Name) == "" { + return errors.New("job name is required") + } + + if strings.TrimSpace(job.DumpFile) == "" { + return errors.New("dump file path is required") + } + + if strings.TrimSpace(job.DBDsn) == "" { + return errors.New("databse dsn is required") + } + + if strings.TrimSpace(job.DBDriver) == "" { + return errors.New("databse driver is required") + } + + return nil +} + +func (job *Job) viaSsh() bool { + if strings.TrimSpace(job.SshHost) != "" && strings.TrimSpace(job.SshUser) != "" && strings.TrimSpace(job.PrivateKeyFile) != "" { + return true + } + + return false +} + +func (job *Job) getDBDriver() (driver.Driver, error) { + switch job.DBDriver { + case "mysql": + driver, err := driver.NewMysqlDriver(job.DBDsn, job.DumpOptions, job.viaSsh()) + if err != nil { + return nil, err + } + + return driver, nil + default: + return nil, fmt.Errorf("%s is not a supported database driver", job.DBDriver) + } +} + +func ensureHaveSSHPort(addr string) string { + if _, _, err := net.SplitHostPort(addr); err != nil { + return net.JoinHostPort(addr, "22") + } + return addr +} + +func (job *Job) sshDump() error { + host := ensureHaveSSHPort(job.SshHost) + + pKey, err := os.ReadFile(job.PrivateKeyFile) + if err != nil { + return fmt.Errorf("can not read the private key file :%w", err) + } + + signer, err := ssh.ParsePrivateKey(pKey) + if err != nil { + return fmt.Errorf("failed to create singer :%w", err) + } + + conf := &ssh.ClientConfig{ + User: job.SshUser, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + } + + client, err := ssh.Dial("tcp", host, conf) + if err != nil { + return fmt.Errorf("failed to dial remote server via ssh: %w", err) + } + + defer client.Close() + + session, err := client.NewSession() + if err != nil { + return fmt.Errorf("failed to start ssh session: %w", err) + } + + defer session.Close() + + err = job.dump(session) + if err != nil { + return err + } + + return nil +} + +func (job *Job) execDump() error { + err := job.dump(nil) + if err != nil { + return fmt.Errorf("failed to exec command dump: %w", err) + } + + return nil +} + +func (job *Job) Run() *JobResult { + start := time.Now() + var result JobResult + + defer func() { + elapsed := time.Since(start) + result.Elapsed = elapsed + }() + + result.JobName = job.Name + + if job.viaSsh() { + err := job.sshDump() + if err != nil { + result.Error = fmt.Errorf("job %s, failed to run ssh dump command: %v", job.Name, err) + } + + return &result + } + + err := job.execDump() + if err != nil { + result.Error = fmt.Errorf("job %s, failed to run dump command: %v", job.Name, err) + + } + + return &result +} + +func (job *Job) dumpToFile(sshSession *ssh.Session, file io.WriteCloser) error { + var gzipWriter *gzip.Writer + if job.Gzip { + gzipWriter = gzip.NewWriter(file) + } + + defer func() { + if gzipWriter != nil { + gzipWriter.Close() + } + + file.Close() + }() + + driver, err := job.getDBDriver() + if err != nil { + return fmt.Errorf("failed to get db driver: %w", err) + } + + if sshSession != nil { + var remoteErr bytes.Buffer + sshSession.Stderr = &remoteErr + if gzipWriter != nil { + sshSession.Stdout = gzipWriter + } else { + sshSession.Stdout = file + } + + sshCommand, err := driver.GetSshDumpCommand() + if err != nil { + return fmt.Errorf("failed to get ssh dump command %w", err) + } + + if err := sshSession.Run(sshCommand); err != nil { + return fmt.Errorf("remote command error: %s, %v", remoteErr.String(), err) + } + + return nil + } + + command, args, err := driver.GetDumpCommand() + if err != nil { + return fmt.Errorf("job %s failed to get dump command: %v", job.Name, err) + } + + cmd := exec.Command(command, args...) + + cmd.Stderr = os.Stderr + if gzipWriter != nil { + cmd.Stdout = gzipWriter + } else { + cmd.Stdout = file + } + + if err := cmd.Run(); err != nil { + return fmt.Errorf("remote command error: %v", err) + } + + return nil +} + +// The core function that dump db content to a file (locally or remotely). +// It checks the filename to determine if we need to upload the file to remote storage or keep it locally. +// For uploading file to S3 bucket, the filename shold follow the pattern: s3:/// . +// For any remote upload, we try to cache it in a local dir then upload it to the remote storage. +func (job *Job) dump(sshSession *ssh.Session) error { + filename := ensureFileSuffix(job.DumpFile, job.Gzip) + + store := job.createCloudStorage(filename) + if store != nil { + err := os.MkdirAll(storage.UploadCacheDir(), 0750) + if err != nil { + return fmt.Errorf("failed to create upload cache dir for remote upload. %w", err) + } + + defer func() { + err = os.RemoveAll(storage.UploadCacheDir()) + if err != nil { + log.Println("failed to remove cache dir after dump", err) + } + }() + + filename = ensureFileSuffix(storage.UploadCacheFilePath(), job.Gzip) + } + + file, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to create dump file: %w", err) + } + + err = job.dumpToFile(sshSession, file) + if err != nil { + return err + } + + if store != nil { + uploadFile, err := os.Open(file.Name()) + if err != nil { + return fmt.Errorf("failed to open the cached dump file %w", err) + } + + err = store.Upload(uploadFile) + if err != nil { + return fmt.Errorf("failed to upload file to cloud storage: %w", err) + } + + log.Printf("successfully upload dump file to %s", store.CloudFilePath()) + } + + return nil +} + +// Factory method to create the cloud storage struct based on filename. +func (job *Job) createCloudStorage(filename string) storage.CloudStorage { + name := strings.TrimSpace(filename) + + if strings.HasPrefix(name, storage.S3Prefix) { + path := strings.TrimPrefix(name, storage.S3Prefix) + pathChunks := strings.Split(path, "/") + + if len(pathChunks) < 2 { + panic(storage.ErrInvalidS3Path) + } + + bucket := pathChunks[0] + key := strings.Join(pathChunks[1:], "/") + + return storage.NewS3Storage(bucket, key, job.S3) + } + + return nil +} + +// Ensure a file has proper file extension. +func ensureFileSuffix(filename string, shouldGzip bool) string { + if !shouldGzip { + return filename + } + + if strings.HasSuffix(filename, ".gz") { + return filename + } + + return filename + ".gz" +} diff --git a/dump/job_test.go b/dump/job_test.go new file mode 100644 index 0000000..5b360eb --- /dev/null +++ b/dump/job_test.go @@ -0,0 +1,165 @@ +package dump + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "testing" +) + +func generateTestRSAPrivatePEMFile() (string, error) { + tempDir := os.TempDir() + + key, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return "", fmt.Errorf("could not genereate rsa key pair %w", err) + } + + keyPEM := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }, + ) + + privatePEMFile := fmt.Sprintf("%s/%s", tempDir, "sshdump_test.rsa") + + if err := os.WriteFile(privatePEMFile, keyPEM, 0700); err != nil { + return "", fmt.Errorf("failed to write private key to file %w", err) + } + + return privatePEMFile, nil +} + +func TestEnsureSSHHostHavePort(t *testing.T) { + sshHost := "127.0.0.1" + + if ensureHaveSSHPort(sshHost) != sshHost+":22" { + t.Error("ssh host port is not ensured") + } +} + +// func TestNewSshDumper(t *testing.T) { +// sshDumper := NewSshDumper("127.0.0.1", "root", "~/.ssh/test.pem") + +// if sshDumper.Host != "127.0.0.1" { +// t.Errorf("ssh host is unexpected, exepct: %s, actual: %s", "127.0.0.1", sshDumper.Host) +// } + +// if sshDumper.User != "root" { +// t.Errorf("ssh user is unexpected, exepct: %s, actual: %s", "root", sshDumper.User) +// } + +// if sshDumper.PrivateKeyFile != "~/.ssh/test.pem" { +// t.Errorf("ssh private key file path is unexpected, exepct: %s, actual: %s", "~/.ssh/test.pem", sshDumper.PrivateKeyFile) +// } +// } + +// func TestSSHDump(t *testing.T) { +// privateKeyFile, err := generateTestRSAPrivatePEMFile() + +// dumpFile := os.TempDir() + "sshdump.sql.gz" + +// if err != nil { +// t.Error("failed to generate test rsa key pairs", err) +// } + +// defer func() { +// err := os.Remove(privateKeyFile) +// if err != nil { +// t.Logf("failed to remove private key file %s", privateKeyFile) +// } +// }() + +// go func() { +// sshDumper := NewSshDumper("127.0.0.1:2022", "root", privateKeyFile) +// err := sshDumper.Dump(dumpFile, "echo hello", true) +// if err != nil { +// t.Error("failed to dump file", err) +// } +// }() + +// // An SSH server is represented by a ServerConfig, which holds +// // certificate details and handles authentication of ServerConns. +// config := &ssh.ServerConfig{ +// PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { +// return &ssh.Permissions{ +// // Record the public key used for authentication. +// Extensions: map[string]string{ +// "pubkey-fp": ssh.FingerprintSHA256(pubKey), +// }, +// }, nil +// }, +// } + +// privateBytes, err := os.ReadFile(privateKeyFile) +// if err != nil { +// t.Fatal("Failed to load private key: ", err) +// } + +// private, err := ssh.ParsePrivateKey(privateBytes) +// if err != nil { +// t.Fatal("Failed to parse private key: ", err) +// } + +// config.AddHostKey(private) + +// // Once a ServerConfig has been configured, connections can be +// // accepted. +// listener, err := net.Listen("tcp", "0.0.0.0:2022") +// if err != nil { +// t.Fatal("failed to listen for connection: ", err) +// } + +// nConn, err := listener.Accept() +// if err != nil { +// t.Fatal("failed to accept incoming connection: ", err) +// } + +// // Before use, a handshake must be performed on the incoming +// // net.Conn. +// conn, chans, reqs, err := ssh.NewServerConn(nConn, config) +// if err != nil { +// log.Fatal("failed to handshake: ", err) +// } +// t.Logf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"]) + +// // The incoming Request channel must be serviced. +// go ssh.DiscardRequests(reqs) + +// // Service the incoming Channel channel. +// newChannel := <-chans +// // Channels have a type, depending on the application level +// // protocol intended. In the case of a shell, the type is +// // "session" and ServerShell may be used to present a simple +// // terminal interface. +// if newChannel.ChannelType() != "session" { +// newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") +// t.Fatal("unknown channel type") +// } + +// channel, requests, err := newChannel.Accept() + +// if err != nil { +// log.Fatalf("Could not accept channel: %v", err) +// } + +// req := <-requests +// req.Reply(true, []byte("ssh dump")) +// channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) + +// channel.Close() +// conn.Close() + +// if _, err := os.Stat(dumpFile); errors.Is(err, os.ErrNotExist) { +// t.Error("dump file does not existed") +// } else { +// err := os.Remove(dumpFile) +// if err != nil { +// t.Fatal("failed to remove the test dump file", err) +// } +// } +// } diff --git a/storage/storage_test.go b/storage/storage_test.go new file mode 100644 index 0000000..6d9c8e6 --- /dev/null +++ b/storage/storage_test.go @@ -0,0 +1,18 @@ +package storage + +import ( + "fmt" + "os" + "testing" +) + +func TestUploadCacheDir(t *testing.T) { + actual := UploadCacheDir() + + workDir, _ := os.Getwd() + expected := fmt.Sprintf("%s/%s", workDir, uploadDumpCacheDir) + + if actual != expected { + t.Errorf("get unexpected cache dir: expected: %s, actual: %s", expected, actual) + } +} From e922b582229efe7bb8d391ef0fe8f7a1fe39939e Mon Sep 17 00:00:00 2001 From: liweiyi88 Date: Fri, 16 Dec 2022 21:13:44 +1100 Subject: [PATCH 3/8] refactoring --- dump/job.go | 45 +++++++++++++++++++++++++++------------------ storage/s3.go | 6 +----- storage/storage.go | 2 +- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/dump/job.go b/dump/job.go index a65f7a3..77d4bd2 100644 --- a/dump/job.go +++ b/dump/job.go @@ -243,20 +243,13 @@ func (job *Job) Run() *JobResult { return &result } -func (job *Job) dumpToFile(sshSession *ssh.Session, file io.WriteCloser) error { +func (job *Job) writeToFile(sshSession *ssh.Session, file io.Writer) error { var gzipWriter *gzip.Writer if job.Gzip { gzipWriter = gzip.NewWriter(file) + defer gzipWriter.Close() } - defer func() { - if gzipWriter != nil { - gzipWriter.Close() - } - - file.Close() - }() - driver, err := job.getDBDriver() if err != nil { return fmt.Errorf("failed to get db driver: %w", err) @@ -304,12 +297,30 @@ func (job *Job) dumpToFile(sshSession *ssh.Session, file io.WriteCloser) error { return nil } +func (job *Job) dumpToFile(filename string, sshSession *ssh.Session) (string, error) { + dumpFileName := ensureFileSuffix(filename, job.Gzip) + + file, err := os.Create(dumpFileName) + if err != nil { + return "", fmt.Errorf("failed to create dump file: %w", err) + } + + defer file.Close() + + err = job.writeToFile(sshSession, file) + if err != nil { + return "", fmt.Errorf("failed to write dump content to file: %w,", err) + } + + return file.Name(), nil +} + // The core function that dump db content to a file (locally or remotely). // It checks the filename to determine if we need to upload the file to remote storage or keep it locally. // For uploading file to S3 bucket, the filename shold follow the pattern: s3:/// . // For any remote upload, we try to cache it in a local dir then upload it to the remote storage. func (job *Job) dump(sshSession *ssh.Session) error { - filename := ensureFileSuffix(job.DumpFile, job.Gzip) + filename := job.DumpFile store := job.createCloudStorage(filename) if store != nil { @@ -325,25 +336,23 @@ func (job *Job) dump(sshSession *ssh.Session) error { } }() - filename = ensureFileSuffix(storage.UploadCacheFilePath(), job.Gzip) + filename = storage.UploadCacheFilePath() } - file, err := os.Create(filename) - if err != nil { - return fmt.Errorf("failed to create dump file: %w", err) - } + dumpFile, err := job.dumpToFile(filename, sshSession) - err = job.dumpToFile(sshSession, file) if err != nil { - return err + return fmt.Errorf("failed to dump db content to file %w: ", err) } if store != nil { - uploadFile, err := os.Open(file.Name()) + uploadFile, err := os.Open(dumpFile) if err != nil { return fmt.Errorf("failed to open the cached dump file %w", err) } + defer uploadFile.Close() + err = store.Upload(uploadFile) if err != nil { return fmt.Errorf("failed to upload file to cloud storage: %w", err) diff --git a/storage/s3.go b/storage/s3.go index 67a8821..789f849 100644 --- a/storage/s3.go +++ b/storage/s3.go @@ -34,11 +34,7 @@ type S3Storage struct { Credentials *AWSCredentials } -func (s3 *S3Storage) Upload(reader io.ReadCloser) error { - defer func() { - reader.Close() - }() - +func (s3 *S3Storage) Upload(reader io.Reader) error { var awsConfig aws.Config if s3.Credentials != nil { if s3.Credentials.Region != "" { diff --git a/storage/storage.go b/storage/storage.go index 2fe0b95..e1ad3e6 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -10,7 +10,7 @@ import ( ) type CloudStorage interface { - Upload(reader io.ReadCloser) error + Upload(reader io.Reader) error CloudFilePath() string } From fa3a5887c03016ba3912a6112dcfe1a6aa30be59 Mon Sep 17 00:00:00 2001 From: liweiyi88 Date: Sat, 17 Dec 2022 20:47:20 +1100 Subject: [PATCH 4/8] refactoring --- dump/job.go | 141 ++++++++++++++++++++++++--------------------- go.mod | 2 + go.sum | 4 ++ storage/s3.go | 68 ---------------------- storage/storage.go | 19 +++++- 5 files changed, 97 insertions(+), 137 deletions(-) delete mode 100644 storage/s3.go diff --git a/dump/job.go b/dump/job.go index 77d4bd2..f04dfe1 100644 --- a/dump/job.go +++ b/dump/job.go @@ -11,10 +11,13 @@ import ( "os" "os/exec" "strings" + "sync" "time" "github.com/liweiyi88/onedump/driver" "github.com/liweiyi88/onedump/storage" + "github.com/liweiyi88/onedump/storage/local" + "github.com/liweiyi88/onedump/storage/s3" "golang.org/x/crypto/ssh" ) @@ -54,16 +57,18 @@ func (result *JobResult) Print() { } type Job struct { - DumpFile string `yaml:"dumpfile"` - Name string `yaml:"name"` - DBDriver string `yaml:"dbdriver"` - DBDsn string `yaml:"dbdsn"` - Gzip bool `yaml:"gzip"` - SshHost string `yaml:"sshhost"` - SshUser string `yaml:"sshuser"` - PrivateKeyFile string `yaml:"privatekeyfile"` - DumpOptions []string `yaml:"options"` - S3 *storage.AWSCredentials `yaml:"s3"` + Name string `yaml:"name"` + DBDriver string `yaml:"dbdriver"` + DBDsn string `yaml:"dbdsn"` + Gzip bool `yaml:"gzip"` + SshHost string `yaml:"sshhost"` + SshUser string `yaml:"sshuser"` + PrivateKeyFile string `yaml:"privatekeyfile"` + DumpOptions []string `yaml:"options"` + Storage struct { + Local []*local.Local `yaml:"local"` + S3 []*s3.S3 `yaml:"s3"` + } `yaml:"storage"` } type Option func(job *Job) @@ -102,7 +107,6 @@ func NewJob(name, driver, dumpFile, dbDsn string, opts ...Option) *Job { job := &Job{ Name: name, DBDriver: driver, - DumpFile: dumpFile, DBDsn: dbDsn, } @@ -118,10 +122,6 @@ func (job Job) validate() error { return errors.New("job name is required") } - if strings.TrimSpace(job.DumpFile) == "" { - return errors.New("dump file path is required") - } - if strings.TrimSpace(job.DBDsn) == "" { return errors.New("databse dsn is required") } @@ -297,8 +297,8 @@ func (job *Job) writeToFile(sshSession *ssh.Session, file io.Writer) error { return nil } -func (job *Job) dumpToFile(filename string, sshSession *ssh.Session) (string, error) { - dumpFileName := ensureFileSuffix(filename, job.Gzip) +func (job *Job) dumpToCacheFile(sshSession *ssh.Session) (string, error) { + dumpFileName := storage.EnsureFileSuffix(storage.UploadCacheFilePath(), job.Gzip) file, err := os.Create(dumpFileName) if err != nil { @@ -320,80 +320,89 @@ func (job *Job) dumpToFile(filename string, sshSession *ssh.Session) (string, er // For uploading file to S3 bucket, the filename shold follow the pattern: s3:/// . // For any remote upload, we try to cache it in a local dir then upload it to the remote storage. func (job *Job) dump(sshSession *ssh.Session) error { - filename := job.DumpFile + err := os.MkdirAll(storage.UploadCacheDir(), 0750) + if err != nil { + return fmt.Errorf("failed to create upload cache dir for remote upload. %w", err) + } - store := job.createCloudStorage(filename) - if store != nil { - err := os.MkdirAll(storage.UploadCacheDir(), 0750) + defer func() { + err = os.RemoveAll(storage.UploadCacheDir()) if err != nil { - return fmt.Errorf("failed to create upload cache dir for remote upload. %w", err) + log.Println("failed to remove cache dir after dump", err) } + }() - defer func() { - err = os.RemoveAll(storage.UploadCacheDir()) - if err != nil { - log.Println("failed to remove cache dir after dump", err) - } - }() + cacheFile, err := job.dumpToCacheFile(sshSession) - filename = storage.UploadCacheFilePath() + dumpFile, err := os.Open(cacheFile) + if err != nil { + return fmt.Errorf("failed to open the cached dump file %w", err) } - dumpFile, err := job.dumpToFile(filename, sshSession) + defer dumpFile.Close() - if err != nil { - return fmt.Errorf("failed to dump db content to file %w: ", err) - } + job.dumpToDestinations(dumpFile) - if store != nil { - uploadFile, err := os.Open(dumpFile) - if err != nil { - return fmt.Errorf("failed to open the cached dump file %w", err) - } + return nil +} - defer uploadFile.Close() +func (job *Job) dumpToDestinations(cacheFile io.Reader) error { + storages := job.getStorages() + numberOfStorages := len(storages) - err = store.Upload(uploadFile) - if err != nil { - return fmt.Errorf("failed to upload file to cloud storage: %w", err) + if numberOfStorages > 0 { + readers, writer, closer := storageReadWriteCloser(numberOfStorages) + + go func() { + io.Copy(writer, cacheFile) + closer.Close() + }() + + var wg sync.WaitGroup + wg.Add(numberOfStorages) + for i, s := range storages { + storage := s + go func(i int) { + defer wg.Done() + storage.Save(readers[i], job.Gzip) + }(i) } - log.Printf("successfully upload dump file to %s", store.CloudFilePath()) + wg.Wait() } return nil } -// Factory method to create the cloud storage struct based on filename. -func (job *Job) createCloudStorage(filename string) storage.CloudStorage { - name := strings.TrimSpace(filename) - - if strings.HasPrefix(name, storage.S3Prefix) { - path := strings.TrimPrefix(name, storage.S3Prefix) - pathChunks := strings.Split(path, "/") - - if len(pathChunks) < 2 { - panic(storage.ErrInvalidS3Path) +func (job *Job) getStorages() []storage.Storage { + var storages []storage.Storage + if len(job.Storage.Local) > 0 { + for _, v := range job.Storage.Local { + storages = append(storages, v) } + } - bucket := pathChunks[0] - key := strings.Join(pathChunks[1:], "/") - - return storage.NewS3Storage(bucket, key, job.S3) + if len(job.Storage.S3) > 0 { + for _, v := range job.Storage.S3 { + storages = append(storages, v) + } } - return nil + return storages } -// Ensure a file has proper file extension. -func ensureFileSuffix(filename string, shouldGzip bool) string { - if !shouldGzip { - return filename - } +// Pipe readers, writers and closer for fanout the same os.file +func storageReadWriteCloser(count int) ([]io.Reader, io.Writer, io.Closer) { + var prs []io.Reader + var pws []io.Writer + var pcs []io.Closer + for i := 0; i < count; i++ { + pr, pw := io.Pipe() - if strings.HasSuffix(filename, ".gz") { - return filename + prs = append(prs, pr) + pws = append(pws, pw) + pcs = append(pcs, pw) } - return filename + ".gz" + return prs, io.MultiWriter(pws...), NewMultiCloser(pcs) } diff --git a/go.mod b/go.mod index 4383688..b0f70b7 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,14 @@ go 1.19 require ( github.com/aws/aws-sdk-go v1.44.151 github.com/go-sql-driver/mysql v1.6.0 + github.com/hashicorp/go-multierror v1.1.1 github.com/spf13/cobra v1.5.0 golang.org/x/exp v0.0.0-20220921164117-439092de6870 gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/hashicorp/errwrap v1.0.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect golang.org/x/sys v0.1.0 // indirect ) diff --git a/go.sum b/go.sum index 6e889a4..58c8b1a 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,10 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc= github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= diff --git a/storage/s3.go b/storage/s3.go deleted file mode 100644 index 789f849..0000000 --- a/storage/s3.go +++ /dev/null @@ -1,68 +0,0 @@ -package storage - -import ( - "fmt" - "io" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3/s3manager" -) - -const S3Prefix = "s3://" - -var ErrInvalidS3Path = fmt.Errorf("invalid s3 filename, it should follow the format %s/", S3Prefix) - -func NewS3Storage(bucket, key string, credentials *AWSCredentials) *S3Storage { - return &S3Storage{ - Credentials: credentials, - Bucket: bucket, - Key: key, - } -} - -type AWSCredentials struct { - Region string `yaml:"region"` - AccessKeyId string `yaml:"access-key-id"` - SecretAccessKey string `yaml:"secret-access-key"` -} - -type S3Storage struct { - Bucket string - Key string - Credentials *AWSCredentials -} - -func (s3 *S3Storage) Upload(reader io.Reader) error { - var awsConfig aws.Config - if s3.Credentials != nil { - if s3.Credentials.Region != "" { - awsConfig.Region = aws.String(s3.Credentials.Region) - } - - if s3.Credentials.AccessKeyId != "" && s3.Credentials.SecretAccessKey != "" { - awsConfig.Credentials = credentials.NewStaticCredentials(s3.Credentials.AccessKeyId, s3.Credentials.SecretAccessKey, "") - } - } - - session := session.Must(session.NewSession(&awsConfig)) - uploader := s3manager.NewUploader(session) - - // TODO: implement re-try - _, uploadErr := uploader.Upload(&s3manager.UploadInput{ - Bucket: aws.String(s3.Bucket), - Key: aws.String(s3.Key), - Body: reader, - }) - - if uploadErr != nil { - return fmt.Errorf("failed to upload file to s3 bucket %w", uploadErr) - } - - return nil -} - -func (s3 *S3Storage) CloudFilePath() string { - return fmt.Sprintf("s3://%s/%s", s3.Bucket, s3.Key) -} diff --git a/storage/storage.go b/storage/storage.go index e1ad3e6..0d47599 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -6,12 +6,12 @@ import ( "log" "math/rand" "os" + "strings" "time" ) -type CloudStorage interface { - Upload(reader io.Reader) error - CloudFilePath() string +type Storage interface { + Save(reader io.Reader, gzip bool) error } const uploadDumpCacheDir = ".onedump" @@ -50,3 +50,16 @@ func generateCacheFileName(n int) string { } return string(b) } + +// Ensure a file has proper file extension. +func EnsureFileSuffix(filename string, shouldGzip bool) string { + if !shouldGzip { + return filename + } + + if strings.HasSuffix(filename, ".gz") { + return filename + } + + return filename + ".gz" +} From c7938db5d71db9a28c9a3ed53031d9a52b450509 Mon Sep 17 00:00:00 2001 From: liweiyi88 Date: Sat, 17 Dec 2022 20:47:25 +1100 Subject: [PATCH 5/8] refactoring --- dump/closer.go | 27 +++++++++++++++++++ storage/local/local.go | 31 ++++++++++++++++++++++ storage/s3/s3.go | 60 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 dump/closer.go create mode 100644 storage/local/local.go create mode 100644 storage/s3/s3.go diff --git a/dump/closer.go b/dump/closer.go new file mode 100644 index 0000000..a15408b --- /dev/null +++ b/dump/closer.go @@ -0,0 +1,27 @@ +package dump + +import ( + "io" + + "github.com/hashicorp/go-multierror" +) + +type MultiCloser struct { + closers []io.Closer +} + +func NewMultiCloser(closers []io.Closer) *MultiCloser { + return &MultiCloser{ + closers: closers, + } +} + +func (m *MultiCloser) Close() error { + var err error + for _, c := range m.closers { + if e := c.Close(); e != nil { + err = multierror.Append(err, e) + } + } + return err +} diff --git a/storage/local/local.go b/storage/local/local.go new file mode 100644 index 0000000..af58dda --- /dev/null +++ b/storage/local/local.go @@ -0,0 +1,31 @@ +package local + +import ( + "fmt" + "io" + "os" + + "github.com/liweiyi88/onedump/storage" +) + +type Local struct { + Path string `yaml:"path"` +} + +func (local *Local) Save(reader io.Reader, gzip bool) error { + path := storage.EnsureFileSuffix(local.Path, gzip) + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to create local dump file: %w", err) + } + + defer file.Close() + + _, err = io.Copy(file, reader) + + if err != nil { + return fmt.Errorf("failed to copy cache file to the dest file: %w", err) + } + + return nil +} diff --git a/storage/s3/s3.go b/storage/s3/s3.go new file mode 100644 index 0000000..937f111 --- /dev/null +++ b/storage/s3/s3.go @@ -0,0 +1,60 @@ +package s3 + +import ( + "fmt" + "io" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/liweiyi88/onedump/storage" +) + +func NewS3(bucket, key, region, accessKeyId, secretAccessKey string) *S3 { + return &S3{ + Bucket: bucket, + Key: key, + Region: region, + AccessKeyId: accessKeyId, + SecretAccessKey: secretAccessKey, + } +} + +type S3 struct { + Bucket string + Key string + Region string `yaml:"region"` + AccessKeyId string `yaml:"access-key-id"` + SecretAccessKey string `yaml:"secret-access-key"` +} + +func (s3 *S3) Save(reader io.Reader, gzip bool) error { + var awsConfig aws.Config + + if s3.Region != "" { + awsConfig.Region = aws.String(s3.Region) + } + + if s3.AccessKeyId != "" && s3.SecretAccessKey != "" { + awsConfig.Credentials = credentials.NewStaticCredentials(s3.AccessKeyId, s3.SecretAccessKey, "") + } + + session := session.Must(session.NewSession(&awsConfig)) + uploader := s3manager.NewUploader(session) + + key := storage.EnsureFileSuffix(s3.Key, gzip) + + // TODO: implement re-try + _, uploadErr := uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String(s3.Bucket), + Key: aws.String(key), + Body: reader, + }) + + if uploadErr != nil { + return fmt.Errorf("failed to upload file to s3 bucket %w", uploadErr) + } + + return nil +} From 23585a77109f21d832ffc0ada387aafd1a6ff187 Mon Sep 17 00:00:00 2001 From: liweiyi88 Date: Mon, 19 Dec 2022 16:13:42 +1100 Subject: [PATCH 6/8] add more tests --- dump/closer_test.go | 46 +++++++++++++++++++++++++++++++ dump/job.go | 6 ++-- storage/local/local_test.go | 31 +++++++++++++++++++++ storage/s3/s3_test.go | 55 +++++++++++++++++++++++++++++++++++++ storage/storage.go | 5 ++-- storage/storage_test.go | 44 +++++++++++++++++++++++++++++ 6 files changed, 182 insertions(+), 5 deletions(-) create mode 100644 dump/closer_test.go create mode 100644 storage/local/local_test.go create mode 100644 storage/s3/s3_test.go diff --git a/dump/closer_test.go b/dump/closer_test.go new file mode 100644 index 0000000..2e0fc37 --- /dev/null +++ b/dump/closer_test.go @@ -0,0 +1,46 @@ +package dump + +import ( + "bytes" + "fmt" + "io" + "os" + "testing" +) + +type mockCloser struct{} + +func (m mockCloser) Close() error { + fmt.Print("close") + + return nil +} + +func TestNewMultiCloser(t *testing.T) { + + closer1 := mockCloser{} + closer2 := mockCloser{} + + closers := make([]io.Closer, 0) + closers = append(closers, closer1, closer2) + + multiCloser := NewMultiCloser(closers) + + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + multiCloser.Close() + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + io.Copy(&buf, r) + + expected := buf.String() + actual := "closeclose" + if expected != actual { + t.Errorf("expected: %s, actual: %s", expected, actual) + } +} diff --git a/dump/job.go b/dump/job.go index f04dfe1..c5cbde7 100644 --- a/dump/job.go +++ b/dump/job.go @@ -298,7 +298,7 @@ func (job *Job) writeToFile(sshSession *ssh.Session, file io.Writer) error { } func (job *Job) dumpToCacheFile(sshSession *ssh.Session) (string, error) { - dumpFileName := storage.EnsureFileSuffix(storage.UploadCacheFilePath(), job.Gzip) + dumpFileName := storage.UploadCacheFilePath(job.Gzip) file, err := os.Create(dumpFileName) if err != nil { @@ -341,12 +341,12 @@ func (job *Job) dump(sshSession *ssh.Session) error { defer dumpFile.Close() - job.dumpToDestinations(dumpFile) + job.saveToDestinations(dumpFile) return nil } -func (job *Job) dumpToDestinations(cacheFile io.Reader) error { +func (job *Job) saveToDestinations(cacheFile io.Reader) error { storages := job.getStorages() numberOfStorages := len(storages) diff --git a/storage/local/local_test.go b/storage/local/local_test.go new file mode 100644 index 0000000..701a2f1 --- /dev/null +++ b/storage/local/local_test.go @@ -0,0 +1,31 @@ +package local + +import ( + "os" + "strings" + "testing" +) + +func TestSave(t *testing.T) { + filename := os.TempDir() + "test.sql.gz" + local := &Local{Path: filename} + + expected := "hello" + reader := strings.NewReader(expected) + + err := local.Save(reader, true) + if err != nil { + t.Errorf("failed to save file: %v", err) + } + + data, err := os.ReadFile(filename) + if err != nil { + t.Errorf("can not read file %s", err) + } + + if string(data) != expected { + t.Errorf("expected string: %s but actual got: %s", expected, data) + } + + defer os.Remove(filename) +} diff --git a/storage/s3/s3_test.go b/storage/s3/s3_test.go new file mode 100644 index 0000000..ed3dec2 --- /dev/null +++ b/storage/s3/s3_test.go @@ -0,0 +1,55 @@ +package s3 + +import ( + "errors" + "strings" + "testing" +) + +func TestNewS3(t *testing.T) { + expectedBucket := "onedump" + expectedKey := "/backup/dump.sql" + expectedRegion := "ap-southeast-2" + expectedAccessKeyId := "accessKey" + expectedSecretKey := "secret" + + s3 := NewS3(expectedBucket, expectedKey, expectedRegion, expectedAccessKeyId, expectedSecretKey) + + if s3.Bucket != expectedBucket { + t.Errorf("expected: %s, actual: %s", expectedBucket, s3.Bucket) + } + + if s3.Key != expectedKey { + t.Errorf("expected: %s, actual: %s", expectedBucket, s3.Key) + } + + if s3.Region != expectedRegion { + t.Errorf("expected: %s, actual: %s", expectedRegion, s3.Region) + } + + if s3.AccessKeyId != expectedAccessKeyId { + t.Errorf("expected: %s, actual: %s", expectedAccessKeyId, s3.AccessKeyId) + } + + if s3.SecretAccessKey != expectedSecretKey { + t.Errorf("expected: %s, actual: %s", expectedSecretKey, s3.SecretAccessKey) + } +} + +func TestSave(t *testing.T) { + s3 := &S3{ + Bucket: "onedump", + Key: "/backup/dump.sql", + Region: "ap-southeast-2", + AccessKeyId: "none", + SecretAccessKey: "none", + } + + reader := strings.NewReader("hello s3") + err := s3.Save(reader, true) + + actual := errors.Unwrap(err).Error() + if !strings.HasPrefix(actual, "InvalidAccessKeyId") { + t.Errorf("expeceted invalid access key id but actual got: %s", actual) + } +} diff --git a/storage/storage.go b/storage/storage.go index 0d47599..c97656c 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -37,8 +37,9 @@ func UploadCacheDir() string { return fmt.Sprintf("%s/%s", dir, uploadDumpCacheDir) } -func UploadCacheFilePath() string { - return fmt.Sprintf("%s/%s", UploadCacheDir(), generateCacheFileName(8)+".sql") +func UploadCacheFilePath(shouldGzip bool) string { + filename := fmt.Sprintf("%s/%s", UploadCacheDir(), generateCacheFileName(8)+".sql") + return EnsureFileSuffix(filename, shouldGzip) } func generateCacheFileName(n int) string { diff --git a/storage/storage_test.go b/storage/storage_test.go index 6d9c8e6..0c439aa 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -3,6 +3,7 @@ package storage import ( "fmt" "os" + "strings" "testing" ) @@ -16,3 +17,46 @@ func TestUploadCacheDir(t *testing.T) { t.Errorf("get unexpected cache dir: expected: %s, actual: %s", expected, actual) } } + +func TestGenerateCacheFileName(t *testing.T) { + expectedLen := 5 + name := generateCacheFileName(expectedLen) + + actualLen := len([]rune(name)) + if actualLen != expectedLen { + t.Errorf("unexpected cache filename, expected length: %d, actual length: %d", 5, actualLen) + } +} + +func TestUploadCacheFilePath(t *testing.T) { + gziped := UploadCacheFilePath(true) + + if !strings.HasSuffix(gziped, ".gz") { + t.Errorf("expected filename has .gz extention, actual file name: %s", gziped) + } + + sql := UploadCacheFilePath(false) + + if !strings.HasSuffix(sql, ".sql") { + t.Errorf("expected filename has .sql extention, actual file name: %s", sql) + } + + sql2 := UploadCacheFilePath(false) + + if sql == sql2 { + t.Errorf("expected unique file name but got same filename %s", sql) + } +} + +func TestEnsureFileSuffix(t *testing.T) { + gzip := EnsureFileSuffix("test.sql", true) + if gzip != "test.sql.gz" { + t.Errorf("expected filename has .gz extention, actual file name: %s", gzip) + } + + sql := EnsureFileSuffix("test.sql.gz", true) + + if sql != "test.sql.gz" { + t.Errorf("expected: %s is not equals to actual: %s", sql, "test.sql.gz") + } +} From 3ad7db7e87391a22ac5dfa26465760b649be92c0 Mon Sep 17 00:00:00 2001 From: liweiyi88 Date: Mon, 19 Dec 2022 16:20:11 +1100 Subject: [PATCH 7/8] fix dir --- storage/local/local_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/storage/local/local_test.go b/storage/local/local_test.go index 701a2f1..5f737b5 100644 --- a/storage/local/local_test.go +++ b/storage/local/local_test.go @@ -7,7 +7,7 @@ import ( ) func TestSave(t *testing.T) { - filename := os.TempDir() + "test.sql.gz" + filename := os.TempDir() + "/test.sql.gz" local := &Local{Path: filename} expected := "hello" From 45917a3c732f9d32840f94dcaadcddc3f3d8b34c Mon Sep 17 00:00:00 2001 From: liweiyi88 Date: Thu, 22 Dec 2022 11:24:57 +1100 Subject: [PATCH 8/8] add more tests --- cmd/apply.go | 68 ------------------------ cmd/root.go | 81 +++++++++++++++------------- driver/mysql.go | 8 ++- driver/mysql_test.go | 76 ++++++++++++++------------ dump/job.go | 56 +++++++++++++------ dump/job_test.go | 118 ++++++++++++++++++++++++++++++++++++----- storage/local/local.go | 8 ++- 7 files changed, 244 insertions(+), 171 deletions(-) delete mode 100644 cmd/apply.go diff --git a/cmd/apply.go b/cmd/apply.go deleted file mode 100644 index d8332c5..0000000 --- a/cmd/apply.go +++ /dev/null @@ -1,68 +0,0 @@ -package cmd - -import ( - "log" - "os" - "sync" - - "github.com/liweiyi88/onedump/dump" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" -) - -var file string - -var applyCmd = &cobra.Command{ - Use: "apply -f /path/to/jobs.yaml", - Args: cobra.ExactArgs(0), - Short: "Dump database content from different sources to different destinations with a yaml config file.", - Run: func(cmd *cobra.Command, args []string) { - content, err := os.ReadFile(file) - if err != nil { - log.Fatalf("failed to read job file from %s, error: %v", file, err) - } - - var oneDump dump.Dump - err = yaml.Unmarshal(content, &oneDump) - if err != nil { - log.Fatalf("failed to read job content from %s, error: %v", file, err) - } - - err = oneDump.Validate() - if err != nil { - log.Fatalf("invalid job configuration, error: %v", err) - } - - numberOfJobs := len(oneDump.Jobs) - if numberOfJobs == 0 { - log.Printf("no job is defined in the file %s", file) - return - } - - resultCh := make(chan *dump.JobResult) - - for _, job := range oneDump.Jobs { - go func(job *dump.Job, resultCh chan *dump.JobResult) { - resultCh <- job.Run() - }(job, resultCh) - } - - var wg sync.WaitGroup - wg.Add(numberOfJobs) - go func(resultCh chan *dump.JobResult) { - for result := range resultCh { - result.Print() - wg.Done() - } - }(resultCh) - - wg.Wait() - close(resultCh) - }, -} - -func init() { - rootCmd.AddCommand(applyCmd) - applyCmd.Flags().StringVarP(&file, "file", "f", "", "jobs yaml file path.") - applyCmd.MarkFlagRequired("file") -} diff --git a/cmd/root.go b/cmd/root.go index 60a3fda..24ced76 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -4,50 +4,61 @@ import ( "fmt" "log" "os" - "strings" + "sync" "github.com/liweiyi88/onedump/dump" "github.com/spf13/cobra" + "gopkg.in/yaml.v3" ) -var ( - sshHost, sshUser, sshPrivateKeyFile string - dbDsn string - jobName string - dumpOptions []string - gzip bool -) +var file string var rootCmd = &cobra.Command{ - Use: " ", - Short: "Dump database content from a source to a destination via cli command.", - Args: cobra.ExactArgs(2), + Use: "-f /path/to/jobs.yaml", + Short: "Dump database content from different sources to different destinations with a yaml config file.", + Args: cobra.ExactArgs(0), Run: func(cmd *cobra.Command, args []string) { - driver := strings.TrimSpace(args[0]) + content, err := os.ReadFile(file) + if err != nil { + log.Fatalf("failed to read job file from %s, error: %v", file, err) + } + + var oneDump dump.Dump + err = yaml.Unmarshal(content, &oneDump) + if err != nil { + log.Fatalf("failed to read job content from %s, error: %v", file, err) + } - dumpFile := strings.TrimSpace(args[1]) - if dumpFile == "" { - log.Fatal("you must specify the dump file path. e.g. /download/dump.sql") + err = oneDump.Validate() + if err != nil { + log.Fatalf("invalid job configuration, error: %v", err) } - name := "dump via cli" - if strings.TrimSpace(jobName) != "" { - name = jobName + numberOfJobs := len(oneDump.Jobs) + if numberOfJobs == 0 { + log.Printf("no job is defined in the file %s", file) + return + } + + resultCh := make(chan *dump.JobResult) + + for _, job := range oneDump.Jobs { + go func(job *dump.Job, resultCh chan *dump.JobResult) { + resultCh <- job.Run() + }(job, resultCh) } - job := dump.NewJob( - name, - driver, - dumpFile, - dbDsn, - dump.WithGzip(gzip), - dump.WithSshHost(sshHost), - dump.WithSshUser(sshUser), - dump.WithPrivateKeyFile(sshPrivateKeyFile), - dump.WithDumpOptions(dumpOptions), - ) + var wg sync.WaitGroup + wg.Add(numberOfJobs) + go func(resultCh chan *dump.JobResult) { + for result := range resultCh { + result.Print() + wg.Done() + } + }(resultCh) - job.Run().Print() + wg.Wait() + close(resultCh) }, } @@ -59,12 +70,6 @@ func Execute() { } func init() { - rootCmd.Flags().StringVarP(&sshHost, "ssh-host", "", "", "SSH host e.g. yourdomain.com (you can omit port as it uses 22 by default) or 56.09.139.09:33. (required) ") - rootCmd.Flags().StringVarP(&sshUser, "ssh-user", "", "root", "SSH username") - rootCmd.Flags().StringVarP(&sshPrivateKeyFile, "privatekey", "f", "", "private key file path for SSH connection") - rootCmd.Flags().StringArrayVarP(&dumpOptions, "dump-options", "", nil, "use options to overwrite or add new dump command options. e.g. for mysql: --dump-options \"--no-create-info\" --dump-options \"--skip-comments\"") - rootCmd.Flags().StringVarP(&dbDsn, "db-dsn", "d", "", "the database dsn for connection. e.g. :@tcp(:)/") - rootCmd.MarkFlagRequired("db-dsn") - rootCmd.Flags().BoolVarP(&gzip, "gzip", "g", true, "if need to gzip the file") - rootCmd.Flags().StringVarP(&jobName, "job-name", "", "", "The dump job name") + rootCmd.Flags().StringVarP(&file, "file", "f", "", "jobs yaml file path.") + rootCmd.MarkFlagRequired("file") } diff --git a/driver/mysql.go b/driver/mysql.go index ca43a85..b7c3ef4 100644 --- a/driver/mysql.go +++ b/driver/mysql.go @@ -2,6 +2,7 @@ package driver import ( "fmt" + "log" "net" "os" "os/exec" @@ -115,7 +116,12 @@ host = %s` return fileName, fmt.Errorf("failed to create temp folder: %w", err) } - defer file.Close() + defer func() { + err := file.Close() + if err != nil { + log.Printf("failed to close temp file for storing mysql credentials: %v", err) + } + }() _, err = file.WriteString(contents) if err != nil { diff --git a/driver/mysql_test.go b/driver/mysql_test.go index 161a944..f328f31 100644 --- a/driver/mysql_test.go +++ b/driver/mysql_test.go @@ -2,6 +2,8 @@ package driver import ( "os" + "os/exec" + "strings" "testing" "golang.org/x/exp/slices" @@ -113,37 +115,43 @@ host = 127.0.0.1` t.Log("removed temp credential file", fileName) } -// func TestDump(t *testing.T) { -// dsn := "root@tcp(127.0.0.1:3306)/test_local" -// mysql, err := NewMysqlDumper(dsn, nil, false) -// if err != nil { -// t.Fatal(err) -// } - -// dumpfile, err := os.CreateTemp("", "dbdump") -// if err != nil { -// t.Fatal(err) -// } -// defer dumpfile.Close() - -// err = mysql.Dump(dumpfile.Name(), false) -// if err != nil { -// t.Fatal(err) -// } - -// out, err := os.ReadFile(dumpfile.Name()) -// if err != nil { -// t.Fatal("failed to read the test dump file") -// } - -// if len(out) == 0 { -// t.Fatal("test dump file is empty") -// } - -// t.Log("test dump file content size", len(out)) - -// err = os.Remove(dumpfile.Name()) -// if err != nil { -// t.Fatal("can not cleanup the test dump file", err) -// } -// } +func TestGetSshDumpCommand(t *testing.T) { + mysql, err := NewMysqlDriver(testDBDsn, nil, false) + if err != nil { + t.Fatal(err) + } + + command, err := mysql.GetSshDumpCommand() + if err != nil { + t.Errorf("failed to get dump command %v", command) + } + + if !strings.Contains(command, "mysqldump --defaults-extra-file") || !strings.Contains(command, "--skip-comments --extended-insert dump_test") { + t.Errorf("unexpected command: %s", command) + } +} + +func TestGetDumpCommand(t *testing.T) { + mysql, err := NewMysqlDriver(testDBDsn, nil, false) + if err != nil { + t.Fatal(err) + } + + mysqldumpPath, err := exec.LookPath(mysql.MysqlDumpBinaryPath) + if err != nil { + t.Fatal(err) + } + + path, args, err := mysql.GetDumpCommand() + if err != nil { + t.Error("failed to get dump command") + } + + if mysqldumpPath != path { + t.Errorf("expected mysqldump path: %s, actual got: %s", mysqldumpPath, path) + } + + if len(args) != 4 { + t.Errorf("get unexpected args, expected %d args, but got: %d", 4, len(args)) + } +} diff --git a/dump/job.go b/dump/job.go index c5cbde7..426536c 100644 --- a/dump/job.go +++ b/dump/job.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/liweiyi88/onedump/driver" "github.com/liweiyi88/onedump/storage" "github.com/liweiyi88/onedump/storage/local" @@ -21,25 +22,27 @@ import ( "golang.org/x/crypto/ssh" ) +var ( + ErrMissingJobName = errors.New("job name is required") + ErrMissingDBDsn = errors.New("databse dsn is required") + ErrMissingDBDriver = errors.New("databse driver is required") +) + type Dump struct { Jobs []*Job `yaml:"jobs"` } func (dump *Dump) Validate() error { - errorCollection := make([]string, 0) + var errs error for _, job := range dump.Jobs { err := job.validate() if err != nil { - errorCollection = append(errorCollection, err.Error()) + errs = multierror.Append(errs, err) } } - if len(errorCollection) == 0 { - return nil - } - - return errors.New(strings.Join(errorCollection, ",")) + return errs } type JobResult struct { @@ -91,7 +94,7 @@ func WithGzip(gzip bool) Option { } } -func WithDumpOptions(dumpOptions []string) Option { +func WithDumpOptions(dumpOptions ...string) Option { return func(job *Job) { job.DumpOptions = dumpOptions } @@ -119,15 +122,15 @@ func NewJob(name, driver, dumpFile, dbDsn string, opts ...Option) *Job { func (job Job) validate() error { if strings.TrimSpace(job.Name) == "" { - return errors.New("job name is required") + return ErrMissingJobName } if strings.TrimSpace(job.DBDsn) == "" { - return errors.New("databse dsn is required") + return ErrMissingDBDsn } if strings.TrimSpace(job.DBDriver) == "" { - return errors.New("databse driver is required") + return ErrMissingDBDriver } return nil @@ -188,15 +191,19 @@ func (job *Job) sshDump() error { return fmt.Errorf("failed to dial remote server via ssh: %w", err) } - defer client.Close() + defer func() { + // Do not need to call session.Close() here as it will only give EOF error. + err = client.Close() + if err != nil { + log.Printf("failed to close ssh client: %v", err) + } + }() session, err := client.NewSession() if err != nil { return fmt.Errorf("failed to start ssh session: %w", err) } - defer session.Close() - err = job.dump(session) if err != nil { return err @@ -247,7 +254,12 @@ func (job *Job) writeToFile(sshSession *ssh.Session, file io.Writer) error { var gzipWriter *gzip.Writer if job.Gzip { gzipWriter = gzip.NewWriter(file) - defer gzipWriter.Close() + defer func() { + err := gzipWriter.Close() + if err != nil { + log.Printf("failed to close gzip writer: %v", err) + } + }() } driver, err := job.getDBDriver() @@ -305,7 +317,12 @@ func (job *Job) dumpToCacheFile(sshSession *ssh.Session) (string, error) { return "", fmt.Errorf("failed to create dump file: %w", err) } - defer file.Close() + defer func() { + err := file.Close() + if err != nil { + log.Printf("failed to close dump cache file: %v", err) + } + }() err = job.writeToFile(sshSession, file) if err != nil { @@ -339,7 +356,12 @@ func (job *Job) dump(sshSession *ssh.Session) error { return fmt.Errorf("failed to open the cached dump file %w", err) } - defer dumpFile.Close() + defer func() { + err := dumpFile.Close() + if err != nil { + log.Printf("failed to close dump cache file for saving to destination: %v", err) + } + }() job.saveToDestinations(dumpFile) diff --git a/dump/job_test.go b/dump/job_test.go index 5b360eb..0b8a296 100644 --- a/dump/job_test.go +++ b/dump/job_test.go @@ -5,11 +5,15 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" "fmt" "os" + "sync" "testing" ) +var testDBDsn = "root@tcp(127.0.0.1:3306)/dump_test" + func generateTestRSAPrivatePEMFile() (string, error) { tempDir := os.TempDir() @@ -40,23 +44,113 @@ func TestEnsureSSHHostHavePort(t *testing.T) { if ensureHaveSSHPort(sshHost) != sshHost+":22" { t.Error("ssh host port is not ensured") } + + sshHost = "127.0.0.1:22" + actual := ensureHaveSSHPort(sshHost) + if actual != sshHost { + t.Errorf("expect ssh host: %s, actual: %s", sshHost, actual) + } } -// func TestNewSshDumper(t *testing.T) { -// sshDumper := NewSshDumper("127.0.0.1", "root", "~/.ssh/test.pem") +func TestGetDBDriver(t *testing.T) { + job := NewJob("job1", "mysql", "test.sql", testDBDsn) -// if sshDumper.Host != "127.0.0.1" { -// t.Errorf("ssh host is unexpected, exepct: %s, actual: %s", "127.0.0.1", sshDumper.Host) -// } + _, err := job.getDBDriver() + if err != nil { + t.Errorf("expect get mysql db driver, but get err: %v", err) + } -// if sshDumper.User != "root" { -// t.Errorf("ssh user is unexpected, exepct: %s, actual: %s", "root", sshDumper.User) -// } + job = NewJob("job1", "x", "test.sql", testDBDsn) + _, err = job.getDBDriver() + if err == nil { + t.Error("expect unsupport database driver err, but actual get nil") + } +} -// if sshDumper.PrivateKeyFile != "~/.ssh/test.pem" { -// t.Errorf("ssh private key file path is unexpected, exepct: %s, actual: %s", "~/.ssh/test.pem", sshDumper.PrivateKeyFile) -// } -// } +func TestDumpValidate(t *testing.T) { + jobs := make([]*Job, 0) + job1 := NewJob( + "job1", + "mysql", + "test.sql", + testDBDsn, + WithGzip(true), + WithDumpOptions("--skip-comments"), + WithPrivateKeyFile("/privatekey.pen"), + WithSshUser("root"), + WithSshHost("localhost"), + ) + jobs = append(jobs, job1) + + dump := Dump{Jobs: jobs} + + err := dump.Validate() + if err != nil { + t.Errorf("expected validate dump but got err :%v", err) + } + + job2 := NewJob("", "mysql", "dump.sql", "") + jobs = append(jobs, job2) + dump.Jobs = jobs + err = dump.Validate() + + if !errors.Is(err, ErrMissingJobName) { + t.Errorf("expected err: %v, actual got: %v", ErrMissingJobName, err) + } + + job3 := NewJob("job3", "mysql", "dump.sql", "") + jobs = append(jobs, job3) + dump.Jobs = jobs + err = dump.Validate() + + if !errors.Is(err, ErrMissingDBDsn) { + t.Errorf("expected err: %v, actual got: %v", ErrMissingJobName, err) + } + + job4 := NewJob("job3", "", "dump.sql", testDBDsn) + jobs = append(jobs, job4) + dump.Jobs = jobs + err = dump.Validate() + + if !errors.Is(err, ErrMissingDBDriver) { + t.Errorf("expected err: %v, actual got: %v", ErrMissingJobName, err) + } +} + +func TestRun(t *testing.T) { + tempDir := os.TempDir() + privateKeyFile, err := generateTestRSAPrivatePEMFile() + if err != nil { + t.Errorf("failed to generate test rsa key pairs %v", err) + } + + jobs := make([]*Job, 0, 2) + job1 := NewJob("exec-dump", "mysql", tempDir+"/test.sql", testDBDsn) + job2 := NewJob("ssh", "mysql", tempDir+"/test.sql", testDBDsn, WithSshHost("127.0.0.1:2022"), WithSshUser("root"), WithPrivateKeyFile(privateKeyFile)) + + jobs = append(jobs, job1) + jobs = append(jobs, job2) + dump := Dump{Jobs: jobs} + + resultCh := make(chan *JobResult) + for _, job := range dump.Jobs { + go func(job *Job, resultCh chan *JobResult) { + resultCh <- job.Run() + }(job, resultCh) + } + + var wg sync.WaitGroup + wg.Add(2) + go func(resultCh chan *JobResult) { + for result := range resultCh { + result.Print() + wg.Done() + } + }(resultCh) + + wg.Wait() + close(resultCh) +} // func TestSSHDump(t *testing.T) { // privateKeyFile, err := generateTestRSAPrivatePEMFile() diff --git a/storage/local/local.go b/storage/local/local.go index af58dda..6bcf82e 100644 --- a/storage/local/local.go +++ b/storage/local/local.go @@ -3,6 +3,7 @@ package local import ( "fmt" "io" + "log" "os" "github.com/liweiyi88/onedump/storage" @@ -19,7 +20,12 @@ func (local *Local) Save(reader io.Reader, gzip bool) error { return fmt.Errorf("failed to create local dump file: %w", err) } - defer file.Close() + defer func() { + err := file.Close() + if err != nil { + log.Printf("failed to close local dump file %v", err) + } + }() _, err = io.Copy(file, reader)