Skip to content

Commit

Permalink
issue #31 - correction of lint test
Browse files Browse the repository at this point in the history
  • Loading branch information
melanie-fressard authored Nov 24, 2023
1 parent 27bec9f commit ffb7126
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 37 deletions.
40 changes: 8 additions & 32 deletions tests/test_db_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,27 @@
import json

import ailab.db as db
import ailab.db.api as api

MATCH_THRESHOLD = 0.5
MATCH_COUNT = 10
import tests.testing_utils as test

class TestDBAPI(unittest.TestCase):
"""Test the database functions"""
def setUp(self):
self.connection = db.connect_db()
self.cursor = db.cursor(self.connection)
# Refresh materialized view
self.cursor.execute("REFRESH MATERIALIZED VIEW default_chunk")
self.connection.commit()

def tearDown(self):
self.connection.rollback()
self.connection.close()

def test_match_documents_text_query(self):
with db.cursor(self.connection) as cursor:
docs = api.match_documents_from_text_query(
cursor,
'what are the cooking temperatures for e.coli?')
self.connection.rollback()
self.assertEqual(len(docs), 10)

# obsoleted by weighted search
# def test_president_of_cfia(self):
# with db.cursor(self.connection) as cursor:
# docs = api.match_documents_from_text_query(
# cursor, 'who is the president of the CFIA?')
# self.connection.rollback()
# self.assertEqual(
# docs[0]['title'],
# 'Dr. Harpreet S. Kochhar - Canadian Food Inspection Agency')

def test_weighted_search(self):
with open('tests/embeddings/president.json') as f:
embeddings = json.load(f)
query = 'who is the president of the CFIA?'
weights = json.dumps(
{'similarity': 0.6, 'recency': 0.2, 'traffic': 0.0, 'current': 0.1})
{'similarity': 1.0, 'typicality': 0.2, 'recency': 1.0, 'traffic': 1.0, 'current': 0.5})
self.cursor.execute(
"SELECT * FROM search(%s, %s::vector, %s::float, %s::integer, %s::jsonb)", (
query, embeddings, MATCH_THRESHOLD, MATCH_COUNT, weights))
query, embeddings, test.MATCH_THRESHOLD, test.MATCH_COUNT, weights))
results = self.cursor.fetchall()
result = results[0]['search']
self.assertEqual(
Expand All @@ -61,17 +37,17 @@ def test_weighted_search(self):
self.assertEqual(result[0]['query'], query)
result_embedding = result[0]['embedding']
self.assertAlmostEqual(result_embedding[0], embeddings[0])
self.assertEqual(len(result[0]['result']), MATCH_COUNT)
self.assertEqual(len(result[0]['result']), test.MATCH_COUNT)

def test_weighted_search_with_empty_query(self):
weights = json.dumps({ 'recency': 0.4, 'traffic': 0.4, 'current': 0.2})
weights = json.dumps({'similarity': 1.0, 'typicality': 0.2, 'recency': 1.0, 'traffic': 1.0, 'current': 0.5})
self.cursor.execute(
"SELECT * FROM search(%s, %s::vector, %s::float, %s::integer, %s::jsonb)", (
None, None, MATCH_THRESHOLD, MATCH_COUNT, weights))
None, None, test.MATCH_THRESHOLD, test.MATCH_COUNT, weights))
result = self.cursor.fetchall()[0]['search']
self.assertEqual(len(result), MATCH_COUNT, "Should return 10 results")
self.assertEqual(len(result), test.MATCH_COUNT, "Should return 10 results")
urls = dict([(r['url'], True) for r in result])
self.assertEqual(
len(urls.keys()),
MATCH_COUNT,
test.MATCH_COUNT,
"All urls should be unique")
3 changes: 1 addition & 2 deletions tests/test_db_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import ailab.db as db
import ailab.db.crawler as crawler

import testing_utils as test
import tests.testing_utils as test

class TestDBCrawler(unittest.TestCase):
"""Test the database functions"""
Expand Down
4 changes: 1 addition & 3 deletions tests/test_db_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""test database functions"""
import unittest

import testing_utils as test
import ailab.db as db

import tests.testing_utils as test

class TestDBUtils(unittest.TestCase):
"""Test the database functions"""
Expand Down

0 comments on commit ffb7126

Please sign in to comment.