diff --git a/app/schema.graphql b/app/schema.graphql index 71dc80571e..aa5009660d 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -44,6 +44,11 @@ type Dimension implements Node { dataType: DimensionDataType! dataQualityMetric(metric: DataQualityMetric!): Float + """ + Returns the observed categories of a categorical dimension (usually a dimension of string values) as a list of unique string labels sorted in lexicographical order. Missing values are excluded. Non-categorical dimensions return an empty list. + """ + categories: [String!]! + """ Returns the time series of the specified metric for data within timeRange. Data points are generated starting at the end time, are separated by the sampling interval. Each data point is labeled by the end instant of and contains data from their respective evaluation window. """ diff --git a/src/phoenix/core/dimension.py b/src/phoenix/core/dimension.py index 0a16af5ffb..36b969c568 100644 --- a/src/phoenix/core/dimension.py +++ b/src/phoenix/core/dimension.py @@ -1,4 +1,9 @@ from dataclasses import dataclass +from functools import cached_property +from itertools import chain +from typing import Any, Callable, List + +import pandas as pd from .dimension_data_type import DimensionDataType from .dimension_type import DimensionType @@ -9,3 +14,14 @@ class Dimension: name: str data_type: DimensionDataType type: DimensionType + data: Callable[[], List["pd.Series[Any]"]] + + @cached_property + def categories(self) -> List[str]: + if self.data_type != DimensionDataType.CATEGORICAL: + return [] + return sorted( + value + for value in set(chain.from_iterable(series.unique() for series in self.data())) + if isinstance(value, str) + ) diff --git a/src/phoenix/core/model.py b/src/phoenix/core/model.py index cfbd57ef6e..e0faf3f964 100644 --- a/src/phoenix/core/model.py +++ b/src/phoenix/core/model.py @@ -77,6 +77,18 @@ def _get_dimensions( name=name, data_type=self._infer_dimension_data_type(name), type=dimension_type, + data=( + lambda name: ( + lambda: ( + [primary_dataset.dataframe.loc[:, name]] + + ( + [reference_dataset.dataframe.loc[:, name]] + if reference_dataset is not None + else [] + ) + ) + ) + )(name), ) ) diff --git a/src/phoenix/server/api/types/Dimension.py b/src/phoenix/server/api/types/Dimension.py index 56e3de0436..5ec3dddfc8 100644 --- a/src/phoenix/server/api/types/Dimension.py +++ b/src/phoenix/server/api/types/Dimension.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Optional +from typing import List, Optional import strawberry from strawberry.types import Info @@ -41,6 +41,19 @@ async def dataQualityMetric( return await info.context.loaders.percent_empty.load(dimension_name) raise NotImplementedError(f"Metric {metric} is not implemented.") + @strawberry.field( + description=( + "Returns the observed categories of a categorical dimension (usually a dimension of" + " string values) as a list of unique string labels sorted in lexicographical order." + " Missing values are excluded. Non-categorical dimensions return an empty list." + ) + ) # type: ignore # https://github.com/strawberry-graphql/strawberry/issues/1929 + def categories(self, info: Info[Context, None]) -> List[str]: + for dim in info.context.model.dimensions: + if dim.name == self.name: + return dim.categories + return [] + @strawberry.field( description=( "Returns the time series of the specified metric for data within timeRange. Data points"