Skip to content

Commit

Permalink
Add PyTorch and NumPy ANN backends, close #468, close #469
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed May 3, 2023
1 parent e9d18d0 commit 85125c2
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/python/txtai/ann/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
from .factory import ANNFactory
from .faiss import Faiss
from .hnsw import HNSW
from .numpy import NumPy
from .torch import Torch
6 changes: 6 additions & 0 deletions src/python/txtai/ann/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .annoy import Annoy
from .faiss import Faiss
from .hnsw import HNSW
from .numpy import NumPy
from .torch import Torch


class ANNFactory:
Expand Down Expand Up @@ -37,6 +39,10 @@ def create(config):
ann = Faiss(config)
elif backend == "hnsw":
ann = HNSW(config)
elif backend == "numpy":
ann = NumPy(config)
elif backend == "torch":
ann = Torch(config)
else:
ann = ANNFactory.resolve(backend, config)

Expand Down
61 changes: 61 additions & 0 deletions src/python/txtai/ann/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
NumPy module
"""

import pickle

import numpy as np

from .. import __pickle__

from .base import ANN


class NumPy(ANN):
"""
Builds an ANN index backed by a NumPy array.
"""

def load(self, path):
# Load array from file
with open(path, "rb") as handle:
self.backend = pickle.load(handle)

def index(self, embeddings):
# Create index
self.backend = embeddings

# Add id offset and index build metadata
self.config["offset"] = embeddings.shape[0]

def append(self, embeddings):
new = embeddings.shape[0]

# Append new data to array
self.backend = np.concatenate((self.backend, embeddings), axis=0)

# Update id offset
self.config["offset"] += new

def delete(self, ids):
# Filter any index greater than size of array
ids = [x for x in ids if x < self.backend.shape[0]]

# Clear specified ids
self.backend[ids] = np.zeros((len(ids), self.backend.shape[1]))

def search(self, queries, limit):
# Dot product on normalized vectors is equal to cosine similarity
scores = np.dot(queries, self.backend.T).tolist()

# Add index and sort desc based on score
return [sorted(enumerate(score), key=lambda x: x[1], reverse=True)[:limit] for score in scores]

def count(self):
# Get count of non-zero rows (ignores deleted rows)
return self.backend[~np.all(self.backend == 0, axis=1)].shape[0]

def save(self, path):
# Save array to file
with open(path, "wb") as handle:
pickle.dump(self.backend, handle, protocol=__pickle__)
80 changes: 80 additions & 0 deletions src/python/txtai/ann/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
PyTorch module
"""

import pickle

import numpy as np
import torch

from .. import __pickle__

from .base import ANN


class Torch(ANN):
"""
Builds an ANN index backed by a PyTorch array.
"""

def load(self, path):
# Load array from file
with open(path, "rb") as handle:
self.backend = self.tensor(pickle.load(handle))

def index(self, embeddings):
# Create index
self.backend = self.tensor(embeddings)

# Add id offset and index build metadata
self.config["offset"] = embeddings.shape[0]

def append(self, embeddings):
new = embeddings.shape[0]

# Append new data to array
self.backend = torch.cat((self.backend, self.tensor(embeddings)), 0)

# Update id offset
self.config["offset"] += new

def delete(self, ids):
# Filter any index greater than size of array
ids = [x for x in ids if x < self.backend.shape[0]]

# Clear specified ids
self.backend[ids] = self.tensor(torch.zeros((len(ids), self.backend.shape[1])))

def search(self, queries, limit):
# Dot product on normalized vectors is equal to cosine similarity
scores = torch.mm(self.tensor(queries), self.backend.T).tolist()

# Add index and sort desc based on score
return [sorted(enumerate(score), key=lambda x: x[1], reverse=True)[:limit] for score in scores]

def count(self):
# Get count of non-zero rows (ignores deleted rows)
return self.backend[~torch.all(self.backend == 0, axis=1)].shape[0]

def save(self, path):
# Save array to file
with open(path, "wb") as handle:
pickle.dump(self.backend, handle, protocol=__pickle__)

def tensor(self, array):
"""
Loads array as a Tensor. Loads to GPU device, if available.
Args:
array: data array
Returns:
Tensor
"""

# Convert array to Tensor
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)

# Load to GPU device, if available
return array.cuda() if torch.cuda.is_available() else array
14 changes: 14 additions & 0 deletions test/python/testann.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ def testNotImplemented(self):
self.assertRaises(NotImplementedError, ann.count)
self.assertRaises(NotImplementedError, ann.save, None)

def testNumPy(self):
"""
Test NumPy backend
"""

self.runTests("numpy")

def testTorch(self):
"""
Test Torch backend
"""

self.runTests("torch")

def runTests(self, name, params=None, update=True):
"""
Runs a series of standard backend tests.
Expand Down

0 comments on commit 85125c2

Please sign in to comment.