Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: rename dataset type to dataset role #461

Merged
merged 3 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/src/@types/dataset.d.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* The type of dataset.E.g. primary is the dataset under test, reference is the
* The role of dataset.E.g. primary is the dataset under test, reference is the
* dataset used for comparison.
*/
declare type DatasetType = "primary" | "reference";
declare type DatasetRole = "primary" | "reference";
2 changes: 1 addition & 1 deletion app/src/components/filter/ReferenceDatasetTimeRange.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {

const timeFormatter = timeFormat("%x %X");
type ReferenceDatasetTimeRangeProps = {
datasetType: DatasetType;
datasetRole: DatasetRole;
/**
* The bookend times of the dataset
*/
Expand Down
6 changes: 3 additions & 3 deletions app/src/components/pointcloud/EventItem.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import React from "react";
import { transparentize } from "polished";
import { css } from "@emotion/react";

import { DatasetType } from "@phoenix/types";
import { DatasetRole } from "@phoenix/types";

type EventItemProps = {
/**
Expand All @@ -16,7 +16,7 @@ type EventItemProps = {
/**
* Which dataset the event belongs to
*/
datasetType: DatasetType;
datasetRole: DatasetRole;
/**
* event handler for when the user clicks on the event item
*/
Expand Down Expand Up @@ -82,7 +82,7 @@ export function EventItem(props: EventItemProps) {
)}
<div
data-testid="event-association"
data-dataset-type={props.datasetType}
data-dataset-type={props.datasetRole}
css={css`
height: var(--px-gradient-bar-height);
flex: none;
Expand Down
2 changes: 1 addition & 1 deletion app/src/pages/embedding/Embedding.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ function EmbeddingMain() {
<PrimaryDatasetTimeRange />
{referenceDataset ? (
<ReferenceDatasetTimeRange
datasetType="reference"
datasetRole="reference"
timeRange={{
start: new Date(referenceDataset.startTime),
end: new Date(referenceDataset.endTime),
Expand Down
10 changes: 5 additions & 5 deletions app/src/pages/embedding/PointSelectionPanelContent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
} from "@phoenix/components/pointcloud";
import { SelectionDisplay } from "@phoenix/constants/pointCloudConstants";
import { usePointCloudContext } from "@phoenix/contexts";
import { DatasetType } from "@phoenix/types";
import { DatasetRole } from "@phoenix/types";

import {
PointSelectionPanelContentQuery,
Expand Down Expand Up @@ -289,16 +289,16 @@ function SelectionGridView(props: SelectionGridViewProps) {
const data = pointIdToDataMap.get(event.id);
const { rawData = null, linkToData = null } =
data?.embeddingMetadata ?? {};
const datasetType = event.id.includes("PRIMARY")
? DatasetType.primary
: DatasetType.reference;
const datasetRole = event.id.includes("PRIMARY")
? DatasetRole.primary
: DatasetRole.reference;
const color = pointGroupColors[pointIdToGroup[event.id]];
return (
<li key={idx}>
<EventItem
rawData={rawData}
linkToData={linkToData}
datasetType={datasetType}
datasetRole={datasetRole}
onClick={() => {
onItemSelected(event.id);
}}
Expand Down
2 changes: 1 addition & 1 deletion app/src/pages/home/Home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export function Home(_props: HomePageProps) {
<PrimaryDatasetTimeRange />
{referenceDataset ? (
<ReferenceDatasetTimeRange
datasetType="reference"
datasetRole="reference"
timeRange={{
start: new Date(referenceDataset.startTime),
end: new Date(referenceDataset.endTime),
Expand Down
2 changes: 1 addition & 1 deletion app/src/types/dataset.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export enum DatasetType {
export enum DatasetRole {
primary = "primary",
reference = "reference",
}
4 changes: 2 additions & 2 deletions src/phoenix/core/embedding_dimension.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Set

from phoenix.datasets.dataset import DatasetType
from phoenix.datasets.dataset import DatasetRole
from phoenix.datasets.event import EventId


Expand Down Expand Up @@ -29,7 +29,7 @@ def calculate_drift_ratio(events: Set[EventId]) -> float:
reference_point_count = 0

for event in events:
if event.dataset_id == DatasetType.PRIMARY:
if event.dataset_id == DatasetRole.PRIMARY:
primary_point_count += 1
else:
reference_point_count += 1
Expand Down
14 changes: 7 additions & 7 deletions src/phoenix/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from phoenix.datasets import Dataset
from phoenix.datasets.schema import EmbeddingColumnNames, EmbeddingFeatures

from ..datasets.dataset import DatasetType
from ..datasets.dataset import DatasetRole
from .dimension import Dimension
from .dimension_data_type import DimensionDataType
from .dimension_type import DimensionType
Expand All @@ -23,8 +23,8 @@ def __init__(self, primary_dataset: Dataset, reference_dataset: Optional[Dataset
self.primary_dataset, self.reference_dataset
)
self.__datasets = {
DatasetType.PRIMARY: primary_dataset,
DatasetType.REFERENCE: reference_dataset,
DatasetRole.PRIMARY: primary_dataset,
DatasetRole.REFERENCE: reference_dataset,
}

@property
Expand Down Expand Up @@ -114,7 +114,7 @@ def _infer_dimension_data_type(self, dimension_name: str) -> DimensionDataType:

def export_events_as_parquet_file(
self,
rows: Mapping[DatasetType, Iterable[int]],
rows: Mapping[DatasetRole, Iterable[int]],
parquet_file: BinaryIO,
) -> None:
"""
Expand All @@ -123,14 +123,14 @@ def export_events_as_parquet_file(

Parameters
----------
rows: Mapping[DatasetType, Iterable[int]]
rows: Mapping[DatasetRole, Iterable[int]]
mapping of dataset type to list of row numbers
parquet_file: file handle
output parquet file handle
"""
pd.concat(
dataset.get_events(rows.get(dataset_type, ()))
for dataset_type, dataset in self.__datasets.items()
dataset.get_events(rows.get(dataset_role, ()))
for dataset_role, dataset in self.__datasets.items()
if dataset is not None
).to_parquet(parquet_file, index=False)

Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _create_and_normalize_dataframe_and_schema(
return parsed_dataframe, parsed_schema


class DatasetType(Enum):
class DatasetRole(Enum):
PRIMARY = 0
REFERENCE = 1

Expand Down
4 changes: 2 additions & 2 deletions src/phoenix/datasets/event.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import NamedTuple

from .dataset import DatasetType
from .dataset import DatasetRole


class EventId(NamedTuple):
"""Identifies an event."""

row_id: int = 0
dataset_id: DatasetType = DatasetType.PRIMARY
dataset_id: DatasetRole = DatasetRole.PRIMARY

def __str__(self) -> str:
return ":".join(map(str, self))
12 changes: 6 additions & 6 deletions src/phoenix/server/api/types/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from phoenix.core.dimension import Dimension as CoreDimension
from phoenix.core.dimension_type import DimensionType
from phoenix.datasets import Dataset as InternalDataset
from phoenix.datasets.dataset import DatasetType
from phoenix.datasets.dataset import DatasetRole

from ..context import Context
from ..input_types.DimensionInput import DimensionInput
Expand All @@ -23,7 +23,7 @@ class Dataset:
start_time: datetime = strawberry.field(description="The start bookend of the data")
end_time: datetime = strawberry.field(description="The end bookend of the data")
dataset: strawberry.Private[InternalDataset]
type: strawberry.Private[DatasetType]
role: strawberry.Private[DatasetRole]

@strawberry.field
def events(
Expand All @@ -39,9 +39,9 @@ def events(
if not event_ids:
return []
row_ids = parse_event_ids(event_ids)
if len(row_ids) > 1 or self.type not in row_ids:
if len(row_ids) > 1 or self.role not in row_ids:
raise ValueError("eventIds contains IDs from incorrect dataset.")
row_indexes = row_ids.get(self.type, ())
row_indexes = row_ids.get(self.role, ())
dataframe = self.dataset.dataframe
schema = self.dataset.schema
requested_gql_dimensions = _get_requested_features_and_tags(
Expand All @@ -68,7 +68,7 @@ def events(
return [
create_event(
row_index=row_index,
dataset_type=self.type,
dataset_role=self.role,
row=dataframe.iloc[row_index, column_indexes],
schema=schema,
dimensions=requested_gql_dimensions,
Expand All @@ -85,7 +85,7 @@ def to_gql_dataset(dataset: InternalDataset, type: Literal["primary", "reference
name=dataset.name,
start_time=dataset.start_time,
end_time=dataset.end_time,
type=DatasetType.PRIMARY if type == "primary" else DatasetType.REFERENCE,
role=DatasetRole.PRIMARY if type == "primary" else DatasetRole.REFERENCE,
dataset=dataset,
)

Expand Down
14 changes: 7 additions & 7 deletions src/phoenix/server/api/types/EmbeddingDimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from phoenix.core import EmbeddingDimension as CoreEmbeddingDimension
from phoenix.datasets import Dataset
from phoenix.datasets.dataset import DatasetType
from phoenix.datasets.dataset import DatasetRole
from phoenix.datasets.errors import SchemaError
from phoenix.datasets.event import EventId
from phoenix.metrics.timeseries import row_interval_from_sorted_time_index
Expand Down Expand Up @@ -197,8 +197,8 @@ def UMAPPoints(
] = DEFAULT_CLUSTER_SELECTION_EPSILON,
) -> UMAPPoints:
datasets = {
DatasetType.PRIMARY: info.context.model.primary_dataset,
DatasetType.REFERENCE: info.context.model.reference_dataset,
DatasetRole.PRIMARY: info.context.model.primary_dataset,
DatasetRole.REFERENCE: info.context.model.reference_dataset,
}

data = {}
Expand All @@ -207,7 +207,7 @@ def UMAPPoints(
continue
dataframe = dataset.dataframe
row_id_start, row_id_stop = 0, len(dataframe)
if dataset_id == DatasetType.PRIMARY:
if dataset_id == DatasetRole.PRIMARY:
row_id_start, row_id_stop = row_interval_from_sorted_time_index(
time_index=dataframe.index,
time_start=time_range.start,
Expand Down Expand Up @@ -300,11 +300,11 @@ def UMAPPoints(
)
)

has_reference_data = datasets[DatasetType.REFERENCE] is not None
has_reference_data = datasets[DatasetRole.REFERENCE] is not None

return UMAPPoints(
data=points[DatasetType.PRIMARY],
reference_data=points[DatasetType.REFERENCE],
data=points[DatasetRole.PRIMARY],
reference_data=points[DatasetRole.REFERENCE],
clusters=to_gql_clusters(cluster_membership, has_reference_data=has_reference_data),
)

Expand Down
14 changes: 7 additions & 7 deletions src/phoenix/server/api/types/Event.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from strawberry import ID

from phoenix.datasets import Schema
from phoenix.datasets.dataset import DatasetType
from phoenix.datasets.dataset import DatasetRole
from phoenix.datasets.event import EventId

from .Dimension import Dimension
Expand All @@ -21,21 +21,21 @@ class Event:
dimensions: List[DimensionWithValue]


def parse_event_ids(event_ids: List[ID]) -> Dict[DatasetType, List[int]]:
def parse_event_ids(event_ids: List[ID]) -> Dict[DatasetRole, List[int]]:
"""
Parses event IDs and returns the corresponding row indexes.
"""
row_indexes = defaultdict(list)
for event_id in event_ids:
row_index, dataset_type_str = str(event_id).split(":")
dataset_type = DatasetType[dataset_type_str.split(".")[-1]]
row_indexes[dataset_type].append(int(row_index))
row_index, dataset_role_str = str(event_id).split(":")
dataset_role = DatasetRole[dataset_role_str.split(".")[-1]]
row_indexes[dataset_role].append(int(row_index))
return row_indexes


def create_event(
row_index: int,
dataset_type: "DatasetType",
dataset_role: "DatasetRole",
row: "Series[Any]",
schema: Schema,
dimensions: List[Dimension],
Expand Down Expand Up @@ -63,7 +63,7 @@ def create_event(
]

return Event(
id=ID(str(EventId(row_id=row_index, dataset_id=dataset_type))),
id=ID(str(EventId(row_id=row_index, dataset_id=dataset_role))),
eventMetadata=event_metadata,
dimensions=dimensions_with_values,
)
12 changes: 6 additions & 6 deletions tests/server/api/types/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_no_input_dimensions_correctly_selects_event_ids_and_all_features_and_ta
query=self._get_events_query("primaryDataset"),
context_value=context_factory(primary_dataset, reference_dataset),
variable_values={
"eventIds": ["0:DatasetType.PRIMARY"],
"eventIds": ["0:DatasetRole.PRIMARY"],
},
)
assert result.errors is None
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_input_dimensions_correctly_selects_event_ids_and_dimensions(
query=self._get_events_query("referenceDataset"),
context_value=context_factory(primary_dataset, reference_dataset),
variable_values={
"eventIds": ["1:DatasetType.REFERENCE", "2:DatasetType.REFERENCE"],
"eventIds": ["1:DatasetRole.REFERENCE", "2:DatasetRole.REFERENCE"],
"dimensions": [
{"name": "tag0", "type": "tag"},
],
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_empty_input_dimensions_returns_events_with_empty_dimensions(
query=self._get_events_query("referenceDataset"),
context_value=context_factory(primary_dataset, reference_dataset),
variable_values={
"eventIds": ["1:DatasetType.REFERENCE"],
"eventIds": ["1:DatasetRole.REFERENCE"],
"dimensions": [],
},
)
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_event_ids_from_incorrect_dataset_returns_error(
query=self._get_events_query("primaryDataset"),
context_value=context_factory(primary_dataset, reference_dataset),
variable_values={
"eventIds": ["0:DatasetType.PRIMARY", "1:DatasetType.REFERENCE"],
"eventIds": ["0:DatasetRole.PRIMARY", "1:DatasetRole.REFERENCE"],
},
)
assert result.errors is not None
Expand All @@ -204,7 +204,7 @@ def test_event_ids_from_incorrect_dataset_returns_error(
assert result.data is None

@staticmethod
def _get_events_query(dataset_type: Literal["primaryDataset", "referenceDataset"]) -> str:
def _get_events_query(dataset_role: Literal["primaryDataset", "referenceDataset"]) -> str:
"""
Returns a formatted events query for the input dataset type.
"""
Expand All @@ -231,7 +231,7 @@ def _get_events_query(dataset_type: Literal["primaryDataset", "referenceDataset"
}
}
"""
% dataset_type
% dataset_role
)

@staticmethod
Expand Down