From 662ee41933b817056b73bb2947731772c87fbb91 Mon Sep 17 00:00:00 2001 From: thejumpman2323 Date: Wed, 19 Jul 2023 15:56:14 +0530 Subject: [PATCH 1/4] Add unit tests for mongo queries --- superduperdb/datalayer/mongodb/query.py | 9 ++ .../datalayer/mongodb/test_queries.py | 147 ++++++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 tests/unittests/datalayer/mongodb/test_queries.py diff --git a/superduperdb/datalayer/mongodb/query.py b/superduperdb/datalayer/mongodb/query.py index af639394f..b92ceb559 100644 --- a/superduperdb/datalayer/mongodb/query.py +++ b/superduperdb/datalayer/mongodb/query.py @@ -331,6 +331,9 @@ class FindOne(SelectOne): type_id: t.Literal['mongodb.FindOne'] = 'mongodb.FindOne' + def add_fold(self, fold: str) -> 'Select': + raise NotImplementedError + def __call__(self, db): if self.collection is not None: return SuperDuperCursor.wrap_document( @@ -361,6 +364,12 @@ class Aggregate(Select): type_id: t.Literal['mongodb.Aggregate'] = 'mongodb.Aggregate' + def add_fold(self, fold: str) -> 'Select': + raise NotImplementedError + + def is_trivial(self) -> bool: + raise NotImplementedError + @property def select_ids(self) -> 'Select': raise NotImplementedError diff --git a/tests/unittests/datalayer/mongodb/test_queries.py b/tests/unittests/datalayer/mongodb/test_queries.py new file mode 100644 index 000000000..f6c2767ad --- /dev/null +++ b/tests/unittests/datalayer/mongodb/test_queries.py @@ -0,0 +1,147 @@ +import PIL.PngImagePlugin +import pytest +import torch + +from superduperdb.core.documents import Document +from superduperdb.datalayer.mongodb.query import Collection + +n_data_points = 250 + +IMAGE_URL = 'https://www.superduperdb.com/logos/white.png' + + +def test_delete_many(random_data): + r = random_data.execute(Collection(name='documents').find_one()) + random_data.execute(Collection(name='documents').delete_many({'_id': r['_id']})) + with pytest.raises(StopIteration): + next(random_data.execute(Collection(name='documents').find({'_id': r['_id']}))) + + +def test_replace(random_data): + r = next(random_data.execute(Collection(name='documents').find())) + x = torch.randn(32) + t = random_data.encoders['torch.float32[32]'] + r['x'] = t(x) + random_data.execute( + Collection(name='documents').replace_one( + {'_id': r['_id']}, + r, + ) + ) + + +def test_insert_from_uris(empty, image_type): + to_insert = [ + Document( + { + 'item': { + '_content': { + 'uri': IMAGE_URL, + 'encoder': 'pil_image', + } + }, + 'other': { + 'item': { + '_content': { + 'uri': IMAGE_URL, + 'encoder': 'pil_image', + } + } + }, + } + ) + for _ in range(2) + ] + empty.execute(Collection(name='documents').insert_many(to_insert)) + r = empty.execute(Collection(name='documents').find_one()) + assert isinstance(r['item'].x, PIL.PngImagePlugin.PngImageFile) + assert isinstance(r['other']['item'].x, PIL.PngImagePlugin.PngImageFile) + + +def test_update_many(random_data, a_watcher): + to_update = torch.randn(32) + t = random_data.encoders['torch.float32[32]'] + random_data.execute( + Collection(name='documents').update_many( + {}, Document({'$set': {'x': t(to_update)}}) + ) + ) + cur = random_data.execute(Collection(name='documents').find()) + r = next(cur) + s = next(cur) + + assert all(r['x'].x == to_update) + assert all(s['x'].x == to_update) + assert ( + r['_outputs']['x']['linear_a'].x.tolist() + == s['_outputs']['x']['linear_a'].x.tolist() + ) + + +def test_insert_many(random_data, a_watcher, an_update): + random_data.execute(Collection(name='documents').insert_many(an_update)) + r = next(random_data.execute(Collection(name='documents').find({'update': True}))) + assert 'linear_a' in r['_outputs']['x'] + assert ( + len(list(random_data.execute(Collection(name='documents').find()))) + == n_data_points + 10 + ) + + +def test_like(with_vector_index): + db = with_vector_index + r = db.execute(Collection(name='documents').find_one()) + query = Collection(name='documents').like( + r=Document({'x': r['x']}), + vector_index='test_vector_search', + ) + s = next(db.execute(query)) + assert r['_id'] == s['_id'] + + +def test_insert_one(random_data, a_watcher, a_single_insert): + out, _ = random_data.execute( + Collection(name='documents').insert_one(a_single_insert) + ) + r = random_data.execute( + Collection(name='documents').find({'_id': out.inserted_ids[0]}) + ) + docs = list(r) + assert docs[0]['x'].x.tolist() == a_single_insert['x'].x.tolist() + + +def test_delete_one(random_data): + r = random_data.execute(Collection(name='documents').find_one()) + random_data.execute(Collection(name='documents').delete_one({'_id': r['_id']})) + with pytest.raises(StopIteration): + next(random_data.execute(Collection(name='documents').find({'_id': r['_id']}))) + + +def test_find(random_data): + r = random_data.execute(Collection(name='documents').find().limit(1)) + assert len(list(r)) == 1 + + +def test_find_one(random_data): + r = random_data.execute(Collection(name='documents').find_one()) + assert isinstance(r, Document) + + +def test_aggregate(random_data): + r = random_data.execute( + Collection(name='documents').aggregate([{'$sample': {'size': 1}}]) + ) + assert len(list(r)) == 1 + + +def test_replace_one(random_data): + new_x = torch.randn(32) + t = random_data.encoders['torch.float32[32]'] + r = random_data.execute(Collection(name='documents').find_one()) + random_data.execute( + Collection(name='documents').replace_one( + {'_id': r['_id']}, Document({'x': t(new_x)}) + ) + ) + doc = random_data.execute(Collection(name='documents').find_one({'_id': r['_id']})) + assert doc.unpack()['x'].tolist() == new_x.tolist() From 5452d465aaa9bf4201072242610d7bef925049f7 Mon Sep 17 00:00:00 2001 From: thejumpman2323 Date: Wed, 19 Jul 2023 20:49:02 +0530 Subject: [PATCH 2/4] Add tests for vector search --- superduperdb/core/documents.py | 2 +- superduperdb/vector_search/faiss_index.py | 1 - superduperdb/vector_search/table_scan.py | 4 +- tests/unittests/models/test_langchain.py | 3 +- tests/unittests/models/test_torch.py | 4 +- .../vector_search/faiss/test_hashes.py | 27 ------ tests/unittests/vector_search/test_base.py | 87 +++++++++++++++++++ tests/unittests/vector_search/test_faiss.py | 48 ++++++++++ .../vector_search/test_vanillaindex.py | 16 ++++ 9 files changed, 158 insertions(+), 34 deletions(-) create mode 100644 tests/unittests/vector_search/test_base.py create mode 100644 tests/unittests/vector_search/test_faiss.py create mode 100644 tests/unittests/vector_search/test_vanillaindex.py diff --git a/superduperdb/core/documents.py b/superduperdb/core/documents.py index b952f9bda..38345b565 100644 --- a/superduperdb/core/documents.py +++ b/superduperdb/core/documents.py @@ -11,7 +11,7 @@ class Document: that resource to a mix of jsonable content or `bytes` """ - _DEFAULT_ID_KEY = '_id' + _DEFAULT_ID_KEY: str = '_id' def __init__(self, content: t.Dict): self.content = content diff --git a/superduperdb/vector_search/faiss_index.py b/superduperdb/vector_search/faiss_index.py index 09336057d..878644bf5 100644 --- a/superduperdb/vector_search/faiss_index.py +++ b/superduperdb/vector_search/faiss_index.py @@ -46,7 +46,6 @@ def __init__(self, h, index, measure='l2', faiss_index=None): def find_nearest_from_arrays(self, h, n=100): import torch - if isinstance(h, list): h = numpy.array(h).astype('float32') if isinstance(h, torch.Tensor): diff --git a/superduperdb/vector_search/table_scan.py b/superduperdb/vector_search/table_scan.py index 8c2d3ccd5..99e5ff622 100644 --- a/superduperdb/vector_search/table_scan.py +++ b/superduperdb/vector_search/table_scan.py @@ -15,7 +15,7 @@ class VanillaVectorIndex(BaseVectorIndex): name = 'vanilla' - def __init__(self, h, index, measure='cosine'): + def __init__(self, h, index, measure='css'): if isinstance(measure, str): measure = measures[measure] super().__init__(h, index, measure) @@ -49,4 +49,4 @@ def cosine(x, y): return dot(x, y) -measures = {'cosine': cosine, 'dot': dot, 'l2': l2} +measures = {'css': cosine, 'dot': dot, 'l2': l2} diff --git a/tests/unittests/models/test_langchain.py b/tests/unittests/models/test_langchain.py index e9405bae8..b06d68a66 100644 --- a/tests/unittests/models/test_langchain.py +++ b/tests/unittests/models/test_langchain.py @@ -1,6 +1,6 @@ -from langchain import OpenAI import os import numpy + import pytest from superduperdb.models.sentence_transformers.wrapper import SentenceTransformer @@ -18,6 +18,7 @@ @pytest.mark.skipif(SKIP_PAID, reason='don\'t test paid API') def test_db_qa_with_sources_chain(nursery_rhymes): + from langchain import OpenAI nursery_rhymes.add(array(numpy.float32, shape=(1024,))) pl = SentenceTransformer(model_name_or_path='all-MiniLM-L6-v2', encoder='array') nursery_rhymes.add(pl) diff --git a/tests/unittests/models/test_torch.py b/tests/unittests/models/test_torch.py index ca525040f..7a99bb796 100644 --- a/tests/unittests/models/test_torch.py +++ b/tests/unittests/models/test_torch.py @@ -103,11 +103,11 @@ def test_ensemble(si_validation, metric): validation_interval=5, loader_kwargs={'batch_size': 10, 'num_workers': 0}, compute_metrics=VectorSearchPerformance( - measure='cosine', + measure='css', predict_kwargs={'batch_size': 10}, index_key='x', ), - kwargs={'hash_set_cls': Artifact(VanillaVectorIndex), 'measure': 'cosine'}, + kwargs={'hash_set_cls': Artifact(VanillaVectorIndex), 'measure': 'css'}, ) m = TorchModelEnsemble( diff --git a/tests/unittests/vector_search/faiss/test_hashes.py b/tests/unittests/vector_search/faiss/test_hashes.py index 84ddca376..e69de29bb 100644 --- a/tests/unittests/vector_search/faiss/test_hashes.py +++ b/tests/unittests/vector_search/faiss/test_hashes.py @@ -1,27 +0,0 @@ -import os -from scipy.spatial.distance import cdist -import torch -import uuid - -from superduperdb.vector_search.table_scan import VanillaVectorIndex -from superduperdb.vector_search.faiss_index import FaissVectorIndex - -os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' - - -def test_faiss_hash_set(): - x = torch.randn(1000, 32) - ids = [uuid.uuid4() for _ in range(x.shape[0])] - - def l2(x, y): - return -cdist(x, y) - - h1 = FaissVectorIndex(x, ids, 'l2') - h2 = VanillaVectorIndex(x, ids, l2) - - y = torch.randn(32) - - res1, _ = h1.find_nearest_from_array(y) - res2, _ = h2.find_nearest_from_array(y) - - assert res1[0] == res2[0] diff --git a/tests/unittests/vector_search/test_base.py b/tests/unittests/vector_search/test_base.py new file mode 100644 index 000000000..82eadbff4 --- /dev/null +++ b/tests/unittests/vector_search/test_base.py @@ -0,0 +1,87 @@ +import numpy as np +import pytest +import torch +from unittest import mock + +from superduperdb.vector_search.base import BaseVectorIndex, to_numpy + + +class TestBaseVectorIndex: + @pytest.fixture + def base_vector_index(self): + # Create a sample BaseVectorIndex instance for testing + h = np.random.rand(10, 3) # Sample data + index = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] # Sample index + measure = 'euclidean' # Sample measure + return BaseVectorIndex(h, index, measure) + + def test_init(self, base_vector_index): + assert isinstance(base_vector_index.h, np.ndarray) + assert isinstance(base_vector_index.index, list) + assert isinstance(base_vector_index.lookup, dict) + assert isinstance(base_vector_index.measure, str) + + def test_shape(self, base_vector_index): + assert base_vector_index.shape == base_vector_index.h.shape + + @mock.patch('superduperdb.vector_search.base.to_numpy') + def test_find_nearest_from_id(self, mock_to_numpy, base_vector_index): + # Mock to_numpy function to avoid its actual execution + mock_to_numpy.return_value = np.array([[0.1, 0.2, 0.3]]) + + _id = 'a' + n = 100 + with pytest.raises(NotImplementedError): + result = base_vector_index.find_nearest_from_id(_id, n) + + @mock.patch('superduperdb.vector_search.base.to_numpy') + def test_find_nearest_from_ids(self, mock_to_numpy, base_vector_index): + # Mock to_numpy function to avoid its actual execution + mock_to_numpy.return_value = np.array([[0.1, 0.2, 0.3]]) + + _ids = ['a', 'b'] + n = 100 + with pytest.raises(NotImplementedError): + result = base_vector_index.find_nearest_from_ids(_ids, n) + + @mock.patch('superduperdb.vector_search.base.to_numpy') + def test_find_nearest_from_array(self, mock_to_numpy, base_vector_index): + # Mock to_numpy function to avoid its actual execution + mock_to_numpy.return_value = np.array([[0.1, 0.2, 0.3]]) + + h = np.random.rand(1, 3) # Sample array + n = 100 + with pytest.raises(NotImplementedError): + result = base_vector_index.find_nearest_from_array(h, n) + + @mock.patch('superduperdb.vector_search.base.to_numpy') + def test_find_nearest_from_arrays(self, mock_to_numpy, base_vector_index): + # Mock to_numpy function to avoid its actual execution + mock_to_numpy.return_value = np.array([[0.1, 0.2, 0.3]]) + + h = np.random.rand(2, 3) # Sample array + n = 100 + with pytest.raises(NotImplementedError): + base_vector_index.find_nearest_from_arrays(h, n) + + def test_getitem(self, base_vector_index): + with pytest.raises(NotImplementedError): + base_vector_index['item'] + + +def test_to_numpy(): + # Test to_numpy function with different input types + x = np.array([[1, 2, 3], [4, 5, 6]]) + result = to_numpy(x) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, x) + + x = torch.tensor([[1, 2, 3], [4, 5, 6]]) + result = to_numpy(x) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, x.numpy()) + + x = [[1, 2, 3], [4, 5, 6]] + result = to_numpy(x) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array(x)) diff --git a/tests/unittests/vector_search/test_faiss.py b/tests/unittests/vector_search/test_faiss.py new file mode 100644 index 000000000..6ccbffbbf --- /dev/null +++ b/tests/unittests/vector_search/test_faiss.py @@ -0,0 +1,48 @@ +import os +import uuid + +from scipy.spatial.distance import cdist +import torch +import pytest + +from superduperdb.vector_search.table_scan import VanillaVectorIndex +from superduperdb.vector_search.faiss_index import FaissVectorIndex + +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' + + +@pytest.mark.parametrize('metric', ['l2', 'css', 'dot']) +def test_faiss_hash_set(metric): + x = torch.randn(1000, 32) + ids = [uuid.uuid4() for _ in range(x.shape[0])] + + def l2(x, y): + return -cdist(x, y) + + h1 = FaissVectorIndex(x, ids, metric) + h2 = VanillaVectorIndex(x, ids, metric) + + y = torch.randn(32) + + res1, _ = h1.find_nearest_from_array(y) + res2, _ = h2.find_nearest_from_array(y) + + assert res1[0] == res2[0] + +@pytest.mark.skip(reason="Faiss doesn't support batched queries") +def test_faiss_from_arrays(): + x = torch.randn(1000, 4) + ids = [uuid.uuid4() for _ in range(x.shape[0])] + + def l2(x, y): + return -cdist(x, y) + + h1 = FaissVectorIndex(x, ids, 'l2') + h2 = VanillaVectorIndex(x, ids, l2) + + y = torch.randn(2, 4) + + res1, _ = h1.find_nearest_from_arrays(y, 2) + res2, _ = h2.find_nearest_from_arrays(y, 2) + + assert res1[0] == res2[0] diff --git a/tests/unittests/vector_search/test_vanillaindex.py b/tests/unittests/vector_search/test_vanillaindex.py new file mode 100644 index 000000000..15b5c8601 --- /dev/null +++ b/tests/unittests/vector_search/test_vanillaindex.py @@ -0,0 +1,16 @@ +import uuid + +import pytest +import numpy as np + +from superduperdb.vector_search.table_scan import VanillaVectorIndex +from superduperdb.vector_search.table_scan import l2, dot, cosine + +@pytest.mark.parametrize("measure", [l2, dot, cosine]) +def test_vaniila_index(measure): + x = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]]) + ids = [uuid.uuid4() for _ in range(x.shape[0])] + h2 = VanillaVectorIndex(x, ids, measure) + y = np.array([0, 0.5, 0.5]) + res2, _ = h2.find_nearest_from_array(y, 1) + assert res2[0] == ids[0] From d6a15ce8006d25be828759e4092f0f0b1415f300 Mon Sep 17 00:00:00 2001 From: thejumpman2323 Date: Thu, 20 Jul 2023 14:48:16 +0530 Subject: [PATCH 3/4] Add unit tests for models --- superduperdb/core/model.py | 6 +- superduperdb/datalayer/mongodb/query.py | 4 +- superduperdb/models/vanilla/wrapper.py | 12 ++-- superduperdb/vector_search/faiss_index.py | 5 +- superduperdb/vector_search/table_scan.py | 4 +- tests/integration/conftest.py | 6 ++ tests/unittests/models/test_langchain.py | 5 +- tests/unittests/models/test_torch.py | 4 +- tests/unittests/models/test_torch_utils.py | 67 +++++++++++++++++++ tests/unittests/models/test_vanilla.py | 11 +++ tests/unittests/vector_search/test_base.py | 6 +- tests/unittests/vector_search/test_faiss.py | 3 +- tests/unittests/vector_search/test_lancedb.py | 23 +++++-- .../vector_search/test_vanillaindex.py | 1 + 14 files changed, 130 insertions(+), 27 deletions(-) create mode 100644 tests/unittests/models/test_torch_utils.py create mode 100644 tests/unittests/models/test_vanilla.py diff --git a/superduperdb/core/model.py b/superduperdb/core/model.py index 74ef40ba4..82f83c711 100644 --- a/superduperdb/core/model.py +++ b/superduperdb/core/model.py @@ -1,4 +1,3 @@ -import inspect from dask.distributed import Future import dataclasses as dc import typing as t @@ -142,10 +141,7 @@ def predict( return else: - if 'one' in inspect.signature(self._predict).parameters: - return self._predict(X, one=one, **kwargs) - else: - return self._predict(X, **kwargs) + return self._predict(X, **kwargs) @dc.dataclass diff --git a/superduperdb/datalayer/mongodb/query.py b/superduperdb/datalayer/mongodb/query.py index b92ceb559..3509d2c53 100644 --- a/superduperdb/datalayer/mongodb/query.py +++ b/superduperdb/datalayer/mongodb/query.py @@ -498,7 +498,6 @@ class InsertMany(Insert): verbose: bool = True args: t.List = dc.field(default_factory=list) kwargs: t.Dict = dc.field(default_factory=dict) - valid_prob: float = 0.05 encoders: t.List = dc.field(default_factory=list) type_id: t.Literal['mongodb.InsertMany'] = 'mongodb.InsertMany' @@ -515,11 +514,12 @@ def select_using_ids(self, ids): return Find(collection=self.collection, args=[{'_id': {'$in': ids}}]) def __call__(self, db): + valid_prob = self.kwargs.get('valid_prob', 0.5) for e in self.encoders: db.add(e) documents = [r.encode() for r in self.documents] for r in documents: - if random.random() < self.valid_prob: + if random.random() < valid_prob: r['_fold'] = 'valid' else: r['_fold'] = 'train' diff --git a/superduperdb/models/vanilla/wrapper.py b/superduperdb/models/vanilla/wrapper.py index a5a4a66f3..bbb81bbba 100644 --- a/superduperdb/models/vanilla/wrapper.py +++ b/superduperdb/models/vanilla/wrapper.py @@ -3,18 +3,22 @@ class Function(Model): + vanilla = True + def predict_one(self, x, **kwargs): - return self.object.a(x, **kwargs) + return self.object.artifact(x, **kwargs) - def _predict(self, docs, num_workers=0): + def _predict(self, docs, num_workers=0, **kwargs): outputs = [] + if not isinstance(docs, list): + return self.predict_one(docs) if num_workers: pool = multiprocessing.Pool(processes=num_workers) - for r in pool.map(self.object, docs): + for r in pool.map(self.object.artifact, docs): outputs.append(r) pool.close() pool.join() else: for r in docs: - outputs.append(self.object(r)) + outputs.append(self.object.artifact(r)) return outputs diff --git a/superduperdb/vector_search/faiss_index.py b/superduperdb/vector_search/faiss_index.py index 878644bf5..b730d28b8 100644 --- a/superduperdb/vector_search/faiss_index.py +++ b/superduperdb/vector_search/faiss_index.py @@ -27,13 +27,13 @@ def __init__(self, h, index, measure='l2', faiss_index=None): super().__init__(h, index, measure) self.h = self.h.astype('float32') if faiss_index is None: - if measure == 'css': + if measure == 'cosine': self.h = self.h / (numpy.linalg.norm(self.h, axis=1)[:, None]) if measure == 'l2': faiss_index = faiss.index_factory( self.h.shape[1], 'Flat', faiss.METRIC_L2 ) - elif measure in {'css', 'dot'}: + elif measure in {'cosine', 'dot'}: faiss_index = faiss.index_factory( self.h.shape[1], 'Flat', faiss.METRIC_INNER_PRODUCT ) @@ -46,6 +46,7 @@ def __init__(self, h, index, measure='l2', faiss_index=None): def find_nearest_from_arrays(self, h, n=100): import torch + if isinstance(h, list): h = numpy.array(h).astype('float32') if isinstance(h, torch.Tensor): diff --git a/superduperdb/vector_search/table_scan.py b/superduperdb/vector_search/table_scan.py index 99e5ff622..8c2d3ccd5 100644 --- a/superduperdb/vector_search/table_scan.py +++ b/superduperdb/vector_search/table_scan.py @@ -15,7 +15,7 @@ class VanillaVectorIndex(BaseVectorIndex): name = 'vanilla' - def __init__(self, h, index, measure='css'): + def __init__(self, h, index, measure='cosine'): if isinstance(measure, str): measure = measures[measure] super().__init__(h, index, measure) @@ -49,4 +49,4 @@ def cosine(x, y): return dot(x, y) -measures = {'css': cosine, 'dot': dot, 'l2': l2} +measures = {'cosine': cosine, 'dot': dot, 'l2': l2} diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 58d8ce603..af2c56e4a 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,4 +1,5 @@ import random +import numpy as np import time from threading import Thread from unittest import mock @@ -32,6 +33,11 @@ as much as possible. This will make it easier to understand the test suite. ''' +# Set the seeds +random.seed(42) +torch.manual_seed(42) +np.random.seed(42) + mongodb_test_config = { 'host': '0.0.0.0', diff --git a/tests/unittests/models/test_langchain.py b/tests/unittests/models/test_langchain.py index b06d68a66..251c74a63 100644 --- a/tests/unittests/models/test_langchain.py +++ b/tests/unittests/models/test_langchain.py @@ -6,7 +6,9 @@ from superduperdb.models.sentence_transformers.wrapper import SentenceTransformer from superduperdb.core.watcher import Watcher from superduperdb.core.vector_index import VectorIndex -from superduperdb.models.langchain.retriever import DBQAWithSourcesChain +from superduperdb.models.langchain.retriever import ( + DBQAWithSourcesChain, +) from superduperdb.datalayer.mongodb.query import Collection from superduperdb.encoders.numpy.array import array @@ -19,6 +21,7 @@ @pytest.mark.skipif(SKIP_PAID, reason='don\'t test paid API') def test_db_qa_with_sources_chain(nursery_rhymes): from langchain import OpenAI + nursery_rhymes.add(array(numpy.float32, shape=(1024,))) pl = SentenceTransformer(model_name_or_path='all-MiniLM-L6-v2', encoder='array') nursery_rhymes.add(pl) diff --git a/tests/unittests/models/test_torch.py b/tests/unittests/models/test_torch.py index 7a99bb796..ca525040f 100644 --- a/tests/unittests/models/test_torch.py +++ b/tests/unittests/models/test_torch.py @@ -103,11 +103,11 @@ def test_ensemble(si_validation, metric): validation_interval=5, loader_kwargs={'batch_size': 10, 'num_workers': 0}, compute_metrics=VectorSearchPerformance( - measure='css', + measure='cosine', predict_kwargs={'batch_size': 10}, index_key='x', ), - kwargs={'hash_set_cls': Artifact(VanillaVectorIndex), 'measure': 'css'}, + kwargs={'hash_set_cls': Artifact(VanillaVectorIndex), 'measure': 'cosine'}, ) m = TorchModelEnsemble( diff --git a/tests/unittests/models/test_torch_utils.py b/tests/unittests/models/test_torch_utils.py new file mode 100644 index 000000000..192c890a9 --- /dev/null +++ b/tests/unittests/models/test_torch_utils.py @@ -0,0 +1,67 @@ +import pytest +import torch +from superduperdb.models.torch.utils import device_of, eval, set_device, to_device + + +@pytest.fixture +def model(): + return torch.nn.Linear(10, 2) + + +def test_device_of_cpu(model): + device = device_of(model) + assert device.type == 'cpu' + + +def test_device_of_cuda(model): + if torch.cuda.is_available(): + model.to(torch.device('cuda')) + device = device_of(model) + assert device == 'cuda' + + +def test_eval_context_manager(model): + with eval(model): + assert not model.training + + +def test_set_device_context_manager(model): + device_before = device_of(model) + if torch.cuda.is_available(): + with set_device(model, torch.device('cuda')): + device_after = device_of(model) + assert device_after == 'cuda' + assert device_of(model) == device_before + + +def test_to_device_tensor(model): + tensor = torch.tensor([1, 2, 3]) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + tensor_device = to_device(tensor, device) + assert tensor_device.device == device + + +def test_to_device_nested_list(model): + nested_list = [ + torch.tensor([1, 2, 3]), + [torch.tensor([4, 5]), torch.tensor([6, 7])], + ] + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + nested_list_device = to_device(nested_list, device) + for item in nested_list_device: + if isinstance(item, list): + assert all(i.device == device for i in item) + else: + assert item.device == device + + +def test_to_device_nested_dict(model): + nested_dict = {'a': torch.tensor([1, 2, 3]), 'b': {'c': torch.tensor([4, 5])}} + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + nested_dict_device = to_device(nested_dict, device) + for item in nested_dict_device.values(): + if isinstance(item, dict): + for i in item.values(): + assert i.device == device + else: + assert item.device == device diff --git a/tests/unittests/models/test_vanilla.py b/tests/unittests/models/test_vanilla.py new file mode 100644 index 000000000..4d94f9504 --- /dev/null +++ b/tests/unittests/models/test_vanilla.py @@ -0,0 +1,11 @@ +from superduperdb.models.vanilla.wrapper import Function + + +def test_function_predict_one(): + function = Function(object=lambda x: x, identifier='test') + assert function.predict(1) == 1 + + +def test_function_predict(): + function = Function(object=lambda x: x, identifier='test') + assert function.predict([1, 1]) == [1, 1] diff --git a/tests/unittests/vector_search/test_base.py b/tests/unittests/vector_search/test_base.py index 82eadbff4..e69682627 100644 --- a/tests/unittests/vector_search/test_base.py +++ b/tests/unittests/vector_search/test_base.py @@ -32,7 +32,7 @@ def test_find_nearest_from_id(self, mock_to_numpy, base_vector_index): _id = 'a' n = 100 with pytest.raises(NotImplementedError): - result = base_vector_index.find_nearest_from_id(_id, n) + base_vector_index.find_nearest_from_id(_id, n) @mock.patch('superduperdb.vector_search.base.to_numpy') def test_find_nearest_from_ids(self, mock_to_numpy, base_vector_index): @@ -42,7 +42,7 @@ def test_find_nearest_from_ids(self, mock_to_numpy, base_vector_index): _ids = ['a', 'b'] n = 100 with pytest.raises(NotImplementedError): - result = base_vector_index.find_nearest_from_ids(_ids, n) + base_vector_index.find_nearest_from_ids(_ids, n) @mock.patch('superduperdb.vector_search.base.to_numpy') def test_find_nearest_from_array(self, mock_to_numpy, base_vector_index): @@ -52,7 +52,7 @@ def test_find_nearest_from_array(self, mock_to_numpy, base_vector_index): h = np.random.rand(1, 3) # Sample array n = 100 with pytest.raises(NotImplementedError): - result = base_vector_index.find_nearest_from_array(h, n) + base_vector_index.find_nearest_from_array(h, n) @mock.patch('superduperdb.vector_search.base.to_numpy') def test_find_nearest_from_arrays(self, mock_to_numpy, base_vector_index): diff --git a/tests/unittests/vector_search/test_faiss.py b/tests/unittests/vector_search/test_faiss.py index 6ccbffbbf..6194a20d8 100644 --- a/tests/unittests/vector_search/test_faiss.py +++ b/tests/unittests/vector_search/test_faiss.py @@ -11,7 +11,7 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' -@pytest.mark.parametrize('metric', ['l2', 'css', 'dot']) +@pytest.mark.parametrize('metric', ['l2', 'cosine', 'dot']) def test_faiss_hash_set(metric): x = torch.randn(1000, 32) ids = [uuid.uuid4() for _ in range(x.shape[0])] @@ -29,6 +29,7 @@ def l2(x, y): assert res1[0] == res2[0] + @pytest.mark.skip(reason="Faiss doesn't support batched queries") def test_faiss_from_arrays(): x = torch.randn(1000, 4) diff --git a/tests/unittests/vector_search/test_lancedb.py b/tests/unittests/vector_search/test_lancedb.py index 6f08ead44..5880e84f9 100644 --- a/tests/unittests/vector_search/test_lancedb.py +++ b/tests/unittests/vector_search/test_lancedb.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock import pyarrow as pa +from superduperdb.vector_search.base import VectorCollectionConfig from superduperdb import CFG from superduperdb.vector_search.base import VectorCollectionItem from superduperdb.misc.config import LanceDB @@ -72,7 +73,6 @@ def test_create_table_new(lance_client): assert table.measure == measure -@pytest.mark.skip(reason="`tantivy` package needs to be installed") def test_add(lance_table): data = [ {"id": "1", "vector": [2, 3]}, @@ -81,10 +81,11 @@ def test_add(lance_table): ] data = [VectorCollectionItem(**d) for d in data] lance_table.add(data) - assert lance_table.get("1") == data[0] + data = lance_table.find_nearest_from_array([2, 3]) + assert data[0].id == 1 -@pytest.mark.skip(reason="`tantivy` package needs to be installed") +@pytest.mark.skip(reason="Not implemented") def test_find_nearest_from_id(lance_table): identifier = "1" limit = 100 @@ -114,7 +115,8 @@ def test_find_nearest_from_array(lance_table): assert result[0].id == 1 -def test_create_schema(): +@pytest.mark.parametrize("measure", ["cosine", "euclidean"]) +def test_create_schema(measure): dimensions = 3 expected_schema = pa.schema( [ @@ -122,6 +124,17 @@ def test_create_schema(): pa.field("id", pa.string()), ] ) - vector_index = LanceVectorIndex(config=CFG.vector_search.type) + vector_index = LanceVectorIndex(config=CFG.vector_search.type, measure=measure) schema = vector_index._create_schema(dimensions) assert schema.equals(expected_schema) + + +def test_vector_index_get_table(): + vector_index = LanceVectorIndex(config=CFG.vector_search.type) + table = vector_index.get_table( + VectorCollectionConfig( + id="1", + dimensions=3, + ) + ) + assert isinstance(table, LanceTable) diff --git a/tests/unittests/vector_search/test_vanillaindex.py b/tests/unittests/vector_search/test_vanillaindex.py index 15b5c8601..e4959da5a 100644 --- a/tests/unittests/vector_search/test_vanillaindex.py +++ b/tests/unittests/vector_search/test_vanillaindex.py @@ -6,6 +6,7 @@ from superduperdb.vector_search.table_scan import VanillaVectorIndex from superduperdb.vector_search.table_scan import l2, dot, cosine + @pytest.mark.parametrize("measure", [l2, dot, cosine]) def test_vaniila_index(measure): x = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]]) From 0e3d95b28e41569b2bf01bea788a5d4de05f5e13 Mon Sep 17 00:00:00 2001 From: thejumpman2323 Date: Thu, 20 Jul 2023 14:48:33 +0530 Subject: [PATCH 4/4] Add test for transformers --- requirements/requirements-test.txt | 9 +++ requirements/requirements.in | 1 + requirements/requirements.txt | 9 +++ superduperdb/core/model.py | 6 +- superduperdb/core/vector_index.py | 1 - superduperdb/models/transformers/wrapper.py | 25 ++++--- tests/integration/conftest.py | 1 - tests/unittests/models/test_transformers.py | 79 +++++++++++++++++++++ 8 files changed, 118 insertions(+), 13 deletions(-) create mode 100644 tests/unittests/models/test_transformers.py diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index efcf787f4..06c5b327e 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -4,6 +4,8 @@ # # pip-compile requirements-test.in # +accelerate==0.21.0 + # via -r requirements.in aiohttp==3.8.4 # via # fsspec @@ -132,6 +134,8 @@ fsspec[http]==2023.6.0 # huggingface-hub # lightning # pytorch-lightning +greenlet==2.0.2 + # via sqlalchemy grpcio==1.49.1 # via # pymilvus @@ -242,6 +246,7 @@ numexpr==2.8.4 numpy==1.24.4 # via # -r requirements.in + # accelerate # langchain # lightning # numexpr @@ -263,6 +268,7 @@ ordered-set==4.1.0 # via deepdiff packaging==23.1 # via + # accelerate # dask # distributed # huggingface-hub @@ -293,6 +299,7 @@ protobuf==4.23.4 # ray psutil==5.9.5 # via + # accelerate # distributed # lightning py==1.11.0 @@ -346,6 +353,7 @@ pytz==2023.3 # pandas pyyaml==6.0.1 # via + # accelerate # dask # distributed # huggingface-hub @@ -449,6 +457,7 @@ toolz==0.12.0 torch==2.0.0 # via # -r requirements.in + # accelerate # lightning # pytorch-lightning # torchmetrics diff --git a/requirements/requirements.in b/requirements/requirements.in index 9a4a00782..42c09bccb 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -26,3 +26,4 @@ ray>=2.4.0 scikit-learn>=1.1.3 torch>=2.0.0,!=2.0.1 transformers>=4.29.1 +accelerate>=0.20.1 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index a75dc29d0..0d1ec5754 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,6 +4,8 @@ # # pip-compile requirements.in # +accelerate==0.21.0 + # via -r requirements.in aiohttp==3.8.4 # via # fsspec @@ -127,6 +129,8 @@ fsspec[http]==2023.6.0 # huggingface-hub # lightning # pytorch-lightning +greenlet==2.0.2 + # via sqlalchemy grpcio==1.49.1 # via # pymilvus @@ -226,6 +230,7 @@ numexpr==2.8.4 numpy==1.24.4 # via # -r requirements.in + # accelerate # langchain # lightning # numexpr @@ -247,6 +252,7 @@ ordered-set==4.1.0 # via deepdiff packaging==23.1 # via + # accelerate # dask # distributed # huggingface-hub @@ -274,6 +280,7 @@ protobuf==4.23.4 # ray psutil==5.9.5 # via + # accelerate # distributed # lightning py==1.11.0 @@ -321,6 +328,7 @@ pytz==2023.3 # pandas pyyaml==6.0.1 # via + # accelerate # dask # distributed # huggingface-hub @@ -417,6 +425,7 @@ toolz==0.12.0 torch==2.0.0 # via # -r requirements.in + # accelerate # lightning # pytorch-lightning # torchmetrics diff --git a/superduperdb/core/model.py b/superduperdb/core/model.py index 82f83c711..8ad332ee5 100644 --- a/superduperdb/core/model.py +++ b/superduperdb/core/model.py @@ -1,4 +1,5 @@ from dask.distributed import Future +import inspect import dataclasses as dc import typing as t @@ -141,7 +142,10 @@ def predict( return else: - return self._predict(X, **kwargs) + if 'one' in inspect.signature(self._predict).parameters: + return self._predict(X, one=one, **kwargs) + else: + return self._predict(X, **kwargs) @dc.dataclass diff --git a/superduperdb/core/vector_index.py b/superduperdb/core/vector_index.py index 7d135aa86..58e65ec02 100644 --- a/superduperdb/core/vector_index.py +++ b/superduperdb/core/vector_index.py @@ -130,7 +130,6 @@ def get_nearest( models, keys = self.models_keys if len(models) != len(keys): raise ValueError(f'len(models={models}) != len(keys={keys})') - within_ids = ids or () if db.db.id_field in like.content: # type: ignore diff --git a/superduperdb/models/transformers/wrapper.py b/superduperdb/models/transformers/wrapper.py index 35f4d9d47..069002775 100644 --- a/superduperdb/models/transformers/wrapper.py +++ b/superduperdb/models/transformers/wrapper.py @@ -31,13 +31,13 @@ def TransformersTrainerConfiguration(identifier: str, *args, **kwargs): class Pipeline(Model): tokenizer: t.Optional[t.Callable] = None - def __post_init__(self, db): + def __post_init__(self): if not self.device: self.device = "cuda" if torch.cuda.is_available() else "cpu" self.object.to(self.device) if not isinstance(self.tokenizer, Artifact): - self.tokenizer = Artifact(_artifact=self.tokenizer) - super().__post_init__(db) + self.tokenizer = Artifact(artifact=self.tokenizer) + super().__post_init__() @property def pipeline(self): @@ -57,7 +57,7 @@ def _get_data( **tokenizer_kwargs, ): tokenizing_function = TokenizingFunction( - self.tokenizer.a, key=X_key, **tokenizer_kwargs + self.tokenizer.artifact, key=X_key, **tokenizer_kwargs ) train_data = query_dataset_factory( select=self.training_select, @@ -104,7 +104,7 @@ def _fit( # type: ignore[override] prefetch_size: int = _DEFAULT_PREFETCH_SIZE, tokenizer_kwargs: t.Dict[str, t.Any] = {}, **kwargs, - ): + ) -> t.Optional[t.Dict[str, t.Any]]: if configuration is not None: self.configuration = configuration if select is not None: @@ -114,6 +114,8 @@ def _fit( # type: ignore[override] if metrics is not None: self.metrics = metrics + evaluate = kwargs.pop('evaluate', True) + if isinstance(X, str): train_data, valid_data = self._get_data( db, @@ -131,18 +133,21 @@ def _fit( # type: ignore[override] eval_dataset=valid_data, **kwargs, ) + evaluation = None try: trainer.train() - evaluation = trainer.evaluate() + if evaluate: + evaluation = trainer.evaluate() except Exception as exc: - log.error(f"Training could not finish :: {exc}") - - return evaluation + log.exception(f"Training could not finish :: {exc}") + raise + else: + return evaluation def _predict_one(self, input: str, **kwargs): tokenized_input = self.tokenizer.a(input, return_tensors='pt').to(self.device) - return self.object.a(**tokenized_input, **kwargs) + return self.object.artifact(**tokenized_input, **kwargs) def _predict(self, input: str, **kwargs): if not isinstance(input, list): diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index af2c56e4a..e2b1c6704 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -35,7 +35,6 @@ # Set the seeds random.seed(42) -torch.manual_seed(42) np.random.seed(42) diff --git a/tests/unittests/models/test_transformers.py b/tests/unittests/models/test_transformers.py new file mode 100644 index 000000000..ab47827c6 --- /dev/null +++ b/tests/unittests/models/test_transformers.py @@ -0,0 +1,79 @@ +import pytest + +from superduperdb.core.documents import Document as D +from superduperdb.datalayer.mongodb.query import Collection + +from superduperdb.models.transformers.wrapper import ( + TransformersTrainerConfiguration, + Pipeline, +) + + +@pytest.fixture(scope="function") +def trainer(random_data): + from transformers import AutoModelForSequenceClassification + from transformers import AutoTokenizer + + data = [ + {'text': 'dummy text 1', 'label': 1}, + {'text': 'dummy text 2', 'label': 0}, + {'text': 'dummy text 1', 'label': 1}, + ] + data = [D(d) for d in data] + random_data.execute(Collection('train_documents').insert_many(data)) + tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") + model = AutoModelForSequenceClassification.from_pretrained( + "distilbert-base-uncased", num_labels=2 + ) + + trainer = Pipeline( + identifier='my-sentiment-analysis', + tokenizer=tokenizer, + object=model, + train_X='text', + train_y='label', + device='cpu', + ) + yield trainer, tokenizer + + +def test_transformer(trainer): + pass + + +def test_tranformers_trainer(trainer, random_data): + trainer, tokenizer = trainer + + from transformers import DataCollatorWithPadding + from superduperdb.core.dataset import Dataset + + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + repo_name = "test-superduperdb-sentiment-analysis" + training_args = TransformersTrainerConfiguration( + identifier=repo_name, + output_dir=repo_name, + learning_rate=2e-5, + per_device_train_batch_size=1, + per_device_eval_batch_size=1, + num_train_epochs=1, + weight_decay=0.01, + save_strategy="epoch", + use_mps_device=False, + ) + trainer.fit( + X='text', + y='label', + db=random_data, + select=Collection('train_documents').find(), + configuration=training_args, + validation_sets=[ + Dataset( + identifier='my-eval', + select=Collection(name='train_documents').find({'_fold': 'valid'}), + ) + ], + data_collator=data_collator, + data_prefetch=False, + tokenizer_kwargs={'truncation': True}, + evaluate=False, + )