From 42527ae4bed945be9b8ddd6c3b772c28ed523be9 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Thu, 31 Oct 2024 12:11:17 +0200 Subject: [PATCH] Use functools.wraps in @instrument --- logfire/_internal/instrument.py | 4 +++- tests/test_logfire.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/logfire/_internal/instrument.py b/logfire/_internal/instrument.py index 44a6be282..5f7f7b940 100644 --- a/logfire/_internal/instrument.py +++ b/logfire/_internal/instrument.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import functools import inspect import warnings from collections.abc import Sequence @@ -105,7 +106,8 @@ def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: with open_span(*func_args, **func_kwargs): return func(*func_args, **func_kwargs) - return wrapper # type: ignore + wrapper = functools.wraps(func)(wrapper) # type: ignore + return wrapper return decorator diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 213ae276e..6d21332ab 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -712,6 +712,7 @@ async def foo(): ) assert warnings[0].filename.endswith('test_logfire.py') assert warnings[0].lineno == inspect.currentframe().f_lineno - 8 # type: ignore + assert foo.__name__ == 'foo' assert [value async for value in foo()] == [1] @@ -752,6 +753,7 @@ def foo(): ) assert warnings[0].filename.endswith('test_logfire.py') assert warnings[0].lineno == inspect.currentframe().f_lineno - 8 # type: ignore + assert foo.__name__ == 'foo' assert list(foo()) == [1] @@ -897,6 +899,8 @@ def test_instrument_contextmanager_prevent_warning(exporter: TestExporter): def foo(): yield + assert foo.__name__ == 'foo' + with foo(): logfire.info('hello') @@ -944,6 +948,8 @@ async def test_instrument_asynccontextmanager_prevent_warning(exporter: TestExpo async def foo(): yield + assert foo.__name__ == 'foo' + async with foo(): logfire.info('hello') @@ -990,6 +996,7 @@ async def test_instrument_async(exporter: TestExporter): async def foo(): return 456 + assert foo.__name__ == 'foo' assert await foo() == 456 assert exporter.exported_spans_as_dict(_strip_function_qualname=False) == snapshot( @@ -1018,6 +1025,8 @@ def test_instrument_extract_false(exporter: TestExporter): def hello_world(a: int) -> str: return f'hello {a}' + assert hello_world.__name__ == 'hello_world' + assert hello_world(123) == 'hello 123' assert exporter.exported_spans_as_dict(_strip_function_qualname=False) == snapshot(