Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor API so that keys are generic #6

Merged
merged 7 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ go get github.com/coder/hnsw@main
```

```go
g := hnsw.NewGraph[hnsw.Vector]()
g := hnsw.NewGraph[int]()
g.Add(
hnsw.MakeVector("1", []float32{1, 1, 1}),
hnsw.MakeVector("2", []float32{1, -1, 0.999}),
hnsw.MakeVector("3", []float32{1, 0, -0.5}),
hnsw.MakeNode(1, []float32{1, 1, 1}),
hnsw.MakeNode(2, []float32{1, -1, 0.999}),
hnsw.MakeNode(3, []float32{1, 0, -0.5}),
)

neighbors := g.Search(
[]float32{0.5, 0.5, 0.5},
1,
)
fmt.Printf("best friend: %v\n", neighbors[0].Embedding())
fmt.Printf("best friend: %v\n", neighbors[0].Vec)
// Output: best friend: [1 1 1]
```

Expand All @@ -59,13 +59,13 @@ If you're using a single file as the backend, hnsw provides a convenient `SavedG

```go
path := "some.graph"
g1, err := LoadSavedGraph[hnsw.Vector](path)
g1, err := LoadSavedGraph[int](path)
if err != nil {
panic(err)
}
// Insert some vectors
for i := 0; i < 128; i++ {
g1.Add(MakeVector(strconv.Itoa(i), []float32{float32(i)}))
g1.Add(hnsw.MakeNode(i, []float32{float32(i)}))
}

// Save to disk
Expand All @@ -76,7 +76,7 @@ if err != nil {

// Later...
// g2 is a copy of g1
g2, err := LoadSavedGraph[Vector](path)
g2, err := LoadSavedGraph[int](path)
if err != nil {
panic(err)
}
Expand All @@ -94,10 +94,10 @@ nearly at disk speed. On my M3 Macbook I get these benchmark results:
goos: darwin
goarch: arm64
pkg: github.com/coder/hnsw
BenchmarkGraph_Import-16 2733 369803 ns/op 228.65 MB/s 352041 B/op 9880 allocs/op
BenchmarkGraph_Export-16 6046 194441 ns/op 1076.65 MB/s 261854 B/op 3760 allocs/op
BenchmarkGraph_Import-16 4029 259927 ns/op 796.85 MB/s 496022 B/op 3212 allocs/op
BenchmarkGraph_Export-16 7042 168028 ns/op 1232.49 MB/s 239886 B/op 2388 allocs/op
PASS
ok github.com/coder/hnsw 2.530s
ok github.com/coder/hnsw 2.624s
```

when saving/loading a graph of 100 vectors with 256 dimensions.
Expand Down Expand Up @@ -130,18 +130,18 @@ $$

where:
* $n$ is the number of vectors in the graph
* $\text{size(id)}$ is the average size of the ID in bytes
* $\text{size(key)}$ is the average size of the key in bytes
* $M$ is the maximum number of neighbors each node can have
* $d$ is the dimensionality of the vectors
* $mem_{graph}$ is the memory used by the graph structure across all layers
* $mem_{base}$ is the memory used by the vectors themselves in the base or 0th layer

You can infer that:
* Connectivity ($M$) is very expensive if IDs are large
* If $d \cdot 4$ is far larger than $M \cdot \text{size(id)}$, you should expect linear memory usage spent on representing vector data
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(id)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure
* Connectivity ($M$) is very expensive if keys are large
* If $d \cdot 4$ is far larger than $M \cdot \text{size(key)}$, you should expect linear memory usage spent on representing vector data
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(key)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure

In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte IDs, you would see that each vector takes:
In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte keys, you would see that each vector takes:

* $256 \cdot 4 = 1024$ data bytes
* $16 \cdot 8 = 128$ metadata bytes
Expand Down
14 changes: 8 additions & 6 deletions analyzer.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package hnsw

import "cmp"

// Analyzer is a struct that holds a graph and provides
// methods for analyzing it. It offers no compatibility guarantee
// as the methods of measuring the graph's health with change
// with the implementation.
type Analyzer[T Embeddable] struct {
Graph *Graph[T]
type Analyzer[K cmp.Ordered] struct {
Graph *Graph[K]
}

func (a *Analyzer[T]) Height() int {
Expand All @@ -17,16 +19,16 @@ func (a *Analyzer[T]) Height() int {
func (a *Analyzer[T]) Connectivity() []float64 {
var layerConnectivity []float64
for _, layer := range a.Graph.layers {
if len(layer.Nodes) == 0 {
if len(layer.nodes) == 0 {
continue
}

var sum float64
for _, node := range layer.Nodes {
for _, node := range layer.nodes {
sum += float64(len(node.neighbors))
}

layerConnectivity = append(layerConnectivity, sum/float64(len(layer.Nodes)))
layerConnectivity = append(layerConnectivity, sum/float64(len(layer.nodes)))
}

return layerConnectivity
Expand All @@ -36,7 +38,7 @@ func (a *Analyzer[T]) Connectivity() []float64 {
func (a *Analyzer[T]) Topography() []int {
var topography []int
for _, layer := range a.Graph.layers {
topography = append(topography, len(layer.Nodes))
topography = append(topography, len(layer.nodes))
}
return topography
}
82 changes: 47 additions & 35 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package hnsw

import (
"bufio"
"cmp"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -43,6 +44,16 @@ func binaryRead(r io.Reader, data interface{}) (int, error) {
*v = string(s)
return len(s), err

case *[]float32:
var ln int
_, err := binaryRead(r, &ln)
if err != nil {
return 0, err
}

*v = make([]float32, ln)
return binary.Size(*v), binary.Read(r, byteOrder, *v)

case io.ReaderFrom:
n, err := v.ReadFrom(r)
return int(n), err
Expand Down Expand Up @@ -73,6 +84,12 @@ func binaryWrite(w io.Writer, data any) (int, error) {
}

return n + n2, nil
case []float32:
n, err := binaryWrite(w, len(v))
if err != nil {
return n, err
}
return n + binary.Size(v), binary.Write(w, byteOrder, v)

default:
sz := binary.Size(data)
Expand Down Expand Up @@ -113,7 +130,7 @@ const encodingVersion = 1
// Export writes the graph to a writer.
//
// T must implement io.WriterTo.
func (h *Graph[T]) Export(w io.Writer) error {
func (h *Graph[K]) Export(w io.Writer) error {
distFuncName, ok := distanceFuncToName(h.Distance)
if !ok {
return fmt.Errorf("distance function %v must be registered with RegisterDistanceFunc", h.Distance)
Expand All @@ -134,24 +151,20 @@ func (h *Graph[T]) Export(w io.Writer) error {
return fmt.Errorf("encode number of layers: %w", err)
}
for _, layer := range h.layers {
_, err = binaryWrite(w, len(layer.Nodes))
_, err = binaryWrite(w, len(layer.nodes))
if err != nil {
return fmt.Errorf("encode number of nodes: %w", err)
}
for _, node := range layer.Nodes {
_, err = binaryWrite(w, node.Point)
for _, node := range layer.nodes {
_, err = multiBinaryWrite(w, node.Key, node.Value, len(node.neighbors))
if err != nil {
return fmt.Errorf("encode node point: %w", err)
}

if _, err = binaryWrite(w, len(node.neighbors)); err != nil {
return fmt.Errorf("encode number of neighbors: %w", err)
return fmt.Errorf("encode node data: %w", err)
}

for neighbor := range node.neighbors {
_, err = binaryWrite(w, neighbor)
if err != nil {
return fmt.Errorf("encode neighbor %q: %w", neighbor, err)
return fmt.Errorf("encode neighbor %v: %w", neighbor, err)
}
}
}
Expand All @@ -164,7 +177,7 @@ func (h *Graph[T]) Export(w io.Writer) error {
// T must implement io.ReaderFrom.
// The imported graph does not have to match the exported graph's parameters (except for
// dimensionality). The graph will converge onto the new parameters.
func (h *Graph[T]) Import(r io.Reader) error {
func (h *Graph[K]) Import(r io.Reader) error {
var (
version int
dist string
Expand Down Expand Up @@ -195,55 +208,54 @@ func (h *Graph[T]) Import(r io.Reader) error {
return err
}

h.layers = make([]*layer[T], nLayers)
h.layers = make([]*layer[K], nLayers)
for i := 0; i < nLayers; i++ {
var nNodes int
_, err = binaryRead(r, &nNodes)
if err != nil {
return err
}

nodes := make(map[string]*layerNode[T], nNodes)
nodes := make(map[K]*layerNode[K], nNodes)
for j := 0; j < nNodes; j++ {
var point T
_, err = binaryRead(r, &point)
if err != nil {
return fmt.Errorf("decoding node %d: %w", j, err)
}

var key K
var vec Vector
var nNeighbors int
_, err = binaryRead(r, &nNeighbors)
_, err = multiBinaryRead(r, &key, &vec, &nNeighbors)
if err != nil {
return fmt.Errorf("decoding number of neighbors for node %d: %w", j, err)
return fmt.Errorf("decoding node %d: %w", j, err)
}

neighbors := make([]string, nNeighbors)
neighbors := make([]K, nNeighbors)
for k := 0; k < nNeighbors; k++ {
var neighbor string
var neighbor K
_, err = binaryRead(r, &neighbor)
if err != nil {
return fmt.Errorf("decoding neighbor %d for node %d: %w", k, j, err)
}
neighbors[k] = neighbor
}

node := &layerNode[T]{
Point: point,
neighbors: make(map[string]*layerNode[T]),
node := &layerNode[K]{
Node: Node[K]{
Key: key,
Value: vec,
},
neighbors: make(map[K]*layerNode[K]),
}

nodes[point.ID()] = node
nodes[key] = node
for _, neighbor := range neighbors {
node.neighbors[neighbor] = nil
}
}
// Fill in neighbor pointers
for _, node := range nodes {
for id := range node.neighbors {
node.neighbors[id] = nodes[id]
for key := range node.neighbors {
node.neighbors[key] = nodes[key]
}
}
h.layers[i] = &layer[T]{Nodes: nodes}
h.layers[i] = &layer[K]{nodes: nodes}
}

return nil
Expand All @@ -253,8 +265,8 @@ func (h *Graph[T]) Import(r io.Reader) error {
// changes to a file upon calls to Save. It is more convenient
// but less powerful than calling Graph.Export and Graph.Import
// directly.
type SavedGraph[T Embeddable] struct {
*Graph[T]
type SavedGraph[K cmp.Ordered] struct {
*Graph[K]
Path string
}

Expand All @@ -265,7 +277,7 @@ type SavedGraph[T Embeddable] struct {
//
// It does not hold open a file descriptor, so SavedGraph can be forgotten
// without ever calling Save.
func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {
func LoadSavedGraph[K cmp.Ordered](path string) (*SavedGraph[K], error) {
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
if err != nil {
return nil, err
Expand All @@ -276,15 +288,15 @@ func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {
return nil, err
}

g := NewGraph[T]()
g := NewGraph[K]()
if info.Size() > 0 {
err = g.Import(bufio.NewReader(f))
if err != nil {
return nil, fmt.Errorf("import: %w", err)
}
}

return &SavedGraph[T]{Graph: g, Path: path}, nil
return &SavedGraph[K]{Graph: g, Path: path}, nil
}

// Save writes the graph to the file.
Expand Down
Loading
Loading