From ffb7126b6c8be2b7a875762d25913b41034048b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lanie=20F?= <73828657+melanie-fressard@users.noreply.github.com> Date: Fri, 24 Nov 2023 16:24:45 +0000 Subject: [PATCH] issue #31 - correction of lint test --- tests/test_db_api.py | 40 ++++++++-------------------------------- tests/test_db_crawler.py | 3 +-- tests/test_db_utils.py | 4 +--- 3 files changed, 10 insertions(+), 37 deletions(-) diff --git a/tests/test_db_api.py b/tests/test_db_api.py index 4cc8b23..505df26 100644 --- a/tests/test_db_api.py +++ b/tests/test_db_api.py @@ -3,51 +3,27 @@ import json import ailab.db as db -import ailab.db.api as api - -MATCH_THRESHOLD = 0.5 -MATCH_COUNT = 10 +import tests.testing_utils as test class TestDBAPI(unittest.TestCase): """Test the database functions""" def setUp(self): self.connection = db.connect_db() self.cursor = db.cursor(self.connection) - # Refresh materialized view - self.cursor.execute("REFRESH MATERIALIZED VIEW default_chunk") - self.connection.commit() def tearDown(self): self.connection.rollback() self.connection.close() - def test_match_documents_text_query(self): - with db.cursor(self.connection) as cursor: - docs = api.match_documents_from_text_query( - cursor, - 'what are the cooking temperatures for e.coli?') - self.connection.rollback() - self.assertEqual(len(docs), 10) - - # obsoleted by weighted search - # def test_president_of_cfia(self): - # with db.cursor(self.connection) as cursor: - # docs = api.match_documents_from_text_query( - # cursor, 'who is the president of the CFIA?') - # self.connection.rollback() - # self.assertEqual( - # docs[0]['title'], - # 'Dr. Harpreet S. Kochhar - Canadian Food Inspection Agency') - def test_weighted_search(self): with open('tests/embeddings/president.json') as f: embeddings = json.load(f) query = 'who is the president of the CFIA?' weights = json.dumps( - {'similarity': 0.6, 'recency': 0.2, 'traffic': 0.0, 'current': 0.1}) + {'similarity': 1.0, 'typicality': 0.2, 'recency': 1.0, 'traffic': 1.0, 'current': 0.5}) self.cursor.execute( "SELECT * FROM search(%s, %s::vector, %s::float, %s::integer, %s::jsonb)", ( - query, embeddings, MATCH_THRESHOLD, MATCH_COUNT, weights)) + query, embeddings, test.MATCH_THRESHOLD, test.MATCH_COUNT, weights)) results = self.cursor.fetchall() result = results[0]['search'] self.assertEqual( @@ -61,17 +37,17 @@ def test_weighted_search(self): self.assertEqual(result[0]['query'], query) result_embedding = result[0]['embedding'] self.assertAlmostEqual(result_embedding[0], embeddings[0]) - self.assertEqual(len(result[0]['result']), MATCH_COUNT) + self.assertEqual(len(result[0]['result']), test.MATCH_COUNT) def test_weighted_search_with_empty_query(self): - weights = json.dumps({ 'recency': 0.4, 'traffic': 0.4, 'current': 0.2}) + weights = json.dumps({'similarity': 1.0, 'typicality': 0.2, 'recency': 1.0, 'traffic': 1.0, 'current': 0.5}) self.cursor.execute( "SELECT * FROM search(%s, %s::vector, %s::float, %s::integer, %s::jsonb)", ( - None, None, MATCH_THRESHOLD, MATCH_COUNT, weights)) + None, None, test.MATCH_THRESHOLD, test.MATCH_COUNT, weights)) result = self.cursor.fetchall()[0]['search'] - self.assertEqual(len(result), MATCH_COUNT, "Should return 10 results") + self.assertEqual(len(result), test.MATCH_COUNT, "Should return 10 results") urls = dict([(r['url'], True) for r in result]) self.assertEqual( len(urls.keys()), - MATCH_COUNT, + test.MATCH_COUNT, "All urls should be unique") diff --git a/tests/test_db_crawler.py b/tests/test_db_crawler.py index 2eb9df5..891b424 100644 --- a/tests/test_db_crawler.py +++ b/tests/test_db_crawler.py @@ -3,8 +3,7 @@ import ailab.db as db import ailab.db.crawler as crawler - -import testing_utils as test +import tests.testing_utils as test class TestDBCrawler(unittest.TestCase): """Test the database functions""" diff --git a/tests/test_db_utils.py b/tests/test_db_utils.py index ba71638..b641278 100644 --- a/tests/test_db_utils.py +++ b/tests/test_db_utils.py @@ -1,9 +1,7 @@ """test database functions""" import unittest - -import testing_utils as test import ailab.db as db - +import tests.testing_utils as test class TestDBUtils(unittest.TestCase): """Test the database functions"""