diff --git a/newrelic/config.py b/newrelic/config.py index 9b15737fa9..7083cc872b 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -2309,6 +2309,11 @@ def _process_module_builtin_defaults(): "newrelic.hooks.datastore_firestore", "instrument_google_cloud_firestore_v1_bulk_batch", ) + _process_module_definition( + "google.cloud.firestore_v1.transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_transaction", + ) _process_module_definition( "ariadne.asgi", diff --git a/newrelic/hooks/datastore_firestore.py b/newrelic/hooks/datastore_firestore.py index 9c3997bb3e..b591176a98 100644 --- a/newrelic/hooks/datastore_firestore.py +++ b/newrelic/hooks/datastore_firestore.py @@ -133,6 +133,7 @@ def instrument_google_cloud_firestore_v1_batch(module): module, "WriteBatch.%s" % method, product="Firestore", target=None, operation=method ) + def instrument_google_cloud_firestore_v1_bulk_batch(module): if hasattr(module, "BulkWriteBatch"): class_ = module.BulkWriteBatch @@ -141,3 +142,14 @@ def instrument_google_cloud_firestore_v1_bulk_batch(module): wrap_datastore_trace( module, "BulkWriteBatch.%s" % method, product="Firestore", target=None, operation=method ) + + +def instrument_google_cloud_firestore_v1_transaction(module): + if hasattr(module, "Transaction"): + class_ = module.Transaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, "Transaction.%s" % method, product="Firestore", target=None, operation=operation + ) diff --git a/tests/datastore_firestore/test_transaction.py b/tests/datastore_firestore/test_transaction.py new file mode 100644 index 0000000000..3e462b3244 --- /dev/null +++ b/tests/datastore_firestore/test_transaction.py @@ -0,0 +1,122 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from newrelic.api.time_trace import current_trace +from newrelic.api.datastore_trace import DatastoreTrace +from testing_support.db_settings import firestore_settings +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) + + +@pytest.fixture(autouse=True) +def sample_data(collection, reset_firestore): + # reset_firestore must be run before, not after this fixture + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +def _exercise_transaction_commit(client, collection): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # get a DocumentReference + list(transaction.get(collection.document("doc1"))) + + # get a Query + query = collection.select("x").where(field_path="x", op_string=">", value=2) + assert len(list(transaction.get(query))) == 1 + + # get_all on a list of DocumentReferences + all_docs = transaction.get_all([collection.document("doc%d" % x) for x in range(1, 4)]) + assert len(list(all_docs)) == 3 + + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 0}) + transaction.delete(collection.document("doc3")) + + _exercise(client.transaction()) + assert len(list(collection.list_documents())) == 2 + + +def _exercise_transaction_rollback(client, collection): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 99}) + transaction.delete(collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + _exercise(client.transaction()) + assert len(list(collection.list_documents())) == 3 + + +def test_firestore_transaction_commit(client, collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ("Datastore/operation/Firestore/get_all", 2), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + _exercise_transaction_commit(client, collection) + + _test() + + +def test_firestore_transaction_rollback(client, collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + _exercise_transaction_rollback(client, collection) + + _test()