Skip to content

Commit

Permalink
Use functools.wraps in @Instrument
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki committed Oct 31, 2024
1 parent 70d93fe commit 42527ae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion logfire/_internal/instrument.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import functools
import inspect
import warnings
from collections.abc import Sequence
Expand Down Expand Up @@ -105,7 +106,8 @@ def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R:
with open_span(*func_args, **func_kwargs):
return func(*func_args, **func_kwargs)

return wrapper # type: ignore
wrapper = functools.wraps(func)(wrapper) # type: ignore
return wrapper

return decorator

Expand Down
9 changes: 9 additions & 0 deletions tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ async def foo():
)
assert warnings[0].filename.endswith('test_logfire.py')
assert warnings[0].lineno == inspect.currentframe().f_lineno - 8 # type: ignore
assert foo.__name__ == 'foo'

assert [value async for value in foo()] == [1]

Expand Down Expand Up @@ -752,6 +753,7 @@ def foo():
)
assert warnings[0].filename.endswith('test_logfire.py')
assert warnings[0].lineno == inspect.currentframe().f_lineno - 8 # type: ignore
assert foo.__name__ == 'foo'

assert list(foo()) == [1]

Expand Down Expand Up @@ -897,6 +899,8 @@ def test_instrument_contextmanager_prevent_warning(exporter: TestExporter):
def foo():
yield

assert foo.__name__ == 'foo'

with foo():
logfire.info('hello')

Expand Down Expand Up @@ -944,6 +948,8 @@ async def test_instrument_asynccontextmanager_prevent_warning(exporter: TestExpo
async def foo():
yield

assert foo.__name__ == 'foo'

async with foo():
logfire.info('hello')

Expand Down Expand Up @@ -990,6 +996,7 @@ async def test_instrument_async(exporter: TestExporter):
async def foo():
return 456

assert foo.__name__ == 'foo'
assert await foo() == 456

assert exporter.exported_spans_as_dict(_strip_function_qualname=False) == snapshot(
Expand Down Expand Up @@ -1018,6 +1025,8 @@ def test_instrument_extract_false(exporter: TestExporter):
def hello_world(a: int) -> str:
return f'hello {a}'

assert hello_world.__name__ == 'hello_world'

assert hello_world(123) == 'hello 123'

assert exporter.exported_spans_as_dict(_strip_function_qualname=False) == snapshot(
Expand Down

0 comments on commit 42527ae

Please sign in to comment.