Skip to content

Commit

Permalink
Fix the vector tests to match the final PR (#440)
Browse files Browse the repository at this point in the history
  • Loading branch information
msullivan authored Jun 9, 2023
1 parent 874a071 commit c517e72
Showing 1 changed file with 17 additions and 27 deletions.
44 changes: 17 additions & 27 deletions tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,106 +36,96 @@ def setUp(self):

if not self.client.query_required_single('''
select exists (
select sys::ExtensionPackage filter .name = 'vector'
select sys::ExtensionPackage filter .name = 'pgvector'
)
'''):
self.skipTest("feature not implemented")

self.client.execute('''
create extension vector version '1.0'
create extension pgvector;
''')

def tearDown(self):
try:
self.client.execute('''
drop extension vector version '1.0'
drop extension pgvector;
''')
finally:
super().tearDown()

async def test_vector_01(self):
# if not self.client.query_required_single('''
# select exists (
# select sys::ExtensionPackage filter .name = 'vector'
# )
# '''):
# self.skipTest("feature not implemented")

# self.client.execute('''
# create extension vector version '1.0'
# ''')

val = self.client.query_single('''
select <vector::vector>'[1.5,2.0,3.8]'
select <ext::pgvector::vector>[1.5,2.0,3.8]
''')
self.assertTrue(isinstance(val, array.array))
self.assertEqual(val, array.array('f', [1.5, 2.0, 3.8]))

val = self.client.query_single(
'''
select <str><vector::vector>$0
select <json><ext::pgvector::vector>$0
''',
[3.0, 9.0, -42.5],
)
self.assertEqual(val, '[3,9,-42.5]')
self.assertEqual(val, '[3, 9, -42.5]')

val = self.client.query_single(
'''
select <str><vector::vector>$0
select <json><ext::pgvector::vector>$0
''',
array.array('f', [3.0, 9.0, -42.5])
)
self.assertEqual(val, '[3,9,-42.5]')
self.assertEqual(val, '[3, 9, -42.5]')

val = self.client.query_single(
'''
select <str><vector::vector>$0
select <json><ext::pgvector::vector>$0
''',
array.array('i', [1, 2, 3]),
)
self.assertEqual(val, '[1,2,3]')
self.assertEqual(val, '[1, 2, 3]')

# Test that the fast-path works: if the encoder tries to
# call __getitem__ on this brokenarray, it will fail.
val = self.client.query_single(
'''
select <str><vector::vector>$0
select <json><ext::pgvector::vector>$0
''',
brokenarray('f', [3.0, 9.0, -42.5])
)
self.assertEqual(val, '[3,9,-42.5]')
self.assertEqual(val, '[3, 9, -42.5]')

# I don't think it's worth adding a dependency to test this,
# but this works too:
# import numpy as np
# val = self.client.query_single(
# '''
# select <str><vector::vector>$0
# select <json><ext::pgvector::vector>$0
# ''',
# np.asarray([3.0, 9.0, -42.5], dtype=np.float32),
# )
# self.assertEqual(val, '[3,9,-42.5]')

# Some sad path tests
with self.assertRaises(edgedb.InvalidArgumentError):
self.client.query_single(
'''
select <vector::vector>$0
select <ext::pgvector::vector>$0
''',
[3.0, None, -42.5],
)

with self.assertRaises(edgedb.InvalidArgumentError):
self.client.query_single(
'''
select <vector::vector>$0
select <ext::pgvector::vector>$0
''',
[3.0, 'x', -42.5],
)

with self.assertRaises(edgedb.InvalidArgumentError):
self.client.query_single(
'''
select <vector::vector>$0
select <ext::pgvector::vector>$0
''',
'foo',
)

0 comments on commit c517e72

Please sign in to comment.