Skip to content

Commit

Permalink
Add run_id parameter to the search_trace API (mlflow#13251)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 authored Sep 27, 2024
1 parent 76776c9 commit ba3d931
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 1 deletion.
15 changes: 14 additions & 1 deletion docs/source/llms/tracing/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ For example, in the following code, the traces are generated within the ``start_
mlflow.set_experiment("Run Associated Tracing")
# Start a new MLflow Run
with mlflow.start_run():
with mlflow.start_run() as run:
# Initiate a trace by starting a Span context from within the Run context
with mlflow.start_span(name="Run Span") as parent_span:
parent_span.set_inputs({"input": "a"})
Expand All @@ -904,6 +904,19 @@ well as providing a link to navigate to the run within the MLflow UI. See the be
:width: 100%
:align: center

You can also programmatically retrieve the traces associated to a particular Run by using the :py:meth:`mlflow.client.MlflowClient.search_traces` method.

.. code-block:: python
from mlflow import MlflowClient
client = MlflowClient()
# Retrieve traces associated with a specific Run
traces = client.search_traces(run_id=run.info.run_id)
print(traces)
Q: Can I use the fluent API and the client API together?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
23 changes: 23 additions & 0 deletions mlflow/tracing/fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def search_traces(
max_results: Optional[int] = None,
order_by: Optional[List[str]] = None,
extract_fields: Optional[List[str]] = None,
run_id: Optional[str] = None,
) -> "pandas.DataFrame":
"""
Return traces that match the given list of search expressions within the experiments.
Expand Down Expand Up @@ -374,6 +375,9 @@ def search_traces(
# span name and field name contain a dot
extract_fields = ["`span.name`.inputs.`field.name`"]
run_id: A run id to scope the search. When a trace is created under an active run,
it will be associated with the run and you can filter on the run id to retrieve the
trace. See the example below for how to filter traces by run id.
Returns:
A Pandas DataFrame containing information about traces that satisfy the search expressions.
Expand Down Expand Up @@ -406,6 +410,24 @@ def search_traces(
mlflow.search_traces(
extract_fields=["non_dict_span.inputs", "non_dict_span.outputs"],
)
.. code-block:: python
:test:
:caption: Search traces by run ID
import mlflow
@mlflow.trace
def traced_func(x):
return x + 1
with mlflow.start_run() as run:
traced_func(1)
mlflow.search_traces(run_id=run.info.run_id)
"""
# Check if pandas is installed early to avoid unnecessary computation
if importlib.util.find_spec("pandas") is None:
Expand All @@ -428,6 +450,7 @@ def search_traces(
def pagination_wrapper_func(number_to_get, next_page_token):
return MlflowClient().search_traces(
experiment_ids=experiment_ids,
run_id=run_id,
max_results=number_to_get,
filter_string=filter_string,
order_by=order_by,
Expand Down
17 changes: 17 additions & 0 deletions mlflow/tracking/_tracking_service/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from mlflow.store.tracking.rest_store import RestStore
from mlflow.tracing.artifact_utils import get_artifact_uri_for_trace
from mlflow.tracing.constant import TraceMetadataKey
from mlflow.tracing.utils import TraceJSONEncoder, exclude_immutable_tags
from mlflow.tracking._tracking_service import utils
from mlflow.tracking.metric_value_conversion_utils import convert_metric_value_to_float_if_possible
Expand Down Expand Up @@ -312,6 +313,7 @@ def search_traces(
max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
order_by: Optional[List[str]] = None,
page_token: Optional[str] = None,
run_id: Optional[str] = None,
) -> PagedList[Trace]:
def download_trace_data(trace_info: TraceInfo) -> Optional[Trace]:
"""
Expand All @@ -332,6 +334,21 @@ def download_trace_data(trace_info: TraceInfo) -> Optional[Trace]:
else:
return Trace(trace_info, trace_data)

# If run_id is provided, add it to the filter string
if run_id:
additional_filter = f"metadata.{TraceMetadataKey.SOURCE_RUN} = '{run_id}'"
if filter_string:
if TraceMetadataKey.SOURCE_RUN in filter_string:
raise MlflowException(
"You cannot filter by run_id when it is already part of the filter string."
f"Please remove the {TraceMetadataKey.SOURCE_RUN} filter from the filter "
"string and try again.",
error_code=INVALID_PARAMETER_VALUE,
)
filter_string += f" AND {additional_filter}"
else:
filter_string = additional_filter

traces = []
next_max_results = max_results
next_token = page_token
Expand Down
5 changes: 5 additions & 0 deletions mlflow/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,17 +487,21 @@ def search_traces(
max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
order_by: Optional[List[str]] = None,
page_token: Optional[str] = None,
run_id: Optional[str] = None,
) -> PagedList[Trace]:
"""
Return traces that match the given list of search expressions within the experiments.
Args:
experiment_ids: List of experiment ids to scope the search.
it will be associated with the run and you can filter on the run id to retrieve
the trace.
filter_string: A search filter string.
max_results: Maximum number of traces desired.
order_by: List of order_by clauses.
page_token: Token specifying the next page of results. It should be obtained from
a ``search_traces`` call.
run_id: A run id to scope the search. When a trace is created under an active run,
Returns:
A :py:class:`PagedList <mlflow.store.entities.PagedList>` of
Expand All @@ -513,6 +517,7 @@ def search_traces(
max_results=max_results,
order_by=order_by,
page_token=page_token,
run_id=run_id,
)

get_display_handler().display_traces(traces)
Expand Down
42 changes: 42 additions & 0 deletions tests/tracing/test_fluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ def test_search_traces(mock_client):
assert len(traces) == 10
mock_client.search_traces.assert_called_once_with(
experiment_ids=["1"],
run_id=None,
filter_string="name = 'foo'",
max_results=10,
order_by=["timestamp DESC"],
Expand Down Expand Up @@ -739,6 +740,7 @@ def test_search_traces_with_pagination(mock_client):
assert len(traces) == 30
common_args = {
"experiment_ids": ["1"],
"run_id": None,
"max_results": SEARCH_TRACES_DEFAULT_MAX_RESULTS,
"filter_string": None,
"order_by": None,
Expand All @@ -759,6 +761,7 @@ def test_search_traces_with_default_experiment_id(mock_client):

mock_client.search_traces.assert_called_once_with(
experiment_ids=["123"],
run_id=None,
filter_string=None,
max_results=SEARCH_TRACES_DEFAULT_MAX_RESULTS,
order_by=None,
Expand Down Expand Up @@ -1057,6 +1060,45 @@ def search_traces(self, experiment_ids, *args, **kwargs):
monkeypatch.setattr("mlflow.tracing.fluent.MlflowClient", MockMlflowClient)


def test_search_traces_with_run_id():
def _create_trace(name, tags=None):
with mlflow.start_span(name=name) as span:
for k, v in (tags or {}).items():
mlflow.MlflowClient().set_trace_tag(request_id=span.request_id, key=k, value=v)
return span.request_id

def _get_names(traces):
tags = traces["tags"].tolist()
return [tags[i].get(TraceTagKey.TRACE_NAME) for i in range(len(tags))]

with mlflow.start_run() as run1:
_create_trace(name="tr-1")
_create_trace(name="tr-2", tags={"fruit": "apple"})

with mlflow.start_run() as run2:
_create_trace(name="tr-3")
_create_trace(name="tr-4", tags={"fruit": "banana"})
_create_trace(name="tr-5", tags={"fruit": "apple"})

traces = mlflow.search_traces()
assert _get_names(traces) == ["tr-5", "tr-4", "tr-3", "tr-2", "tr-1"]

traces = mlflow.search_traces(run_id=run1.info.run_id)
assert _get_names(traces) == ["tr-2", "tr-1"]

traces = mlflow.search_traces(
run_id=run2.info.run_id,
filter_string="tag.fruit = 'apple'",
)
assert _get_names(traces) == ["tr-5"]

with pytest.raises(MlflowException, match="You cannot filter by run_id when it is already"):
mlflow.search_traces(
run_id=run2.info.run_id,
filter_string="metadata.mlflow.sourceRun = '123'",
)


@pytest.mark.parametrize(
"extract_fields",
[
Expand Down

0 comments on commit ba3d931

Please sign in to comment.