Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Foldable Maps and Lists #995

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions common/types/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ func (l *baseList) IsZeroValue() bool {
return l.size == 0
}

// Fold calls the FoldEntry method for each (index, value) pair in the list.
func (l *baseList) Fold(f traits.Folder) {
for i := 0; i < l.size; i++ {
if !f.FoldEntry(i, l.get(i)) {
break
}
}
}

// Iterator implements the traits.Iterable interface method.
func (l *baseList) Iterator() traits.Iterator {
return newListIterator(l)
Expand Down Expand Up @@ -433,6 +442,15 @@ func (l *concatList) IsZeroValue() bool {
return l.Size().(Int) == 0
}

// Fold calls the FoldEntry method for each (index, value) pair in the list.
func (l *concatList) Fold(f traits.Folder) {
for i := Int(0); i < l.Size().(Int); i++ {
if !f.FoldEntry(i, l.Get(i)) {
break
}
}
}

// Iterator implements the traits.Iterable interface method.
func (l *concatList) Iterator() traits.Iterator {
return newListIterator(l)
Expand Down
113 changes: 113 additions & 0 deletions common/types/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,119 @@ func TestValueListConvertToNative_Json(t *testing.T) {
}
}

func TestMutableList(t *testing.T) {
l := NewMutableList(DefaultTypeAdapter)
l.Add(NewRefValList(DefaultTypeAdapter, []ref.Val{String("hello")}))
l.Add(NewRefValList(DefaultTypeAdapter, []ref.Val{String("world")}))
il := l.ToImmutableList()
if il.Size() != Int(2) {
t.Errorf("il.Size() got %d, wanted size 2", il.Size())
}
l.Add(NewRefValList(DefaultTypeAdapter, []ref.Val{String("!")}))
if il.Size() != Int(2) {
t.Errorf("il.Size() got %d, wanted size 2", il.Size())
}
}

func TestListFold(t *testing.T) {

tests := []struct {
l any
folds int
foldLimit int
}{
{
l: []string{"hello", "world"},
folds: 2,
foldLimit: 2,
},
{
l: []string{"hello", "world"},
folds: 1,
foldLimit: 1,
},
{
l: []string{"hello"},
folds: 1,
foldLimit: 2,
},
{
l: []ref.Val{},
folds: 0,
foldLimit: 20,
},
{
l: []ref.Val{
String("hello"),
String("world"),
String("goodbye"),
String("cruel world"),
},
folds: 1,
foldLimit: 1,
},
{
l: []ref.Val{
String("hello"),
String("world"),
String("goodbye"),
String("cruel world"),
},
folds: 4,
foldLimit: 10,
},
{
l: DefaultTypeAdapter.NativeToValue([]ref.Val{
String("hello"),
String("world"),
}).(traits.Lister).Add(DefaultTypeAdapter.NativeToValue([]ref.Val{
String("goodbye"),
String("cruel world"),
})),
folds: 4,
foldLimit: 10,
},
{
l: DefaultTypeAdapter.NativeToValue([]ref.Val{
String("hello"),
String("world"),
}).(traits.Lister).Add(DefaultTypeAdapter.NativeToValue([]ref.Val{
String("goodbye"),
String("cruel world"),
})),
folds: 3,
foldLimit: 3,
},
}
reg := NewEmptyRegistry()
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
f := &testListFolder{foldLimit: tc.foldLimit}
l := reg.NativeToValue(tc.l).(traits.Foldable)
l.Fold(f)
if f.folds != tc.folds {
t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds)
}
})
}
}

type testListFolder struct {
foldLimit int
folds int
}

func (f *testListFolder) FoldEntry(k, v any) bool {
if f.foldLimit != 0 {
if f.folds >= f.foldLimit {
return false
}
}
f.folds++
return true
}

func getElem(t *testing.T, list traits.Indexer, index ref.Val) any {
t.Helper()
val := list.Get(index)
Expand Down
96 changes: 96 additions & 0 deletions common/types/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,24 @@ func NewProtoMap(adapter Adapter, value *pb.Map) traits.Mapper {
}
}

// NewMutableMap constructs a mutable map from an adapter and a set of map values.
func NewMutableMap(adapter Adapter, mutableValues map[ref.Val]ref.Val) traits.MutableMapper {
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
mutableCopy := make(map[ref.Val]ref.Val, len(mutableValues))
for k, v := range mutableValues {
mutableCopy[k] = v
}
m := &mutableMap{
baseMap: &baseMap{
Adapter: adapter,
mapAccessor: newRefValMapAccessor(mutableCopy),
value: mutableCopy,
size: len(mutableCopy),
},
mutableValues: mutableCopy,
}
return m
}

// mapAccessor is a private interface for finding values within a map and iterating over the keys.
// This interface implements portions of the API surface area required by the traits.Mapper
// interface.
Expand All @@ -105,6 +123,9 @@ type mapAccessor interface {

// Iterator returns an Iterator over the map key set.
Iterator() traits.Iterator

// Fold calls the FoldEntry method for each (key, value) pair in the map.
Fold(traits.Folder)
}

// baseMap is a reflection based map implementation designed to handle a variety of map-like types.
Expand Down Expand Up @@ -307,6 +328,28 @@ func (m *baseMap) Value() any {
return m.value
}

// mutableMap holds onto a set of mutable values which are used for intermediate computations.
type mutableMap struct {
*baseMap
mutableValues map[ref.Val]ref.Val
}

// Insert implements the traits.MutableMapper interface method, returning true if the key insertion
// succeeds.
func (m *mutableMap) Insert(k, v ref.Val) bool {
if _, found := m.mutableValues[k]; found {
return false
}
m.mutableValues[k] = v
return true
}

// ToImmutableMap implements the traits.MutableMapper interface method, converting a mutable map
// an immutable map implementation.
func (m *mutableMap) ToImmutableMap() traits.Mapper {
return NewRefValMap(m.Adapter, m.mutableValues)
}

func newJSONStructAccessor(adapter Adapter, st map[string]*structpb.Value) mapAccessor {
return &jsonStructAccessor{
Adapter: adapter,
Expand Down Expand Up @@ -350,6 +393,15 @@ func (a *jsonStructAccessor) Iterator() traits.Iterator {
}
}

// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (a *jsonStructAccessor) Fold(f traits.Folder) {
for k, v := range a.st {
if !f.FoldEntry(k, v) {
break
}
}
}

func newReflectMapAccessor(adapter Adapter, value reflect.Value) mapAccessor {
keyType := value.Type().Key()
return &reflectMapAccessor{
Expand Down Expand Up @@ -424,6 +476,16 @@ func (m *reflectMapAccessor) Iterator() traits.Iterator {
}
}

// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (m *reflectMapAccessor) Fold(f traits.Folder) {
mapRange := m.refValue.MapRange()
for mapRange.Next() {
if !f.FoldEntry(mapRange.Key().Interface(), mapRange.Value().Interface()) {
break
}
}
}

func newRefValMapAccessor(mapVal map[ref.Val]ref.Val) mapAccessor {
return &refValMapAccessor{mapVal: mapVal}
}
Expand Down Expand Up @@ -477,6 +539,15 @@ func (a *refValMapAccessor) Iterator() traits.Iterator {
}
}

// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (a *refValMapAccessor) Fold(f traits.Folder) {
for k, v := range a.mapVal {
if !f.FoldEntry(k, v) {
break
}
}
}

func newStringMapAccessor(strMap map[string]string) mapAccessor {
return &stringMapAccessor{mapVal: strMap}
}
Expand Down Expand Up @@ -515,6 +586,15 @@ func (a *stringMapAccessor) Iterator() traits.Iterator {
}
}

// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (a *stringMapAccessor) Fold(f traits.Folder) {
for k, v := range a.mapVal {
if !f.FoldEntry(k, v) {
break
}
}
}

func newStringIfaceMapAccessor(adapter Adapter, mapVal map[string]any) mapAccessor {
return &stringIfaceMapAccessor{
Adapter: adapter,
Expand Down Expand Up @@ -557,6 +637,15 @@ func (a *stringIfaceMapAccessor) Iterator() traits.Iterator {
}
}

// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (a *stringIfaceMapAccessor) Fold(f traits.Folder) {
for k, v := range a.mapVal {
if !f.FoldEntry(k, v) {
break
}
}
}

// protoMap is a specialized, separate implementation of the traits.Mapper interfaces tailored to
// accessing protoreflect.Map values.
type protoMap struct {
Expand Down Expand Up @@ -769,6 +858,13 @@ func (m *protoMap) Iterator() traits.Iterator {
}
}

// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (m *protoMap) Fold(f traits.Folder) {
m.value.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
return f.FoldEntry(k.Interface(), v.Interface())
})
}

// Size returns the number of entries in the protoreflect.Map.
func (m *protoMap) Size() ref.Val {
return Int(m.value.Len())
Expand Down
Loading