Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure force_flush at end of AWS Lambda invocation #296

Merged
merged 8 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions logfire-api/logfire_api/_internal/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ class LogfireConfig(_LogfireConfigData):
def configure(self, base_url: str | None, send_to_logfire: bool | Literal['if-token-present'] | None, token: str | None, project_name: str | None, service_name: str | None, service_version: str | None, trace_sample_rate: float | None, console: ConsoleOptions | Literal[False] | None, show_summary: bool | None, config_dir: Path | None, data_dir: Path | None, collect_system_metrics: bool | None, id_generator: IdGenerator | None, ns_timestamp_generator: Callable[[], int] | None, additional_span_processors: Sequence[SpanProcessor] | None, default_span_processor: Callable[[SpanExporter], SpanProcessor] | None, additional_metric_readers: Sequence[MetricReader] | None, pydantic_plugin: PydanticPlugin | None, fast_shutdown: bool, scrubbing: ScrubbingOptions | Literal[False] | None, inspect_arguments: bool | None, tail_sampling: TailSamplingOptions | None) -> None: ...
def initialize(self) -> ProxyTracerProvider:
"""Configure internals to start exporting traces and metrics."""
def force_flush(self, timeout_millis: int = 30000) -> bool:
"""Force flush all spans and metrics.

Args:
timeout_millis: The timeout in milliseconds.

Returns:
Whether the flush of spans was successful.
"""
def get_tracer_provider(self) -> ProxyTracerProvider:
"""Get a tracer provider from this `LogfireConfig`.

Expand Down
4 changes: 2 additions & 2 deletions logfire-api/logfire_api/_internal/main.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,13 @@ class Logfire:
A new Logfire instance with the given settings applied.
"""
def force_flush(self, timeout_millis: int = 3000) -> bool:
"""Force flush all spans.
"""Force flush all spans and metrics.

Args:
timeout_millis: The timeout in milliseconds.

Returns:
Whether the flush was successful.
Whether the flush of spans was successful.
"""
def log_slow_async_callbacks(self, slow_duration: float = 0.1) -> ContextManager[None]:
"""Log a warning whenever a function running in the asyncio event loop blocks for too long.
Expand Down
57 changes: 57 additions & 0 deletions logfire/_internal/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations as _annotations

import dataclasses
import functools
import json
import os
import re
Expand Down Expand Up @@ -759,8 +760,23 @@ def check_token():

# set up context propagation for ThreadPoolExecutor and ProcessPoolExecutor
instrument_executors()

self._ensure_flush_after_aws_lambda()

return self._tracer_provider

def force_flush(self, timeout_millis: int = 30_000) -> bool:
"""Force flush all spans and metrics.

Args:
timeout_millis: The timeout in milliseconds.

Returns:
Whether the flush of spans was successful.
"""
self._meter_provider.force_flush(timeout_millis)
return self._tracer_provider.force_flush(timeout_millis)

def get_tracer_provider(self) -> ProxyTracerProvider:
"""Get a tracer provider from this `LogfireConfig`.

Expand Down Expand Up @@ -807,6 +823,47 @@ def meter(self) -> metrics.Meter:
def _initialize_credentials_from_token(self, token: str) -> LogfireCredentials | None:
return LogfireCredentials.from_token(token, requests.Session(), self.base_url)

def _ensure_flush_after_aws_lambda(self):
"""Ensure that `force_flush` is called after an AWS Lambda invocation.

This way Logfire will just work in Lambda without the user needing to know anything.
Without the `force_flush`, spans may just remain in the queue when the Lambda runtime is frozen.
"""

def wrap_client_post_invocation_method(client_method: Any): # pragma: no cover
@functools.wraps(client_method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
self.force_flush(timeout_millis=3000)
except Exception:
import traceback

traceback.print_exc()

return client_method(*args, **kwargs)

return wrapper

# This suggests that the lambda runtime module moves around a lot:
# https://github.com/getsentry/sentry-python/blob/eab218c91ae2b894df18751e347fd94972a4fe06/sentry_sdk/integrations/aws_lambda.py#L280-L314
# So we just look for the client class in all modules.
# This feels inefficient but it appears be a tiny fraction of the time `configure` takes anyway.
# We convert the modules to a list in case something gets imported during the loop and the dict gets modified.
for mod in list(sys.modules.values()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need list here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same idea as #277. Added a comment.

try:
client = getattr(mod, 'LambdaRuntimeClient', None)
except Exception: # pragma: no cover
continue
if not client: # pragma: no branch
continue
try: # pragma: no cover
client.post_invocation_error = wrap_client_post_invocation_method(client.post_invocation_error)
client.post_invocation_result = wrap_client_post_invocation_method(client.post_invocation_result)
except Exception: # pragma: no cover
import traceback

traceback.print_exc()


def _get_default_span_processor(exporter: SpanExporter) -> SpanProcessor:
schedule_delay_millis = _get_int_from_env(OTEL_BSP_SCHEDULE_DELAY) or 500
Expand Down
6 changes: 3 additions & 3 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,15 +745,15 @@ def with_settings(
)

def force_flush(self, timeout_millis: int = 3_000) -> bool: # pragma: no cover
"""Force flush all spans.
"""Force flush all spans and metrics.

Args:
timeout_millis: The timeout in milliseconds.

Returns:
Whether the flush was successful.
Whether the flush of spans was successful.
"""
return self._tracer_provider.force_flush(timeout_millis)
return self._config.force_flush(timeout_millis)

def log_slow_async_callbacks(self, slow_duration: float = 0.1) -> ContextManager[None]:
"""Log a warning whenever a function running in the asyncio event loop blocks for too long.
Expand Down
2 changes: 1 addition & 1 deletion logfire/_internal/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def resource(self) -> Resource: # pragma: no cover
return self.provider.resource
return Resource.create({ResourceAttributes.SERVICE_NAME: self.config.service_name})

def force_flush(self, timeout_millis: int = 30000) -> bool: # pragma: no cover
def force_flush(self, timeout_millis: int = 30000) -> bool:
with self.lock:
if isinstance(self.provider, SDKTracerProvider): # pragma: no branch
return self.provider.force_flush(timeout_millis)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from opentelemetry.proto.common.v1.common_pb2 import AnyValue
from opentelemetry.sdk.metrics.export import InMemoryMetricReader
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor
from opentelemetry.trace import StatusCode
from pydantic import BaseModel
from pydantic_core import ValidationError
Expand Down Expand Up @@ -2570,3 +2570,14 @@ def test_otel_status_code(exporter: TestExporter):

assert exporter.exported_spans[0].status.status_code == StatusCode.UNSET
assert exporter.exported_spans[1].status.status_code == StatusCode.ERROR


def test_force_flush(exporter: TestExporter):
logfire.configure(send_to_logfire=False, console=False, additional_span_processors=[BatchSpanProcessor(exporter)])
logfire.info('hi')

assert not exporter.exported_spans_as_dict()

logfire.force_flush()

assert len(exporter.exported_spans_as_dict()) == 1