diff --git a/pkg/utils/crypt/crypt.go b/pkg/utils/crypt/crypt.go new file mode 100644 index 00000000..d2df1bbc --- /dev/null +++ b/pkg/utils/crypt/crypt.go @@ -0,0 +1,73 @@ +package crypt + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" + "strings" +) + +// MustEncrypt ... +func MustEncrypt(key, plaintext string) string { + result, err := Encrypt(key, plaintext) + if err != nil { + panic(fmt.Sprintf("encrypt string failed: %v", err)) + } + return result +} + +// Encrypt ... +func Encrypt(key, plaintext string) (string, error) { + keyBytes := sha256.Sum256([]byte(key)) + + block, err := aes.NewCipher(keyBytes[:]) + if err != nil { + return "", err + } + + iv := make([]byte, aes.BlockSize) + _, err = io.ReadFull(rand.Reader, iv) + if err != nil { + return "", err + } + + reader := &cipher.StreamReader{S: cipher.NewCFBEncrypter(block, iv), R: strings.NewReader(plaintext)} + ciphertext, err := io.ReadAll(reader) + if err != nil { + return "", err + } + return base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString(append(iv, ciphertext...)), nil +} + +// Decrypt ... +func Decrypt(key, ciphertext string) (string, error) { + keyBytes := sha256.Sum256([]byte(key)) + + srcBytes, err := base64.StdEncoding.WithPadding(base64.StdPadding).DecodeString(ciphertext) + if err != nil { + return "", err + } + if len(srcBytes) < aes.BlockSize { + return "", fmt.Errorf("ciphertext should be have iv and length bigger than %d bytes", aes.BlockSize) + } + + block, err := aes.NewCipher(keyBytes[:]) + if err != nil { + return "", err + } + + iv := srcBytes[:aes.BlockSize] + + reader := &cipher.StreamReader{S: cipher.NewCFBDecrypter(block, iv), R: bytes.NewReader(srcBytes[aes.BlockSize:])} + + plaintext, err := io.ReadAll(reader) + if err != nil { + return "", err + } + return string(plaintext), nil +} diff --git a/pkg/utils/crypt/crypt_test.go b/pkg/utils/crypt/crypt_test.go new file mode 100644 index 00000000..de5f4393 --- /dev/null +++ b/pkg/utils/crypt/crypt_test.go @@ -0,0 +1,81 @@ +package crypt + +import ( + "testing" +) + +func TestEncrypt(t *testing.T) { + type args struct { + key string + plaintext string + } + tests := []struct { + name string + args args + want func(*testing.T, string, string) string + wantErr bool + }{ + { + name: "common", + args: args{ + key: "sigma", + plaintext: "sigma", + }, + want: func(t *testing.T, ciphertext, key string) string { + plaintext, err := Decrypt(key, ciphertext) + if err != nil { + t.Errorf("decrypt failed: %v", err) + } + return plaintext + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Encrypt(tt.args.key, tt.args.plaintext) + if (err != nil) != tt.wantErr { + t.Errorf("Encrypt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.want != nil && tt.args.plaintext != tt.want(t, got, tt.args.key) { + t.Errorf("Encrypt() = %v, want %v", got, tt.want(t, got, tt.args.key)) + } + }) + } +} + +func TestDecrypt(t *testing.T) { + type args struct { + key string + ciphertext string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "common", + args: args{ + key: "sigma", + ciphertext: "uIjWyiYunVcb6aLRw5vaeIavFq1K", + }, + want: "sigma", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Decrypt(tt.args.key, tt.args.ciphertext) + if (err != nil) != tt.wantErr { + t.Errorf("Decrypt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Decrypt() = %v, want %v", got, tt.want) + } + }) + } +}