From 5045e79978783036bd14d27ae37fb6275dc2f4bd Mon Sep 17 00:00:00 2001 From: Bob Glickstein Date: Fri, 26 Jul 2024 10:25:53 -0700 Subject: [PATCH] Checkpoint. Some renaming. --- parallel/parallel.go | 7 ++-- seqs/accum_test.go | 34 +++++++++++++---- seqs/chan.go | 12 +++--- seqs/chan_test.go | 10 ++--- seqs/seqs.go | 90 ++++++++++++++++++++++++++++++++++++++------ seqs/seqs_test.go | 82 +++++++++++++++++++++++++++++++++------- seqs/string.go | 38 +++++++++++++++++++ seqs/string_test.go | 44 ++++++++++++++++++++++ seqs/zip_test.go | 2 +- 9 files changed, 270 insertions(+), 49 deletions(-) create mode 100644 seqs/string.go create mode 100644 seqs/string_test.go diff --git a/parallel/parallel.go b/parallel/parallel.go index 1f2f9d0..fb875e7 100644 --- a/parallel/parallel.go +++ b/parallel/parallel.go @@ -70,9 +70,8 @@ func Values[T any, F ~func(context.Context, int) (T, error)](ctx context.Context // // The caller gets an iterator over the values produced // and a non-nil pointer to an error. -// Once the iterator has been consumed, -// the caller may dereference the error pointer to see if any worker failed. -// There is the risk of a data race if the caller dereferences the error pointer before the iterator is consumed. +// The caller may dereference the error pointer to see if any worker failed, +// but not before the iterator has been fully consumed. // The error (if there is one) is of type [Error], // whose N field indicates which worker failed. func Producers[T any, F ~func(context.Context, int, func(T) error) error](ctx context.Context, n int, f F) (iter.Seq[T], *error) { @@ -104,7 +103,7 @@ func Producers[T any, F ~func(context.Context, int, func(T) error) error](ctx co close(ch) }() - return seqs.Chan(ch), &err + return seqs.FromChan(ch), &err } // Consumers launches n parallel workers each consuming values supplied by the caller. diff --git a/seqs/accum_test.go b/seqs/accum_test.go index a870bec..f39e955 100644 --- a/seqs/accum_test.go +++ b/seqs/accum_test.go @@ -1,18 +1,36 @@ package seqs import ( + "fmt" "slices" "testing" ) func TestAccum(t *testing.T) { - var ( - inp = slices.Values([]int{1, 2, 3, 4}) - a = Accum(inp, func(a, b int) int { return a + b }) - got = slices.Collect(a) - want = []int{1, 3, 6, 10} - ) - if !slices.Equal(got, want) { - t.Errorf("got %v, want [1 3 6 10]", got) + cases := []struct { + inp []int + want []int + }{{}, { + inp: []int{1}, + want: []int{1}, + }, { + inp: []int{1, 2}, + want: []int{1, 3}, + }, { + inp: []int{1, 2, 3}, + want: []int{1, 3, 6}, + }} + + for i, tc := range cases { + t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { + var ( + inp = slices.Values(tc.inp) + a = Accum(inp, func(a, b int) int { return a + b }) + got = slices.Collect(a) + ) + if !slices.Equal(got, tc.want) { + t.Errorf("got %v, want %v", got, tc.want) + } + }) } } diff --git a/seqs/chan.go b/seqs/chan.go index 03667bb..b17a191 100644 --- a/seqs/chan.go +++ b/seqs/chan.go @@ -5,8 +5,8 @@ import ( "iter" ) -// Chan produces an [iter.Seq] over the contents of a channel. -func Chan[T any](inp <-chan T) iter.Seq[T] { +// FromChan produces an [iter.Seq] over the contents of a channel. +func FromChan[T any](inp <-chan T) iter.Seq[T] { return func(yield func(T) bool) { for x := range inp { if !yield(x) { @@ -16,13 +16,13 @@ func Chan[T any](inp <-chan T) iter.Seq[T] { } } -// ChanContext produces an [iter.Seq] over the contents of a channel. +// FromChanContext produces an [iter.Seq] over the contents of a channel. // It stops at the end of the channel or when the given context is canceled. // // The caller can dereference the returned error pointer to check for errors // (such as context cancellation), // but only after iteration is done. -func ChanContext[T any](ctx context.Context, inp <-chan T) (iter.Seq[T], *error) { +func FromChanContext[T any](ctx context.Context, inp <-chan T) (iter.Seq[T], *error) { var err error f := func(yield func(T) bool) { @@ -76,9 +76,11 @@ func ToChanContext[T any](ctx context.Context, f iter.Seq[T]) (<-chan T, *error) defer close(ch) for val := range f { + // This extra check helps to ensure that context cancellation "wins" when both cases in the select can proceed. if err = ctx.Err(); err != nil { return } + select { case ch <- val: // OK, do nothing. @@ -107,5 +109,5 @@ func Go[T any, F ~func(chan<- T) error](f F) (iter.Seq[T], *error) { close(ch) }() - return Chan(ch), &err + return FromChan(ch), &err } diff --git a/seqs/chan_test.go b/seqs/chan_test.go index 6612a6b..f64bb83 100644 --- a/seqs/chan_test.go +++ b/seqs/chan_test.go @@ -26,7 +26,7 @@ func TestToChan(t *testing.T) { func TestToChanContext(t *testing.T) { var ( ch1 = make(chan int, 1) - seq1 = Chan(ch1) + seq1 = FromChan(ch1) ctx = context.Background() ) ctx, cancel := context.WithCancel(ctx) @@ -63,7 +63,7 @@ func TestToChanContext(t *testing.T) { } } -func TestChan(t *testing.T) { +func TestFromChan(t *testing.T) { ch := make(chan int) go func() { for i := 0; i < 3; i++ { @@ -73,7 +73,7 @@ func TestChan(t *testing.T) { }() var ( - seq = Chan(ch) + seq = FromChan(ch) got = slices.Collect(seq) want = []int{0, 1, 2} ) @@ -82,7 +82,7 @@ func TestChan(t *testing.T) { } } -func TestChanContext(t *testing.T) { +func TestFromChanContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -98,7 +98,7 @@ func TestChanContext(t *testing.T) { } }() - it, errptr := ChanContext(ctx, ch) + it, errptr := FromChanContext(ctx, ch) next, stop := iter.Pull(it) defer stop() if _, ok := next(); !ok { diff --git a/seqs/seqs.go b/seqs/seqs.go index 1846971..df24ef2 100644 --- a/seqs/seqs.go +++ b/seqs/seqs.go @@ -1,6 +1,7 @@ package seqs import ( + "cmp" "iter" "slices" ) @@ -16,8 +17,8 @@ type Pair[T, U any] struct { Y U } -// Seq1 changes an [iter.Seq2] to an [iter.Seq] of [Pair]s. -func Seq1[T, U any](inp iter.Seq2[T, U]) iter.Seq[Pair[T, U]] { +// ToPairs changes an [iter.Seq2] to an [iter.Seq] of [Pair]s. +func ToPairs[T, U any](inp iter.Seq2[T, U]) iter.Seq[Pair[T, U]] { return func(yield func(Pair[T, U]) bool) { for x, y := range inp { if !yield(Pair[T, U]{X: x, Y: y}) { @@ -27,6 +28,28 @@ func Seq1[T, U any](inp iter.Seq2[T, U]) iter.Seq[Pair[T, U]] { } } +// Left changes an [iter.Seq2] to an [iter.Seq] by dropping the second value. +func Left[T, U any](inp iter.Seq2[T, U]) iter.Seq[T] { + return func(yield func(T) bool) { + for x := range inp { + if !yield(x) { + return + } + } + } +} + +// Right changes an [iter.Seq2] to an [iter.Seq] by dropping the first value. +func Right[T, U any](inp iter.Seq2[T, U]) iter.Seq[U] { + return func(yield func(U) bool) { + for _, y := range inp { + if !yield(y) { + return + } + } + } +} + // Enumerate changes an [iter.Seq] to an [iter.Seq2] of (index, val) pairs. func Enumerate[T any](inp iter.Seq[T]) iter.Seq2[int, T] { return func(yield func(int, T) bool) { @@ -40,8 +63,8 @@ func Enumerate[T any](inp iter.Seq[T]) iter.Seq2[int, T] { } } -// Seq2 changes an [iter.Seq] of [Pair]s to an [iter.Seq2]. -func Seq2[T, U any](inp iter.Seq[Pair[T, U]]) iter.Seq2[T, U] { +// FromPairs changes an [iter.Seq] of [Pair]s to an [iter.Seq2]. +func FromPairs[T, U any](inp iter.Seq[Pair[T, U]]) iter.Seq2[T, U] { return func(yield func(T, U) bool) { for val := range inp { if !yield(val.X, val.Y) { @@ -51,20 +74,63 @@ func Seq2[T, U any](inp iter.Seq[Pair[T, U]]) iter.Seq2[T, U] { } } -// String produces an [iter.Seq2] over position-rune pairs in a string. -// The position of each rune is measured in bytes from the beginning of the string. -func String(inp string) iter.Seq2[int, rune] { - return func(yield func(int, rune) bool) { - for i, r := range inp { - if !yield(i, r) { - return - } +// Compare performs an elementwise comparison of two sequences. +// It returns the result of [cmp.Compare] on the first pair of unequal elements. +// If a ends before b, Compare returns -1. +// If b ends before a, Compare returns 1. +// If the sequences are equal, Compare returns 0. +func Compare[T cmp.Ordered](a, b iter.Seq[T]) int { + return CompareFunc(a, b, cmp.Compare) +} + +// CompareFunc performs an elementwise comparison of two sequences, using a custom comparison function. +// The function should return a negative number if the first argument is less than the second, +// a positive number if the first argument is greater than the second, +// and zero if the arguments are equal. +// +// CompareFunc returns the result of f on the first pair of unequal elements. +// If a ends before b, CompareFunc returns -1. +// If b ends before a, CompareFunc returns 1. +// If the sequences are equal, CompareFunc returns 0. +func CompareFunc[T any](a, b iter.Seq[T], f func(T, T) int) int { + anext, astop := iter.Pull(a) + defer astop() + + bnext, bstop := iter.Pull(b) + defer bstop() + + aOK, bOK := true, true + + var aVal, bVal T + + for { + if aOK { + aVal, aOK = anext() + } + if bOK { + bVal, bOK = bnext() + } + if !aOK && !bOK { + return 0 + } + if !aOK { + return -1 + } + if !bOK { + return 1 + } + if cmp := f(aVal, bVal); cmp != 0 { + return cmp } } } // Empty is an empty sequence that can be used where an [iter.Seq] is expected. +// Usage note: you generally don't want to call this function, +// just refer to it as Empty[typename]. func Empty[T any](func(T) bool) {} // Empty2 is an empty sequence that can be used where an [iter.Seq2] is expected. +// Usage note: you generally don't want to call this function, +// just refer to it as Empty2[typename1, typename2]. func Empty2[T, U any](func(T, U) bool) {} diff --git a/seqs/seqs_test.go b/seqs/seqs_test.go index fd97f35..e3b5bb1 100644 --- a/seqs/seqs_test.go +++ b/seqs/seqs_test.go @@ -7,6 +7,18 @@ import ( "testing" ) +func TestEmpty(t *testing.T) { + got := slices.Collect(Empty[int]) + if len(got) != 0 { + t.Errorf("got %v, want []", got) + } + + got2 := slices.Collect(ToPairs(Empty2[int, int])) + if len(got2) != 0 { + t.Errorf("got %v, want []", got2) + } +} + func TestFrom(t *testing.T) { var ( slice = []int{1, 2, 3} @@ -18,11 +30,11 @@ func TestFrom(t *testing.T) { } } -func TestSeq1(t *testing.T) { +func TestToPairs(t *testing.T) { var ( m = map[int]int{1: 2, 3: 4} seq2 = maps.All(m) - seq1 = Seq1(seq2) + seq1 = ToPairs(seq2) got = slices.Collect(seq1) want = []Pair[int, int]{{X: 1, Y: 2}, {X: 3, Y: 4}} ) @@ -37,7 +49,7 @@ func TestEnumerate(t *testing.T) { slice = []string{"alice", "bob", "charlie"} seq = slices.Values(slice) enum = Enumerate(seq) - enum1 = Seq1(enum) + enum1 = ToPairs(enum) got = slices.Collect(enum1) want = []Pair[int, string]{{X: 0, Y: "alice"}, {X: 1, Y: "bob"}, {X: 2, Y: "charlie"}} ) @@ -46,12 +58,12 @@ func TestEnumerate(t *testing.T) { } } -func TestSeq2(t *testing.T) { +func TestFromPairs(t *testing.T) { var ( m = map[int]int{1: 2, 3: 4} seq2 = maps.All(m) - seq1 = Seq1(seq2) - seq2a = Seq2(seq1) + seq1 = ToPairs(seq2) + seq2a = FromPairs(seq1) got = maps.Collect(seq2a) want = map[int]int{1: 2, 3: 4} ) @@ -60,14 +72,56 @@ func TestSeq2(t *testing.T) { } } -func TestString(t *testing.T) { +func TestLeftRight(t *testing.T) { var ( - seq = String("abc") - s1 = Seq1(seq) - got = slices.Collect(s1) - want = []Pair[int, rune]{{X: 0, Y: 'a'}, {X: 1, Y: 'b'}, {X: 2, Y: 'c'}} + left = []int{1, 2, 3, 4, 5} + right = []int{6, 7, 8, 9, 10} + zipped = Zip(slices.Values(left), slices.Values(right)) ) - if !slices.Equal(got, want) { - t.Errorf("got %v, want %v", got, want) - } + t.Run("left", func(t *testing.T) { + got := slices.Collect(Left(zipped)) + if !slices.Equal(got, left) { + t.Errorf("got %v, want %v", got, left) + } + }) + t.Run("right", func(t *testing.T) { + got := slices.Collect(Right(zipped)) + if !slices.Equal(got, right) { + t.Errorf("got %v, want %v", got, right) + } + }) +} + +func TestCompare(t *testing.T) { + var ( + a = []int{1, 2, 3} + b = []int{1, 2, 4} + c = []int{1, 2} + d = []int{1, 2, 3, 4} + ) + t.Run("equal", func(t *testing.T) { + if got := Compare(slices.Values(a), slices.Values(a)); got != 0 { + t.Errorf("got %d, want 0", got) + } + }) + t.Run("shorter", func(t *testing.T) { + if got := Compare(slices.Values(a), slices.Values(d)); got >= 0 { + t.Errorf("got %d, want < 0", got) + } + }) + t.Run("longer", func(t *testing.T) { + if got := Compare(slices.Values(a), slices.Values(c)); got <= 0 { + t.Errorf("got %d, want > 0", got) + } + }) + t.Run("less", func(t *testing.T) { + if got := Compare(slices.Values(a), slices.Values(b)); got >= 0 { + t.Errorf("got %d, want < 0", got) + } + }) + t.Run("greater", func(t *testing.T) { + if got := Compare(slices.Values(b), slices.Values(a)); got <= 0 { + t.Errorf("got %d, want > 0", got) + } + }) } diff --git a/seqs/string.go b/seqs/string.go new file mode 100644 index 0000000..5c4bd96 --- /dev/null +++ b/seqs/string.go @@ -0,0 +1,38 @@ +package seqs + +import "iter" + +// String produces an [iter.Seq2] over position-rune pairs in a string. +// The position of each rune is measured in bytes from the beginning of the string. +func String(inp string) iter.Seq2[int, rune] { + return func(yield func(int, rune) bool) { + for i, r := range inp { + if !yield(i, r) { + return + } + } + } +} + +// Bytes returns an iterator over the bytes in a string. +func Bytes(inp string) iter.Seq[byte] { + return func(yield func(byte) bool) { + for i := 0; i < len(inp); i++ { + if !yield(inp[i]) { + return + } + } + } +} + +// Runes returns an iterator over the runes in a string. +// This is the same as Right(String(inp)). +func Runes(inp string) iter.Seq[rune] { + return func(yield func(rune) bool) { + for _, r := range inp { + if !yield(r) { + return + } + } + } +} diff --git a/seqs/string_test.go b/seqs/string_test.go new file mode 100644 index 0000000..e6607a7 --- /dev/null +++ b/seqs/string_test.go @@ -0,0 +1,44 @@ +package seqs + +import ( + "slices" + "testing" +) + +func TestString(t *testing.T) { + const s = "こんにちは" + + t.Run("string", func(t *testing.T) { + var ( + seq = String(s) + pairs = ToPairs(seq) + got = slices.Collect(pairs) + want = []Pair[int, rune]{{X: 0, Y: 'こ'}, {X: 3, Y: 'ん'}, {X: 6, Y: 'に'}, {X: 9, Y: 'ち'}, {X: 12, Y: 'は'}} + ) + if !slices.Equal(got, want) { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("runes", func(t *testing.T) { + var ( + seq = Runes(s) + got = slices.Collect(seq) + want = []rune{'こ', 'ん', 'に', 'ち', 'は'} + ) + if !slices.Equal(got, want) { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("bytes", func(t *testing.T) { + var ( + seq = Bytes(s) + got = slices.Collect(seq) + want = []byte{227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175} + ) + if !slices.Equal(got, want) { + t.Errorf("got %v, want %v", got, want) + } + }) +} diff --git a/seqs/zip_test.go b/seqs/zip_test.go index 784f4ac..5a006a2 100644 --- a/seqs/zip_test.go +++ b/seqs/zip_test.go @@ -10,7 +10,7 @@ func TestZip(t *testing.T) { inp1 = slices.Values([]int{1, 2, 3}) inp2 = slices.Values([]string{"a", "b", "c", "d"}) z = Zip(inp1, inp2) - z1 = Seq1(z) + z1 = ToPairs(z) got = slices.Collect(z1) want = []Pair[int, string]{{X: 1, Y: "a"}, {X: 2, Y: "b"}, {X: 3, Y: "c"}, {Y: "d"}} )