diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py index c71ce92ea355d..20dbf9f0dda90 100644 --- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py +++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py @@ -25,6 +25,7 @@ import pymongo from pymongo import MongoClient, ReplaceOne +from pymongo.errors import CollectionInvalid from airflow.exceptions import AirflowConfigException from airflow.hooks.base import BaseHook @@ -225,6 +226,37 @@ def get_collection(self, mongo_collection: str, mongo_db: str | None = None) -> return mongo_conn.get_database(mongo_db).get_collection(mongo_collection) + def create_collection( + self, + mongo_collection: str, + mongo_db: str | None = None, + return_if_exists: bool = True, + **create_kwargs: Any, + ) -> MongoCollection: + """ + Create the collection (optionally a time‑series collection) and return it. + + https://pymongo.readthedocs.io/en/stable/api/pymongo/database.html#pymongo.database.Database.create_collection + + :param mongo_collection: Name of the collection. + :param mongo_db: Target database; defaults to the schema in the connection string. + :param return_if_exists: If True and the collection already exists, return it instead of raising. + :param create_kwargs: Additional keyword arguments forwarded to ``db.create_collection()``, + e.g. ``timeseries={...}``, ``capped=True``. + """ + mongo_db = mongo_db or self.connection.schema + mongo_conn: MongoClient = self.get_conn() + db = mongo_conn.get_database(mongo_db) + + try: + db.create_collection(mongo_collection, **create_kwargs) + except CollectionInvalid: + if not return_if_exists: + raise + # Collection already exists – fall through and fetch it. + + return db.get_collection(mongo_collection) + def aggregate( self, mongo_collection: str, aggregate_query: list, mongo_db: str | None = None, **kwargs ) -> CommandCursor: diff --git a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py index 0d646e6e6cbc6..126617af88fc9 100644 --- a/providers/mongo/tests/unit/mongo/hooks/test_mongo.py +++ b/providers/mongo/tests/unit/mongo/hooks/test_mongo.py @@ -22,6 +22,7 @@ import pymongo import pytest +from pymongo.errors import CollectionInvalid from airflow.exceptions import AirflowConfigException from airflow.models import Connection @@ -387,6 +388,36 @@ def test_distinct_with_filter(self): results = self.hook.distinct(collection, "test_id", {"test_status": "failure"}) assert len(results) == 1 + def test_create_standard_collection(self): + mock_client = mongomock.MongoClient() + self.hook.get_conn = lambda: mock_client + self.hook.connection.schema = "test_db" + + collection = self.hook.create_collection(mongo_collection="plain_collection") + assert collection.name == "plain_collection" + assert "plain_collection" in mock_client["test_db"].list_collection_names() + + def test_return_if_exists_true_returns_existing(self): + mock_client = mongomock.MongoClient() + self.hook.get_conn = lambda: mock_client + self.hook.connection.schema = "test_db" + + first = self.hook.create_collection(mongo_collection="foo") + second = self.hook.create_collection(mongo_collection="foo", return_if_exists=True) + + assert first.full_name == second.full_name + assert "foo" in mock_client["test_db"].list_collection_names() + + def test_return_if_exists_false_raises(self): + mock_client = mongomock.MongoClient() + self.hook.get_conn = lambda: mock_client + self.hook.connection.schema = "test_db" + + self.hook.create_collection(mongo_collection="bar") + + with pytest.raises(CollectionInvalid): + self.hook.create_collection(mongo_collection="bar", return_if_exists=False) + def test_context_manager(): with MongoHook(mongo_conn_id="mongo_default") as ctx_hook: