From 95644ecb33fd310b7e9fa95e47ea922b4c03415a Mon Sep 17 00:00:00 2001 From: nailo2c Date: Mon, 12 May 2025 19:17:46 -0700 Subject: [PATCH 1/6] feat: add create_collection function with unit tests --- .../airflow/providers/mongo/hooks/mongo.py | 29 +++++++++++++++ .../tests/unit/mongo/hooks/test_mongo.py | 35 +++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py index c71ce92ea355d..de0274472120e 100644 --- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py +++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py @@ -225,6 +225,35 @@ def get_collection(self, mongo_collection: str, mongo_db: str | None = None) -> return mongo_conn.get_database(mongo_db).get_collection(mongo_collection) + def create_collection( + self, mongo_collection: str, mongo_db: str | None = None, create_if_exists: bool = True, **create_kwargs: Any, + ) -> MongoCollection: + """ + Create the collection (optionally a time‑series collection) and return it. + + https://pymongo.readthedocs.io/en/stable/api/pymongo/database.html#pymongo.database.Database.create_collection + + :param mongo_collection: Name of the collection. + :param mongo_db: Target database; defaults to the schema in the connection string. + :param create_if_exists: If True and the collection already exists, return it instead of raising. + :param create_kwargs: Additional keyword arguments forwarded to ``db.create_collection()``, + e.g. ``timeseries={...}``, ``capped=True``. + """ + from pymongo.errors import CollectionInvalid + + mongo_db = mongo_db or self.connection.schema + mongo_conn: MongoClient = self.get_conn() + db = mongo_conn.get_database(mongo_db) + + try: + db.create_collection(mongo_collection, **create_kwargs) + except CollectionInvalid: + if not create_if_exists: + raise + # Collection already exists – fall through and fetch it. + + return db.get_collection(mongo_collection) + def aggregate( self, mongo_collection: str, aggregate_query: list, mongo_db: str | None = None, **kwargs ) -> CommandCursor: diff --git a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py index 0d646e6e6cbc6..e4f71f1c29a4b 100644 --- a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py +++ b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py @@ -29,6 +29,8 @@ from tests_common.test_utils.compat import connection_as_json +from pymongo.errors import CollectionInvalid + pytestmark = pytest.mark.db_test if TYPE_CHECKING: @@ -387,6 +389,39 @@ def test_distinct_with_filter(self): results = self.hook.distinct(collection, "test_id", {"test_status": "failure"}) assert len(results) == 1 + def test_create_standard_collection(self): + mock_client = mongomock.MongoClient() + self.hook.get_conn = lambda: mock_client + self.hook.connection.schema = "test_db" + + # 不帶 timeseries 的建立,就能在 mongomock 裡跑通 + collection = self.hook.create_collection(mongo_collection="plain_collection") + assert collection.name == "plain_collection" + assert "plain_collection" in mock_client["test_db"].list_collection_names() + + def test_create_if_exists_true_returns_existing(self): + mock_client = mongomock.MongoClient() + self.hook.get_conn = lambda: mock_client + self.hook.connection.schema = "test_db" + + first = self.hook.create_collection(mongo_collection="foo") + second = self.hook.create_collection(mongo_collection="foo", create_if_exists=True) + + assert first.full_name == second.full_name + assert "foo" in mock_client["test_db"].list_collection_names() + + + def test_create_if_exists_false_raises(self): + # Patch get_conn → mongomock client,並指定預設 DB + mock_client = mongomock.MongoClient() + self.hook.get_conn = lambda: mock_client + self.hook.connection.schema = "test_db" + + self.hook.create_collection(mongo_collection="bar") + + with pytest.raises(CollectionInvalid): + self.hook.create_collection(mongo_collection="bar", create_if_exists=False) + def test_context_manager(): with MongoHook(mongo_conn_id="mongo_default") as ctx_hook: From 3c1847dbf9800cfb74c6c2c157e329bee57d2bd9 Mon Sep 17 00:00:00 2001 From: nailo2c Date: Mon, 12 May 2025 19:30:53 -0700 Subject: [PATCH 2/6] rm comments --- providers/mongo/tests/unit/mongo/hooks/test_mongo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py index e4f71f1c29a4b..a20dfc508aa8b 100644 --- a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py +++ b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py @@ -394,7 +394,6 @@ def test_create_standard_collection(self): self.hook.get_conn = lambda: mock_client self.hook.connection.schema = "test_db" - # 不帶 timeseries 的建立,就能在 mongomock 裡跑通 collection = self.hook.create_collection(mongo_collection="plain_collection") assert collection.name == "plain_collection" assert "plain_collection" in mock_client["test_db"].list_collection_names() @@ -412,7 +411,6 @@ def test_create_if_exists_true_returns_existing(self): def test_create_if_exists_false_raises(self): - # Patch get_conn → mongomock client,並指定預設 DB mock_client = mongomock.MongoClient() self.hook.get_conn = lambda: mock_client self.hook.connection.schema = "test_db" From 91fa706fee8138df987ff61b5b549ecdc9ed4a40 Mon Sep 17 00:00:00 2001 From: nailo2c Date: Mon, 12 May 2025 20:56:09 -0700 Subject: [PATCH 3/6] fix static checks --- providers/mongo/src/airflow/providers/mongo/hooks/mongo.py | 6 +++++- providers/mongo/tests/unit/mongo/hooks/test_mongo.py | 4 +--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py index de0274472120e..965b15d1f392c 100644 --- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py +++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py @@ -226,7 +226,11 @@ def get_collection(self, mongo_collection: str, mongo_db: str | None = None) -> return mongo_conn.get_database(mongo_db).get_collection(mongo_collection) def create_collection( - self, mongo_collection: str, mongo_db: str | None = None, create_if_exists: bool = True, **create_kwargs: Any, + self, + mongo_collection: str, + mongo_db: str | None = None, + create_if_exists: bool = True, + **create_kwargs: Any, ) -> MongoCollection: """ Create the collection (optionally a time‑series collection) and return it. diff --git a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py index a20dfc508aa8b..efbdf0b4a972f 100644 --- a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py +++ b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py @@ -22,6 +22,7 @@ import pymongo import pytest +from pymongo.errors import CollectionInvalid from airflow.exceptions import AirflowConfigException from airflow.models import Connection @@ -29,8 +30,6 @@ from tests_common.test_utils.compat import connection_as_json -from pymongo.errors import CollectionInvalid - pytestmark = pytest.mark.db_test if TYPE_CHECKING: @@ -409,7 +408,6 @@ def test_create_if_exists_true_returns_existing(self): assert first.full_name == second.full_name assert "foo" in mock_client["test_db"].list_collection_names() - def test_create_if_exists_false_raises(self): mock_client = mongomock.MongoClient() self.hook.get_conn = lambda: mock_client From 3c14eeafd1ffe6a3a99df30b5f4011a21371c916 Mon Sep 17 00:00:00 2001 From: Aaron Chen Date: Mon, 19 May 2025 09:29:22 -0700 Subject: [PATCH 4/6] make the type of `create_kwargs` clearer Co-authored-by: Wei Lee --- providers/mongo/src/airflow/providers/mongo/hooks/mongo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py index 965b15d1f392c..8f300b20248c5 100644 --- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py +++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py @@ -230,7 +230,7 @@ def create_collection( mongo_collection: str, mongo_db: str | None = None, create_if_exists: bool = True, - **create_kwargs: Any, + **create_kwargs: dict[str, Any], ) -> MongoCollection: """ Create the collection (optionally a time‑series collection) and return it. From f6e7f840477716e935169e5600c60da48a9ca251 Mon Sep 17 00:00:00 2001 From: nailo2c Date: Mon, 19 May 2025 09:49:42 -0700 Subject: [PATCH 5/6] modfiy variable name: create_if_exists -> return_if_exists --- .../mongo/src/airflow/providers/mongo/hooks/mongo.py | 8 ++++---- providers/mongo/tests/unit/mongo/hooks/test_mongo.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py index 8f300b20248c5..f92895342260d 100644 --- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py +++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py @@ -229,8 +229,8 @@ def create_collection( self, mongo_collection: str, mongo_db: str | None = None, - create_if_exists: bool = True, - **create_kwargs: dict[str, Any], + return_if_exists: bool = True, + **create_kwargs: Any, ) -> MongoCollection: """ Create the collection (optionally a time‑series collection) and return it. @@ -239,7 +239,7 @@ def create_collection( :param mongo_collection: Name of the collection. :param mongo_db: Target database; defaults to the schema in the connection string. - :param create_if_exists: If True and the collection already exists, return it instead of raising. + :param return_if_exists: If True and the collection already exists, return it instead of raising. :param create_kwargs: Additional keyword arguments forwarded to ``db.create_collection()``, e.g. ``timeseries={...}``, ``capped=True``. """ @@ -252,7 +252,7 @@ def create_collection( try: db.create_collection(mongo_collection, **create_kwargs) except CollectionInvalid: - if not create_if_exists: + if not return_if_exists: raise # Collection already exists – fall through and fetch it. diff --git a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py index efbdf0b4a972f..126617af88fc9 100644 --- a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py +++ b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py @@ -397,18 +397,18 @@ def test_create_standard_collection(self): assert collection.name == "plain_collection" assert "plain_collection" in mock_client["test_db"].list_collection_names() - def test_create_if_exists_true_returns_existing(self): + def test_return_if_exists_true_returns_existing(self): mock_client = mongomock.MongoClient() self.hook.get_conn = lambda: mock_client self.hook.connection.schema = "test_db" first = self.hook.create_collection(mongo_collection="foo") - second = self.hook.create_collection(mongo_collection="foo", create_if_exists=True) + second = self.hook.create_collection(mongo_collection="foo", return_if_exists=True) assert first.full_name == second.full_name assert "foo" in mock_client["test_db"].list_collection_names() - def test_create_if_exists_false_raises(self): + def test_return_if_exists_false_raises(self): mock_client = mongomock.MongoClient() self.hook.get_conn = lambda: mock_client self.hook.connection.schema = "test_db" @@ -416,7 +416,7 @@ def test_create_if_exists_false_raises(self): self.hook.create_collection(mongo_collection="bar") with pytest.raises(CollectionInvalid): - self.hook.create_collection(mongo_collection="bar", create_if_exists=False) + self.hook.create_collection(mongo_collection="bar", return_if_exists=False) def test_context_manager(): From ef421a9de4e14ef9d128e7c36878fd04b47cae14 Mon Sep 17 00:00:00 2001 From: nailo2c Date: Mon, 19 May 2025 09:52:12 -0700 Subject: [PATCH 6/6] move import CollectionInvalid to top --- providers/mongo/src/airflow/providers/mongo/hooks/mongo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py index f92895342260d..20dbf9f0dda90 100644 --- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py +++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py @@ -25,6 +25,7 @@ import pymongo from pymongo import MongoClient, ReplaceOne +from pymongo.errors import CollectionInvalid from airflow.exceptions import AirflowConfigException from airflow.hooks.base import BaseHook @@ -243,8 +244,6 @@ def create_collection( :param create_kwargs: Additional keyword arguments forwarded to ``db.create_collection()``, e.g. ``timeseries={...}``, ``capped=True``. """ - from pymongo.errors import CollectionInvalid - mongo_db = mongo_db or self.connection.schema mongo_conn: MongoClient = self.get_conn() db = mongo_conn.get_database(mongo_db)