Skip to content

Commit

Permalink
✨ Add crypt utils (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
tosone authored Aug 3, 2023
1 parent eebd27e commit 46dfba9
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 0 deletions.
73 changes: 73 additions & 0 deletions pkg/utils/crypt/crypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package crypt

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"strings"
)

// MustEncrypt ...
func MustEncrypt(key, plaintext string) string {
result, err := Encrypt(key, plaintext)
if err != nil {
panic(fmt.Sprintf("encrypt string failed: %v", err))
}
return result
}

// Encrypt ...
func Encrypt(key, plaintext string) (string, error) {
keyBytes := sha256.Sum256([]byte(key))

block, err := aes.NewCipher(keyBytes[:])
if err != nil {
return "", err
}

iv := make([]byte, aes.BlockSize)
_, err = io.ReadFull(rand.Reader, iv)
if err != nil {
return "", err
}

reader := &cipher.StreamReader{S: cipher.NewCFBEncrypter(block, iv), R: strings.NewReader(plaintext)}
ciphertext, err := io.ReadAll(reader)
if err != nil {
return "", err
}
return base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString(append(iv, ciphertext...)), nil
}

// Decrypt ...
func Decrypt(key, ciphertext string) (string, error) {
keyBytes := sha256.Sum256([]byte(key))

srcBytes, err := base64.StdEncoding.WithPadding(base64.StdPadding).DecodeString(ciphertext)
if err != nil {
return "", err
}
if len(srcBytes) < aes.BlockSize {
return "", fmt.Errorf("ciphertext should be have iv and length bigger than %d bytes", aes.BlockSize)
}

block, err := aes.NewCipher(keyBytes[:])
if err != nil {
return "", err
}

iv := srcBytes[:aes.BlockSize]

reader := &cipher.StreamReader{S: cipher.NewCFBDecrypter(block, iv), R: bytes.NewReader(srcBytes[aes.BlockSize:])}

plaintext, err := io.ReadAll(reader)
if err != nil {
return "", err
}
return string(plaintext), nil
}
81 changes: 81 additions & 0 deletions pkg/utils/crypt/crypt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package crypt

import (
"testing"
)

func TestEncrypt(t *testing.T) {
type args struct {
key string
plaintext string
}
tests := []struct {
name string
args args
want func(*testing.T, string, string) string
wantErr bool
}{
{
name: "common",
args: args{
key: "sigma",
plaintext: "sigma",
},
want: func(t *testing.T, ciphertext, key string) string {
plaintext, err := Decrypt(key, ciphertext)
if err != nil {
t.Errorf("decrypt failed: %v", err)
}
return plaintext
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Encrypt(tt.args.key, tt.args.plaintext)
if (err != nil) != tt.wantErr {
t.Errorf("Encrypt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want != nil && tt.args.plaintext != tt.want(t, got, tt.args.key) {
t.Errorf("Encrypt() = %v, want %v", got, tt.want(t, got, tt.args.key))
}
})
}
}

func TestDecrypt(t *testing.T) {
type args struct {
key string
ciphertext string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "common",
args: args{
key: "sigma",
ciphertext: "uIjWyiYunVcb6aLRw5vaeIavFq1K",
},
want: "sigma",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Decrypt(tt.args.key, tt.args.ciphertext)
if (err != nil) != tt.wantErr {
t.Errorf("Decrypt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Decrypt() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 46dfba9

Please sign in to comment.