Skip to content

Commit

Permalink
chore: add more utilities to xiter
Browse files Browse the repository at this point in the history
This is breaking change PR:
- Reorder `func` in xiter package, so that it's always first. That helps with several transformations in one place. See [this](golang/go#61898 (comment)).
- Add `Single` and `Single2` iterators.
- Add `Find` and `Find2` iterators.
- Rename `Fold` to `Reduce`.
- Other small changes.

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
  • Loading branch information
DmitriyMV committed Oct 31, 2024
1 parent f3c5a2b commit e847d2a
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 57 deletions.
68 changes: 46 additions & 22 deletions xiter/xiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func Equal2[K, V comparable](x, y iter.Seq2[K, V]) bool {
}

// EqualFunc reports whether the two sequences are equal according to the function f.
func EqualFunc[V1, V2 any](x iter.Seq[V1], y iter.Seq[V2], f func(V1, V2) bool) bool {
func EqualFunc[V1, V2 any](f func(V1, V2) bool, x iter.Seq[V1], y iter.Seq[V2]) bool {
next, stop := iter.Pull(y)
defer stop()

Expand All @@ -87,7 +87,7 @@ func EqualFunc[V1, V2 any](x iter.Seq[V1], y iter.Seq[V2], f func(V1, V2) bool)
}

// EqualFunc2 reports whether the two sequences are equal according to the function f.
func EqualFunc2[K1, V1, K2, V2 any](x iter.Seq2[K1, V1], y iter.Seq2[K2, V2], f func(K1, V1, K2, V2) bool) bool {
func EqualFunc2[K1, V1, K2, V2 any](f func(K1, V1, K2, V2) bool, x iter.Seq2[K1, V1], y iter.Seq2[K2, V2]) bool {
next, stop := iter.Pull2(y)
defer stop()

Expand All @@ -104,7 +104,7 @@ func EqualFunc2[K1, V1, K2, V2 any](x iter.Seq2[K1, V1], y iter.Seq2[K2, V2], f
}

// Map returns an iterator over f applied to seq.
func Map[In, Out any](seq iter.Seq[In], f func(In) Out) iter.Seq[Out] {
func Map[In, Out any](f func(In) Out, seq iter.Seq[In]) iter.Seq[Out] {
return func(yield func(Out) bool) {
for in := range seq {
if !yield(f(in)) {
Expand All @@ -115,7 +115,7 @@ func Map[In, Out any](seq iter.Seq[In], f func(In) Out) iter.Seq[Out] {
}

// Map2 returns an iterator over f applied to seq.
func Map2[KIn, VIn, KOut, VOut any](seq iter.Seq2[KIn, VIn], f func(KIn, VIn) (KOut, VOut)) iter.Seq2[KOut, VOut] {
func Map2[KIn, VIn, KOut, VOut any](f func(KIn, VIn) (KOut, VOut), seq iter.Seq2[KIn, VIn]) iter.Seq2[KOut, VOut] {
return func(yield func(KOut, VOut) bool) {
for k, v := range seq {
if !yield(f(k, v)) {
Expand All @@ -126,7 +126,7 @@ func Map2[KIn, VIn, KOut, VOut any](seq iter.Seq2[KIn, VIn], f func(KIn, VIn) (K
}

// Filter returns an iterator over the elements in seq for which f returns true.
func Filter[V any](seq iter.Seq[V], f func(V) bool) iter.Seq[V] {
func Filter[V any](f func(V) bool, seq iter.Seq[V]) iter.Seq[V] {
return func(yield func(V) bool) {
for e := range seq {
if !f(e) {
Expand All @@ -141,7 +141,7 @@ func Filter[V any](seq iter.Seq[V], f func(V) bool) iter.Seq[V] {
}

// Filter2 returns an iterator over the elements in seq for which f returns true.
func Filter2[K, V any](seq iter.Seq2[K, V], f func(K, V) bool) iter.Seq2[K, V] {
func Filter2[K, V any](f func(K, V) bool, seq iter.Seq2[K, V]) iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
for k, v := range seq {
if !f(k, v) {
Expand All @@ -155,8 +155,8 @@ func Filter2[K, V any](seq iter.Seq2[K, V], f func(K, V) bool) iter.Seq2[K, V] {
}
}

// IterKeys returns an iterator over the keys in seq.
func IterKeys[K, V any](seq iter.Seq2[K, V]) iter.Seq[K] {
// Keys returns an iterator over the keys in seq.
func Keys[K, V any](seq iter.Seq2[K, V]) iter.Seq[K] {
return func(yield func(K) bool) {
for k := range seq {
if !yield(k) {
Expand All @@ -178,7 +178,7 @@ func Values[K, V any](seq iter.Seq2[K, V]) iter.Seq[V] {
}

// ToSeq returns an iterator where each element is the result of applying fn to the elements in seq.
func ToSeq[K, V, R any](seq iter.Seq2[K, V], fn func(K, V) R) iter.Seq[R] {
func ToSeq[K, V, R any](fn func(K, V) R, seq iter.Seq2[K, V]) iter.Seq[R] {
return func(yield func(R) bool) {
for k, v := range seq {
if !yield(fn(k, v)) {
Expand All @@ -189,7 +189,7 @@ func ToSeq[K, V, R any](seq iter.Seq2[K, V], fn func(K, V) R) iter.Seq[R] {
}

// ToSeq2 returns an iterator where each element is the result of applying fn to the elements in seq.
func ToSeq2[V1, R1, R2 any](seq iter.Seq[V1], fn func(V1) (R1, R2)) iter.Seq2[R1, R2] {
func ToSeq2[V1, R1, R2 any](fn func(V1) (R1, R2), seq iter.Seq[V1]) iter.Seq2[R1, R2] {
return func(yield func(R1, R2) bool) {
for v := range seq {
if !yield(fn(v)) {
Expand All @@ -199,30 +199,54 @@ func ToSeq2[V1, R1, R2 any](seq iter.Seq[V1], fn func(V1) (R1, R2)) iter.Seq2[R1
}
}

// Fold applies f to the elements in seq, starting with the initial value.
func Fold[V, R any](seq iter.Seq[V], initial R, f func(R, V) R) R {
result := initial

// Reduce applies f to the elements in seq, starting with the initial value.
func Reduce[V, R any](f func(R, V) R, sum R, seq iter.Seq[V]) R {
for e := range seq {
result = f(result, e)
sum = f(sum, e)
}

return result
return sum
}

// Fold2 applies f to the elements in seq, starting with the initial value.
func Fold2[K, V, R any](seq iter.Seq2[K, V], initial R, f func(R, K, V) R) R {
result := initial

// Reduce2 applies f to the elements in seq, starting with the initial value.
func Reduce2[K, V, R any](f func(R, K, V) R, sum R, seq iter.Seq2[K, V]) R {
for k, v := range seq {
result = f(result, k, v)
sum = f(sum, k, v)
}

return result
return sum
}

// Empty returns an empty iterator.
func Empty[V any](func(V) bool) {}

// Empty2 returns an empty iterator.
func Empty2[V, V2 any](func(V, V2) bool) {}

// Single returns an iterator over a single element.
func Single[V any](v V) iter.Seq[V] { return func(yield func(V) bool) { yield(v) } }

// Single2 returns an iterator over a single element.
func Single2[K, V any](k K, v V) iter.Seq2[K, V] { return func(yield func(K, V) bool) { yield(k, v) } }

// Find returns the first element in seq for which f returns true.
func Find[V any](f func(V) bool, seq iter.Seq[V]) (V, bool) {
for e := range seq {
if f(e) {
return e, true
}
}

return *new(V), false
}

// Find2 returns the first element in seq for which f returns true.
func Find2[K, V any](f func(K, V) bool, seq iter.Seq2[K, V]) (K, V, bool) {
for k, v := range seq {
if f(k, v) {
return k, v, true
}
}

return *new(K), *new(V), false
}
94 changes: 59 additions & 35 deletions xiter/xiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package xiter_test

import (
"fmt"
"iter"
"maps"
"slices"
"strconv"
Expand All @@ -18,11 +17,11 @@ import (
func Example_with_numbers() {
numbers := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}

oddNumbers := xiter.Filter(xiter.Values(slices.All(numbers)), func(n int) bool { return n%2 == 1 })
evenNumbers := xiter.Filter(xiter.Values(slices.All(numbers)), func(n int) bool { return n%2 == 0 })
oddNumbers := xiter.Filter(func(n int) bool { return n%2 == 1 }, xiter.Values(slices.All(numbers)))
evenNumbers := xiter.Filter(func(n int) bool { return n%2 == 0 }, xiter.Values(slices.All(numbers)))

fmt.Println("Odd numbers:", xiter.Fold(oddNumbers, 0, func(acc, _ int) int { acc++; return acc }))
fmt.Println("Even numbers:", xiter.Fold(evenNumbers, 0, func(acc, _ int) int { acc++; return acc }))
fmt.Println("Odd numbers:", xiter.Reduce(func(acc, _ int) int { acc++; return acc }, 0, oddNumbers))
fmt.Println("Even numbers:", xiter.Reduce(func(acc, _ int) int { acc++; return acc }, 0, evenNumbers))

// Print all odd numbers followed by all even numbers

Expand All @@ -35,8 +34,8 @@ func Example_with_numbers() {
// Print all odd numbers followed by all even numbers but with text this time

for v := range xiter.Map(
xiter.Concat(oddNumbers, evenNumbers),
func(v int) string { return m[v] },
xiter.Concat(oddNumbers, evenNumbers),
) {
fmt.Print(v, ",")
}
Expand All @@ -47,7 +46,7 @@ func Example_with_numbers() {

slc := []string{"1", "2", "3", "NaN"}

for val, err := range xiter.ToSeq2(xiter.Values(slices.All(slc)), strconv.Atoi) {
for val, err := range xiter.ToSeq2(strconv.Atoi, xiter.Values(slices.All(slc))) {
if err != nil {
fmt.Print(err)

Expand All @@ -61,11 +60,11 @@ func Example_with_numbers() {

// Print the positions of prime numbers

primeNumbers := xiter.Filter2(slices.All(numbers), func(_, n int) bool { return isPrime(n) })
primeNumbers := xiter.Filter2(func(_, n int) bool { return isPrime(n) }, slices.All(numbers))

fmt.Print("Prime number positions:")

for pos := range xiter.IterKeys(primeNumbers) {
for pos := range xiter.Keys(primeNumbers) {
fmt.Print(pos, ",")
}

Expand Down Expand Up @@ -94,21 +93,21 @@ func Example_with_numbers() {
))

fmt.Println("numbers and rev(reverseNumbers) are equal:", xiter.EqualFunc(
xiter.ToSeq(slices.All(numbers), func(_, v int) int { return v }),
xiter.ToSeq(slices.Backward(reverseNumbers), func(_, v int) int { return v }),
func(a, b int) bool { return a == b },
xiter.ToSeq(func(_, v int) int { return v }, slices.All(numbers)),
xiter.ToSeq(func(_, v int) int { return v }, slices.Backward(reverseNumbers)),
))

fmt.Println("numbers and rev(reverseNumbers) with pos dropped are equal:", xiter.EqualFunc2(
func(_, a, _, b int) bool { return a == b },
slices.All(numbers),
slices.Backward(reverseNumbers),
func(_, a, _, b int) bool { return a == b },
))

fmt.Println("numbers and reverseNumbers are not equal:", !xiter.EqualFunc(
func(a, b int) bool { return a == b },
xiter.Values(slices.All(numbers)),
xiter.Values(slices.All(reverseNumbers)),
func(a, b int) bool { return a == b },
))

// Output:
Expand All @@ -128,18 +127,7 @@ func Example_with_numbers() {
}

func ExampleConcat2() {
result := xiter.Fold2(
xiter.Map2(
xiter.Concat2(maps.All(numbersAndLetters), maps.All(numbersAndLetters2)),
func(k, v string) (int64, error) {
if v == "number" {
return strconv.ParseInt(k, 10, 64)
}

return 0, nil
},
),
0,
result := xiter.Reduce2(
func(acc int, k int64, v error) int {
if v != nil {
fmt.Println("Error:", v)
Expand All @@ -149,6 +137,17 @@ func ExampleConcat2() {

return acc + int(k)
},
0,
xiter.Map2(
func(k, v string) (int64, error) {
if v == "number" {
return strconv.ParseInt(k, 10, 64)
}

return 0, nil
},
xiter.Concat2(maps.All(numbersAndLetters), maps.All(numbersAndLetters2)),
),
)

fmt.Println(result)
Expand Down Expand Up @@ -204,19 +203,44 @@ func isPrime(n int) bool {
return true
}

func ExampleEmpty() {
var it iter.Seq[int] = xiter.Empty
func Example_single_and_empty() {
it := xiter.Single(42)

for v := range it {
fmt.Printf("This %d should not be printed\n", v)
}
fmt.Println("Found 42 in seq:")
fmt.Println(xiter.Find(func(v int) bool { return v == 42 }, it))

var it2 iter.Seq2[int, string] = xiter.Empty2
fmt.Println("Found 43 in seq:")
fmt.Println(xiter.Find(func(v int) bool { return v == 43 }, it))

for v, s := range it2 {
fmt.Printf("This %d %s should not be printed\n", v, s)
}
it = xiter.Empty

fmt.Println("Found 42 in seq:")
fmt.Println(xiter.Find(func(v int) bool { return v == 42 }, it))

it2 := xiter.Single2(42, 2012)

fmt.Println("Found 42 and 2012 in seq2:")
fmt.Println(xiter.Find2(func(k, v int) bool { return k == 42 && v == 2012 }, it2))

fmt.Println("Found 43 and 2012 in seq2:")
fmt.Println(xiter.Find2(func(k, v int) bool { return k == 43 && v == 2012 }, it2))

it2 = xiter.Empty2

fmt.Println("Found 42 and 2012 in seq2:")
fmt.Println(xiter.Find2(func(k, v int) bool { return k == 42 && v == 2012 }, it2))

// Output:
//
// Found 42 in seq:
// 42 true
// Found 43 in seq:
// 0 false
// Found 42 in seq:
// 0 false
// Found 42 and 2012 in seq2:
// 42 2012 true
// Found 43 and 2012 in seq2:
// 0 0 false
// Found 42 and 2012 in seq2:
// 0 0 false
}

0 comments on commit e847d2a

Please sign in to comment.