Skip to content

Commit

Permalink
Merge pull request #130 from awnumar/dev
Browse files Browse the repository at this point in the history
Changes and improvements
  • Loading branch information
awnumar authored Mar 20, 2020
2 parents 7286b7e + ec1b812 commit 153b5da
Show file tree
Hide file tree
Showing 11 changed files with 214 additions and 64 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<p align="center">
<a href="https://cirrus-ci.com/github/awnumar/memguard"><img src="https://api.cirrus-ci.com/github/awnumar/memguard.svg"></a>
<a href="https://www.codacy.com/app/awnumar/memguard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=awnumar/memguard&amp;utm_campaign=Badge_Grade"><img src="https://api.codacy.com/project/badge/Grade/eebb7ecd6e794890999cfcf26328e9cb"/></a>
<a href="https://godoc.org/github.com/awnumar/memguard"><img src="https://godoc.org/github.com/awnumar/memguard?status.svg"></a>
<a href="https://pkg.go.dev/github.com/awnumar/memguard?tab=doc"><img src="https://godoc.org/github.com/awnumar/memguard?status.svg"></a>
</p>
</p>

Expand Down
38 changes: 22 additions & 16 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,42 +76,42 @@ func NewBufferFromBytes(src []byte) *LockedBuffer {
/*
NewBufferFromReader reads some number of bytes from an io.Reader into an immutable LockedBuffer.
If an error is encountered before size bytes are read, they will be returned. The number of bytes read can be inferred using the Size method.
An error is returned precisely when the number of bytes read is less than the requested amount. Any data read is returned in either case.
*/
func NewBufferFromReader(r io.Reader, size int) *LockedBuffer {
func NewBufferFromReader(r io.Reader, size int) (*LockedBuffer, error) {
// Construct a buffer of the provided size.
b := NewBuffer(size)
if b.Size() == 0 {
return b
return b, nil
}

// Attempt to fill it with data from the Reader.
if n, err := io.ReadFull(r, b.Bytes()); err != nil {
if n == 0 {
// nothing was read
b.Destroy()
return newNullBuffer()
return newNullBuffer(), err
}

// partial read
d := NewBuffer(n)
d.Copy(b.Bytes()[:n])
d.Freeze()
b.Destroy()
return d
return d, err
}

// success
b.Freeze()
return b
return b, nil
}

/*
NewBufferFromReaderUntil constructs an immutable buffer containing data sourced from an io.Reader object.
It will continue reading until it encounters the delimiter value, or an error occurs. The delimiter will not be included in the returned data.
If an error is encountered before the delimiter value, the error will be returned along with the data read up until that point.
*/
func NewBufferFromReaderUntil(r io.Reader, delim byte) *LockedBuffer {
func NewBufferFromReaderUntil(r io.Reader, delim byte) (*LockedBuffer, error) {
// Construct a buffer with a data page that fills an entire memory page.
b := NewBuffer(os.Getpagesize())

Expand Down Expand Up @@ -140,35 +140,37 @@ func NewBufferFromReaderUntil(r io.Reader, delim byte) *LockedBuffer {
// if instead there was an error, we're done early
if i == 0 { // no data read
b.Destroy()
return newNullBuffer()
return newNullBuffer(), err
}
d := NewBuffer(i)
d.Copy(b.Bytes()[:i])
d.Freeze()
b.Destroy()
return d
return d, err
}
// we managed to read a byte, check if it was the delimiter
// note that errors are ignored in this case where we got data
if b.Bytes()[i] == delim {
if i == 0 {
// if first byte was delimiter, there's no data to return
b.Destroy()
return newNullBuffer()
return newNullBuffer(), nil
}
d := NewBuffer(i)
d.Copy(b.Bytes()[:i])
d.Freeze()
b.Destroy()
return d
return d, nil
}
}
}

/*
NewBufferFromEntireReader reads from an io.Reader into an immutable buffer. It will continue reading until EOF or any other error.
NewBufferFromEntireReader reads from an io.Reader into an immutable buffer. It will continue reading until EOF.
A nil error is returned precisely when we managed to read all the way until EOF. Any data read is returned in either case.
*/
func NewBufferFromEntireReader(r io.Reader) *LockedBuffer {
func NewBufferFromEntireReader(r io.Reader) (*LockedBuffer, error) {
// Create a buffer with a data region of one page size.
b := NewBuffer(os.Getpagesize())

Expand All @@ -189,17 +191,21 @@ func NewBufferFromEntireReader(r io.Reader) *LockedBuffer {
read += n

if err != nil {
// Suppress EOF error
if err == io.EOF {
err = nil
}
// We're done, return the data.
if read == 0 {
// No data read.
b.Destroy()
return newNullBuffer()
return newNullBuffer(), err
}
d := NewBuffer(read)
d.Copy(b.Bytes()[:read])
d.Freeze()
b.Destroy()
return d
return d, err
}

// If we've filled this buffer, grow it by another page size.
Expand Down
110 changes: 94 additions & 16 deletions buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package memguard
import (
"bytes"
"crypto/rand"
"errors"
"io"
"io/ioutil"
mrand "math/rand"
Expand Down Expand Up @@ -99,7 +100,10 @@ func TestNewBufferFromBytes(t *testing.T) {
}

func TestNewBufferFromReader(t *testing.T) {
b := NewBufferFromReader(rand.Reader, 4096)
b, err := NewBufferFromReader(rand.Reader, 4096)
if err != nil {
t.Error(err)
}
if b.Size() != 4096 {
t.Error("buffer of incorrect size")
}
Expand All @@ -112,7 +116,10 @@ func TestNewBufferFromReader(t *testing.T) {
b.Destroy()

r := bytes.NewReader([]byte("yellow submarine"))
b = NewBufferFromReader(r, 16)
b, err = NewBufferFromReader(r, 16)
if err != nil {
t.Error(err)
}
if b.Size() != 16 {
t.Error("buffer of incorrect size")
}
Expand All @@ -125,7 +132,10 @@ func TestNewBufferFromReader(t *testing.T) {
b.Destroy()

r = bytes.NewReader([]byte("yellow submarine"))
b = NewBufferFromReader(r, 17)
b, err = NewBufferFromReader(r, 17)
if err == nil {
t.Error("expected error got nil;", err)
}
if b.Size() != 16 {
t.Error("incorrect size")
}
Expand All @@ -138,7 +148,10 @@ func TestNewBufferFromReader(t *testing.T) {
b.Destroy()

r = bytes.NewReader([]byte(""))
b = NewBufferFromReader(r, 32)
b, err = NewBufferFromReader(r, 32)
if err == nil {
t.Error("expected error got nil")
}
if b.IsAlive() {
t.Error("expected destroyed buffer")
}
Expand All @@ -149,7 +162,10 @@ func TestNewBufferFromReader(t *testing.T) {
t.Error("expected nul sized buffer")
}
r = bytes.NewReader([]byte("yellow submarine"))
b = NewBufferFromReader(r, 0)
b, err = NewBufferFromReader(r, 0)
if err != nil {
t.Error(err)
}
if b.Bytes() != nil {
t.Error("data slice should be nil")
}
Expand Down Expand Up @@ -185,7 +201,10 @@ func TestNewBufferFromReaderUntil(t *testing.T) {
data := make([]byte, 5000)
data[4999] = 1
r := bytes.NewReader(data)
b := NewBufferFromReaderUntil(r, 1)
b, err := NewBufferFromReaderUntil(r, 1)
if err != nil {
t.Error(err)
}
if b.Size() != 4999 {
t.Error("buffer has incorrect size")
}
Expand All @@ -200,7 +219,10 @@ func TestNewBufferFromReaderUntil(t *testing.T) {
b.Destroy()

r = bytes.NewReader(data[:32])
b = NewBufferFromReaderUntil(r, 1)
b, err = NewBufferFromReaderUntil(r, 1)
if err == nil {
t.Error("expected error got nil")
}
if b.Size() != 32 {
t.Error("invalid size")
}
Expand All @@ -215,7 +237,10 @@ func TestNewBufferFromReaderUntil(t *testing.T) {
b.Destroy()

r = bytes.NewReader([]byte{'x'})
b = NewBufferFromReaderUntil(r, 'x')
b, err = NewBufferFromReaderUntil(r, 'x')
if err != nil {
t.Error(err)
}
if b.Size() != 0 {
t.Error("expected no data")
}
Expand All @@ -224,7 +249,10 @@ func TestNewBufferFromReaderUntil(t *testing.T) {
}

r = bytes.NewReader([]byte(""))
b = NewBufferFromReaderUntil(r, 1)
b, err = NewBufferFromReaderUntil(r, 1)
if err == nil {
t.Error("expected error got nil")
}
if b.IsAlive() {
t.Error("expected destroyed buffer")
}
Expand All @@ -236,7 +264,10 @@ func TestNewBufferFromReaderUntil(t *testing.T) {
}

rr := new(s)
b = NewBufferFromReaderUntil(rr, 1)
b, err = NewBufferFromReaderUntil(rr, 1)
if err != nil {
t.Error(err)
}
if b.Size() != 4999 {
t.Error("invalid size")
}
Expand Down Expand Up @@ -267,9 +298,25 @@ func (reader *ss) Read(p []byte) (n int, err error) {
return 1, nil
}

type se struct {
count int
}

func (reader *se) Read(p []byte) (n int, err error) {
copy(p, []byte{0})
reader.count++
if reader.count == 5000 {
return 1, errors.New("shut up bro")
}
return 1, nil
}

func TestNewBufferFromEntireReader(t *testing.T) {
r := bytes.NewReader([]byte("yellow submarine"))
b := NewBufferFromEntireReader(r)
b, err := NewBufferFromEntireReader(r)
if err != nil {
t.Error(err)
}
if b.Size() != 16 {
t.Error("incorrect size", b.Size())
}
Expand All @@ -284,7 +331,10 @@ func TestNewBufferFromEntireReader(t *testing.T) {
data := make([]byte, 16000)
ScrambleBytes(data)
r = bytes.NewReader(data)
b = NewBufferFromEntireReader(r)
b, err = NewBufferFromEntireReader(r)
if err != nil {
t.Error(err)
}
if b.Size() != len(data) {
t.Error("incorrect size", b.Size())
}
Expand All @@ -297,7 +347,10 @@ func TestNewBufferFromEntireReader(t *testing.T) {
b.Destroy()

r = bytes.NewReader([]byte{})
b = NewBufferFromEntireReader(r)
b, err = NewBufferFromEntireReader(r)
if err != nil {
t.Error(err)
}
if b.Size() != 0 {
t.Error("buffer should be nil size")
}
Expand All @@ -306,7 +359,10 @@ func TestNewBufferFromEntireReader(t *testing.T) {
}

rr := new(ss)
b = NewBufferFromEntireReader(rr)
b, err = NewBufferFromEntireReader(rr)
if err != nil {
t.Error(err)
}
if b.Size() != 4999 {
t.Error("incorrect size", b.Size())
}
Expand All @@ -318,6 +374,22 @@ func TestNewBufferFromEntireReader(t *testing.T) {
}
b.Destroy()

re := new(se)
b, err = NewBufferFromEntireReader(re)
if err == nil {
t.Error("expected error got nil")
}
if b.Size() != 5000 {
t.Error(b.Size())
}
if !b.EqualTo(make([]byte, 5000)) {
t.Error("incorrect data")
}
if b.IsMutable() {
t.Error("buffer should be immutable")
}
b.Destroy()

// real world test
f, err := os.Open("LICENSE")
if err != nil {
Expand All @@ -331,7 +403,10 @@ func TestNewBufferFromEntireReader(t *testing.T) {
if err != nil {
t.Error(err)
}
b = NewBufferFromEntireReader(f)
b, err = NewBufferFromEntireReader(f)
if err != nil {
t.Error(err)
}
if !b.EqualTo(data) {
t.Error("incorrect data")
}
Expand Down Expand Up @@ -797,7 +872,10 @@ func TestBytes(t *testing.T) {

func TestReader(t *testing.T) {
b := NewBufferRandom(32)
c := NewBufferFromReader(b.Reader(), 32)
c, err := NewBufferFromReader(b.Reader(), 32)
if err != nil {
t.Error(err)
}
if !bytes.Equal(b.Bytes(), c.Bytes()) {
t.Error("data not equal")
}
Expand Down
8 changes: 4 additions & 4 deletions examples/socketkey/socketkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func SocketKey(size int) {
memguard.CatchSignal(func(s os.Signal) {
fmt.Println("Received signal:", s.String())
listener.Close()
})
}, os.Interrupt, os.Kill)

// Purge the session before returning.
defer memguard.Purge()
Expand Down Expand Up @@ -90,9 +90,9 @@ func SocketKey(size int) {
}

// Read the data directly into a guarded memory region
buf := memguard.NewBufferFromReader(conn, size)
if buf.Size() != size {
memguard.SafePanic("not enough data read")
buf, err := memguard.NewBufferFromReader(conn, size)
if err != nil {
memguard.SafePanic(err)
}
defer buf.Destroy()
conn.Close()
Expand Down
Loading

0 comments on commit 153b5da

Please sign in to comment.