diff --git a/airflow/.dockerignore b/airflow/.dockerignore index c8f3b275..fa8e599f 100644 --- a/airflow/.dockerignore +++ b/airflow/.dockerignore @@ -3,4 +3,3 @@ astro .env airflow_settings.yaml logs/ -dags/ diff --git a/airflow/Dockerfile b/airflow/Dockerfile index 39183c3b..a67c6702 100644 --- a/airflow/Dockerfile +++ b/airflow/Dockerfile @@ -1,5 +1 @@ -# syntax=quay.io/astronomer/airflow-extensions:latest - -FROM quay.io/astronomer/astro-runtime:9.5.0-base - -COPY include/airflow_provider_weaviate-0.0.1-py3-none-any.whl /tmp +FROM quay.io/astronomer/astro-runtime:9.5.0 diff --git a/airflow/dags/ingestion/ask-astro-load-airflow-docs.py b/airflow/dags/ingestion/ask-astro-load-airflow-docs.py new file mode 100644 index 00000000..c0f75851 --- /dev/null +++ b/airflow/dags/ingestion/ask-astro-load-airflow-docs.py @@ -0,0 +1,55 @@ +import os +from datetime import datetime + +from include.tasks import split +from include.tasks.extract import airflow_docs +from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook + +from airflow.decorators import dag, task + +ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "dev") + +_WEAVIATE_CONN_ID = f"weaviate_{ask_astro_env}" +WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsDev") + +ask_astro_weaviate_hook = AskAstroWeaviateHook(_WEAVIATE_CONN_ID) + +airflow_docs_base_url = "https://airflow.apache.org/docs/" + +default_args = {"retries": 3, "retry_delay": 30} + +schedule_interval = "0 5 * * *" if ask_astro_env == "prod" else None + + +@dag( + schedule_interval=schedule_interval, + start_date=datetime(2023, 9, 27), + catchup=False, + is_paused_upon_creation=True, + default_args=default_args, +) +def ask_astro_load_airflow_docs(): + """ + This DAG performs incremental load for any new Airflow docs. Initial load via ask_astro_load_bulk imported + data from a point-in-time data capture. By using the upsert logic of the weaviate_import decorator + any existing documents that have been updated will be removed and re-added. + """ + + extracted_airflow_docs = task(airflow_docs.extract_airflow_docs)(docs_base_url=airflow_docs_base_url) + + split_md_docs = task(split.split_html).expand(dfs=[extracted_airflow_docs]) + + _import_data = ( + task(ask_astro_weaviate_hook.ingest_data, retries=10) + .partial( + class_name=WEAVIATE_CLASS, + existing="upsert", + doc_key="docLink", + batch_params={"batch_size": 1000}, + verbose=True, + ) + .expand(dfs=[split_md_docs]) + ) + + +ask_astro_load_airflow_docs() diff --git a/airflow/dags/ingestion/ask-astro-load-blogs.py b/airflow/dags/ingestion/ask-astro-load-blogs.py index 239aad8e..314fef3e 100644 --- a/airflow/dags/ingestion/ask-astro-load-blogs.py +++ b/airflow/dags/ingestion/ask-astro-load-blogs.py @@ -1,37 +1,54 @@ +import datetime import os -from datetime import datetime -from include.tasks import ingest, split +from include.tasks import split from include.tasks.extract import blogs +from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook from airflow.decorators import dag, task -ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "") +ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "dev") _WEAVIATE_CONN_ID = f"weaviate_{ask_astro_env}" -WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsProd") +WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsDev") +ask_astro_weaviate_hook = AskAstroWeaviateHook(_WEAVIATE_CONN_ID) -blog_cutoff_date = datetime.strptime("2023-01-19", "%Y-%m-%d") +blog_cutoff_date = datetime.date(2023, 1, 19) +default_args = {"retries": 3, "retry_delay": 30} -@dag(schedule_interval="0 5 * * *", start_date=datetime(2023, 9, 27), catchup=False, is_paused_upon_creation=True) +schedule_interval = "0 5 * * *" if ask_astro_env == "prod" else None + + +@dag( + schedule_interval=schedule_interval, + start_date=datetime.datetime(2023, 9, 27), + catchup=False, + is_paused_upon_creation=True, + default_args=default_args, +) def ask_astro_load_blogs(): """ - This DAG performs incremental load for any new docs. Initial load via ask_astro_load_bulk imported - data from a point-in-time data capture. By using the upsert logic of the weaviate_import decorator + This DAG performs incremental load for any new docs. Initial load via ask_astro_load_bulk imported + data from a point-in-time data capture. By using the upsert logic of the weaviate_import decorator any existing documents that have been updated will be removed and re-added. """ - blogs_docs = task(blogs.extract_astro_blogs, retries=3)(blog_cutoff_date=blog_cutoff_date) + blogs_docs = task(blogs.extract_astro_blogs)(blog_cutoff_date=blog_cutoff_date) split_md_docs = task(split.split_markdown).expand(dfs=[blogs_docs]) - task.weaviate_import( - ingest.import_upsert_data, - weaviate_conn_id=_WEAVIATE_CONN_ID, - retries=10, - retry_delay=30, - ).partial(class_name=WEAVIATE_CLASS, primary_key="docLink").expand(dfs=[split_md_docs]) + _import_data = ( + task(ask_astro_weaviate_hook.ingest_data, retries=10) + .partial( + class_name=WEAVIATE_CLASS, + existing="upsert", + doc_key="docLink", + batch_params={"batch_size": 1000}, + verbose=True, + ) + .expand(dfs=[split_md_docs]) + ) ask_astro_load_blogs() diff --git a/airflow/dags/ingestion/ask-astro-load-github.py b/airflow/dags/ingestion/ask-astro-load-github.py index 4dbad5a7..8d19741f 100644 --- a/airflow/dags/ingestion/ask-astro-load-github.py +++ b/airflow/dags/ingestion/ask-astro-load-github.py @@ -1,25 +1,26 @@ +import datetime import os -from datetime import datetime -from include.tasks import ingest, split +from include.tasks import split from include.tasks.extract import github +from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook from airflow.decorators import dag, task -ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "") +ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "dev") _WEAVIATE_CONN_ID = f"weaviate_{ask_astro_env}" _GITHUB_CONN_ID = "github_ro" -WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsProd") +WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsDev") + +ask_astro_weaviate_hook = AskAstroWeaviateHook(_WEAVIATE_CONN_ID) + markdown_docs_sources = [ {"doc_dir": "learn", "repo_base": "astronomer/docs"}, {"doc_dir": "astro", "repo_base": "astronomer/docs"}, {"doc_dir": "", "repo_base": "OpenLineage/docs"}, {"doc_dir": "", "repo_base": "OpenLineage/OpenLineage"}, ] -rst_docs_sources = [ - {"doc_dir": "docs", "repo_base": "apache/airflow", "exclude_docs": ["changelog.rst", "commits.rst"]}, -] code_samples_sources = [ {"doc_dir": "code-samples", "repo_base": "astronomer/docs"}, ] @@ -27,53 +28,54 @@ "apache/airflow", ] +default_args = {"retries": 3, "retry_delay": 30} + +schedule_interval = "0 5 * * *" if ask_astro_env == "prod" else None -@dag(schedule_interval="0 5 * * *", start_date=datetime(2023, 9, 27), catchup=False, is_paused_upon_creation=True) + +@dag( + schedule_interval=schedule_interval, + start_date=datetime.datetime(2023, 9, 27), + catchup=False, + is_paused_upon_creation=True, + default_args=default_args, +) def ask_astro_load_github(): """ - This DAG performs incremental load for any new docs. Initial load via ask_astro_load_bulk imported - data from a point-in-time data capture. By using the upsert logic of the weaviate_import decorator + This DAG performs incremental load for any new docs. Initial load via ask_astro_load_bulk imported + data from a point-in-time data capture. By using the upsert logic of the weaviate_import decorator any existing documents that have been updated will be removed and re-added. """ md_docs = ( - task(github.extract_github_markdown, retries=3) + task(github.extract_github_markdown) .partial(github_conn_id=_GITHUB_CONN_ID) .expand(source=markdown_docs_sources) ) - rst_docs = ( - task(github.extract_github_rst, retries=3) - .partial(github_conn_id=_GITHUB_CONN_ID) - .expand(source=rst_docs_sources) - ) - issues_docs = ( - task(github.extract_github_issues, retries=3) - .partial(github_conn_id=_GITHUB_CONN_ID) - .expand(repo_base=issues_docs_sources) + task(github.extract_github_issues).partial(github_conn_id=_GITHUB_CONN_ID).expand(repo_base=issues_docs_sources) ) code_samples = ( - task(github.extract_github_python, retries=3) - .partial(github_conn_id=_GITHUB_CONN_ID) - .expand(source=code_samples_sources) + task(github.extract_github_python).partial(github_conn_id=_GITHUB_CONN_ID).expand(source=code_samples_sources) ) - markdown_tasks = [md_docs, rst_docs, issues_docs] - - split_md_docs = task(split.split_markdown).expand(dfs=markdown_tasks) + split_md_docs = task(split.split_markdown).expand(dfs=[md_docs, issues_docs]) split_code_docs = task(split.split_python).expand(dfs=[code_samples]) - task.weaviate_import( - ingest.import_upsert_data, - weaviate_conn_id=_WEAVIATE_CONN_ID, - retries=10, - retry_delay=30, - ).partial(class_name=WEAVIATE_CLASS, primary_key="docLink").expand(dfs=[split_md_docs, split_code_docs]) - - issues_docs >> md_docs >> rst_docs >> code_samples + _import_data = ( + task(ask_astro_weaviate_hook.ingest_data, retries=10) + .partial( + class_name=WEAVIATE_CLASS, + existing="upsert", + doc_key="docLink", + batch_params={"batch_size": 1000}, + verbose=True, + ) + .expand(dfs=[split_md_docs, split_code_docs]) + ) ask_astro_load_github() diff --git a/airflow/dags/ingestion/ask-astro-load-registry.py b/airflow/dags/ingestion/ask-astro-load-registry.py index 08b617c9..c73b59c9 100644 --- a/airflow/dags/ingestion/ask-astro-load-registry.py +++ b/airflow/dags/ingestion/ask-astro-load-registry.py @@ -1,39 +1,57 @@ import os from datetime import datetime -from include.tasks import ingest, split +from include.tasks import split from include.tasks.extract import registry +from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook from airflow.decorators import dag, task -ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "") +ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "dev") _WEAVIATE_CONN_ID = f"weaviate_{ask_astro_env}" -WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsProd") +WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsDev") +ask_astro_weaviate_hook = AskAstroWeaviateHook(_WEAVIATE_CONN_ID) -@dag(schedule_interval="0 5 * * *", start_date=datetime(2023, 9, 27), catchup=False, is_paused_upon_creation=True) +default_args = {"retries": 3, "retry_delay": 30} + +schedule_interval = "0 5 * * *" if ask_astro_env == "prod" else None + + +@dag( + schedule_interval=schedule_interval, + start_date=datetime(2023, 9, 27), + catchup=False, + is_paused_upon_creation=True, + default_args=default_args, +) def ask_astro_load_registry(): """ - This DAG performs incremental load for any new docs. Initial load via ask_astro_load_bulk imported - data from a point-in-time data capture. By using the upsert logic of the weaviate_import decorator + This DAG performs incremental load for any new docs. Initial load via ask_astro_load_bulk imported + data from a point-in-time data capture. By using the upsert logic of the weaviate_import decorator any existing documents that have been updated will be removed and re-added. """ - registry_cells_docs = task(registry.extract_astro_registry_cell_types, retries=3)() + registry_cells_docs = task(registry.extract_astro_registry_cell_types)() - registry_dags_docs = task(registry.extract_astro_registry_dags, retries=3)() + registry_dags_docs = task(registry.extract_astro_registry_dags)() split_md_docs = task(split.split_markdown).expand(dfs=[registry_cells_docs]) split_code_docs = task(split.split_python).expand(dfs=[registry_dags_docs]) - task.weaviate_import( - ingest.import_upsert_data, - weaviate_conn_id=_WEAVIATE_CONN_ID, - retries=10, - retry_delay=30, - ).partial(class_name=WEAVIATE_CLASS, primary_key="docLink").expand(dfs=[split_md_docs, split_code_docs]) + _import_data = ( + task(ask_astro_weaviate_hook.ingest_data, retries=10) + .partial( + class_name=WEAVIATE_CLASS, + existing="upsert", + doc_key="docLink", + batch_params={"batch_size": 1000}, + verbose=True, + ) + .expand(dfs=[split_md_docs, split_code_docs]) + ) ask_astro_load_registry() diff --git a/airflow/dags/ingestion/ask-astro-load-slack.py b/airflow/dags/ingestion/ask-astro-load-slack.py index d8804c5c..0066794e 100644 --- a/airflow/dags/ingestion/ask-astro-load-slack.py +++ b/airflow/dags/ingestion/ask-astro-load-slack.py @@ -1,15 +1,17 @@ import os from datetime import datetime -from include.tasks import ingest, split +from include.tasks import split from include.tasks.extract import slack +from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook from airflow.decorators import dag, task -ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "") +ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "dev") _WEAVIATE_CONN_ID = f"weaviate_{ask_astro_env}" -WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsProd") +WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsDev") +ask_astro_weaviate_hook = AskAstroWeaviateHook(_WEAVIATE_CONN_ID) slack_channel_sources = [ { "channel_name": "troubleshooting", @@ -20,25 +22,40 @@ } ] +default_args = {"retries": 3, "retry_delay": 30} -@dag(schedule_interval="0 5 * * *", start_date=datetime(2023, 9, 27), catchup=False, is_paused_upon_creation=True) +schedule_interval = "0 5 * * *" if ask_astro_env == "prod" else None + + +@dag( + schedule_interval=schedule_interval, + start_date=datetime(2023, 9, 27), + catchup=False, + is_paused_upon_creation=True, + default_args=default_args, +) def ask_astro_load_slack(): """ - This DAG performs incremental load for any new slack threads. The slack archive is a point-in-time capture. This - DAG should run nightly to capture threads between archive periods. By using the upsert logic of the + This DAG performs incremental load for any new slack threads. The slack archive is a point-in-time capture. This + DAG should run nightly to capture threads between archive periods. By using the upsert logic of the weaviate_import decorator any existing documents that have been updated will be removed and re-added. """ - slack_docs = task(slack.extract_slack, retries=3).expand(source=slack_channel_sources) + slack_docs = task(slack.extract_slack).expand(source=slack_channel_sources) split_md_docs = task(split.split_markdown).expand(dfs=[slack_docs]) - task.weaviate_import( - ingest.import_upsert_data, - weaviate_conn_id=_WEAVIATE_CONN_ID, - retries=10, - retry_delay=30, - ).partial(class_name=WEAVIATE_CLASS, primary_key="docLink").expand(dfs=[split_md_docs]) + _import_data = ( + task(ask_astro_weaviate_hook.ingest_data, retries=10) + .partial( + class_name=WEAVIATE_CLASS, + existing="upsert", + doc_key="docLink", + batch_params={"batch_size": 1000}, + verbose=True, + ) + .expand(dfs=[split_md_docs]) + ) ask_astro_load_slack() diff --git a/airflow/dags/ingestion/ask-astro-load-stackoverflow.py b/airflow/dags/ingestion/ask-astro-load-stackoverflow.py index 28e43993..66f24a33 100644 --- a/airflow/dags/ingestion/ask-astro-load-stackoverflow.py +++ b/airflow/dags/ingestion/ask-astro-load-stackoverflow.py @@ -1,44 +1,61 @@ import os from datetime import datetime -from include.tasks import ingest, split +from include.tasks import split from include.tasks.extract import stack_overflow +from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook from airflow.decorators import dag, task -ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "") +ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "dev") _WEAVIATE_CONN_ID = f"weaviate_{ask_astro_env}" -WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsProd") +WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsDev") +ask_astro_weaviate_hook = AskAstroWeaviateHook(_WEAVIATE_CONN_ID) stackoverflow_cutoff_date = "2021-09-01" stackoverflow_tags = [ "airflow", ] +default_args = {"retries": 3, "retry_delay": 30} -@dag(schedule_interval=None, start_date=datetime(2023, 9, 27), catchup=False, is_paused_upon_creation=True) +schedule_interval = "0 5 * * *" if ask_astro_env == "prod" else None + + +@dag( + schedule_interval=schedule_interval, + start_date=datetime(2023, 9, 27), + catchup=False, + is_paused_upon_creation=True, + default_args=default_args, +) def ask_astro_load_stackoverflow(): """ - This DAG performs incremental load for any new docs. Initial load via ask_astro_load_bulk imported - data from a point-in-time data capture. By using the upsert logic of the weaviate_import decorator + This DAG performs incremental load for any new docs. Initial load via ask_astro_load_bulk imported + data from a point-in-time data capture. By using the upsert logic of the weaviate_import decorator any existing documents that have been updated will be removed and re-added. """ stack_overflow_docs = ( - task(stack_overflow.extract_stack_overflow_archive, retries=3) + task(stack_overflow.extract_stack_overflow_archive) .partial(stackoverflow_cutoff_date=stackoverflow_cutoff_date) .expand(tag=stackoverflow_tags) ) split_md_docs = task(split.split_markdown).expand(dfs=[stack_overflow_docs]) - task.weaviate_import( - ingest.import_upsert_data, - weaviate_conn_id=_WEAVIATE_CONN_ID, - retries=10, - retry_delay=30, - ).partial(class_name=WEAVIATE_CLASS, primary_key="docLink").expand(dfs=[split_md_docs]) + _import_data = ( + task(ask_astro_weaviate_hook.ingest_data, retries=10) + .partial( + class_name=WEAVIATE_CLASS, + existing="upsert", + doc_key="docLink", + batch_params={"batch_size": 1000}, + verbose=True, + ) + .expand(dfs=[split_md_docs]) + ) ask_astro_load_stackoverflow() diff --git a/airflow/dags/ingestion/ask-astro-load.py b/airflow/dags/ingestion/ask-astro-load.py index 5fe5f33b..0cdb554c 100644 --- a/airflow/dags/ingestion/ask-astro-load.py +++ b/airflow/dags/ingestion/ask-astro-load.py @@ -1,21 +1,27 @@ +from __future__ import annotations + +import datetime +import json +import logging import os -from datetime import datetime -from textwrap import dedent +from pathlib import Path import pandas as pd -from include.tasks import ingest, split -from include.tasks.extract import blogs, github, registry, stack_overflow -from weaviate_provider.operators.weaviate import WeaviateCheckSchemaBranchOperator, WeaviateCreateSchemaOperator +from include.tasks import split +from include.tasks.extract import airflow_docs, blogs, github, registry, stack_overflow +from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook from airflow.decorators import dag, task seed_baseline_url = None - -ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "") +stackoverflow_cutoff_date = "2021-09-01" +ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "dev") _WEAVIATE_CONN_ID = f"weaviate_{ask_astro_env}" _GITHUB_CONN_ID = "github_ro" -WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsProd") +WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsDev") + +ask_astro_weaviate_hook = AskAstroWeaviateHook(_WEAVIATE_CONN_ID) markdown_docs_sources = [ {"doc_dir": "learn", "repo_base": "astronomer/docs"}, @@ -23,9 +29,6 @@ {"doc_dir": "", "repo_base": "OpenLineage/docs"}, {"doc_dir": "", "repo_base": "OpenLineage/OpenLineage"}, ] -rst_docs_sources = [ - {"doc_dir": "docs", "repo_base": "apache/airflow", "exclude_docs": ["changelog.rst", "commits.rst"]}, -] code_samples_sources = [ {"doc_dir": "code-samples", "repo_base": "astronomer/docs"}, ] @@ -42,53 +45,90 @@ } ] -blog_cutoff_date = datetime.strptime("2023-01-19", "%Y-%m-%d") +blog_cutoff_date = datetime.date(2023, 1, 19) -stackoverflow_cutoff_date = "2021-09-01" -stackoverflow_tags = [ - "airflow", -] +stackoverflow_tags = [{"airflow": "2021-09-01"}] -schedule_interval = "@daily" if ask_astro_env == "prod" else None +airflow_docs_base_url = "https://airflow.apache.org/docs/" +default_args = {"retries": 3, "retry_delay": 30} +logger = logging.getLogger("airflow.task") -@dag(schedule_interval=schedule_interval, start_date=datetime(2023, 9, 27), catchup=False, is_paused_upon_creation=True) + +@dag( + schedule_interval=None, + start_date=datetime.datetime(2023, 9, 27), + catchup=False, + is_paused_upon_creation=True, + default_args=default_args, +) def ask_astro_load_bulk(): """ This DAG performs the initial load of data from sources. If seed_baseline_url (set above) points to a parquet file with pre-embedded data it will be - ingested. Otherwise new data is extracted, split, embedded and ingested. + ingested. Otherwise, new data is extracted, split, embedded and ingested. - The first time this DAG runs (without seeded baseline) it will take at lease 20 minutes to + The first time this DAG runs (without seeded baseline) it will take at lease 90 minutes to extract data from all sources. Extracted data is then serialized to disk in the project directory in order to simplify later iterations of ingest with different chunking strategies, vector databases or embedding models. """ - _check_schema = WeaviateCheckSchemaBranchOperator( - task_id="check_schema", - weaviate_conn_id=_WEAVIATE_CONN_ID, - class_object_data="file://include/data/schema.json", - follow_task_ids_if_true=["check_seed_baseline"], - follow_task_ids_if_false=["create_schema"], - doc_md=dedent( - """ - As the Weaviate schema may change over time this task checks if the most - recent schema is in place before ingesting.""" - ), - ) + @task + def get_schema_and_process(schema_file: str) -> list: + """ + Retrieves and processes the schema from a given JSON file. - _create_schema = WeaviateCreateSchemaOperator( - task_id="create_schema", - weaviate_conn_id=_WEAVIATE_CONN_ID, - class_object_data="file://include/data/schema.json", - existing="ignore", - ) + :param schema_file: path to the schema JSON file + """ + try: + class_objects = json.loads(Path(schema_file).read_text()) + except FileNotFoundError: + logger.error(f"Schema file {schema_file} not found.") + raise + except json.JSONDecodeError: + logger.error(f"Invalid JSON in the schema file {schema_file}.") + raise + + class_objects["classes"][0].update({"class": WEAVIATE_CLASS}) + + if "classes" not in class_objects: + class_objects = [class_objects] + else: + class_objects = class_objects["classes"] + + logger.info("Schema processing completed.") + return class_objects + + @task.branch + def check_schema(class_objects: list) -> list[str]: + """ + Check if the current schema includes the requested schema. The current schema could be a superset + so check_schema_subset is used recursively to check that all objects in the requested schema are + represented in the current schema. + + :param class_objects: Class objects to be checked against the current schema. + """ + return ( + ["check_seed_baseline"] + if ask_astro_weaviate_hook.check_schema(class_objects=class_objects) + else ["create_schema"] + ) + + @task(trigger_rule="none_failed") + def create_schema(class_objects: list, existing: str = "ignore") -> None: + """ + Creates or updates the schema in Weaviate based on the given class objects. + + :param class_objects: A list of class objects for schema creation or update. + :param existing: Strategy to handle existing classes ('ignore' or 'replace'). Defaults to 'ignore'. + """ + ask_astro_weaviate_hook.create_schema(class_objects=class_objects, existing=existing) @task.branch(trigger_rule="none_failed") - def check_seed_baseline() -> str: + def check_seed_baseline(seed_baseline_url: str = None) -> str | set: """ Check if we will ingest from pre-embedded baseline or extract each source. """ @@ -96,19 +136,18 @@ def check_seed_baseline() -> str: if seed_baseline_url is not None: return "import_baseline" else: - return [ + return { "extract_github_markdown", - "extract_github_rst", + "extract_airflow_docs", "extract_stack_overflow", - # "extract_slack_archive", "extract_astro_registry_cell_types", "extract_github_issues", "extract_astro_blogs", "extract_github_python", "extract_astro_registry_dags", - ] + } - @task(trigger_rule="none_skipped") + @task(trigger_rule="none_failed") def extract_github_markdown(source: dict): try: df = pd.read_parquet(f"include/data/{source['repo_base']}/{source['doc_dir']}.parquet") @@ -118,28 +157,28 @@ def extract_github_markdown(source: dict): return df - @task(trigger_rule="none_skipped") - def extract_github_rst(source: dict): + @task(trigger_rule="none_failed") + def extract_github_python(source: dict): try: df = pd.read_parquet(f"include/data/{source['repo_base']}/{source['doc_dir']}.parquet") except Exception: - df = github.extract_github_rst(source=source, github_conn_id=_GITHUB_CONN_ID) + df = github.extract_github_python(source, _GITHUB_CONN_ID) df.to_parquet(f"include/data/{source['repo_base']}/{source['doc_dir']}.parquet") return df @task(trigger_rule="none_failed") - def extract_github_python(source: dict): + def extract_airflow_docs(): try: - df = pd.read_parquet(f"include/data/{source['repo_base']}/{source['doc_dir']}.parquet") + df = pd.read_parquet("include/data/apache/airflow/docs.parquet") except Exception: - df = github.extract_github_python(source, _GITHUB_CONN_ID) - df.to_parquet(f"include/data/{source['repo_base']}/{source['doc_dir']}.parquet") + df = airflow_docs.extract_airflow_docs(docs_base_url=airflow_docs_base_url)[0] + df.to_parquet("include/data/apache/airflow/docs.parquet") - return df + return [df] @task(trigger_rule="none_failed") - def extract_stack_overflow(tag: str, stackoverflow_cutoff_date: str): + def extract_stack_overflow(tag: str, stackoverflow_cutoff_date: str = stackoverflow_cutoff_date): try: df = pd.read_parquet("include/data/stack_overflow/base.parquet") except Exception: @@ -150,16 +189,6 @@ def extract_stack_overflow(tag: str, stackoverflow_cutoff_date: str): return df - # @task(trigger_rule="none_failed") - # def extract_slack_archive(source: dict): - # try: - # df = pd.read_parquet("include/data/slack/troubleshooting.parquet") - # except Exception: - # df = slack.extract_slack_archive(source) - # df.to_parquet("include/data/slack/troubleshooting.parquet") - # - # return df - @task(trigger_rule="none_failed") def extract_github_issues(repo_base: str): try: @@ -200,66 +229,66 @@ def extract_astro_blogs(): return [df] - _check_seed_baseline = check_seed_baseline() - md_docs = extract_github_markdown.expand(source=markdown_docs_sources) - - rst_docs = extract_github_rst.expand(source=rst_docs_sources) - issues_docs = extract_github_issues.expand(repo_base=issues_docs_sources) - - stackoverflow_docs = extract_stack_overflow.partial(stackoverflow_cutoff_date=stackoverflow_cutoff_date).expand( - tag=stackoverflow_tags - ) - - # slack_docs = extract_slack_archive.expand(source=slack_channel_sources) - + stackoverflow_docs = extract_stack_overflow.expand(tag=stackoverflow_tags) registry_cells_docs = extract_astro_registry_cell_types() - blogs_docs = extract_astro_blogs() - registry_dags_docs = extract_astro_registry_dags() + code_samples = extract_github_python.expand(source=code_samples_sources) + _airflow_docs = extract_airflow_docs() - code_samples = extract_github_python.partial().expand(source=code_samples_sources) + _get_schema = get_schema_and_process(schema_file="include/data/schema.json") + _check_schema = check_schema(class_objects=_get_schema) + _create_schema = create_schema(class_objects=_get_schema) + _check_seed_baseline = check_seed_baseline(seed_baseline_url=seed_baseline_url) markdown_tasks = [ md_docs, - rst_docs, issues_docs, stackoverflow_docs, - # slack_docs, blogs_docs, registry_cells_docs, ] + html_tasks = [_airflow_docs] + python_code_tasks = [registry_dags_docs, code_samples] split_md_docs = task(split.split_markdown).expand(dfs=markdown_tasks) split_code_docs = task(split.split_python).expand(dfs=python_code_tasks) - task.weaviate_import(ingest.import_data, weaviate_conn_id=_WEAVIATE_CONN_ID, retries=10, retry_delay=30).partial( - class_name=WEAVIATE_CLASS - ).expand(dfs=[split_md_docs, split_code_docs]) + split_html_docs = task(split.split_html).expand(dfs=html_tasks) + + _import_data = ( + task(ask_astro_weaviate_hook.ingest_data, retries=10) + .partial( + class_name=WEAVIATE_CLASS, + existing="upsert", + doc_key="docLink", + batch_params={"batch_size": 1000}, + verbose=True, + ) + .expand(dfs=[split_md_docs, split_code_docs, split_html_docs]) + ) - _import_baseline = task.weaviate_import( - ingest.import_baseline, trigger_rule="none_failed", weaviate_conn_id=_WEAVIATE_CONN_ID - )(class_name=WEAVIATE_CLASS, seed_baseline_url=seed_baseline_url) + _import_baseline = task(ask_astro_weaviate_hook.import_baseline, trigger_rule="none_failed")( + seed_baseline_url=seed_baseline_url, + class_name=WEAVIATE_CLASS, + existing="upsert", + doc_key="docLink", + uuid_column="id", + vector_column="vector", + batch_params={"batch_size": 1000}, + verbose=True, + ) _check_schema >> [_check_seed_baseline, _create_schema] - _create_schema >> markdown_tasks + python_code_tasks + [_check_seed_baseline] - - _check_seed_baseline >> issues_docs >> rst_docs >> md_docs - # ( - # _check_seed_baseline - # >> [stackoverflow_docs, slack_docs, blogs_docs, registry_cells_docs, _import_baseline] + python_code_tasks - # ) + _create_schema >> markdown_tasks + python_code_tasks + html_tasks + [_check_seed_baseline] - ( - _check_seed_baseline - >> [stackoverflow_docs, blogs_docs, registry_cells_docs, _import_baseline] + python_code_tasks - ) + _check_seed_baseline >> markdown_tasks + python_code_tasks + html_tasks + [_import_baseline] ask_astro_load_bulk() diff --git a/airflow/dags/monitor/monitor.py b/airflow/dags/monitor/monitor.py index c38cd631..0cb7b2c0 100644 --- a/airflow/dags/monitor/monitor.py +++ b/airflow/dags/monitor/monitor.py @@ -9,12 +9,12 @@ import firebase_admin import requests -from weaviate_provider.hooks.weaviate import WeaviateHook from airflow.decorators import dag, task from airflow.exceptions import AirflowException from airflow.models import TaskInstance from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator +from airflow.providers.weaviate.hooks.weaviate import WeaviateHook from airflow.utils.context import Context from airflow.utils.trigger_rule import TriggerRule @@ -204,7 +204,7 @@ def check_weaviate_status(**context) -> None: """ try: weaviate_hook = WeaviateHook(weaviate_conn_id) - client = weaviate_hook.get_conn() + client = weaviate_hook.get_client() schemas = client.query.aggregate(weaviate_class).with_meta_count().do() schema = schemas["data"]["Aggregate"][weaviate_class] count = 0 diff --git a/airflow/include/__init__.py b/airflow/include/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/airflow/include/airflow_provider_weaviate-0.0.1-py3-none-any.whl b/airflow/include/airflow_provider_weaviate-0.0.1-py3-none-any.whl deleted file mode 100644 index ae2d05b6..00000000 Binary files a/airflow/include/airflow_provider_weaviate-0.0.1-py3-none-any.whl and /dev/null differ diff --git a/airflow/include/data/apache/apache_license.rst b/airflow/include/data/apache/apache_license.rst deleted file mode 100644 index 106592bd..00000000 --- a/airflow/include/data/apache/apache_license.rst +++ /dev/null @@ -1,16 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. diff --git a/airflow/include/data/schema.json b/airflow/include/data/schema.json index a10e6233..d49b9a6f 100644 --- a/airflow/include/data/schema.json +++ b/airflow/include/data/schema.json @@ -54,6 +54,7 @@ "name": "content", "description": "Document content", "dataType": ["text"], + "tokenization": "word", "moduleConfig": { "text2vec-openai": { "skip": "False", diff --git a/airflow/include/tasks/__init__.py b/airflow/include/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/airflow/include/tasks/extract/__init__.py b/airflow/include/tasks/extract/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/airflow/include/tasks/extract/airflow_docs.py b/airflow/include/tasks/extract/airflow_docs.py new file mode 100644 index 00000000..5c0fbd98 --- /dev/null +++ b/airflow/include/tasks/extract/airflow_docs.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import re +import urllib.parse + +import pandas as pd +import requests +from bs4 import BeautifulSoup +from weaviate.util import generate_uuid5 + +from include.tasks.extract.utils.html_helpers import get_all_links + + +def extract_airflow_docs(docs_base_url: str) -> list[pd.DataFrame]: + """ + This task scrapes docs from the Airflow website and returns a list of pandas dataframes. Return + type is a list in order to map to upstream dynamic tasks. The code recursively generates a list + of html files relative to 'docs_base_url' and then extracts each as text. + + Note: Only the (class_: body, role: main) tag and children are extracted. + + Note: This code will also pickup https://airflow.apache.org/howto/* + which are also referenced in the docs pages. These are useful for Ask Astro and also they are relatively few + pages so we leave them in. + + param docs_base_url: Base URL to start extract. + type docs_base_url: str + + The returned data includes the following fields: + 'docSource': 'apache/airflow/docs' + 'docLink': URL for the page + 'content': HTML content of the page + 'sha': A UUID from the other fields + """ + + # we exclude the following docs which are not useful and/or too large for easy processing. + exclude_docs = [ + "changelog.html", + "commits.html", + "docs/apache-airflow/stable/release_notes.html", + "docs/stable/release_notes.html", + "_api", + "_modules", + "installing-providers-from-sources.html", + "apache-airflow/1.", + "apache-airflow/2.", + "example", + "cli-and-env-variables-ref.html", + ] + + docs_url_parts = urllib.parse.urlsplit(docs_base_url) + docs_url_base = f"{docs_url_parts.scheme}://{docs_url_parts.netloc}" + + all_links = {docs_base_url} + get_all_links(url=list(all_links)[0], all_links=all_links, exclude_docs=exclude_docs) + + # make sure we didn't accidentally pickup any unrelated links in recursion + non_doc_links = {link if docs_url_base not in link else "" for link in all_links} + docs_links = all_links - non_doc_links + + df = pd.DataFrame(docs_links, columns=["docLink"]) + + df["html_content"] = df["docLink"].apply(lambda x: requests.get(x).content) + + df["content"] = df["html_content"].apply( + lambda x: str(BeautifulSoup(x, "html.parser").find(class_="body", role="main")) + ) + df["content"] = df["content"].apply(lambda x: re.sub("ΒΆ", "", x)) + + df["sha"] = df["content"].apply(generate_uuid5) + df["docSource"] = "apache/airflow/docs" + df.reset_index(drop=True, inplace=True) + + # column order matters for uuid generation + df = df[["docSource", "sha", "content", "docLink"]] + + return [df] diff --git a/airflow/include/tasks/extract/blogs.py b/airflow/include/tasks/extract/blogs.py index bec05921..1d8d3cd1 100644 --- a/airflow/include/tasks/extract/blogs.py +++ b/airflow/include/tasks/extract/blogs.py @@ -48,7 +48,7 @@ def extract_astro_blogs(blog_cutoff_date: datetime) -> list[pd.DataFrame]: df = pd.DataFrame(zip(links, dates), columns=["docLink", "date"]) df["date"] = pd.to_datetime(df["date"]).dt.date - df = df[df["date"] > blog_cutoff_date.date()] + df = df[df["date"] > blog_cutoff_date] df.drop("date", inplace=True, axis=1) df.drop_duplicates(inplace=True) diff --git a/airflow/include/tasks/extract/slack.py b/airflow/include/tasks/extract/slack.py index 17b36093..9bb61215 100644 --- a/airflow/include/tasks/extract/slack.py +++ b/airflow/include/tasks/extract/slack.py @@ -5,10 +5,10 @@ import numpy as np import pandas as pd import requests -from include.tasks.extract.utils.slack_helpers import get_slack_replies from weaviate.util import generate_uuid5 from airflow.providers.slack.hooks.slack import SlackHook +from include.tasks.extract.utils.slack_helpers import get_slack_replies slack_archive_host = "apache-airflow.slack-archives.org" slack_base_url = "https://{slack_archive_host}/v1/messages?size={size}&team={team}&channel={channel}" diff --git a/airflow/include/tasks/extract/stack_overflow.py b/airflow/include/tasks/extract/stack_overflow.py index 4f2cc0f9..264edf26 100644 --- a/airflow/include/tasks/extract/stack_overflow.py +++ b/airflow/include/tasks/extract/stack_overflow.py @@ -1,13 +1,14 @@ from __future__ import annotations import pandas as pd +from weaviate.util import generate_uuid5 + from include.tasks.extract.utils.stack_overflow_helpers import ( process_stack_answers, process_stack_comments, process_stack_posts, process_stack_questions, ) -from weaviate.util import generate_uuid5 def extract_stack_overflow_archive(tag: str, stackoverflow_cutoff_date: str) -> pd.DataFrame: diff --git a/airflow/include/tasks/extract/utils/__init__.py b/airflow/include/tasks/extract/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/airflow/include/tasks/extract/utils/html_helpers.py b/airflow/include/tasks/extract/utils/html_helpers.py new file mode 100644 index 00000000..34029303 --- /dev/null +++ b/airflow/include/tasks/extract/utils/html_helpers.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import logging +import urllib.parse +from time import sleep + +import requests +from bs4 import BeautifulSoup + + +def get_links(url: str, exclude_docs: list) -> set: + """ + Given a HTML url this function scrapes the page for any HTML links ( tags) and returns a set of links which: + a) starts with the same base (ie. scheme + netloc) + b) is a relative link from the currently read page + Relative links are converted to absolute links.Note that the absolute link may not be unique due to redirects. + + :param url: The url to scrape for links. + :param exclude_docs: A list of strings to exclude from the returned links. + """ + response = requests.get(url) + data = response.text + soup = BeautifulSoup(data, "lxml") + + url_parts = urllib.parse.urlsplit(url) + url_base = f"{url_parts.scheme}://{url_parts.netloc}" + + links = set() + for link in soup.find_all("a"): + link_url = link.get("href") + + if link_url.endswith(".html"): + if link_url.startswith(url_base) and not any(substring in link_url for substring in exclude_docs): + links.add(link_url) + elif not link_url.startswith("http"): + absolute_url = urllib.parse.urljoin(url, link_url) + if not any(substring in absolute_url for substring in exclude_docs): + links.add(absolute_url) + + return links + + +def get_all_links(url: str, all_links: set, exclude_docs: list, retry_count: int = 0, max_retries: int = 5): + """ + Recursive function to find all sub-pages of a webpage. + + :param url: The url to scrape for links. + :param all_links: A set of all links found so far. + :param exclude_docs: A list of strings to exclude from the returned links. + :param retry_count: Current retry attempt. + :param max_retries: Maximum number of retries allowed for a single URL. + """ + try: + links = get_links(url=url, exclude_docs=exclude_docs) + for link in links: + # check if the linked page actually exists and get the redirect which is hopefully unique + + response = requests.head(link, allow_redirects=True) + if response.ok: + redirect_url = response.url + if redirect_url not in all_links: + logging.info(redirect_url) + all_links.add(redirect_url) + get_all_links(url=redirect_url, all_links=all_links, exclude_docs=exclude_docs) + except requests.exceptions.ConnectionError as ce: + if retry_count < max_retries: + logging.warning(f"Connection error for {url}: {ce}. Retrying ({retry_count + 1}/{max_retries})") + sleep(2**retry_count) # Exponential backoff + get_all_links( + url=url, + all_links=all_links, + exclude_docs=exclude_docs, + retry_count=retry_count + 1, + max_retries=max_retries, + ) + else: + logging.warning(f"Max retries reached for {url}. Skipping this URL.") diff --git a/airflow/include/tasks/extract/utils/stack_overflow_helpers.py b/airflow/include/tasks/extract/utils/stack_overflow_helpers.py index cc1b7843..9b171e6e 100644 --- a/airflow/include/tasks/extract/utils/stack_overflow_helpers.py +++ b/airflow/include/tasks/extract/utils/stack_overflow_helpers.py @@ -20,7 +20,7 @@ {body}{answer_comments}""" ) -comment_template = "{user} on {date} [Score: {score}]: {body}\n" +comment_template = "\n{user} on {date} [Score: {score}]: {body}\n" post_types = { "1": "Question", diff --git a/airflow/include/tasks/extract/utils/weaviate/__init__.py b/airflow/include/tasks/extract/utils/weaviate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py b/airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py new file mode 100644 index 00000000..34764ad8 --- /dev/null +++ b/airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import logging +from typing import Any + +import pandas as pd +import requests +from weaviate.exceptions import UnexpectedStatusCodeException +from weaviate.util import generate_uuid5 + +from airflow.exceptions import AirflowException +from airflow.providers.weaviate.hooks.weaviate import WeaviateHook + + +class AskAstroWeaviateHook(WeaviateHook): + """ + Extends the WeaviateHook to include specific methods for handling Ask-Astro. + + This hook will be directly utilize the functionalities provided by Weaviate providers in + upcoming releases of the `apache-airflow-providers-weaviate` package. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.logger = logging.getLogger("airflow.task") + self.client = self.get_client() + + def compare_schema_subset(self, class_object: Any, class_schema: Any) -> bool: + """ + Recursively check if requested schema/object is a subset of the current schema. + + :param class_object: The class object to check against current schema + :param class_schema: The current schema class object + """ + + # Direct equality check + if class_object == class_schema: + return True + + # Type mismatch early return + if type(class_object) != type(class_schema): + return False + + # Dictionary comparison + if isinstance(class_object, dict): + return all( + k in class_schema and self.compare_schema_subset(v, class_schema[k]) for k, v in class_object.items() + ) + + # List or Tuple comparison + if isinstance(class_object, (list, tuple)): + return len(class_object) == len(class_schema) and all( + self.compare_schema_subset(obj, sch) for obj, sch in zip(class_object, class_schema) + ) + + # Default case for non-matching types or unsupported types + return False + + def is_class_missing(self, class_object: dict) -> bool: + """ + Checks if a class is missing from the schema. + + :param class_object: Class object to be checked against the current schema. + """ + try: + class_schema = self.client.schema.get(class_object.get("class", "")) + return not self.compare_schema_subset(class_object=class_object, class_schema=class_schema) + except UnexpectedStatusCodeException as e: + return e.status_code == 404 and "with response body: None." in e.message + except Exception as e: + error_msg = f"Error during schema check {e}" + self.logger.error(error_msg) + raise ValueError(error_msg) + + def check_schema(self, class_objects: list) -> bool: + """ + Verifies if the current schema includes the requested schema. + + :param class_objects: Class objects to be checked against the current schema. + """ + try: + missing_objects = [obj["class"] for obj in class_objects if self.is_class_missing(obj)] + + if missing_objects: + self.logger.warning(f"Classes {missing_objects} are not in the current schema.") + return False + else: + self.logger.info("All classes are present in the current schema.") + return True + except Exception as e: + error_msg = f"Error during schema check {e}" + self.logger.error(error_msg) + raise ValueError(error_msg) + + def create_schema(self, class_objects: list, existing: str = "ignore") -> None: + """ + Creates or updates the schema in Weaviate based on the given class objects. + + :param class_objects: A list of class objects for schema creation or update. + :param existing: Strategy to handle existing classes ('ignore' or 'replace'). Defaults to 'ignore'. + """ + for class_object in class_objects: + class_name = class_object.get("class", "") + self.logger.info(f"Processing schema for class: {class_name}") + + try: + current_class = self.client.schema.get(class_name=class_name) + except Exception as e: + self.logger.error(f"Error retrieving current class schema: {e}") + current_class = None + if current_class is not None and existing == "replace": + self.logger.info(f"Replacing existing class {class_name}") + self.client.schema.delete_class(class_name=class_name) + + if current_class is None or existing == "replace": + self.client.schema.create_class(class_object) + self.logger.info(f"Created/updated class {class_name}") + + def generate_uuids( + self, + df: pd.DataFrame, + class_name: str, + column_subset: list[str] | None = None, + vector_column: str | None = None, + uuid_column: str | None = None, + ) -> tuple[pd.DataFrame, str]: + """ + Adds UUIDs to a DataFrame, useful for upsert operations where UUIDs must be known before ingestion. + By default, UUIDs are generated using a custom function if 'uuid_column' is not specified. + The function can potentially ingest the same data multiple times with different UUIDs. + + :param df: A dataframe with data to generate a UUID from. + :param class_name: The name of the class use as part of the uuid namespace. + :param uuid_column: Name of the column to create. Default is 'id'. + :param column_subset: A list of columns to use for UUID generation. By default, all columns except + vector_column will be used. + :param vector_column: Name of the column containing the vector data. If specified the vector will be + removed prior to generating the uuid. + """ + column_names = df.columns.to_list() + + column_subset = column_subset or column_names + column_subset.sort() + + if uuid_column is None: + self.logger.info(f"No uuid_column provided. Generating UUIDs as column name {uuid_column}.") + df = df[column_names] + if "id" in column_names: + raise AirflowException("Property 'id' already in dataset. Consider renaming or specify 'uuid_column'.") + else: + uuid_column = "id" + + if uuid_column in column_names: + raise AirflowException( + f"Property {uuid_column} already in dataset. Consider renaming or specify a different 'uuid_column'." + ) + + df[uuid_column] = ( + df[column_subset] + .drop(columns=[vector_column], inplace=False, errors="ignore") + .apply(lambda row: generate_uuid5(identifier=row.to_dict(), namespace=class_name), axis=1) + ) + + return df, uuid_column + + def identify_upsert_targets( + self, df: pd.DataFrame, class_name: str, doc_key: str, uuid_column: str + ) -> pd.DataFrame: + """ + Handles the 'upsert' operation for data ingestion. + + :param df: The DataFrame containing the data to be upserted. + :param class_name: The name of the class to import data. + :param doc_key: The document key used for upsert operation. This is a property of the data that + uniquely identifies all chunks associated with one document. + :param uuid_column: The column name containing the UUID. + """ + if doc_key is None or doc_key not in df.columns: + raise AirflowException("Specified doc_key is not specified or not in the dataset.") + + if uuid_column is None or uuid_column not in df.columns: + raise AirflowException("Specified uuid_column is not specified or not in the dataset.") + + df = df.drop_duplicates(subset=[doc_key, uuid_column], keep="first") + + current_schema = self.client.schema.get(class_name=class_name) + doc_key_schema = [prop for prop in current_schema["properties"] if prop["name"] == doc_key] + + if not doc_key_schema: + raise AirflowException("doc_key does not exist in current schema.") + elif doc_key_schema[0]["tokenization"] != "field": + raise AirflowException("Tokenization for provided doc_key is not set to 'field'. Cannot upsert safely.") + + ids_df = df.groupby(doc_key)[uuid_column].apply(set).reset_index(name="new_ids") + ids_df["existing_ids"] = ids_df[doc_key].apply( + lambda x: self._query_objects(value=x, doc_key=doc_key, uuid_column=uuid_column, class_name=class_name) + ) + + ids_df["objects_to_insert"] = ids_df.apply(lambda x: list(x.new_ids.difference(x.existing_ids)), axis=1) + ids_df["objects_to_delete"] = ids_df.apply(lambda x: list(x.existing_ids.difference(x.new_ids)), axis=1) + ids_df["unchanged_objects"] = ids_df.apply(lambda x: x.new_ids.intersection(x.existing_ids), axis=1) + + return ids_df[[doc_key, "objects_to_insert", "objects_to_delete", "unchanged_objects"]] + + def batch_ingest( + self, + df: pd.DataFrame, + class_name: str, + uuid_column: str, + existing: str, + vector_column: str | None = None, + batch_params: dict = {}, + verbose: bool = False, + ) -> (list, Any): + """ + Processes the DataFrame and batches the data for ingestion into Weaviate. + + :param df: DataFrame containing the data to be ingested. + :param class_name: The name of the class in Weaviate to which data will be ingested. + :param uuid_column: Name of the column containing the UUID. + :param vector_column: Name of the column containing the vector data. + :param batch_params: Parameters for batch configuration. + :param existing: Strategy to handle existing data ('skip', 'replace', 'upsert'). + :param verbose: Whether to print verbose output. + """ + batch = self.client.batch.configure(**batch_params) + batch_errors = [] + + for row_id, row in df.iterrows(): + data_object = row.to_dict() + uuid = data_object.pop(uuid_column) + vector = data_object.pop(vector_column, None) + + try: + if self.client.data_object.exists(uuid=uuid, class_name=class_name) is True: + if existing == "skip": + if verbose is True: + self.logger.warning(f"UUID {uuid} exists. Skipping.") + continue + elif existing == "replace": + # Default for weaviate is replace existing + if verbose is True: + self.logger.warning(f"UUID {uuid} exists. Overwriting.") + + except Exception as e: + if verbose: + self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}") + batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) + continue + + try: + added_row = batch.add_data_object( + class_name=class_name, uuid=uuid, data_object=data_object, vector=vector + ) + if verbose is True: + self.logger.info(f"Added row {row_id} with UUID {added_row} for batch import.") + + except Exception as e: + if verbose: + self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}") + batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) + + results = batch.create_objects() + + if len(results) > 0: + batch_errors += self.process_batch_errors(results=results, verbose=verbose) + + return batch_errors + + def process_batch_errors(self, results: list, verbose: bool) -> list: + """ + Processes the results from batch operation and collects any errors. + + :param results: Results from the batch operation. + :param verbose: Flag to enable verbose logging. + """ + errors = [] + for item in results: + if "errors" in item["result"]: + item_error = {"uuid": item["id"], "errors": item["result"]["errors"]} + if verbose: + self.logger.info( + f"Error occurred in batch process for {item['id']} with error {item['result']['errors']}" + ) + errors.append(item_error) + return errors + + def handle_upsert_rollback( + self, objects_to_upsert: pd.DataFrame, batch_errors: list, class_name: str, verbose: bool + ) -> list: + """ + Handles rollback of inserts in case of errors during upsert operation. + + :param objects_to_upsert: Dictionary of objects to upsert. + :param class_name: Name of the class in Weaviate. + :param verbose: Flag to enable verbose logging. + """ + rollback_errors = [] + + error_uuids = {error["uuid"] for error in batch_errors} + + objects_to_upsert["rollback_doc"] = objects_to_upsert.objects_to_insert.apply( + lambda x: any(error_uuids.intersection(x)) + ) + + objects_to_upsert["successful_doc"] = objects_to_upsert.objects_to_insert.apply( + lambda x: error_uuids.isdisjoint(x) + ) + + rollback_objects = objects_to_upsert[objects_to_upsert.rollback_doc].objects_to_insert.to_list() + rollback_objects = {item for sublist in rollback_objects for item in sublist} + + delete_objects = objects_to_upsert[objects_to_upsert.successful_doc].objects_to_delete.to_list() + delete_objects = {item for sublist in delete_objects for item in sublist} + + for uuid in rollback_objects: + try: + if self.client.data_object.exists(uuid=uuid, class_name=class_name): + self.logger.info(f"Removing id {uuid} for rollback.") + self.client.data_object.delete(uuid=uuid, class_name=class_name, consistency_level="ALL") + elif verbose: + self.logger.info(f"UUID {uuid} does not exist. Skipping deletion during rollback.") + except Exception as e: + rollback_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) + if verbose: + self.logger.info(f"Error in rolling back id {uuid}. Error: {str(e)}") + + for uuid in delete_objects: + try: + if self.client.data_object.exists(uuid=uuid, class_name=class_name): + if verbose: + self.logger.info(f"Deleting id {uuid} for successful upsert.") + self.client.data_object.delete(uuid=uuid, class_name=class_name) + elif verbose: + self.logger.info(f"UUID {uuid} does not exist. Skipping deletion.") + except Exception as e: + rollback_errors.append({"uuid": uuid, "result": {"errors": str(e)}}) + if verbose: + self.logger.info(f"Error in rolling back id {uuid}. Error: {str(e)}") + + return rollback_errors + + def ingest_data( + self, + dfs: list[pd.DataFrame] | pd.DataFrame, + class_name: str, + existing: str = "skip", + doc_key: str = None, + uuid_column: str = None, + vector_column: str = None, + batch_params: dict = None, + verbose: bool = True, + ) -> list: + """ + Ingests data into Weaviate, handling upserts and rollbacks, and returns a list of objects that failed to import. + This function ingests data from pandas DataFrame(s) into a specified class in Weaviate. It supports various + modes of handling existing data (upsert, skip, replace). Upsert logic uses 'doc_key' as a unique document + identifier, enabling document-level atomicity during ingestion. Rollback is performed for any document + encountering errors during ingest. The function returns a list of objects that failed to import for further + handling. + + :param dfs: A single pandas DataFrame or a list of pandas DataFrames to be ingested. + :param class_name: Name of the class in Weaviate schema where data is to be ingested. + :param existing: Strategy for handling existing data: 'upsert', 'skip', or 'replace'. Default is 'skip'. + :param doc_key: Column in DataFrame uniquely identifying each document, required for 'upsert' operations. + :param uuid_column: Column with pre-generated UUIDs. If not provided, UUIDs will be generated. + :param vector_column: Column with embedding vectors for pre-embedded data. + :param batch_params: Additional parameters for Weaviate batch configuration. + :param verbose: Flag to enable verbose output during the ingestion process. + """ + + global objects_to_upsert + if existing not in ["skip", "replace", "upsert"]: + raise AirflowException("Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'upsert'") + + df = pd.concat(dfs, ignore_index=True) + + if uuid_column is None: + df, uuid_column = self.generate_uuids( + df=df, class_name=class_name, vector_column=vector_column, uuid_column=uuid_column + ) + + if existing == "upsert": + objects_to_upsert = self.identify_upsert_targets( + df=df, class_name=class_name, doc_key=doc_key, uuid_column=uuid_column + ) + + objects_to_insert = {item for sublist in objects_to_upsert.objects_to_insert.tolist() for item in sublist} + + # subset df with only objects that need to be inserted + df = df[df[uuid_column].isin(objects_to_insert)] + + self.logger.info(f"Passing {len(df)} objects for ingest.") + + batch_errors = self.batch_ingest( + df=df, + class_name=class_name, + uuid_column=uuid_column, + vector_column=vector_column, + batch_params=batch_params, + existing=existing, + verbose=verbose, + ) + + if existing == "upsert" and batch_errors: + self.logger.warning("Error during upsert. Rolling back all inserts for docs with errors.") + rollback_errors = self.handle_upsert_rollback( + objects_to_upsert=objects_to_upsert, batch_errors=batch_errors, class_name=class_name, verbose=verbose + ) + + if len(rollback_errors) > 0: + self.logger.error("Errors encountered during rollback.") + raise AirflowException("Errors encountered during rollback.") + + if batch_errors: + self.logger.error("Errors encountered during ingest.") + raise AirflowException("Errors encountered during ingest.") + + def _query_objects(self, value: Any, doc_key: str, class_name: str, uuid_column: str) -> set: + """ + Check for existence of a data_object as a property of a data class and return all object ids. + + :param value: The value of the property to query. + :param doc_key: The name of the property to query. + :param class_name: The name of the class to query. + :param uuid_column: The name of the column containing the UUID. + """ + existing_uuids = ( + self.client.query.get(properties=[doc_key], class_name=class_name) + .with_additional([uuid_column]) + .with_where({"path": doc_key, "operator": "Equal", "valueText": value}) + .do()["data"]["Get"][class_name] + ) + + return {additional["_additional"]["id"] for additional in existing_uuids} + + def import_baseline( + self, + seed_baseline_url: str, + class_name: str, + existing: str = "skip", + doc_key: str = None, + uuid_column: str = None, + vector_column: str = None, + batch_params: dict = None, + verbose: bool = True, + ) -> list: + """ + This task ingests data from a baseline of pre-embedded data. This is useful for evaluation and baselining changes + over time. This function is used as a python_callable with the weaviate_import decorator. The returned + dictionary is passed to the WeaviateImportDataOperator for ingest. The operator returns a list of any objects + that failed to import. seed_baseline_url is a URI for a parquet file of pre-embedded data. Any existing + documents are replaced. The assumption is that this is a first import of data and older data + should be removed. + + :param class_name: The name of the class to import data. Class should be created with weaviate schema. + :param seed_baseline_url: The url of a parquet file containing baseline data to ingest. + :param vector_column: For pre-embedded data specify the name of the column containing the embedding vector + :param uuid_column: For data with pre-generated UUID specify the name of the column containing the UUID + :param batch_params: Additional parameters to pass to the weaviate batch configuration + :param verbose: Whether to print verbose output + :param existing: Whether to 'upsert', 'skip' or 'replace' any existing documents. Default is 'skip'. + :param doc_key: If using upsert you must specify a doc_key which uniquely identifies a document which may or may + not include multiple (unique) chunks. + """ + + seed_filename = f"include/data/{seed_baseline_url.split('/')[-1]}" + + try: + df = pd.read_parquet(seed_filename) + + except Exception: + with open(seed_filename, "wb") as fh: + response = requests.get(seed_baseline_url, stream=True) + fh.writelines(response.iter_content(1024)) + + df = pd.read_parquet(seed_filename) + + return self.ingest_data( + dfs=df, + class_name=class_name, + existing=existing, + doc_key=doc_key, + uuid_column=uuid_column, + vector_column=vector_column, + verbose=verbose, + batch_params=batch_params, + ) diff --git a/airflow/include/tasks/ingest.py b/airflow/include/tasks/ingest.py index 8c7f7b41..2ab14dac 100644 --- a/airflow/include/tasks/ingest.py +++ b/airflow/include/tasks/ingest.py @@ -2,82 +2,6 @@ import pandas as pd import requests -from weaviate.util import generate_uuid5 - - -def import_upsert_data(dfs: list[pd.DataFrame], class_name: str, primary_key: str) -> list: - """ - This task concatenates multiple dataframes from upstream dynamic tasks and vectorizes with import to weaviate. - This function is used as a python_callable with the weaviate_import decorator. The returned dictionary is passed - to the WeaviateImportDataOperator for ingest. The operator returns a list of any objects that failed to import. - - A 'uuid' is generated based on the content and metadata (the git sha, document url, the document source and a - concatenation of the headers). - - Any existing documents with the same primary_key but differing UUID or sha will be deleted prior to import. - - param dfs: A list of dataframes from downstream dynamic tasks - type dfs: list[pd.DataFrame] - - param class_name: The name of the class to import data. Class should be created with weaviate schema. - type class_name: str - - param primary_key: The name of a column to use as a primary key for upsert logic. - type primary_key: str - """ - - df = pd.concat(dfs, ignore_index=True) - - df["uuid"] = df.apply(lambda x: generate_uuid5(identifier=x.to_dict(), namespace=class_name), axis=1) - - print(f"Passing {len(df)} objects for import.") - - return { - "data": df, - "class_name": class_name, - "upsert": True, - "primary_key": primary_key, - "uuid_column": "uuid", - "error_threshold": 0, - "verbose": True, - } - - -def import_data(dfs: list[pd.DataFrame], class_name: str) -> list: - """ - This task concatenates multiple dataframes from upstream dynamic tasks and vectorizes with import to weaviate. - This function is used as a python_callable with the weaviate_import decorator. The returned dictionary is passed - to the WeaviateImportDataOperator for ingest. The operator returns a list of any objects that failed to import. - - A 'uuid' is generated based on the content and metadata (the git sha, document url, the document source and a - concatenation of the headers) and Weaviate will create the vectors. - - Any existing documents are not upserted. The assumption is that this is a first - import of data and skipping upsert checks will speed up import. - - param dfs: A list of dataframes from downstream dynamic tasks - type dfs: list[pd.DataFrame] - - param class_name: The name of the class to import data. Class should be created with weaviate schema. - type class_name: str - """ - - df = pd.concat(dfs, ignore_index=True) - - df["uuid"] = df.apply(lambda x: generate_uuid5(identifier=x.to_dict(), namespace=class_name), axis=1) - - print(f"Passing {len(df)} objects for import.") - - return { - "data": df, - "class_name": class_name, - "upsert": False, - "uuid_column": "uuid", - "error_threshold": 0, - "batched_mode": True, - "batch_size": 1000, - "verbose": False, - } def import_baseline(class_name: str, seed_baseline_url: str) -> list: @@ -89,8 +13,8 @@ def import_baseline(class_name: str, seed_baseline_url: str) -> list: seed_baseline_url is a URI for a parquet file of pre-embedded data. - Any existing documents are not upserted. The assumption is that this is a first import of data and skipping - upsert checks will speed up import. + Any existing documents are replaced. The assumption is that this is a first import of data and older data + should be removed. param class_name: The name of the class to import data. Class should be created with weaviate schema. type class_name: str @@ -114,7 +38,7 @@ def import_baseline(class_name: str, seed_baseline_url: str) -> list: return { "data": df, "class_name": class_name, - "upsert": False, + "existing": "replace", "uuid_column": "id", "embedding_column": "vector", "error_threshold": 0, diff --git a/airflow/include/tasks/split.py b/airflow/include/tasks/split.py index 57afcd6f..d6720fad 100644 --- a/airflow/include/tasks/split.py +++ b/airflow/include/tasks/split.py @@ -3,6 +3,7 @@ import pandas as pd from langchain.schema import Document from langchain.text_splitter import ( + HTMLHeaderTextSplitter, Language, RecursiveCharacterTextSplitter, ) @@ -68,3 +69,37 @@ def split_python(dfs: list[pd.DataFrame]) -> pd.DataFrame: df.reset_index(inplace=True, drop=True) return df + + +def split_html(dfs: list[pd.DataFrame]) -> pd.DataFrame: + """ + This task concatenates multiple dataframes from upstream dynamic tasks and splits html code before importing + to a vector database. + + param dfs: A list of dataframes from downstream dynamic tasks + type dfs: list[pd.DataFrame] + + Returned dataframe fields are: + 'docSource': ie. 'astro', 'learn', 'docs', etc. + 'sha': the github sha for the document + 'docLink': URL for the specific document in github. + 'content': Chunked content in markdown format. + + """ + + headers_to_split_on = [ + ("h2", "h2"), + ] + + df = pd.concat(dfs, axis=0, ignore_index=True) + + splitter = HTMLHeaderTextSplitter(headers_to_split_on) + + df["doc_chunks"] = df["content"].apply(lambda x: splitter.split_text(text=x)) + df = df.explode("doc_chunks", ignore_index=True) + df["content"] = df["doc_chunks"].apply(lambda x: x.page_content) + + df.drop(["doc_chunks"], inplace=True, axis=1) + df.reset_index(inplace=True, drop=True) + + return df diff --git a/airflow/requirements.txt b/airflow/requirements.txt index d26994bc..04700214 100644 --- a/airflow/requirements.txt +++ b/airflow/requirements.txt @@ -1,4 +1,4 @@ -/tmp/airflow_provider_weaviate-0.0.1-py3-none-any.whl +apache-airflow-providers-weaviate==1.0.0 apache-airflow-providers-github==2.3.1 apache-airflow-providers-slack==7.3.2 streamlit==1.25.0 @@ -13,4 +13,4 @@ openai==0.28.1 bs4==0.0.1 lxml==4.9.3 tiktoken==0.5.1 -firebase-admin +firebase-admin==6.2.0 diff --git a/airflow/ruff.toml b/airflow/ruff.toml new file mode 100644 index 00000000..424b740e --- /dev/null +++ b/airflow/ruff.toml @@ -0,0 +1,5 @@ +line-length = 100 + +[format] +quote-style = "single" +indent-style = "tab"