Skip to content

Commit 5a211af

Browse files
Adds missing return type to collect_traces function (#442)
* Adds missing return type hint to the `collect_traces` function * Creates type safety tests to `collect_traces` function * Removes `@contextmanager` decorator from `collect_traces` function Now we pass `collect_traces` directly to `contextmanager` function at end of file to avoid return type problems. * Fixes type safety tests for `collect_traces` function * Adds `noqa: WPS501` flag in `try` block on `collect_traces` function * Modifies `tracing.py` To fix type hint errors we separate all the possible `collect_traces` function signatures using `@overload` decorator The implementation of `collect_traces`, now, acts like a factory! * Updates safety tests of `collect_traces` function * Updates documentation of `collect_traces` function * Changes type var name from `Function` to `_FunctionType` * Changes `factory` function to use the `@contextmanager` decorator instead of passing `factory` through the `contextmanager` function * Updates safety tests of `collect_traces` function, specifically the `collect_traces_context_manager_return_type_two` case
1 parent 3e8b1b9 commit 5a211af

File tree

3 files changed

+72
-11
lines changed

3 files changed

+72
-11
lines changed

docs/pages/development.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Or as a decorator:
5454
>>> from returns.result import Failure, Result
5555
>>> from returns.primitives.tracing import collect_traces
5656
57-
>>> @collect_traces()
57+
>>> @collect_traces
5858
... def traced_function(value: str) -> IOResult[str, str]:
5959
... return IOFailure(value)
6060

returns/primitives/tracing.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,35 @@
11
import types
22
from contextlib import contextmanager
33
from inspect import FrameInfo, stack
4-
from typing import List, Optional
4+
from typing import (
5+
Callable,
6+
ContextManager,
7+
Iterator,
8+
List,
9+
Optional,
10+
TypeVar,
11+
Union,
12+
overload,
13+
)
514

615
from returns.result import _Failure
716

17+
_FunctionType = TypeVar('_FunctionType', bound=Callable)
818

9-
@contextmanager
10-
def collect_traces():
19+
20+
@overload
21+
def collect_traces() -> ContextManager[None]:
22+
"""Context Manager to active traces collect to the Failures."""
23+
24+
25+
@overload
26+
def collect_traces(function: _FunctionType) -> _FunctionType:
27+
"""Decorator to active traces collect to the Failures."""
28+
29+
30+
def collect_traces(
31+
function: Optional[_FunctionType] = None,
32+
) -> Union[_FunctionType, ContextManager[None]]: # noqa: DAR101, DAR201, DAR301
1133
"""
1234
Context Manager/Decorator to active traces collect to the Failures.
1335
@@ -36,13 +58,16 @@ def collect_traces():
3658
# doctest: # noqa: DAR301, E501
3759
3860
"""
39-
unpatched_get_trace = getattr(_Failure, '_get_trace') # noqa: B009
40-
substitute_get_trace = types.MethodType(_get_trace, _Failure)
41-
setattr(_Failure, '_get_trace', substitute_get_trace) # noqa: B010
42-
try:
43-
yield
44-
finally:
45-
setattr(_Failure, '_get_trace', unpatched_get_trace) # noqa: B010
61+
@contextmanager
62+
def factory() -> Iterator[None]:
63+
unpatched_get_trace = getattr(_Failure, '_get_trace') # noqa: B009
64+
substitute_get_trace = types.MethodType(_get_trace, _Failure)
65+
setattr(_Failure, '_get_trace', substitute_get_trace) # noqa: B010
66+
try: # noqa: WPS501
67+
yield
68+
finally:
69+
setattr(_Failure, '_get_trace', unpatched_get_trace) # noqa: B010
70+
return factory()(function) if function else factory()
4671

4772

4873
def _get_trace(_self: _Failure) -> Optional[List[FrameInfo]]:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
- case: collect_traces_context_manager_return_type_one
2+
disable_cache: true
3+
main: |
4+
from returns.primitives.tracing import collect_traces
5+
6+
reveal_type(collect_traces) # N: Revealed type is 'Overload(def () -> typing.ContextManager[None], def [_FunctionType <: def (*Any, **Any) -> Any] (function: _FunctionType`-1) -> _FunctionType`-1)'
7+
8+
- case: collect_traces_context_manager_return_type_two
9+
disable_cache: true
10+
main: |
11+
from returns.primitives.tracing import collect_traces
12+
13+
with reveal_type(collect_traces()): # N: Revealed type is 'typing.ContextManager[None]'
14+
pass
15+
16+
- case: collect_traces_decorated_function_return_type
17+
disable_cache: true
18+
main: |
19+
from returns.primitives.tracing import collect_traces
20+
21+
@collect_traces
22+
def function() -> int:
23+
return 0
24+
25+
reveal_type(function) # N: Revealed type is 'def () -> builtins.int'
26+
27+
- case: collect_traces_decorated_function_with_argument_return_type
28+
disable_cache: true
29+
main: |
30+
from returns.primitives.tracing import collect_traces
31+
32+
@collect_traces
33+
def function(number: int) -> str:
34+
return str(number)
35+
36+
reveal_type(function) # N: Revealed type is 'def (number: builtins.int) -> builtins.str'

0 commit comments

Comments
 (0)