diff --git a/fill.go b/fill.go new file mode 100644 index 0000000..353d6a6 --- /dev/null +++ b/fill.go @@ -0,0 +1,21 @@ +package plumbing + +import "io" + +type fillReader struct { + b byte +} + +func (r *fillReader) Read(p []byte) (int, error) { + for i := range p { + p[i] = r.b + } + + return len(p), nil +} + +// FillReader returns an io.Reader such that Read calls return an unlimited +// stream of b bytes. +func FillReader(b byte) io.Reader { + return &fillReader{b} +} diff --git a/padded.go b/padded.go index cb7f027..675c85d 100644 --- a/padded.go +++ b/padded.go @@ -1,7 +1,6 @@ package plumbing import ( - "bytes" "io" ) @@ -9,6 +8,5 @@ import ( // fewer than n bytes are available from r then any remaining bytes return // fill instead. func PaddedReader(r io.Reader, n int64, fill byte) io.Reader { - // Naive, but works - return io.LimitReader(io.MultiReader(r, bytes.NewBuffer(bytes.Repeat([]byte{fill}, int(n)))), n) + return io.LimitReader(io.MultiReader(r, FillReader(fill)), n) } diff --git a/zero.go b/zero.go new file mode 100644 index 0000000..943a35d --- /dev/null +++ b/zero.go @@ -0,0 +1,18 @@ +package plumbing + +import "io" + +type devZero struct { + io.Reader +} + +func (w *devZero) Write(p []byte) (int, error) { + return len(p), nil +} + +// DevZero returns an io.ReadWriter that behaves like /dev/zero such that Read +// calls return an unlimited stream of zero bytes and all Write calls succeed +// without doing anything. +func DevZero() io.ReadWriter { + return &devZero{FillReader(0)} +} diff --git a/zero_test.go b/zero_test.go new file mode 100644 index 0000000..32d5d40 --- /dev/null +++ b/zero_test.go @@ -0,0 +1,36 @@ +package plumbing_test + +import ( + "bytes" + "io" + "testing" + + "github.com/bodgit/plumbing" + "github.com/stretchr/testify/assert" +) + +const limit = 10 + +func TestDevZero(t *testing.T) { + t.Parallel() + + rw := plumbing.DevZero() + b := new(bytes.Buffer) + + n, err := io.Copy(b, io.LimitReader(rw, limit)) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, limit, int(n)) + assert.Equal(t, limit, b.Len()) + assert.Equal(t, bytes.Repeat([]byte{0x00}, limit), b.Bytes()) + + n, err = io.Copy(rw, b) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, limit, int(n)) + assert.Equal(t, 0, b.Len()) +}