Skip to content

Commit

Permalink
Add ProximityMap class.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicktobey committed Oct 2, 2024
1 parent bbf447d commit 91a5a5d
Show file tree
Hide file tree
Showing 3 changed files with 728 additions and 0 deletions.
113 changes: 113 additions & 0 deletions go/store/prolly/proximity_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package prolly

import (
"context"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/prolly/message"
"github.com/dolthub/dolt/go/store/prolly/tree"
"github.com/dolthub/dolt/go/store/val"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
)

// ProximityMap wraps a tree.ProximityMap but operates on typed Tuples instead of raw bytestrings.
type ProximityMap struct {
tuples tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]
keyDesc val.TupleDesc
valDesc val.TupleDesc
ctx context.Context
}

// NewProximityMap creates an empty prolly Tree Map
func NewProximityMap(ctx context.Context, node tree.Node, ns tree.NodeStore, keyDesc val.TupleDesc, valDesc val.TupleDesc) ProximityMap {
tuples := tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{
Root: node,
NodeStore: ns,
Order: keyDesc,
DistanceType: expression.DistanceL2Squared{},
Convert: func(bytes []byte) []float64 {
h, _ := keyDesc.GetJSONAddr(0, bytes)
doc := tree.NewJSONDoc(h, ns)
jsonWrapper, err := doc.ToIndexedJSONDocument(ctx)
if err != nil {
panic(err)
}
floats, err := sql.ConvertToVector(jsonWrapper)
if err != nil {
panic(err)
}
return floats
},
}
return ProximityMap{
tuples: tuples,
keyDesc: keyDesc,
valDesc: valDesc,
}
}

type VectorIter interface {
Next(ctx context.Context) (k interface{}, v val.Tuple)
}

func NewProximityMapFromTupleIter(ctx context.Context, ns tree.NodeStore, distanceType expression.DistanceType, keyDesc val.TupleDesc, valDesc val.TupleDesc, keys []val.Tuple, values []val.Tuple, logChunkSize uint8) (ProximityMap, error) {
serializer := message.NewVectorIndexSerializer(ns.Pool())
ch, err := tree.NewChunkerWithDeterministicSplitter(ctx, nil, 0, ns, serializer, logChunkSize)

if err != nil {
return ProximityMap{}, err
}

for i := 0; i < len(keys); i++ {
if err = ch.AddPair(ctx, tree.Item(keys[i]), tree.Item(values[i])); err != nil {
return ProximityMap{}, err
}
}

root, err := ch.Done(ctx)
if err != nil {
return ProximityMap{}, err
}

// We now have a map where each node is at the right level, but now we need to sort it.

getHash := func(tuple []byte) hash.Hash {
h, _ := keyDesc.GetJSONAddr(0, tuple)
return h
}
newRoot, err := tree.FixupProximityMap[val.Tuple, val.TupleDesc](ctx, ns, distanceType, root, getHash, keyDesc)
if err != nil {
return ProximityMap{}, err
}

return NewProximityMap(ctx, newRoot, ns, keyDesc, valDesc), nil
}

// Count returns the number of key-value pairs in the Map.
func (m ProximityMap) Count() (int, error) {
return m.tuples.Count()
}

// Get searches for the key-value pair keyed by |key| and passes the results to the callback.
// If |key| is not present in the map, a nil key-value pair are passed.
func (m ProximityMap) Get(ctx context.Context, query interface{}, cb tree.KeyValueFn[val.Tuple, val.Tuple]) (err error) {
return m.tuples.GetExact(ctx, query, cb)
}

func (m ProximityMap) GetClosest(ctx context.Context, query interface{}, cb tree.KeyValueDistanceFn[val.Tuple, val.Tuple], limit int) (err error) {
return m.tuples.GetClosest(ctx, query, cb, limit)
}
192 changes: 192 additions & 0 deletions go/store/prolly/proximity_map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package prolly

import (
"context"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/pool"
"github.com/dolthub/dolt/go/store/prolly/tree"
"github.com/dolthub/dolt/go/store/val"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/stretchr/testify/require"
"testing"
)

func newJsonValue(t *testing.T, v interface{}) sql.JSONWrapper {
doc, _, err := types.JSON.Convert(v)
require.NoError(t, err)
return doc.(sql.JSONWrapper)
}

// newJsonDocument creates a JSON value from a provided value.
func newJsonDocument(t *testing.T, ctx context.Context, ns tree.NodeStore, v interface{}) hash.Hash {
doc := newJsonValue(t, v)
root, err := tree.SerializeJsonToAddr(ctx, ns, doc)
require.NoError(t, err)
return root.HashOf()
}

func createProximityMap(t *testing.T, ctx context.Context, ns tree.NodeStore, vectors []interface{}, pks []int64, logChunkSize uint8) (ProximityMap, []val.Tuple, []val.Tuple) {
bp := pool.NewBuffPool()

count := len(vectors)
require.Equal(t, count, len(pks))

kd := val.NewTupleDescriptor(
val.Type{Enc: val.JSONAddrEnc, Nullable: true},
)

vd := val.NewTupleDescriptor(
val.Type{Enc: val.Int64Enc, Nullable: true},
)

distanceType := expression.DistanceL2Squared{}

keys := make([]val.Tuple, count)
keyBuilder := val.NewTupleBuilder(kd)
for i, vector := range vectors {
keyBuilder.PutJSONAddr(0, newJsonDocument(t, ctx, ns, vector))
keys[i] = keyBuilder.Build(bp)
}

valueBuilder := val.NewTupleBuilder(vd)
values := make([]val.Tuple, count)
for i, pk := range pks {
valueBuilder.PutInt64(0, pk)
values[i] = valueBuilder.Build(bp)
}

m, err := NewProximityMapFromTupleIter(ctx, ns, distanceType, kd, vd, keys, values, logChunkSize)
require.NoError(t, err)
mapCount, err := m.Count()
require.NoError(t, err)
require.Equal(t, count, mapCount)

return m, keys, values
}

func TestEmptyProximityMap(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
createProximityMap(t, ctx, ns, nil, nil, 10)
}

func TestSingleEntryProximityMap(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[1.0]"}, []int64{1}, 10)
matches := 0
vectorHash, _ := m.keyDesc.GetJSONAddr(0, keys[0])
vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx)
require.NoError(t, err)
err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error {
require.Equal(t, keys[0], foundKey)
require.Equal(t, values[0], foundValue)
matches++
return nil
})
require.NoError(t, err)
require.Equal(t, matches, 1)
}

func TestDoubleEntryProximityMapGetExact(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[0.0, 6.0]", "[3.0, 4.0]"}, []int64{1, 2}, 10)
matches := 0
for i, key := range keys {
vectorHash, _ := m.keyDesc.GetJSONAddr(0, key)
vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx)
err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error {
require.Equal(t, key, foundKey)
require.Equal(t, values[i], foundValue)
matches++
return nil
})
require.NoError(t, err)
}
require.Equal(t, matches, len(keys))
}

func TestDoubleEntryProximityMapGetClosest(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[0.0, 6.0]", "[3.0, 4.0]"}, []int64{1, 2}, 10)
matches := 0

cb := func(foundKey val.Tuple, foundValue val.Tuple, distance float64) error {
require.Equal(t, keys[1], foundKey)
require.Equal(t, values[1], foundValue)
require.InDelta(t, distance, 25.0, 0.1)
matches++
return nil
}

err := m.GetClosest(ctx, newJsonValue(t, "[0.0, 0.0]"), cb, 1)
require.NoError(t, err)
require.Equal(t, matches, 1)
}

func TestMultilevelProximityMap(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
keyStrings := []interface{}{
"[0.0, 1.0]",
"[3.0, 4.0]",
"[5.0, 6.0]",
"[7.0, 8.0]",
}
valueStrings := []int64{1, 2, 3, 4}
m, keys, values := createProximityMap(t, ctx, ns, keyStrings, valueStrings, 1)
matches := 0
for i, key := range keys {
vectorHash, _ := m.keyDesc.GetJSONAddr(0, key)
vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx)
require.NoError(t, err)
err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error {
require.Equal(t, key, foundKey)
require.Equal(t, values[i], foundValue)
matches++
return nil
})
require.NoError(t, err)
}
require.Equal(t, matches, len(keys))
}

func TestInsertOrderIndependence(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
keyStrings1 := []interface{}{
"[0.0, 1.0]",
"[3.0, 4.0]",
"[5.0, 6.0]",
"[7.0, 8.0]",
}
valueStrings1 := []int64{1, 2, 3, 4}
keyStrings2 := []interface{}{
"[7.0, 8.0]",
"[5.0, 6.0]",
"[3.0, 4.0]",
"[0.0, 1.0]",
}
valueStrings2 := []int64{4, 3, 2, 1}
m1, _, _ := createProximityMap(t, ctx, ns, keyStrings1, valueStrings1, 1)
m2, _, _ := createProximityMap(t, ctx, ns, keyStrings2, valueStrings2, 1)
require.Equal(t, m1.tuples.Root.HashOf(), m2.tuples.Root.HashOf())
}
Loading

0 comments on commit 91a5a5d

Please sign in to comment.