Skip to content

Commit

Permalink
collection unit tests (#263)
Browse files Browse the repository at this point in the history
* add tests to collections
  • Loading branch information
jairad26 authored Jun 28, 2024
1 parent e6c310a commit 89d86fc
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 0 deletions.
101 changes: 101 additions & 0 deletions collections/in_mem/sequential/vector_index_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package sequential

import (
"context"
"fmt"
"hmruntime/collections/utils"
"sync"
"testing"
)

func TestMultipleSequentialVectorIndexes(t *testing.T) {
ctx := context.Background()

// Define the base data to be inserted
baseTextIds := []int64{1, 2, 3}
baseKeys := []string{"key1", "key2", "key3"}
baseVecs := [][]float32{
{0.1, 0.2, 0.3},
{0.4, 0.5, 0.6},
{0.7, 0.8, 0.9},
}

// Create a wait group to synchronize the goroutines
var wg sync.WaitGroup

// Define the number of indexes to create
numIndexes := 20

// Create and initialize the indexes
for i := 0; i < numIndexes; i++ {
wg.Add(1)

go func(i int) {
defer wg.Done()

// Create a new SequentialVectorIndex
index := NewSequentialVectorIndex("collection"+fmt.Sprint(i), "searchMethod"+fmt.Sprint(i), "embedder"+fmt.Sprint(i))

// Generate unique data for this index
textIds := make([]int64, len(baseTextIds))
keys := make([]string, len(baseKeys))
vecs := make([][]float32, len(baseVecs))
for j := range baseTextIds {
textIds[j] = baseTextIds[j] + int64(i*len(baseTextIds))
keys[j] = baseKeys[j] + fmt.Sprint(i)
vecs[j] = append([]float32{}, baseVecs[j]...)
for k := range vecs[j] {
vecs[j][k] += float32(i) / 10
}
}

err := index.InsertVectorsToMemory(ctx, textIds, textIds, keys, vecs)
if err != nil {
t.Errorf("Failed to insert vectors into index: %v", err)
}

// Verify the vectors were inserted correctly
for _, key := range keys {
expectedVec, err := index.GetVector(ctx, key)
if err != nil {
t.Errorf("index %d: Failed to get expected vector from index: %v", i, err)
}
objs, err := index.SearchWithKey(ctx, key, 1, nil)
if err != nil {
t.Errorf("index %d: Failed to search vector in index: %v", i, err)
}
if len(objs) == 0 {
t.Errorf("index %d: Expected obj with length 1, got %v", i, len(objs))
}
resVec, err := index.GetVector(ctx, objs[0].GetIndex())
if err != nil {
t.Errorf("index %d: Failed to get result vector from index: %v", i, err)
}

if !utils.EqualFloat32Slices(expectedVec, resVec) {
t.Errorf("index %d: Expected vector %v, got %v", i, expectedVec, resVec)
}

checkpointId, err := index.GetCheckpointId(ctx)
if err != nil {
t.Errorf("Failed to get checkpoint ID: %v", err)
}
if checkpointId != textIds[len(textIds)-1] {
t.Errorf("Expected checkpoint ID %v, got %v", textIds[len(textIds)-1], checkpointId)
}

lastIndexedTextID, err := index.GetLastIndexedTextId(ctx)
if err != nil {
t.Errorf("Failed to get last indexed text ID: %v", err)
}
if lastIndexedTextID != textIds[len(textIds)-1] {
t.Errorf("Expected last indexed text ID %v, got %v", textIds[len(textIds)-1], lastIndexedTextID)
}
}

}(i)
}

// Wait for all goroutines to finish
wg.Wait()
}
73 changes: 73 additions & 0 deletions collections/in_mem/text_index_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package in_mem

import (
"context"
"fmt"
"sync"
"testing"
)

func TestMultipleInMemCollections(t *testing.T) {
ctx := context.Background()

// Create a wait group to synchronize the goroutines
var wg sync.WaitGroup

// Define the number of collections to create
numCollections := 10

// Create and initialize the collections
for i := 0; i < numCollections; i++ {
wg.Add(1)

go func(i int) {
defer wg.Done()

// Define the data to be inserted
ids := []int64{int64(i*3 + 1), int64(i*3 + 2), int64(i*3 + 3)}
keys := []string{fmt.Sprintf("key%d_1", i), fmt.Sprintf("key%d_2", i), fmt.Sprintf("key%d_3", i)}
texts := []string{fmt.Sprintf("text%d_1", i), fmt.Sprintf("text%d_2", i), fmt.Sprintf("text%d_3", i)}

// Create a new InMemCollection
collection := NewCollection("collection" + fmt.Sprint(i))

// Insert the texts into the collection
err := collection.InsertTextsToMemory(ctx, ids, keys, texts)
if err != nil {
t.Errorf("Failed to insert texts into collection: %v", err)
}

// Verify the texts were inserted correctly
for j, key := range keys {
text, err := collection.GetText(ctx, key)
if err != nil {
t.Errorf("Failed to get text from collection: %v", err)
}
if text != texts[j] {
t.Errorf("Expected text %s, got %s", texts[j], text)
}

// Verify the external ID
extID, err := collection.GetExternalId(ctx, key)
if err != nil {
t.Errorf("Failed to get external ID from collection: %v", err)
}
if extID != ids[j] {
t.Errorf("Expected external ID %d, got %d", ids[j], extID)
}
}

// Verify the checkpoint ID
chkID, err := collection.GetCheckpointId(ctx)
if err != nil {
t.Errorf("Failed to get checkpoint ID from collection: %v", err)
}
if chkID != ids[len(ids)-1] {
t.Errorf("Expected checkpoint ID %d, got %d", ids[len(ids)-1], chkID)
}
}(i)
}

// Wait for all goroutines to finish
wg.Wait()
}
60 changes: 60 additions & 0 deletions collections/utils/heap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package utils

import (
"container/heap"
"testing"
)

func TestHeap(t *testing.T) {
h := &MinTupleHeap{}
heap.Init(h)

// Test Len before pushing any elements
if h.Len() != 0 {
t.Errorf("Expected length of 0, got %d", h.Len())
}

// Test Push
heap.Push(h, MinHeapElement{value: 3.0, index: "three"})
heap.Push(h, MinHeapElement{value: 1.0, index: "one"})
heap.Push(h, MinHeapElement{value: 2.0, index: "two"})

// Test Len
if h.Len() != 3 {
t.Errorf("Expected length of 3, got %d", h.Len())
}

// Test Less
if !h.Less(0, 1) {
t.Errorf("Expected h[0] < h[1], got h[0] = %v, h[1] = %v", (*h)[0], (*h)[1])
}

// Test Pop
expectedValues := []float64{1.0, 2.0, 3.0}
expectedIndices := []string{"one", "two", "three"}
initialLen := h.Len() // Store initial length of heap

for i := 0; i < initialLen; i++ {
popped := heap.Pop(h).(MinHeapElement)
if popped.value != expectedValues[i] || popped.index != expectedIndices[i] {
t.Errorf("Expected pop value of %v and index '%s', got %v and '%s'", expectedValues[i], expectedIndices[i], popped.value, popped.index)
}
}

// Test Len after popping all elements
if h.Len() != 0 {
t.Errorf("Expected length of 0, got %d", h.Len())
}
}

func TestHeapSwap(t *testing.T) {
h := &MinTupleHeap{
MinHeapElement{value: 1.0, index: "one"},
MinHeapElement{value: 2.0, index: "two"},
}
h.Swap(0, 1)

if (*h)[0].value != 2.0 || (*h)[0].index != "two" || (*h)[1].value != 1.0 || (*h)[1].index != "one" {
t.Errorf("Expected heap to be swapped, got %v", h)
}
}
14 changes: 14 additions & 0 deletions collections/utils/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package utils
import (
"errors"
"fmt"
"math"

"github.com/chewxy/math32"
)
Expand Down Expand Up @@ -102,3 +103,16 @@ func ConvertToFloat32_2DArray(result any) ([][]float32, error) {
}
return textVecs, nil
}

func EqualFloat32Slices(a, b []float32) bool {
const epsilon = 1e-9
if len(a) != len(b) {
return false
}
for i := range a {
if math.Abs(float64(a[i]-b[i])) > epsilon {
return false
}
}
return true
}
34 changes: 34 additions & 0 deletions collections/utils/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package utils

import (
"reflect"
"testing"
)

func TestConvertToFloat32_2DArray(t *testing.T) {
// Test with valid input
input := []interface{}{
[]interface{}{float64(1.0), float32(2.0)},
[]interface{}{float64(3.0), float32(4.0)},
}
expected := [][]float32{
{1.0, 2.0},
{3.0, 4.0},
}
result, err := ConvertToFloat32_2DArray(input)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v, got %v", expected, result)
}

// Test with invalid input
input = []interface{}{
[]interface{}{float64(1.0), "invalid"},
}
_, err = ConvertToFloat32_2DArray(input)
if err == nil {
t.Errorf("Expected error, got nil")
}
}

0 comments on commit 89d86fc

Please sign in to comment.