Skip to content

Commit

Permalink
WIP: rm Embeddable for real types
Browse files Browse the repository at this point in the history
  • Loading branch information
ammario committed May 31, 2024
1 parent d6543ed commit 241409c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 117 deletions.
4 changes: 2 additions & 2 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func (h *Graph[T]) Import(r io.Reader) error {
}

node := &layerNode[T]{
Point: point,
vec: point,
neighbors: make(map[string]*layerNode[T]),
}

Expand All @@ -243,7 +243,7 @@ func (h *Graph[T]) Import(r io.Reader) error {
node.neighbors[id] = nodes[id]
}
}
h.layers[i] = &layer[T]{Nodes: nodes}
h.layers[i] = &layer[T]{nodes: nodes}
}

return nil
Expand Down
105 changes: 49 additions & 56 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,38 @@ import (
"golang.org/x/exp/maps"
)

type Embedding = []float32

// Embeddable describes a type that can be embedded in a HNSW graph.
type Embeddable[K cmp.Ordered] interface {
// ID returns a unique identifier for the object.
ID() K
// Embedding returns the embedding of the object.
// float32 is used for compatibility with OpenAI embeddings.
Embedding() Embedding
}
type Vector = []float32

// layerNode is a node in a layer of the graph.
type layerNode[K cmp.Ordered, V Embeddable[K]] struct {
Point Embeddable[K]
type layerNode[K cmp.Ordered] struct {
id K
vec Vector

// neighbors is map of neighbor IDs to neighbor nodes.
// It is a map and not a slice to allow for efficient deletes, esp.
// when M is high.
neighbors map[K]*layerNode[K, V]
neighbors map[K]*layerNode[K]
}

// addNeighbor adds a o neighbor to the node, replacing the neighbor
// with the worst distance if the neighbor set is full.
func (n *layerNode[K, V]) addNeighbor(newNode *layerNode[K, V], m int, dist DistanceFunc) {
func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFunc) {
if n.neighbors == nil {
n.neighbors = make(map[K]*layerNode[K, V], m)
n.neighbors = make(map[K]*layerNode[K], m)
}

n.neighbors[newNode.Point.ID()] = newNode
n.neighbors[newNode.id] = newNode
if len(n.neighbors) <= m {
return
}

// Find the neighbor with the worst distance.
var (
worstDist = float32(math.Inf(-1))
worst *layerNode[K, V]
worst *layerNode[K]
)
for _, neighbor := range n.neighbors {
d := dist(neighbor.Point.Embedding(), n.Point.Embedding())
d := dist(neighbor.vec, n.vec)
// d > worstDist may always be false if the distance function
// returns NaN, e.g., when the embeddings are zero.
if d > worstDist || worst == nil {
Expand All @@ -59,49 +52,49 @@ func (n *layerNode[K, V]) addNeighbor(newNode *layerNode[K, V], m int, dist Dist
}
}

delete(n.neighbors, worst.Point.ID())
delete(n.neighbors, worst.id)
// Delete backlink from the worst neighbor.
delete(worst.neighbors, n.Point.ID())
delete(worst.neighbors, n.id)
worst.replenish(m)
}

type searchCandidate[K cmp.Ordered, V Embeddable[K]] struct {
node *layerNode[K, V]
type searchCandidate[K cmp.Ordered] struct {
node *layerNode[K]
dist float32
}

func (s searchCandidate[K, V]) Less(o searchCandidate[K, V]) bool {
func (s searchCandidate[K]) Less(o searchCandidate[K]) bool {
return s.dist < o.dist
}

// search returns the layer node closest to the target node
// within the same layer.
func (n *layerNode[K, V]) search(
func (n *layerNode[K]) search(
// k is the number of candidates in the result set.
k int,
efSearch int,
target Embedding,
target Vector,
distance DistanceFunc,
) []searchCandidate[K, V] {
) []searchCandidate[K] {
// This is a basic greedy algorithm to find the entry point at the given level
// that is closest to the target node.
candidates := heap.Heap[searchCandidate[K, V]]{}
candidates.Init(make([]searchCandidate[K, V], 0, efSearch))
candidates := heap.Heap[searchCandidate[K]]{}
candidates.Init(make([]searchCandidate[K], 0, efSearch))
candidates.Push(
searchCandidate[K, V]{
searchCandidate[K]{
node: n,
dist: distance(n.Point.Embedding(), target),
dist: distance(n.vec, target),
},
)
var (
result = heap.Heap[searchCandidate[K, V]]{}
result = heap.Heap[searchCandidate[K]]{}
visited = make(map[K]bool)
)
result.Init(make([]searchCandidate[K, V], 0, k))
result.Init(make([]searchCandidate[K], 0, k))

// Begin with the entry node in the result set.
result.Push(candidates.Min())
visited[n.Point.ID()] = true
visited[n.id] = true

for candidates.Len() > 0 {
var (
Expand All @@ -120,16 +113,16 @@ func (n *layerNode[K, V]) search(
}
visited[neighborID] = true

dist := distance(neighbor.Point.Embedding(), target)
dist := distance(neighbor.vec, target)
improved = improved || dist < result.Min().dist
if result.Len() < k {
result.Push(searchCandidate[K, V]{node: neighbor, dist: dist})
result.Push(searchCandidate[K]{node: neighbor, dist: dist})
} else if dist < result.Max().dist {
result.PopLast()
result.Push(searchCandidate[K, V]{node: neighbor, dist: dist})
result.Push(searchCandidate[K]{node: neighbor, dist: dist})
}

candidates.Push(searchCandidate[K, V]{node: neighbor, dist: dist})
candidates.Push(searchCandidate[K]{node: neighbor, dist: dist})
// Always store candidates if we haven't reached the limit.
if candidates.Len() > efSearch {
candidates.PopLast()
Expand All @@ -146,7 +139,7 @@ func (n *layerNode[K, V]) search(
return result.Slice()
}

func (n *layerNode[K, V]) replenish(m int) {
func (n *layerNode[K]) replenish(m int) {
if len(n.neighbors) >= m {
return
}
Expand All @@ -173,46 +166,46 @@ func (n *layerNode[K, V]) replenish(m int) {

// isolates remove the node from the graph by removing all connections
// to neighbors.
func (n *layerNode[K, V]) isolate(m int) {
func (n *layerNode[K]) isolate(m int) {
for _, neighbor := range n.neighbors {
delete(neighbor.neighbors, n.Point.ID())
delete(neighbor.neighbors, n.id)
neighbor.replenish(m)
}
}

type layer[T Embeddable] struct {
// Nodes is a map of node IDs to Nodes.
// All Nodes in a higher layer are also in the lower layers, an essential
type layer[K cmp.Ordered] struct {
// nodes is a map of nodes IDs to nodes.
// All nodes in a higher layer are also in the lower layers, an essential
// property of the graph.
//
// Nodes is exported for interop with encoding/gob.
Nodes map[string]*layerNode[T]
// nodes is exported for interop with encoding/gob.
nodes map[string]*layerNode[K]
}

// entry returns the entry node of the layer.
// It doesn't matter which node is returned, even that the
// entry node is consistent, so we just return the first node
// in the map to avoid tracking extra state.
func (l *layer[T]) entry() *layerNode[T] {
func (l *layer[K]) entry() *layerNode[K] {
if l == nil {
return nil
}
for _, node := range l.Nodes {
for _, node := range l.nodes {
return node
}
return nil
}

func (l *layer[T]) size() int {
func (l *layer[K]) size() int {
if l == nil {
return 0
}
return len(l.Nodes)
return len(l.nodes)
}

// Graph is a Hierarchical Navigable Small World graph.
// All public parameters must be set before adding nodes to the graph.
type Graph[T Embeddable] struct {
type Graph[K cmp.Ordered] struct {
// Distance is the distance function used to compare embeddings.
Distance DistanceFunc

Expand All @@ -235,7 +228,7 @@ type Graph[T Embeddable] struct {
EfSearch int

// layers is a slice of layers in the graph.
layers []*layer[T]
layers []*layer[K]
}

func defaultRand() *rand.Rand {
Expand All @@ -244,8 +237,8 @@ func defaultRand() *rand.Rand {

// NewGraph returns a new graph with default parameters, roughly designed for
// storing OpenAI embeddings.
func NewGraph[T Embeddable]() *Graph[T] {
return &Graph[T]{
func NewGraph[K cmp.Ordered, V Embeddable[K]]() *Graph[K, V] {

Check failure on line 240 in graph.go

View workflow job for this annotation

GitHub Actions / test

undefined: Embeddable
return &Graph[K, V]{
M: 16,
Ml: 0.25,
Distance: CosineDistance,
Expand Down Expand Up @@ -298,7 +291,7 @@ func (h *Graph[T]) randomLevel() int {
return max
}

func (g *Graph[T]) assertDims(n Embedding) {
func (g *Graph[T]) assertDims(n Vector) {
if len(g.layers) == 0 {
return
}
Expand Down Expand Up @@ -340,7 +333,7 @@ func (g *Graph[T]) Add(nodes ...T) {
for i := len(g.layers) - 1; i >= 0; i-- {
layer := g.layers[i]
newNode := &layerNode[T]{
Point: n,
vec: n,
}

// Insert the new node into the layer.
Expand Down Expand Up @@ -395,7 +388,7 @@ func (g *Graph[T]) Add(nodes ...T) {
}

// Search finds the k nearest neighbors from the target node.
func (h *Graph[T]) Search(near Embedding, k int) []T {
func (h *Graph[T]) Search(near Vector, k int) []T {
h.assertDims(near)
if len(h.layers) == 0 {
return nil
Expand Down
12 changes: 6 additions & 6 deletions graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ func (n basicPoint) Embedding() []float32 {

func Test_layerNode_search(t *testing.T) {
entry := &layerNode[basicPoint]{
Point: basicPoint(0),
vec: basicPoint(0),
neighbors: map[string]*layerNode[basicPoint]{
"1": {
Point: basicPoint(1),
vec: basicPoint(1),
},
"2": {
Point: basicPoint(2),
vec: basicPoint(2),
},
"3": {
Point: basicPoint(3),
vec: basicPoint(3),
neighbors: map[string]*layerNode[basicPoint]{
"3.8": {
Point: basicPoint(3.8),
vec: basicPoint(3.8),
},
"4.3": {
Point: basicPoint(4.3),
vec: basicPoint(4.3),
},
},
},
Expand Down
53 changes: 0 additions & 53 deletions vector.go

This file was deleted.

0 comments on commit 241409c

Please sign in to comment.