Skip to content
Merged
2 changes: 2 additions & 0 deletions temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
in_operation,
info,
logger,
metric_meter,
)
from ._token import WorkflowHandle

Expand All @@ -29,5 +30,6 @@
"in_operation",
"info",
"logger",
"metric_meter",
"WorkflowHandle",
)
68 changes: 47 additions & 21 deletions temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
TYPE_CHECKING,
Any,
Concatenate,
Optional,
Union,
Generic,
TypeVar,
overload,
)

from nexusrpc.handler import CancelOperationContext, StartOperationContext
from nexusrpc.handler import (
CancelOperationContext,
OperationContext,
StartOperationContext,
)

import temporalio.api.common.v1
import temporalio.api.workflowservice.v1
Expand Down Expand Up @@ -97,6 +101,11 @@ def client() -> temporalio.client.Client:
return _temporal_context().client


def metric_meter() -> temporalio.common.MetricMeter:
"""Get the metric meter for the current Nexus operation."""
return _temporal_context().metric_meter


def _temporal_context() -> (
_TemporalStartOperationContext | _TemporalCancelOperationContext
):
Expand Down Expand Up @@ -129,18 +138,39 @@ def _in_nexus_backing_workflow_start_context() -> bool:
return _temporal_nexus_backing_workflow_start_context.get(False)


@dataclass
class _TemporalStartOperationContext:
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
_OperationCtxT = TypeVar("_OperationCtxT", bound=OperationContext)

nexus_context: StartOperationContext
"""Nexus-specific start operation context."""

@dataclass(kw_only=True)
class _TemporalOperationCtx(Generic[_OperationCtxT]):
client: temporalio.client.Client
"""The Temporal client in use by the worker handling the current Nexus operation."""

info: Callable[[], Info]
"""Temporal information about the running Nexus operation."""

client: temporalio.client.Client
"""The Temporal client in use by the worker handling this Nexus operation."""
nexus_context: _OperationCtxT
"""Nexus-specific start operation context."""

_runtime_metric_meter: temporalio.common.MetricMeter
_metric_meter: temporalio.common.MetricMeter | None = None

@property
def metric_meter(self) -> temporalio.common.MetricMeter:
if not self._metric_meter:
self._metric_meter = self._runtime_metric_meter.with_additional_attributes(
{
"nexus_service": self.nexus_context.service,
"nexus_operation": self.nexus_context.operation,
"task_queue": self.info().task_queue,
}
)
return self._metric_meter


@dataclass
class _TemporalStartOperationContext(_TemporalOperationCtx[StartOperationContext]):
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""

@classmethod
def get(cls) -> _TemporalStartOperationContext:
Expand Down Expand Up @@ -227,6 +257,11 @@ def _from_start_operation_context(
**{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)},
)

@property
def metric_meter(self) -> temporalio.common.MetricMeter:
"""The metric meter"""
return self._temporal_context.metric_meter

# Overload for no-param workflow
@overload
async def start_workflow(
Expand Down Expand Up @@ -480,19 +515,10 @@ class NexusCallback:
"""Header to attach to callback request."""


@dataclass(frozen=True)
class _TemporalCancelOperationContext:
@dataclass
class _TemporalCancelOperationContext(_TemporalOperationCtx[CancelOperationContext]):
"""Context for a Nexus cancel operation being handled by a Temporal Nexus Worker."""

nexus_context: CancelOperationContext
"""Nexus-specific cancel operation context."""

info: Callable[[], Info]
"""Temporal information about the running Nexus cancel operation."""

client: temporalio.client.Client
"""The Temporal client in use by the worker handling the current Nexus operation."""

@classmethod
def get(cls) -> _TemporalCancelOperationContext:
ctx = _temporal_cancel_operation_context.get(None)
Expand Down
44 changes: 37 additions & 7 deletions temporalio/worker/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

import asyncio
import concurrent.futures
import contextvars
import json
import threading
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from typing import (
Any,
NoReturn,
Optional,
Type,
Union,
ParamSpec,
TypeVar,
)

import google.protobuf.json_format
Expand Down Expand Up @@ -64,19 +64,25 @@ def __init__(
data_converter: temporalio.converter.DataConverter,
interceptors: Sequence[Interceptor],
metric_meter: temporalio.common.MetricMeter,
executor: concurrent.futures.Executor | None,
executor: concurrent.futures.ThreadPoolExecutor | None,
) -> None:
# TODO: make it possible to query task queue of bridge worker instead of passing
# unused task_queue into _NexusWorker, _ActivityWorker, etc?
self._bridge_worker = bridge_worker
self._client = client
self._task_queue = task_queue
self._handler = Handler(service_handlers, executor)

self._metric_meter = metric_meter

# If an executor is provided, we wrap the executor with one that will
# copy the contextvars.Context to the thread on submit
handler_executor = _ContextPropagatingExecutor(executor) if executor else None

self._handler = Handler(service_handlers, handler_executor)
self._data_converter = data_converter
# TODO(nexus-preview): interceptors
self._interceptors = interceptors
# TODO(nexus-preview): metric_meter
self._metric_meter = metric_meter

self._running_tasks: dict[bytes, _RunningNexusTask] = {}
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()

Expand Down Expand Up @@ -204,6 +210,7 @@ async def _handle_cancel_operation_task(
info=lambda: Info(task_queue=self._task_queue),
nexus_context=ctx,
client=self._client,
_runtime_metric_meter=self._metric_meter,
).set()
try:
try:
Expand Down Expand Up @@ -321,6 +328,7 @@ async def _start_operation(
nexus_context=ctx,
client=self._client,
info=lambda: Info(task_queue=self._task_queue),
_runtime_metric_meter=self._metric_meter,
).set()
input = LazyValue(
serializer=_DummyPayloadSerializer(
Expand Down Expand Up @@ -595,3 +603,25 @@ def cancel(self, reason: str) -> bool:
self._thread_evt.set()
self._async_evt.set()
return True


_P = ParamSpec("_P")
_T = TypeVar("_T")


class _ContextPropagatingExecutor(concurrent.futures.Executor):
def __init__(self, executor: concurrent.futures.ThreadPoolExecutor) -> None:
self._executor = executor

def submit(
self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs
) -> concurrent.futures.Future[_T]:
ctx = contextvars.copy_context()

def wrapped(*a: _P.args, **k: _P.kwargs) -> _T:
return ctx.run(fn, *a, **k)

return self._executor.submit(wrapped, *args, **kwargs)

def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
return self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
7 changes: 3 additions & 4 deletions temporalio/worker/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
workflows: Sequence[type] = [],
activity_executor: concurrent.futures.Executor | None = None,
workflow_task_executor: concurrent.futures.ThreadPoolExecutor | None = None,
nexus_task_executor: concurrent.futures.Executor | None = None,
nexus_task_executor: concurrent.futures.ThreadPoolExecutor | None = None,
workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(),
unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(),
plugins: Sequence[Plugin] = [],
Expand Down Expand Up @@ -186,8 +186,7 @@ def __init__(
the worker is shut down.
nexus_task_executor: Executor to use for non-async
Nexus operations. This is required if any operation start methods
are non-``async def``. :py:class:`concurrent.futures.ThreadPoolExecutor`
is recommended.
are non-``async def``.

.. warning::
This parameter is experimental and unstable.
Expand Down Expand Up @@ -893,7 +892,7 @@ class WorkerConfig(TypedDict, total=False):
workflows: Sequence[type]
activity_executor: concurrent.futures.Executor | None
workflow_task_executor: concurrent.futures.ThreadPoolExecutor | None
nexus_task_executor: concurrent.futures.Executor | None
nexus_task_executor: concurrent.futures.ThreadPoolExecutor | None
workflow_runner: WorkflowRunner
unsandboxed_workflow_runner: WorkflowRunner
plugins: Sequence[Plugin]
Expand Down
30 changes: 30 additions & 0 deletions tests/helpers/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from collections.abc import Mapping


class PromMetricMatcher:
def __init__(self, prom_lines: list[str]) -> None:
self._prom_lines = prom_lines

# Intentionally naive metric checker
def matches_metric_line(
self, line: str, name: str, at_least_labels: Mapping[str, str], value: int
) -> bool:
# Must have metric name
if not line.startswith(name + "{"):
return False
# Must have labels (don't escape for this test)
for k, v in at_least_labels.items():
if f'{k}="{v}"' not in line:
return False
return line.endswith(f" {value}")

def assert_metric_exists(
self, name: str, at_least_labels: Mapping[str, str], value: int
) -> None:
assert any(
self.matches_metric_line(line, name, at_least_labels, value)
for line in self._prom_lines
)

def assert_description_exists(self, name: str, description: str) -> None:
assert f"# HELP {name} {description}" in self._prom_lines
Loading
Loading