From 023b6d24b180e4c8cae64f23dd837f446bb685a0 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Thu, 12 Sep 2024 00:02:51 -0700 Subject: [PATCH 1/2] Introduce Foldable lists and maps --- common/types/list.go | 18 ++++ common/types/list_test.go | 113 +++++++++++++++++++++++ common/types/map.go | 96 +++++++++++++++++++ common/types/map_test.go | 158 ++++++++++++++++++++++++++++++++ common/types/traits/iterator.go | 13 +++ common/types/traits/mapper.go | 12 +++ common/types/traits/traits.go | 5 +- 7 files changed, 414 insertions(+), 1 deletion(-) diff --git a/common/types/list.go b/common/types/list.go index 06f48dde..e4b2affb 100644 --- a/common/types/list.go +++ b/common/types/list.go @@ -256,6 +256,15 @@ func (l *baseList) IsZeroValue() bool { return l.size == 0 } +// Fold calls the FoldEntry method for each (index, value) pair in the list. +func (l *baseList) Fold(f traits.Folder) { + for i := 0; i < l.size; i++ { + if !f.FoldEntry(ListType, i, l.get(i)) { + break + } + } +} + // Iterator implements the traits.Iterable interface method. func (l *baseList) Iterator() traits.Iterator { return newListIterator(l) @@ -433,6 +442,15 @@ func (l *concatList) IsZeroValue() bool { return l.Size().(Int) == 0 } +// Fold calls the FoldEntry method for each (index, value) pair in the list. +func (l *concatList) Fold(f traits.Folder) { + for i := Int(0); i < l.Size().(Int); i++ { + if !f.FoldEntry(ListType, i, l.Get(i)) { + break + } + } +} + // Iterator implements the traits.Iterable interface method. func (l *concatList) Iterator() traits.Iterator { return newListIterator(l) diff --git a/common/types/list_test.go b/common/types/list_test.go index 7e47ff5d..b57dc531 100644 --- a/common/types/list_test.go +++ b/common/types/list_test.go @@ -700,6 +700,119 @@ func TestValueListConvertToNative_Json(t *testing.T) { } } +func TestMutableList(t *testing.T) { + l := NewMutableList(DefaultTypeAdapter) + l.Add(NewRefValList(DefaultTypeAdapter, []ref.Val{String("hello")})) + l.Add(NewRefValList(DefaultTypeAdapter, []ref.Val{String("world")})) + il := l.ToImmutableList() + if il.Size() != Int(2) { + t.Errorf("il.Size() got %d, wanted size 2", il.Size()) + } + l.Add(NewRefValList(DefaultTypeAdapter, []ref.Val{String("!")})) + if il.Size() != Int(2) { + t.Errorf("il.Size() got %d, wanted size 2", il.Size()) + } +} + +func TestListFold(t *testing.T) { + + tests := []struct { + l any + folds int + foldLimit int + }{ + { + l: []string{"hello", "world"}, + folds: 2, + foldLimit: 2, + }, + { + l: []string{"hello", "world"}, + folds: 1, + foldLimit: 1, + }, + { + l: []string{"hello"}, + folds: 1, + foldLimit: 2, + }, + { + l: []ref.Val{}, + folds: 0, + foldLimit: 20, + }, + { + l: []ref.Val{ + String("hello"), + String("world"), + String("goodbye"), + String("cruel world"), + }, + folds: 1, + foldLimit: 1, + }, + { + l: []ref.Val{ + String("hello"), + String("world"), + String("goodbye"), + String("cruel world"), + }, + folds: 4, + foldLimit: 10, + }, + { + l: DefaultTypeAdapter.NativeToValue([]ref.Val{ + String("hello"), + String("world"), + }).(traits.Lister).Add(DefaultTypeAdapter.NativeToValue([]ref.Val{ + String("goodbye"), + String("cruel world"), + })), + folds: 4, + foldLimit: 10, + }, + { + l: DefaultTypeAdapter.NativeToValue([]ref.Val{ + String("hello"), + String("world"), + }).(traits.Lister).Add(DefaultTypeAdapter.NativeToValue([]ref.Val{ + String("goodbye"), + String("cruel world"), + })), + folds: 3, + foldLimit: 3, + }, + } + reg := NewEmptyRegistry() + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + f := &testListFolder{foldLimit: tc.foldLimit} + l := reg.NativeToValue(tc.l).(traits.Foldable) + l.Fold(f) + if f.folds != tc.folds { + t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds) + } + }) + } +} + +type testListFolder struct { + foldLimit int + folds int +} + +func (f *testListFolder) FoldEntry(t ref.Type, k, v any) bool { + if f.foldLimit != 0 { + if f.folds >= f.foldLimit { + return false + } + } + f.folds++ + return true +} + func getElem(t *testing.T, list traits.Indexer, index ref.Val) any { t.Helper() val := list.Get(index) diff --git a/common/types/map.go b/common/types/map.go index 739b7aab..ce6bf9f5 100644 --- a/common/types/map.go +++ b/common/types/map.go @@ -94,6 +94,24 @@ func NewProtoMap(adapter Adapter, value *pb.Map) traits.Mapper { } } +// NewMutableMap constructs a mutable map from an adapter and a set of map values. +func NewMutableMap(adapter Adapter, mutableValues map[ref.Val]ref.Val) traits.MutableMapper { + mutableCopy := make(map[ref.Val]ref.Val, len(mutableValues)) + for k, v := range mutableValues { + mutableCopy[k] = v + } + m := &mutableMap{ + baseMap: &baseMap{ + Adapter: adapter, + mapAccessor: newRefValMapAccessor(mutableCopy), + value: mutableCopy, + size: len(mutableCopy), + }, + mutableValues: mutableCopy, + } + return m +} + // mapAccessor is a private interface for finding values within a map and iterating over the keys. // This interface implements portions of the API surface area required by the traits.Mapper // interface. @@ -105,6 +123,9 @@ type mapAccessor interface { // Iterator returns an Iterator over the map key set. Iterator() traits.Iterator + + // Fold calls the FoldEntry method for each (key, value) pair in the map. + Fold(traits.Folder) } // baseMap is a reflection based map implementation designed to handle a variety of map-like types. @@ -307,6 +328,28 @@ func (m *baseMap) Value() any { return m.value } +// mutableMap holds onto a set of mutable values which are used for intermediate computations. +type mutableMap struct { + *baseMap + mutableValues map[ref.Val]ref.Val +} + +// Insert implements the traits.MutableMapper interface method, returning true if the key insertion +// succeeds. +func (m *mutableMap) Insert(k, v ref.Val) bool { + if _, found := m.mutableValues[k]; found { + return false + } + m.mutableValues[k] = v + return true +} + +// ToImmutableMap implements the traits.MutableMapper interface method, converting a mutable map +// an immutable map implementation. +func (m *mutableMap) ToImmutableMap() traits.Mapper { + return NewRefValMap(m.Adapter, m.mutableValues) +} + func newJSONStructAccessor(adapter Adapter, st map[string]*structpb.Value) mapAccessor { return &jsonStructAccessor{ Adapter: adapter, @@ -350,6 +393,15 @@ func (a *jsonStructAccessor) Iterator() traits.Iterator { } } +// Fold calls the FoldEntry method for each (key, value) pair in the map. +func (a *jsonStructAccessor) Fold(f traits.Folder) { + for k, v := range a.st { + if !f.FoldEntry(MapType, k, v) { + break + } + } +} + func newReflectMapAccessor(adapter Adapter, value reflect.Value) mapAccessor { keyType := value.Type().Key() return &reflectMapAccessor{ @@ -424,6 +476,16 @@ func (m *reflectMapAccessor) Iterator() traits.Iterator { } } +// Fold calls the FoldEntry method for each (key, value) pair in the map. +func (m *reflectMapAccessor) Fold(f traits.Folder) { + mapRange := m.refValue.MapRange() + for mapRange.Next() { + if !f.FoldEntry(MapType, mapRange.Key().Interface(), mapRange.Value().Interface()) { + break + } + } +} + func newRefValMapAccessor(mapVal map[ref.Val]ref.Val) mapAccessor { return &refValMapAccessor{mapVal: mapVal} } @@ -477,6 +539,15 @@ func (a *refValMapAccessor) Iterator() traits.Iterator { } } +// Fold calls the FoldEntry method for each (key, value) pair in the map. +func (a *refValMapAccessor) Fold(f traits.Folder) { + for k, v := range a.mapVal { + if !f.FoldEntry(MapType, k, v) { + break + } + } +} + func newStringMapAccessor(strMap map[string]string) mapAccessor { return &stringMapAccessor{mapVal: strMap} } @@ -515,6 +586,15 @@ func (a *stringMapAccessor) Iterator() traits.Iterator { } } +// Fold calls the FoldEntry method for each (key, value) pair in the map. +func (a *stringMapAccessor) Fold(f traits.Folder) { + for k, v := range a.mapVal { + if !f.FoldEntry(MapType, k, v) { + break + } + } +} + func newStringIfaceMapAccessor(adapter Adapter, mapVal map[string]any) mapAccessor { return &stringIfaceMapAccessor{ Adapter: adapter, @@ -557,6 +637,15 @@ func (a *stringIfaceMapAccessor) Iterator() traits.Iterator { } } +// Fold calls the FoldEntry method for each (key, value) pair in the map. +func (a *stringIfaceMapAccessor) Fold(f traits.Folder) { + for k, v := range a.mapVal { + if !f.FoldEntry(MapType, k, v) { + break + } + } +} + // protoMap is a specialized, separate implementation of the traits.Mapper interfaces tailored to // accessing protoreflect.Map values. type protoMap struct { @@ -769,6 +858,13 @@ func (m *protoMap) Iterator() traits.Iterator { } } +// Fold calls the FoldEntry method for each (key, value) pair in the map. +func (m *protoMap) Fold(f traits.Folder) { + m.value.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + return f.FoldEntry(MapType, k.Interface(), v.Interface()) + }) +} + // Size returns the number of entries in the protoreflect.Map. func (m *protoMap) Size() ref.Val { return Int(m.value.Len()) diff --git a/common/types/map_test.go b/common/types/map_test.go index f8377cc1..8e6282c6 100644 --- a/common/types/map_test.go +++ b/common/types/map_test.go @@ -966,3 +966,161 @@ func TestProtoMapConvertToNative_NestedProto(t *testing.T) { } } } + +func TestMutableMap(t *testing.T) { + m := NewMutableMap( + DefaultTypeAdapter, + map[ref.Val]ref.Val{String("hello"): String("world")}) + m.Insert(String("goodbye"), String("cruel world")) + im := m.ToImmutableMap() + if im.Size() != Int(2) { + t.Errorf("m.ToImmutableMap() had size %d, wanted 2", im.Size()) + } + if m.Insert(String("goodbye"), String("happy world")) { + t.Error("m.Insert('goodbye', 'happy world') got true, wanted false") + } + m.Insert(String("well"), String("well")) + if im.Size() != Int(2) { + t.Errorf("m.Insert() mutated storage for immutable map: had size %d, wanted 2", im.Size()) + } +} + +func TestMapFold(t *testing.T) { + pbDB := pb.NewDb() + fd, err := pbDB.RegisterMessage(&proto3pb.TestAllTypes{}) + if err != nil { + t.Fatalf("pbdb.RegisterMessage(TestAllTypes) failed: %v", err) + } + td, found := fd.GetTypeDescription(string((&proto3pb.TestAllTypes{}).ProtoReflect().Descriptor().FullName())) + if !found { + t.Fatal("fd.GetTypeDescription() failed") + } + mapStrStrFD, found := td.FieldByName("map_string_string") + if !found { + t.Fatal("Could not find map_string_string field") + } + + mapStrDesc := (&proto3pb.TestAllTypes{}).ProtoReflect().Descriptor().Fields().ByName("map_string_string") + tests := []struct { + m any + folds int + foldLimit int + }{ + { + m: map[string]any{"a": 1, "b": 2}, + folds: 2, + foldLimit: 2, + }, + { + m: map[string]string{"hello": "world"}, + folds: 1, + foldLimit: 2, + }, + { + m: map[string]string{"hello": "world", "goodbye": "cruel world"}, + folds: 1, + foldLimit: 1, + }, + { + m: map[ref.Val]ref.Val{}, + folds: 0, + foldLimit: 20, + }, + { + m: map[ref.Val]ref.Val{ + (String("hello")): String("world"), + (String("goodbye")): String("cruel world"), + }, + folds: 1, + foldLimit: 1, + }, + { + m: testCreateStruct(t, map[string]any{ + "hello": []any{}, + "world": map[string]any{}, + }), + folds: 2, + foldLimit: 2, + }, + { + m: testCreateStruct(t, map[string]any{ + "hello": []any{}, + "world": map[string]any{}, + }), + folds: 1, + foldLimit: 1, + }, + { + m: (&proto3pb.TestAllTypes{ + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{ + 1: {}, + 2: {}, + 3: {}, + }, + }).GetMapInt64NestedType(), + folds: 3, + foldLimit: 3, + }, + { + m: (&proto3pb.TestAllTypes{ + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{ + 1: {}, + 2: {}, + 3: {}, + }, + }).GetMapInt64NestedType(), + folds: 2, + foldLimit: 2, + }, + { + m: &pb.Map{ + Map: (&proto3pb.TestAllTypes{ + MapStringString: map[string]string{ + "1": "one", + "2": "two", + }, + }).ProtoReflect().Get(mapStrDesc).Map(), + KeyType: mapStrStrFD.KeyType, + ValueType: mapStrStrFD.ValueType, + }, + folds: 1, + foldLimit: 1, + }, + } + reg := NewEmptyRegistry() + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + f := &testMapFolder{foldLimit: tc.foldLimit} + m := reg.NativeToValue(tc.m).(traits.Foldable) + m.Fold(f) + if f.folds != tc.folds { + t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds) + } + }) + } +} + +type testMapFolder struct { + foldLimit int + folds int +} + +func (f *testMapFolder) FoldEntry(t ref.Type, k, v any) bool { + if f.foldLimit != 0 { + if f.folds >= f.foldLimit { + return false + } + } + f.folds++ + return true +} + +func testCreateStruct(t *testing.T, m map[string]any) *structpb.Struct { + t.Helper() + v, err := structpb.NewStruct(m) + if err != nil { + t.Fatalf("structpb.NewStruct(m) failed: %v", err) + } + return v +} diff --git a/common/types/traits/iterator.go b/common/types/traits/iterator.go index 42dd371a..61b03402 100644 --- a/common/types/traits/iterator.go +++ b/common/types/traits/iterator.go @@ -34,3 +34,16 @@ type Iterator interface { // Next returns the next element. Next() ref.Val } + +// Foldable aggregate types support iteration over (key, value) or (index, value) pairs. +type Foldable interface { + // Fold invokes the Folder.FoldEntry for all entries in the type + Fold(Folder) +} + +// Folder performs a fold on a given entry and indicates whether to continue folding. +type Folder interface { + // FoldEntry indicates the calling type and the (key, value) pair associated with the entry. + // If the output is true, continue folding. Otherwise, terminate the fold. + FoldEntry(t ref.Type, key, val any) bool +} diff --git a/common/types/traits/mapper.go b/common/types/traits/mapper.go index 2f7c919a..5f1a66b9 100644 --- a/common/types/traits/mapper.go +++ b/common/types/traits/mapper.go @@ -31,3 +31,15 @@ type Mapper interface { // (Unknown|Err, false). Find(key ref.Val) (ref.Val, bool) } + +// MutableMapper interface which emits an immutable result after an intermediate computation. +type MutableMapper interface { + Mapper + + // Insert a key, value pair into the map, returning true if key does not already exist in the map + // to indicate the insert is successful. + Insert(k, v ref.Val) bool + + // ToImmutableMap converts a mutable map into an immutable map. + ToImmutableMap() Mapper +} diff --git a/common/types/traits/traits.go b/common/types/traits/traits.go index 6da3e6a3..41361dd4 100644 --- a/common/types/traits/traits.go +++ b/common/types/traits/traits.go @@ -59,6 +59,9 @@ const ( // SizerType types support the size() method. SizerType - // SubtractorType type support '-' operations. + // SubtractorType types support '-' operations. SubtractorType + + // FoldableType types support comprehensions v2 macros which iterate over (key, value) pairs. + FoldableType ) From 189426bebc2492816209e61f81ede20c75e706c3 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Mon, 23 Sep 2024 09:50:27 -0700 Subject: [PATCH 2/2] Remove FoldEntry type reference --- common/types/list.go | 4 ++-- common/types/list_test.go | 2 +- common/types/map.go | 12 ++++++------ common/types/map_test.go | 2 +- common/types/traits/iterator.go | 4 ++-- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/common/types/list.go b/common/types/list.go index e4b2affb..3e71e33b 100644 --- a/common/types/list.go +++ b/common/types/list.go @@ -259,7 +259,7 @@ func (l *baseList) IsZeroValue() bool { // Fold calls the FoldEntry method for each (index, value) pair in the list. func (l *baseList) Fold(f traits.Folder) { for i := 0; i < l.size; i++ { - if !f.FoldEntry(ListType, i, l.get(i)) { + if !f.FoldEntry(i, l.get(i)) { break } } @@ -445,7 +445,7 @@ func (l *concatList) IsZeroValue() bool { // Fold calls the FoldEntry method for each (index, value) pair in the list. func (l *concatList) Fold(f traits.Folder) { for i := Int(0); i < l.Size().(Int); i++ { - if !f.FoldEntry(ListType, i, l.Get(i)) { + if !f.FoldEntry(i, l.Get(i)) { break } } diff --git a/common/types/list_test.go b/common/types/list_test.go index b57dc531..ea92f23e 100644 --- a/common/types/list_test.go +++ b/common/types/list_test.go @@ -803,7 +803,7 @@ type testListFolder struct { folds int } -func (f *testListFolder) FoldEntry(t ref.Type, k, v any) bool { +func (f *testListFolder) FoldEntry(k, v any) bool { if f.foldLimit != 0 { if f.folds >= f.foldLimit { return false diff --git a/common/types/map.go b/common/types/map.go index ce6bf9f5..bc20239f 100644 --- a/common/types/map.go +++ b/common/types/map.go @@ -396,7 +396,7 @@ func (a *jsonStructAccessor) Iterator() traits.Iterator { // Fold calls the FoldEntry method for each (key, value) pair in the map. func (a *jsonStructAccessor) Fold(f traits.Folder) { for k, v := range a.st { - if !f.FoldEntry(MapType, k, v) { + if !f.FoldEntry(k, v) { break } } @@ -480,7 +480,7 @@ func (m *reflectMapAccessor) Iterator() traits.Iterator { func (m *reflectMapAccessor) Fold(f traits.Folder) { mapRange := m.refValue.MapRange() for mapRange.Next() { - if !f.FoldEntry(MapType, mapRange.Key().Interface(), mapRange.Value().Interface()) { + if !f.FoldEntry(mapRange.Key().Interface(), mapRange.Value().Interface()) { break } } @@ -542,7 +542,7 @@ func (a *refValMapAccessor) Iterator() traits.Iterator { // Fold calls the FoldEntry method for each (key, value) pair in the map. func (a *refValMapAccessor) Fold(f traits.Folder) { for k, v := range a.mapVal { - if !f.FoldEntry(MapType, k, v) { + if !f.FoldEntry(k, v) { break } } @@ -589,7 +589,7 @@ func (a *stringMapAccessor) Iterator() traits.Iterator { // Fold calls the FoldEntry method for each (key, value) pair in the map. func (a *stringMapAccessor) Fold(f traits.Folder) { for k, v := range a.mapVal { - if !f.FoldEntry(MapType, k, v) { + if !f.FoldEntry(k, v) { break } } @@ -640,7 +640,7 @@ func (a *stringIfaceMapAccessor) Iterator() traits.Iterator { // Fold calls the FoldEntry method for each (key, value) pair in the map. func (a *stringIfaceMapAccessor) Fold(f traits.Folder) { for k, v := range a.mapVal { - if !f.FoldEntry(MapType, k, v) { + if !f.FoldEntry(k, v) { break } } @@ -861,7 +861,7 @@ func (m *protoMap) Iterator() traits.Iterator { // Fold calls the FoldEntry method for each (key, value) pair in the map. func (m *protoMap) Fold(f traits.Folder) { m.value.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { - return f.FoldEntry(MapType, k.Interface(), v.Interface()) + return f.FoldEntry(k.Interface(), v.Interface()) }) } diff --git a/common/types/map_test.go b/common/types/map_test.go index 8e6282c6..c96939b5 100644 --- a/common/types/map_test.go +++ b/common/types/map_test.go @@ -1106,7 +1106,7 @@ type testMapFolder struct { folds int } -func (f *testMapFolder) FoldEntry(t ref.Type, k, v any) bool { +func (f *testMapFolder) FoldEntry(k, v any) bool { if f.foldLimit != 0 { if f.folds >= f.foldLimit { return false diff --git a/common/types/traits/iterator.go b/common/types/traits/iterator.go index 61b03402..91c10f08 100644 --- a/common/types/traits/iterator.go +++ b/common/types/traits/iterator.go @@ -43,7 +43,7 @@ type Foldable interface { // Folder performs a fold on a given entry and indicates whether to continue folding. type Folder interface { - // FoldEntry indicates the calling type and the (key, value) pair associated with the entry. + // FoldEntry indicates the key, value pair associated with the entry. // If the output is true, continue folding. Otherwise, terminate the fold. - FoldEntry(t ref.Type, key, val any) bool + FoldEntry(key, val any) bool }