Skip to content

Commit

Permalink
end-to-end: add utility functions to allow hashing local files to mat…
Browse files Browse the repository at this point in the history
…ch encrypted files.

Updates odeke-em#543
  • Loading branch information
sselph committed May 22, 2016
1 parent 878265b commit 4750824
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 25 deletions.
27 changes: 27 additions & 0 deletions src/dcrypto/dcrypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"encoding/binary"
"fmt"
"hash"
"io"

"github.com/odeke-em/drive/src/dcrypto/v1"
Expand All @@ -34,6 +35,9 @@ type decrypter func(io.Reader, []byte) (io.ReadCloser, error)
// encrypter is a function that creates a encrypter.
type encrypter func(io.Reader, []byte) (io.Reader, error)

// hasher is a function that returns the hash of a plaintext as if it were encrypted.
type hasher func(io.Reader, io.Reader, []byte, hash.Hash) ([]byte, error)

// These are the different versions of the en/decryption library.
const (
V1 Version = iota
Expand All @@ -44,6 +48,12 @@ const PreferedVersion = V1

var encrypters map[Version]encrypter
var decrypters map[Version]decrypter
var hashers map[Version]hasher

// MaxHeaderSize is the maximum header size of all versions.
// This many bytes at the beginning of a file should be enough to compute
// a hash of a local file.
var MaxHeaderSize = v1.HeaderSize + 4

func init() {
decrypters = map[Version]decrypter{
Expand All @@ -53,6 +63,9 @@ func init() {
encrypters = map[Version]encrypter{
V1: v1.NewEncryptReader,
}
hashers = map[Version]hasher{
V1: v1.Hash,
}
}

// NewEncrypter returns an encrypting reader using the PreferedVersion.
Expand Down Expand Up @@ -85,6 +98,20 @@ func NewDecrypter(r io.Reader, password []byte) (io.ReadCloser, error) {
return decrypterFn(r, password)
}

// Hash will hash of plaintext based on the header of the encrypted file and returns the hash Sum.
func Hash(r io.Reader, header io.Reader, password []byte, hashFunc func() hash.Hash) ([]byte, error) {
h := hashFunc()
version, err := readVersion(io.TeeReader(header, h))
if err != nil {
return nil, err
}
hasherFn, ok := hashers[version]
if !ok {
return nil, fmt.Errorf("unknown hasher for version(%d)", version)
}
return hasherFn(r, header, password, h)
}

// writeVersion converts a Version to a []byte.
func writeVersion(i Version) ([]byte, error) {
buf := new(bytes.Buffer)
Expand Down
36 changes: 36 additions & 0 deletions src/dcrypto/dcrypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ package dcrypto_test

import (
"bytes"
"crypto/md5"
"crypto/rand"
"io"
"io/ioutil"
"testing"

Expand Down Expand Up @@ -81,3 +83,37 @@ func TestRoundTrip(t *testing.T) {
}
}
}

func TestHash(t *testing.T) {
password := []byte("test")
sizes := []int{0, 24, 1337, 66560}
for _, size := range sizes {
h := md5.New()
t.Logf("Testing file of size: %db, with password: %s", size, password)
b, err := randBytes(size)
if err != nil {
t.Errorf("randBytes(%d) => %q; want nil", size, err)
continue
}
encReader, err := dcrypto.NewEncrypter(bytes.NewBuffer(b), password)
if err != nil {
t.Errorf("NewEncryper() => %q; want nil", err)
continue
}
cipher, err := ioutil.ReadAll(io.TeeReader(encReader, h))
if err != nil {
t.Errorf("ioutil.ReadAll(*EncryptReader) => %q; want nil", err)
continue
}
want := h.Sum(nil)
got, err := dcrypto.Hash(bytes.NewBuffer(b), bytes.NewBuffer(cipher[0:dcrypto.MaxHeaderSize]), password, md5.New)
if err != nil {
t.Errorf("Hash() => err = %q; want nil", err)
continue
}
if !bytes.Equal(got, want) {
t.Errorf("Hash() => %v; want %v", got, want)
}
}

}
88 changes: 63 additions & 25 deletions src/dcrypto/v1/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,22 +165,54 @@ func newEncryptReader(r io.Reader, pass, salt []byte, iterations int32) (io.Read
if err != nil {
return nil, err
}
b, err := aes.NewCipher(aesKey)
iv, err := randBytes(blockSize)
if err != nil {
return nil, err
}
h := hmac.New(hashFunc, hmacKey)
iv, err := randBytes(blockSize)
var header []byte
header = append(header, iByte...)
header = append(header, salt...)
header = append(header, iv...)
return encrypter(r, aesKey, hmacKey, iv, header)
}

// encrypter returns the encrypted reader pased on the keys and IV provided.
func encrypter(r io.Reader, aesKey, hmacKey, iv, header []byte) (io.Reader, error) {
b, err := aes.NewCipher(aesKey)
if err != nil {
return nil, err
}
h := hmac.New(hashFunc, hmacKey)
hr := &hashReadWriter{hash: h}
sr := &cipher.StreamReader{R: r, S: cipher.NewCTR(b, iv)}
var header []byte
return io.MultiReader(io.TeeReader(io.MultiReader(bytes.NewBuffer(header), sr), hr), hr), nil
}

// decodeHeader decodes the header of the reader.
// It returns the keys, IV, and original header using the password and iterations in the reader.
func decodeHeader(r io.Reader, password []byte) (aesKey, hmacKey, iv, header []byte, err error) {
iByte, iterations, err := decInt32(r)
if err != nil {
return nil, nil, nil, nil, err
}
salt := make([]byte, saltSize)
iv = make([]byte, blockSize)
_, err = io.ReadFull(r, salt)
if err != nil {
return nil, nil, nil, nil, err
}
_, err = io.ReadFull(r, iv)
if err != nil {
return nil, nil, nil, nil, err
}
aesKey, hmacKey, err = keys(password, salt, int(iterations))
if err != nil {
return nil, nil, nil, nil, err
}
header = append(header, iByte...)
header = append(header, salt...)
header = append(header, iv...)
return io.MultiReader(io.TeeReader(io.MultiReader(bytes.NewBuffer(header), sr), hr), hr), nil
return aesKey, hmacKey, iv, header, err
}

// decryptReader wraps a io.Reader decrypting its content.
Expand All @@ -195,30 +227,13 @@ type decryptReader struct {
// If the file is athenticated, the DecryptReader will be returned and
// the resulting bytes will be the plaintext.
func NewDecryptReader(r io.Reader, pass []byte) (d io.ReadCloser, err error) {
iByte, iterations, err := decInt32(r)
if err != nil {
return nil, err
}
salt := make([]byte, saltSize)
iv := make([]byte, blockSize)
mac := make([]byte, hmacSize)
_, err = io.ReadFull(r, salt)
if err != nil {
return nil, err
}
_, err = io.ReadFull(r, iv)
if err != nil {
return nil, err
}
aesKey, hmacKey, err := keys(pass, salt, int(iterations))
aesKey, hmacKey, iv, header, err := decodeHeader(r, pass)
h := hmac.New(hashFunc, hmacKey)
h.Write(header)
if err != nil {
return nil, err
}
// Start Verifying the HMAC of the message.
h := hmac.New(hashFunc, hmacKey)
h.Write(iByte)
h.Write(salt)
h.Write(iv)
dst, err := tmpfile.New(&tmpfile.Context{
Dir: os.TempDir(),
Suffix: "drive-encrypted-",
Expand Down Expand Up @@ -277,3 +292,26 @@ func (d *decryptReader) Read(dst []byte) (int, error) {
func (d *decryptReader) Close() error {
return d.tmpFile.Done()
}

// Hash will hash of plaintext based on the header of the encrypted file and returns the hash Sum.
func Hash(plaintext io.Reader, header io.Reader, password []byte, h hash.Hash) ([]byte, error) {
aesKey, hmacKey, iv, eHeader, err := decodeHeader(header, password)
if err != nil {
return nil, err
}
encReader, err := encrypter(plaintext, aesKey, hmacKey, iv, eHeader)
if err != nil {
return nil, err
}
tr := io.TeeReader(encReader, h)
x := make([]byte, _16KB)
for {
if _, err := tr.Read(x); err != nil {
if err == io.EOF {
break
}
return nil, err
}
}
return h.Sum(nil), nil
}
36 changes: 36 additions & 0 deletions src/dcrypto/v1/v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package v1

import (
"bytes"
"crypto/md5"
"io"
"io/ioutil"
"testing"
)
Expand Down Expand Up @@ -92,3 +94,37 @@ func TestRoundTrip(t *testing.T) {
}
}
}

func TestHash(t *testing.T) {
sizes := []int{24, 1024, 15872, 16364, 16384, 16394, 16896, 66560}
for _, size := range sizes {
h := md5.New()
t.Logf("Testing file of size: %db, with password: %s", size, password)
b, err := randBytes(size)
if err != nil {
t.Errorf("randBytes(%d) => %q; want nil", size, err)
continue
}
encReader, err := newEncryptReader(bytes.NewBuffer(b), password, salt, 1024)
if err != nil {
t.Errorf("NewEncryptReader() => %q; want nil", err)
continue
}
cipher, err := ioutil.ReadAll(io.TeeReader(encReader, h))
if err != nil {
t.Errorf("ioutil.ReadAll(*EncryptReader) => %q; want nil", err)
continue
}
want := h.Sum(nil)
h.Reset()
got, err := Hash(bytes.NewBuffer(b), bytes.NewBuffer(cipher[0:HeaderSize]), password, h)
if err != nil {
t.Errorf("Hash() => err = %q; want nil", err)
continue
}
if !bytes.Equal(got, want) {
t.Errorf("Hash() => %v; want %v", got, want)
}
}

}

0 comments on commit 4750824

Please sign in to comment.