diff --git a/ailab/db/crawler/__init__.py b/ailab/db/crawler/__init__.py index bb9ff45..1438890 100644 --- a/ailab/db/crawler/__init__.py +++ b/ailab/db/crawler/__init__.py @@ -45,6 +45,62 @@ def fetch_links(cursor, url): data['destination_urls'] = [r['url'] for r in cursor.fetchall()] return data['destination_urls'] +def get_md5hash(cursor, data): + cursor.execute( + """SELECT md5hash FROM crawl WHERE url = %(url)s + ORDER BY last_updated DESC LIMIT 1""", + data + ) + return cursor.fetchone()['md5hash'] + +def get_chunk_id(cursor, data): + cursor.execute( + """ + WITH e as( + INSERT INTO chunk (title, text_content) + VALUES(%(title)s, %(text_content)s) + ON CONFLICT DO NOTHING + RETURNING id + ) + SELECT id FROM e + UNION ALL + SELECT id FROM chunk WHERE text_content = %(text_content)s + """, + data + ) + row = cursor.fetchone() + return row['id'] if row is not None else None + +def insert_html_content_to_chunk(cursor, data): + cursor.execute( + """ + INSERT INTO html_content_to_chunk (md5hash, chunk_id) + VALUES(%(md5hash)s, %(chunk_id)s::UUID) + ON CONFLICT DO NOTHING + """, + data) + +def get_token_id(cursor, data): + cursor.execute( + """ + WITH e as( + INSERT INTO token (chunk_id, tokens, encoding) + VALUES (%(chunk_id)s::UUID, %(tokens)s, %(encoding)s) + ON CONFLICT DO NOTHING + RETURNING * + ) + SELECT id FROM e + UNION ALL + SELECT id FROM token + WHERE chunk_id = %(chunk_id)s::UUID + and tokens = %(tokens)s::INTEGER[] + and encoding = %(encoding)s + """, + data + ) + res = cursor.fetchone() + return res['id'] if res is not None else None + def store_chunk_item(cursor, item): """Process a ChunkItem and insert it into the database.""" try: @@ -55,60 +111,24 @@ def store_chunk_item(cursor, item): 'tokens': item["tokens"], 'encoding': 'cl100k_base' } - cursor.execute( - """SELECT md5hash FROM crawl WHERE url = %(url)s - ORDER BY last_updated DESC LIMIT 1""", - data - ) - data['md5hash'] = cursor.fetchone()['md5hash'] + new_md5hash = get_md5hash(cursor, data) + if new_md5hash is not None: + data['md5hash'] = new_md5hash + + new_chunk_id = get_chunk_id(cursor, data) + if new_chunk_id is not None: + data['chunk_id'] = new_chunk_id + + insert_html_content_to_chunk(cursor, data) - # TODO: should probably update the title even if the text_content - # is already present as we may have changed how we create the title - cursor.execute( - """ - WITH e as( - INSERT INTO chunk (title, text_content) - VALUES(%(title)s, %(text_content)s) - ON CONFLICT DO NOTHING - RETURNING id - ) - SELECT id FROM e - UNION ALL - SELECT id FROM chunk WHERE text_content = %(text_content)s - """, - data - ) - data['chunk_id'] = cursor.fetchone()['id'] - cursor.execute( - """ - INSERT INTO html_content_to_chunk (html_content_md5hash, chunk_id) - VALUES(%(md5hash)s, %(chunk_id)s::UUID) - ON CONFLICT DO NOTHING - """, - data) - cursor.execute( - """ - WITH e as( - INSERT INTO token (chunk_id, tokens, encoding) - VALUES (%(chunk_id)s::UUID, %(tokens)s, %(encoding)s) - ON CONFLICT DO NOTHING - RETURNING * - ) - SELECT id FROM e - UNION ALL - SELECT id FROM token - WHERE chunk_id = %(chunk_id)s::UUID - and tokens = %(tokens)s::INTEGER[] - and encoding = %(encoding)s - """, - data - ) - data['token_id'] = cursor.fetchone()['id'] + new_token_id = get_token_id(cursor, data) + if new_token_id is not None: + data['token_id'] = new_token_id + return data except psycopg.IntegrityError as e: raise db.DBError("Error storing chunk item for %s" % item['url']) from e - def store_crawl_item(cursor, item): """Process a CrawlItem and insert it into the database.""" try: @@ -127,7 +147,9 @@ def store_crawl_item(cursor, item): """, item ) - return item + cursor.execute("""SELECT * FROM crawl + WHERE url = %(url)s ORDER BY last_updated DESC LIMIT 1""", item) + return cursor.fetchone() except psycopg.IntegrityError as e: raise db.DBError("Error storing crawl item for %s" % item['url']) from e @@ -152,7 +174,7 @@ def store_embedding_item(cursor, item): query, data ) - return item + return data['token_id'] except psycopg.IntegrityError as e: raise db.DBError( "Error storing embedding item for token %s" % item['token_id']) from e @@ -163,7 +185,7 @@ def fetch_crawl_ids_without_chunk(cursor): """ SELECT crawl.id FROM crawl LEFT JOIN html_content_to_chunk - ON crawl.md5hash = html_content_to_chunk.html_content_md5hash + ON crawl.md5hash = html_content_to_chunk.md5hash WHERE chunk_id IS NULL """ ).as_string(cursor) @@ -206,15 +228,13 @@ def fetch_crawl_row(cursor, url): assert 'html_content' in row.keys() return row -def fetch_chunk_token_row(cursor, url): +def fetch_chunk_token_row(cursor, id): """Fetch the most recent chunk token for a given chunk id.""" - data = db.parse_postgresql_url(url) + data = {'id': id} cursor.execute( - "SELECT chunk.id as chunk_id, token.id as token_id, tokens FROM chunk" - " JOIN token ON chunk.id = token.chunk_id" - " WHERE chunk.id = %(id)s LIMIT 1", + """SELECT chunk.id as chunk_id, token.id as token_id, tokens FROM chunk + JOIN token ON chunk.id = token.chunk_id + WHERE chunk.id = %(id)s LIMIT 1""", data ) - # psycopg.extras.DictRow is not a real dict and will convert - # to string as a list so we force convert to dict return cursor.fetchone() diff --git a/tests/test_db_crawler.py b/tests/test_db_crawler.py index 891b424..dcf18fc 100644 --- a/tests/test_db_crawler.py +++ b/tests/test_db_crawler.py @@ -1,6 +1,5 @@ """test database functions""" import unittest - import ailab.db as db import ailab.db.crawler as crawler import tests.testing_utils as test @@ -53,20 +52,6 @@ def test_fetch_crawl_row_by_postgresql_url(self): row['title'], "Sampling procedures - Canadian Food Inspection Agency") - def test_fetch_chunk_row(self): - """sample test to check if fetch_chunk_row works""" - url = db.create_postgresql_url( - "DBNAME", - "chunk", - "469812c5-190c-4e56-9f88-c8621592bcb5") - with db.cursor(self.connection) as cursor: - row = crawler.fetch_chunk_token_row(cursor, url) - self.connection.rollback() - self.assertTrue(isinstance(row, dict)) - self.assertEqual(len(row['tokens']), 76) - self.assertEqual(str(row['chunk_id']), "469812c5-190c-4e56-9f88-c8621592bcb5") - self.assertEqual(str(row['token_id']), 'dbb7b498-2cbf-4ae9-aa10-3169cc72f285') - def test_fetch_chunk_id_without_embedding(self): """sample test to check if fetch_chunk_id_without_embedding works""" with db.cursor(self.connection) as cursor: @@ -74,3 +59,90 @@ def test_fetch_chunk_id_without_embedding(self): rows = crawler.fetch_chunk_id_without_embedding(cursor, 'test-model') _entity_id = rows[0] self.connection.rollback() + + def test_store_chunk_item(self): + """Test storing a chunk item.""" + with db.cursor(self.connection) as cursor: + item = { + "url": "https://inspection.canada.ca/a-propos-de-l-acia/fra/1299008020759/1299008778654", + "title": "À propos de l'ACIA - Agence canadienne d'inspection des aliments", + "text_content": "This is an example content.", + "tokens": [73053,10045,409,326,6,1741,5987,2998,37622,934,6,8629,44618,409,38682,1001,367,3869,348,2328,7330,52760,11,326,6,1741,5987,264,653,348,5642,11837,266,7930,2995,64097,1208,4371,392,1018,978,38450,12267,11,1208,77323,951,4039,12249,11,1208,9313,951,348,19395,978,2629,2249,1880,326,69537,12416,8065,84751,6625,13,17360,535,89,551,5690,6405,13674,33867,14318,3765,91080,1370,2126,22811,11876,459,5979,729,3539,5512,409,326,6,1741,5987,56311,39929,64079,3869,951,90108,35933,46680,969,645,551,2009,85182,40280,3930,7008,90108,1082,3625,459,5979,729,3539,5512,5019,3625,4824,76,1154,25540,5512,1370,514,72601,409,2343,68,2405,10610,953,13,2998,62163,42145,40948,5512,294,6,97675,4149,3462,16848,85046,1880,83229,70,91555,11683,12416,3869,326,6,26125,1880,1208,9313,951,5790,325,625,3808,1732,36527,3459,360,17724,409,326,6,1741,5987,22555,951,24261,288,1880,951,917,2053,3700,34965,11,93084,1880,3057,65,811,57967,1220,294,26248,1088,1759,409,1208,6377,30052,9359,10333,5392,788,95188,4949,11,2126,22811,11,1008,44357,11,95995,409,3729,8471,1880,5790,325,625,3808,65381,10045,409,17317,24789,266,11,9131,11,11376,11,4046,6414,51084,951,97035,13,8245,22139,64829,29696,409,11692,1880,409,85182,77,685,328,5164,409,80080,423,944,59307,80080,11,17889,1354,5860,24985,3946,11,62163,409,5790,325,625,3808,951,35030,3557,1880,3930,586,13,51097,4972,35933,44564,20392,3869,326,58591,73511,7769,3136,13,3744,268,2850,1900,5856,288,9952,3625,447,2053,17317,58673,484,2439,5019,65827,268,404,1880,5201,261,3625,7617,288,409,77463,38450,12267,1880,1097,73511,15171,1208,9313,1880,1208,6225,31617,8082,951,1615,316,18543,1759,13,5856,288,8666,266,22589,1219,13109,8666,50848,294,6,4683,15916,11,1219,13109,8666,1413,3930,49904,265,11,43252,89781,288,1765,3625,13826,25108,4978,409,51304,13,29124,6414,51084,951,97222,1880,951,3600,45629,288,5019,38682,404,15907,22639,9952,3625,2027,38647,11,3625,1615,316,18543,1759,11,326,6,485,592,7379,1880,3625,46106,31957,1821,13,20915,21066,409,1208,26965,13109,44564,3057,1557,409,1208,26965,13109,38682,1001,12267,13,2998,2842,40280,1880,82620,27220,18042,283,35573,514,82620,27220,1880,326,91655,11323,266,2428,13,2998,29033,6672,51097,737,2727,392,645,432,1137,625,3808,1765,3625,737,2727,392,645,1880,46106,21744,10515,5512,409,326,6,1741,5987,13] + } + stored_item = crawler.store_chunk_item(cursor, item) + self.connection.rollback() + self.assertEqual(stored_item["url"], item["url"]) + + def test_store_crawl_item(self): + """Test storing a crawl item.""" + with db.cursor(self.connection) as cursor: + item = { + "url": "https://example.com", + "title": "Example", + "html_content": "This is an example.", + "lang": "en", + "last_crawled": "2022-01-01", + "last_updated": "2022-01-01" + } + stored_item = crawler.store_crawl_item(cursor, item) + self.connection.rollback() + self.assertEqual(item["title"], stored_item["title"]) + self.assertEqual(item["url"], stored_item["url"]) + + def test_store_embedding_item(self): + """Test storing an embedding item.""" + with db.cursor(self.connection) as cursor: + item = { + "token_id": "be612259-9b52-42fd-8d0b-d72120efa3b6", + "embedding": test.generate_random_embedding(1536), + "embedding_model": "test-model" + } + stored_item = crawler.store_embedding_item(cursor, item) + self.connection.rollback() + self.assertEqual(item["token_id"], stored_item) + + def test_fetch_crawl_ids_without_chunk(self): + """Test fetching crawl IDs without a chunk.""" + with db.cursor(self.connection) as cursor: + id = crawler.fetch_crawl_ids_without_chunk(cursor) + self.connection.rollback() + self.assertEqual(id, []) + + def test_fetch_crawl_row(self): + """Test fetching a crawl row.""" + with db.cursor(self.connection) as cursor: + row = crawler.fetch_crawl_row(cursor, "https://inspection.canada.ca/a-propos-de-l-acia/structure-organisationnelle/mandat/fra/1299780188624/1319164463699") + self.connection.rollback() + self.assertEqual(row['title'], "Mandat - Agence canadienne d'inspection des aliments") + + # def test_fetch_crawl_row_with_test_data(self): + # """Test fetching a crawl row.""" + # with db.cursor(self.connection) as cursor: + # test_chunk_id = test.test_uuid + # test_crawl_id = test.test_uuid + # test_md5hash = test.test_hash + + + # cursor.execute(f""" + # INSERT INTO html_content VALUES ('Test Content', '{test_md5hash}'); + # INSERT INTO crawl (id, url, title, lang, last_crawled, last_updated, last_updated_date, md5hash) + # VALUES ('{test_chunk_id}', 'http://example.com', 'Test Title', 'en', NOW(), NOW(), NOW(), '{test_md5hash}'); + # INSERT INTO html_content_to_chunk VALUES ('{test_crawl_id}', '{test_md5hash}'); + # """ + # ) + # row = crawler.fetch_crawl_row(cursor, "http://example.com") + # self.connection.rollback() + # self.assertEqual(row['title'], "Test Title") + + def test_fetch_chunk_token_row(self): + """Test fetching a chunk token row.""" + with db.cursor(self.connection) as cursor: + row = crawler.fetch_chunk_token_row(cursor, "469812c5-190c-4e56-9f88-c8621592bcb5") + self.connection.rollback() + self.assertEqual(str(row['chunk_id']), "469812c5-190c-4e56-9f88-c8621592bcb5") + + def test_fetch_crawl_row_with_invalid_url(self): + """Test fetching a crawl row with an invalid URL.""" + with db.cursor(self.connection) as cursor: + with self.assertRaises(db.DBError): + crawler.fetch_crawl_row(cursor, "invalid_url") diff --git a/tests/test_db_data.py b/tests/test_db_data.py index cf04789..82a740c 100644 --- a/tests/test_db_data.py +++ b/tests/test_db_data.py @@ -1,6 +1,7 @@ import unittest import ailab.db as db +import testing_utils as test class TestDBData(unittest.TestCase): @@ -20,14 +21,14 @@ def tearDown(self): def upgrade_schema(self): return - # if test.LOUIS_SCHEMA == 'louis_v005': - # self.execute('sql/2023-07-11-hotfix-xml-not-well-formed.sql') - # self.execute('sql/2023-07-11-populate-link.sql') - # self.execute('sql/2023-07-12-score-current.sql') - # self.execute('sql/2023-07-19-modify-score_type-add-similarity.sql') - # self.execute('sql/2023-07-19-modified-documents.sql') - # self.execute('sql/2023-07-19-weighted_search.sql') - # self.execute('sql/2023-07-21-default_chunk.sql') + if test.LOUIS_SCHEMA == 'louis_v004': + self.execute('sql/2023-07-11-hotfix-xml-not-well-formed.sql') + self.execute('sql/2023-07-11-populate-link.sql') + self.execute('sql/2023-07-12-score-current.sql') + self.execute('sql/2023-07-19-modify-score_type-add-similarity.sql') + self.execute('sql/2023-07-19-modified-documents.sql') + self.execute('sql/2023-07-19-weighted_search.sql') + self.execute('sql/2023-07-21-default_chunk.sql') def test_well_formed_xml(self): self.upgrade_schema() @@ -39,13 +40,13 @@ def test_well_formed_xml(self): result = self.cursor.fetchall() self.assertEqual(result[0]['count'], 0, "All xml should be well formed") - # def test_every_crawl_doc_should_have_at_least_one_chunk(self): - # # self.execute('sql/2023-08-09-issue8-html_content-table.sql') - # self.cursor.execute(""" - # select count(*) - # from crawl left join documents on crawl.id = documents.id - # where documents.id is null""") - # result = self.cursor.fetchall() - # self.assertEqual( - # result[0]['count'], 0, - # "Every crawl doc should have at least one chunk") + def test_every_crawl_doc_should_have_at_least_one_chunk(self): + # self.execute('sql/2023-08-09-issue8-html_content-table.sql') + self.cursor.execute(""" + select count(*) + from crawl left join documents on crawl.id = documents.id + where documents.id is null""") + result = self.cursor.fetchall() + self.assertEqual( + result[0]['count'], 0, + "Every crawl doc should have at least one chunk") diff --git a/tests/test_db_schema.py b/tests/test_db_schema.py index cb5318b..3813f46 100644 --- a/tests/test_db_schema.py +++ b/tests/test_db_schema.py @@ -1,7 +1,8 @@ """test database functions""" import unittest -#import testing_utils as test +import psycopg + import ailab.db as db @@ -13,6 +14,99 @@ def setUp(self): def tearDown(self): self.connection.close() + def test_crawl_exists(self): + """Test if crawl table exists in the database and is not empty.""" + table_name = "crawl" + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 1;""") + result = cursor.fetchone() + self.assertIsNotNone(result) + + def test_crawl_has_correct_columns(self): + """Test if crawl table has the correct columns.""" + table_name = "crawl" + expected_columns = ["id", "url", "title", "lang", "last_crawled", + "last_updated", "last_updated_date", "md5hash"] + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 0;""") + actual_columns = [desc[0] for desc in cursor.description] + self.assertCountEqual(actual_columns, expected_columns) + + def test_chunk_exists(self): + """Test if chunk table exists in the database and is not empty.""" + table_name = "chunk" + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 1;""") + result = cursor.fetchone() + self.assertIsNotNone(result) + + def test_chunk_has_correct_columns(self): + """Test if chunk table has the correct columns.""" + table_name = "chunk" + expected_columns = ["id", "title", "text_content"] + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 0;""") + actual_columns = [desc[0] for desc in cursor.description] + self.assertEqual(actual_columns, expected_columns) + + def test_token_exists(self): + """Test if token table exists in the database and is not empty.""" + table_name = "token" + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 1;""") + result = cursor.fetchone() + self.assertIsNotNone(result) + + def test_token_has_correct_columns(self): + """Test if token table has the correct columns.""" + table_name = "token" + expected_columns = ["id", "chunk_id", "tokens", "encoding"] + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 0;""") + actual_columns = [desc[0] for desc in cursor.description] + self.assertEqual(actual_columns, expected_columns) + + def test_score_exists(self): + """Test if score table exists in the database and is not empty.""" + table_name = "score" + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 1;""") + result = cursor.fetchone() + self.assertIsNotNone(result) + + def test_score_has_correct_columns(self): + """Test if score table has the correct columns.""" + table_name = "score" + expected_columns = ["entity_id", "score", "score_type"] + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 0;""") + actual_columns = [desc[0] for desc in cursor.description] + self.assertEqual(actual_columns, expected_columns) + + def test_html_content_exists(self): + """Test if html_content table exists in the database and is not empty.""" + table_name = "html_content" + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 1;""") + result = cursor.fetchone() + self.assertIsNotNone(result) + + def test_html_content_has_correct_columns(self): + """Test if html_content table has the correct columns.""" + table_name = "html_content" + expected_columns = ["content", "md5hash"] + with db.cursor(self.connection) as cursor: + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 0;""") + actual_columns = [desc[0] for desc in cursor.description] + self.assertEqual(actual_columns, expected_columns) + + def test_false_table_not_exists(self): + """Test if false_table table does NOT exists in the database.""" + table_name = "false_table" + with db.cursor(self.connection) as cursor: + with self.assertRaises(psycopg.errors.UndefinedTable): + cursor.execute(f"""SELECT * FROM {table_name} LIMIT 1;""") + # def test_schema(self): # """sample test to check if the schema is correct and idempotent""" # schema_filename = f"dumps/{test.LOUIS_SCHEMA}/schema.sql" @@ -22,3 +116,14 @@ def tearDown(self): # with db.cursor(self.connection) as cursor: # cursor.execute(schema) # self.connection.rollback() + + # def test_schema_exist(self): + # """sample test to check if the schema exists""" + # with db.cursor(self.connection) as cursor: + # cursor.execute( + # "SELECT EXISTS(SELECT * FROM )", + # (test.LOUIS_SCHEMA,) + # ) + # self.connection.rollback() + # row = cursor.fetchone() + # self.assertTrue(row[0]) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 9e560cb..2348db2 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1,4 +1,7 @@ import os +import random +import uuid +import hashlib import dotenv dotenv.load_dotenv() @@ -19,3 +22,16 @@ def raise_error(message): unique(token_id) ); """ + +# Generate a random UUID +test_uuid = uuid.uuid4() +test_item = { + "id": test_uuid, + "title": "Title exemple", + "text_content": "This is an example content.", + } + +test_hash = hashlib.md5("test".encode()).hexdigest()[:31] + +def generate_random_embedding(dimensions=100): + return [random.uniform(0, 100000) for _ in range(dimensions)]