From a051ac2d90d28c7867bd5283203a2f4218c96e45 Mon Sep 17 00:00:00 2001
From: TristonianJones <tswadell@google.com>
Date: Fri, 27 Sep 2024 20:14:10 -0700
Subject: [PATCH] Interop foldable maps and lists with map mutation helper

---
 common/types/list.go          |  27 ++++++++
 common/types/list_test.go     |  67 +++++++++++++++++---
 common/types/map.go           |  58 ++++++++++++++++-
 common/types/map_test.go      | 113 +++++++++++++++++++++++++++++++---
 common/types/traits/lister.go |   3 +
 common/types/traits/mapper.go |   9 ++-
 6 files changed, 253 insertions(+), 24 deletions(-)

diff --git a/common/types/list.go b/common/types/list.go
index 3e71e33b..ca47d39f 100644
--- a/common/types/list.go
+++ b/common/types/list.go
@@ -545,3 +545,30 @@ func IndexOrError(index ref.Val) (int, error) {
 		return -1, fmt.Errorf("unsupported index type '%s' in list", index.Type())
 	}
 }
+
+// ToFoldableList will create a Foldable version of a list suitable for key-value pair iteration.
+//
+// For values which are already Foldable, this call is a no-op. For all other values, the fold is
+// driven via the Size() and Get() calls which means that the folding will function, but take a
+// performance hit.
+func ToFoldableList(l traits.Lister) traits.Foldable {
+	if f, ok := l.(traits.Foldable); ok {
+		return f
+	}
+	return interopFoldableList{Lister: l}
+}
+
+type interopFoldableList struct {
+	traits.Lister
+}
+
+// Fold implements the traits.Foldable interface method and performs an iteration over the
+// range of elements of the list.
+func (l interopFoldableList) Fold(f traits.Folder) {
+	sz := l.Size().(Int)
+	for i := Int(0); i < sz; i++ {
+		if !f.FoldEntry(i, l.Get(i)) {
+			break
+		}
+	}
+}
diff --git a/common/types/list_test.go b/common/types/list_test.go
index ea92f23e..ba6c498f 100644
--- a/common/types/list_test.go
+++ b/common/types/list_test.go
@@ -787,14 +787,20 @@ func TestListFold(t *testing.T) {
 	reg := NewEmptyRegistry()
 	for i, tst := range tests {
 		tc := tst
-		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
-			f := &testListFolder{foldLimit: tc.foldLimit}
-			l := reg.NativeToValue(tc.l).(traits.Foldable)
-			l.Fold(f)
-			if f.folds != tc.folds {
-				t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds)
-			}
-		})
+		l := reg.NativeToValue(tc.l).(traits.Lister)
+		foldKinds := map[string]traits.Foldable{
+			"modern": ToFoldableList(l),
+			"legacy": ToFoldableList(proxyLegacyList{proxy: l}),
+		}
+		for foldKind, foldable := range foldKinds {
+			t.Run(fmt.Sprintf("[%d]%s", i, foldKind), func(t *testing.T) {
+				f := &testListFolder{foldLimit: tc.foldLimit}
+				foldable.Fold(f)
+				if f.folds != tc.folds {
+					t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds)
+				}
+			})
+		}
 	}
 }
 
@@ -813,6 +819,51 @@ func (f *testListFolder) FoldEntry(k, v any) bool {
 	return true
 }
 
+// proxyLegacyList omits the foldable interfaces associated with all core Lister implementations
+type proxyLegacyList struct {
+	proxy traits.Lister
+}
+
+func (m proxyLegacyList) ConvertToNative(typeDesc reflect.Type) (any, error) {
+	return m.proxy.ConvertToNative(typeDesc)
+}
+
+func (m proxyLegacyList) ConvertToType(typeValue ref.Type) ref.Val {
+	return m.proxy.ConvertToType(typeValue)
+}
+
+func (m proxyLegacyList) Equal(other ref.Val) ref.Val {
+	return m.proxy.Equal(other)
+}
+
+func (m proxyLegacyList) Type() ref.Type {
+	return m.proxy.Type()
+}
+
+func (m proxyLegacyList) Value() any {
+	return m.proxy.Value()
+}
+
+func (m proxyLegacyList) Add(other ref.Val) ref.Val {
+	return m.proxy.Add(other)
+}
+
+func (m proxyLegacyList) Contains(value ref.Val) ref.Val {
+	return m.proxy.Contains(value)
+}
+
+func (m proxyLegacyList) Get(index ref.Val) ref.Val {
+	return m.proxy.Get(index)
+}
+
+func (m proxyLegacyList) Iterator() traits.Iterator {
+	return m.proxy.Iterator()
+}
+
+func (m proxyLegacyList) Size() ref.Val {
+	return m.proxy.Size()
+}
+
 func getElem(t *testing.T, list traits.Indexer, index ref.Val) any {
 	t.Helper()
 	val := list.Get(index)
diff --git a/common/types/map.go b/common/types/map.go
index bc20239f..89b33f90 100644
--- a/common/types/map.go
+++ b/common/types/map.go
@@ -336,12 +336,12 @@ type mutableMap struct {
 
 // Insert implements the traits.MutableMapper interface method, returning true if the key insertion
 // succeeds.
-func (m *mutableMap) Insert(k, v ref.Val) bool {
+func (m *mutableMap) Insert(k, v ref.Val) ref.Val {
 	if _, found := m.mutableValues[k]; found {
-		return false
+		return NewErr("insert failed: key %v already exists", k)
 	}
 	m.mutableValues[k] = v
-	return true
+	return m
 }
 
 // ToImmutableMap implements the traits.MutableMapper interface method, converting a mutable map
@@ -948,3 +948,55 @@ func (it *stringKeyIterator) Next() ref.Val {
 	}
 	return nil
 }
+
+// ToFoldableMap will create a Foldable version of a map suitable for key-value pair iteration.
+//
+// For values which are already Foldable, this call is a no-op. For all other values, the fold
+// is driven via the Iterator HasNext() and Next() calls as well as the map's Get() method
+// which means that the folding will function, but take a performance hit.
+func ToFoldableMap(m traits.Mapper) traits.Foldable {
+	if f, ok := m.(traits.Foldable); ok {
+		return f
+	}
+	return interopFoldableMap{Mapper: m}
+}
+
+type interopFoldableMap struct {
+	traits.Mapper
+}
+
+func (m interopFoldableMap) Fold(f traits.Folder) {
+	it := m.Iterator()
+	for it.HasNext() == True {
+		k := it.Next()
+		if !f.FoldEntry(k, m.Get(k)) {
+			break
+		}
+	}
+}
+
+// InsertMapKeyValue inserts a key, value pair into the target map if the target map does not
+// already contain the given key.
+//
+// If the map is mutable, it is modified in-place per the MutableMapper contract.
+// If the map is not mutable, a copy containing the new key, value pair is made.
+func InsertMapKeyValue(m traits.Mapper, k, v ref.Val) ref.Val {
+	if mutable, ok := m.(traits.MutableMapper); ok {
+		return mutable.Insert(k, v)
+	}
+
+	// Otherwise perform the slow version of the insertion which makes a copy of the incoming map.
+	if _, found := m.Find(k); !found {
+		size := m.Size().(Int)
+		copy := make(map[ref.Val]ref.Val, size+1)
+		copy[k] = v
+		it := m.Iterator()
+		for it.HasNext() == True {
+			nextK := it.Next()
+			nextV := m.Get(nextK)
+			copy[nextK] = nextV
+		}
+		return DefaultTypeAdapter.NativeToValue(copy)
+	}
+	return NewErr("insert failed: key %v already exists", k)
+}
diff --git a/common/types/map_test.go b/common/types/map_test.go
index c96939b5..b16422c7 100644
--- a/common/types/map_test.go
+++ b/common/types/map_test.go
@@ -976,8 +976,8 @@ func TestMutableMap(t *testing.T) {
 	if im.Size() != Int(2) {
 		t.Errorf("m.ToImmutableMap() had size %d, wanted 2", im.Size())
 	}
-	if m.Insert(String("goodbye"), String("happy world")) {
-		t.Error("m.Insert('goodbye', 'happy world') got true, wanted false")
+	if !IsError(m.Insert(String("goodbye"), String("happy world"))) {
+		t.Error("m.Insert('goodbye', 'happy world') suceeded, wanted error")
 	}
 	m.Insert(String("well"), String("well"))
 	if im.Size() != Int(2) {
@@ -1090,14 +1090,62 @@ func TestMapFold(t *testing.T) {
 	reg := NewEmptyRegistry()
 	for i, tst := range tests {
 		tc := tst
-		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
-			f := &testMapFolder{foldLimit: tc.foldLimit}
-			m := reg.NativeToValue(tc.m).(traits.Foldable)
-			m.Fold(f)
-			if f.folds != tc.folds {
-				t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds)
-			}
-		})
+		m := reg.NativeToValue(tc.m).(traits.Mapper)
+		foldKinds := map[string]traits.Foldable{
+			"modern": ToFoldableMap(m),
+			"legacy": ToFoldableMap(proxyLegacyMap{proxy: m}),
+		}
+		for foldKind, foldable := range foldKinds {
+			t.Run(fmt.Sprintf("[%d]%s", i, foldKind), func(t *testing.T) {
+				f := &testMapFolder{foldLimit: tc.foldLimit}
+				foldable.Fold(f)
+				if f.folds != tc.folds {
+					t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds)
+				}
+			})
+		}
+	}
+}
+
+func TestInsertMapKeyValue_MutableMapper(t *testing.T) {
+	m := NewMutableMap(DefaultTypeAdapter, map[ref.Val]ref.Val{String("first"): Int(1)})
+	modified := InsertMapKeyValue(m, String("second"), Int(2))
+	if IsError(modified) {
+		t.Fatalf("InsertMapKeyValue() got error: %v, wanted insertion", modified)
+	}
+	if modified != m {
+		t.Fatalf("InsertMapKeyValue() created a new map for a mutable input: %v", modified)
+	}
+	im := m.ToImmutableMap()
+	if _, found := im.Find(String("first")); !found {
+		t.Errorf("InsertMapKeyValue() did not preserve entry 'first': %v", im)
+	}
+	if _, found := im.Find(String("second")); !found {
+		t.Errorf("InsertMapKeyValue() did not insert entry 'second': %v", im)
+	}
+	if !IsError(InsertMapKeyValue(m, String("second"), Int(3))) {
+		t.Errorf("InsertMapKeyValue('second', 3) modified the map instead of erroring: %v", m)
+	}
+}
+
+func TestInsertMapKeyValue_Mapper(t *testing.T) {
+	m := NewRefValMap(DefaultTypeAdapter, map[ref.Val]ref.Val{String("first"): Int(1)})
+	modified := InsertMapKeyValue(m, String("second"), Int(2))
+	if IsError(modified) {
+		t.Fatalf("InsertMapKeyValue() got error: %v, wanted insertion", modified)
+	}
+	if modified == m {
+		t.Fatalf("InsertMapKeyValue() modified an immutable input: %v", modified)
+	}
+	im := modified.(traits.Mapper)
+	if _, found := im.Find(String("first")); !found {
+		t.Errorf("InsertMapKeyValue() did not preserve entry 'first': %v", im)
+	}
+	if _, found := im.Find(String("second")); !found {
+		t.Errorf("InsertMapKeyValue() did not insert entry 'second': %v", im)
+	}
+	if !IsError(InsertMapKeyValue(im, String("second"), Int(3))) {
+		t.Errorf("InsertMapKeyValue('second', 3) modified the map instead of erroring: %v", m)
 	}
 }
 
@@ -1124,3 +1172,48 @@ func testCreateStruct(t *testing.T, m map[string]any) *structpb.Struct {
 	}
 	return v
 }
+
+// proxyLegacyMap omits the foldable interfaces associated with all core Mapper implementations
+type proxyLegacyMap struct {
+	proxy traits.Mapper
+}
+
+func (m proxyLegacyMap) ConvertToNative(typeDesc reflect.Type) (any, error) {
+	return m.proxy.ConvertToNative(typeDesc)
+}
+
+func (m proxyLegacyMap) ConvertToType(typeValue ref.Type) ref.Val {
+	return m.proxy.ConvertToType(typeValue)
+}
+
+func (m proxyLegacyMap) Equal(other ref.Val) ref.Val {
+	return m.proxy.Equal(other)
+}
+
+func (m proxyLegacyMap) Type() ref.Type {
+	return m.proxy.Type()
+}
+
+func (m proxyLegacyMap) Value() any {
+	return m.proxy.Value()
+}
+
+func (m proxyLegacyMap) Contains(value ref.Val) ref.Val {
+	return m.proxy.Contains(value)
+}
+
+func (m proxyLegacyMap) Find(key ref.Val) (ref.Val, bool) {
+	return m.proxy.Find(key)
+}
+
+func (m proxyLegacyMap) Get(index ref.Val) ref.Val {
+	return m.proxy.Get(index)
+}
+
+func (m proxyLegacyMap) Iterator() traits.Iterator {
+	return m.proxy.Iterator()
+}
+
+func (m proxyLegacyMap) Size() ref.Val {
+	return m.proxy.Size()
+}
diff --git a/common/types/traits/lister.go b/common/types/traits/lister.go
index 5cf2593f..e54781a6 100644
--- a/common/types/traits/lister.go
+++ b/common/types/traits/lister.go
@@ -27,6 +27,9 @@ type Lister interface {
 }
 
 // MutableLister interface which emits an immutable result after an intermediate computation.
+//
+// Note, this interface is intended only to be used within Comprehensions where the mutable
+// value is not directly observable within the user-authored CEL expression.
 type MutableLister interface {
 	Lister
 	ToImmutableList() Lister
diff --git a/common/types/traits/mapper.go b/common/types/traits/mapper.go
index 5f1a66b9..d13333f3 100644
--- a/common/types/traits/mapper.go
+++ b/common/types/traits/mapper.go
@@ -33,12 +33,15 @@ type Mapper interface {
 }
 
 // MutableMapper interface which emits an immutable result after an intermediate computation.
+//
+// Note, this interface is intended only to be used within Comprehensions where the mutable
+// value is not directly observable within the user-authored CEL expression.
 type MutableMapper interface {
 	Mapper
 
-	// Insert a key, value pair into the map, returning true if key does not already exist in the map
-	// to indicate the insert is successful.
-	Insert(k, v ref.Val) bool
+	// Insert a key, value pair into the map, returning the map if the insert is successful
+	// and an error if key already exists in the mutable map.
+	Insert(k, v ref.Val) ref.Val
 
 	// ToImmutableMap converts a mutable map into an immutable map.
 	ToImmutableMap() Mapper