Skip to content

Commit

Permalink
Fix error in checking for generators in auto-tracing (#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki authored Oct 14, 2024
1 parent aa65f2d commit fe69568
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 20 deletions.
6 changes: 1 addition & 5 deletions logfire-api/logfire_api/_internal/auto_trace/rewrite_ast.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ast
from ..ast_utils import BaseTransformer as BaseTransformer, LogfireArgs as LogfireArgs
from ..main import Logfire as Logfire
from _typeshed import Incomplete
from dataclasses import dataclass
from typing import Any, Callable, ContextManager, TypeVar

Expand Down Expand Up @@ -51,7 +50,4 @@ def no_auto_trace(x: T) -> T:
This decorator simply returns the argument unchanged, so there is zero runtime overhead.
"""

GENERATOR_CODE_FLAGS: Incomplete

def is_generator_function(func_def: ast.FunctionDef | ast.AsyncFunctionDef): ...
def has_yield(node: ast.AST): ...
24 changes: 11 additions & 13 deletions logfire/_internal/auto_trace/rewrite_ast.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import ast
import inspect
import types
import uuid
from collections import deque
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, ContextManager, TypeVar
Expand Down Expand Up @@ -97,7 +96,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
return super().visit_FunctionDef(node)

def rewrite_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.AST:
if is_generator_function(node):
if has_yield(node):
return node

return super().rewrite_function(node, qualname)
Expand Down Expand Up @@ -168,13 +167,12 @@ def no_auto_trace(x: T) -> T:
return x # pragma: no cover


GENERATOR_CODE_FLAGS = inspect.CO_GENERATOR | inspect.CO_ASYNC_GENERATOR


def is_generator_function(func_def: ast.FunctionDef | ast.AsyncFunctionDef):
module_node = ast.parse('')
module_node.body = [func_def]
code = compile(module_node, '<string>', 'exec')
return any(
isinstance(const, types.CodeType) and (const.co_flags & GENERATOR_CODE_FLAGS) for const in code.co_consts
)
def has_yield(node: ast.AST):
queue = deque([node])
while queue:
node = queue.popleft()
for child in ast.iter_child_nodes(node):
if isinstance(child, (ast.Yield, ast.YieldFrom)):
return True
if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)):
queue.append(child)
3 changes: 1 addition & 2 deletions tests/test_auto_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,12 @@ def test_no_auto_trace():
)


# language=Python
generators_sample = """
def make_gen():
def gen():
async def foo():
async def bar():
pass
return lambda: (yield 1)
yield bar()
yield from foo()
return gen
Expand Down

0 comments on commit fe69568

Please sign in to comment.