diff --git a/script.go b/script.go index ea70a56..c471f74 100644 --- a/script.go +++ b/script.go @@ -514,7 +514,8 @@ func (p *Pipe) FilterScan(filter func(string, io.Writer)) *Pipe { // First produces only the first n lines of the pipe's contents, or all the // lines if there are less than n. If n is zero or negative, there is no output -// at all. +// at all. When n lines have been produced, First stops reading its input and +// sends EOF to its output. func (p *Pipe) First(n int) *Pipe { if p.Error() != nil { return p @@ -522,13 +523,15 @@ func (p *Pipe) First(n int) *Pipe { if n <= 0 { return NewPipe() } - i := 0 - return p.FilterScan(func(line string, w io.Writer) { - if i >= n { - return + return p.Filter(func(r io.Reader, w io.Writer) error { + scanner := newScanner(r) + for i := 0; i < n && scanner.Scan(); i++ { + _, err := fmt.Fprintln(w, scanner.Text()) + if err != nil { + return err + } } - fmt.Fprintln(w, line) - i++ + return scanner.Err() }) } diff --git a/script_test.go b/script_test.go index dd12e3b..b19b915 100644 --- a/script_test.go +++ b/script_test.go @@ -573,6 +573,24 @@ func TestFirstHasNoEffectGivenLessThanNInputLines(t *testing.T) { } } +func TestFirstDoesNotConsumeUnnecessaryData(t *testing.T) { + t.Parallel() + // First uses a 4096-byte buffer, so will always read at least + // that much, but no more (once N lines have been read). + r := strings.NewReader(strings.Repeat("line\n", 1000)) + got, err := script.NewPipe().WithReader(r).First(1).String() + if err != nil { + t.Fatal(err) + } + want := "line\n" + if want != got { + t.Errorf("want output %q, got %q", want, got) + } + if r.Len() == 0 { + t.Errorf("no data left in reader") + } +} + func TestFreqHandlesLongLines(t *testing.T) { t.Parallel() got, err := script.Echo(longLine).Freq().Slice()