Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions providers/mongo/src/airflow/providers/mongo/hooks/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import pymongo
from pymongo import MongoClient, ReplaceOne
from pymongo.errors import CollectionInvalid

from airflow.exceptions import AirflowConfigException
from airflow.hooks.base import BaseHook
Expand Down Expand Up @@ -225,6 +226,37 @@ def get_collection(self, mongo_collection: str, mongo_db: str | None = None) ->

return mongo_conn.get_database(mongo_db).get_collection(mongo_collection)

def create_collection(
self,
mongo_collection: str,
mongo_db: str | None = None,
return_if_exists: bool = True,
**create_kwargs: Any,
) -> MongoCollection:
"""
Create the collection (optionally a time‑series collection) and return it.

https://pymongo.readthedocs.io/en/stable/api/pymongo/database.html#pymongo.database.Database.create_collection

:param mongo_collection: Name of the collection.
:param mongo_db: Target database; defaults to the schema in the connection string.
:param return_if_exists: If True and the collection already exists, return it instead of raising.
:param create_kwargs: Additional keyword arguments forwarded to ``db.create_collection()``,
e.g. ``timeseries={...}``, ``capped=True``.
"""
mongo_db = mongo_db or self.connection.schema
mongo_conn: MongoClient = self.get_conn()
db = mongo_conn.get_database(mongo_db)

try:
db.create_collection(mongo_collection, **create_kwargs)
except CollectionInvalid:
if not return_if_exists:
raise
# Collection already exists – fall through and fetch it.

return db.get_collection(mongo_collection)

def aggregate(
self, mongo_collection: str, aggregate_query: list, mongo_db: str | None = None, **kwargs
) -> CommandCursor:
Expand Down
31 changes: 31 additions & 0 deletions providers/mongo/tests/unit/mongo/hooks/test_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pymongo
import pytest
from pymongo.errors import CollectionInvalid

from airflow.exceptions import AirflowConfigException
from airflow.models import Connection
Expand Down Expand Up @@ -387,6 +388,36 @@ def test_distinct_with_filter(self):
results = self.hook.distinct(collection, "test_id", {"test_status": "failure"})
assert len(results) == 1

def test_create_standard_collection(self):
mock_client = mongomock.MongoClient()
self.hook.get_conn = lambda: mock_client
self.hook.connection.schema = "test_db"

collection = self.hook.create_collection(mongo_collection="plain_collection")
assert collection.name == "plain_collection"
assert "plain_collection" in mock_client["test_db"].list_collection_names()

def test_return_if_exists_true_returns_existing(self):
mock_client = mongomock.MongoClient()
self.hook.get_conn = lambda: mock_client
self.hook.connection.schema = "test_db"

first = self.hook.create_collection(mongo_collection="foo")
second = self.hook.create_collection(mongo_collection="foo", return_if_exists=True)

assert first.full_name == second.full_name
assert "foo" in mock_client["test_db"].list_collection_names()

def test_return_if_exists_false_raises(self):
mock_client = mongomock.MongoClient()
self.hook.get_conn = lambda: mock_client
self.hook.connection.schema = "test_db"

self.hook.create_collection(mongo_collection="bar")

with pytest.raises(CollectionInvalid):
self.hook.create_collection(mongo_collection="bar", return_if_exists=False)


def test_context_manager():
with MongoHook(mongo_conn_id="mongo_default") as ctx_hook:
Expand Down