From 904a6221b7437d146c5b328b00031f69db40e217 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Thu, 14 Sep 2023 03:17:00 -0700 Subject: [PATCH] chained model evaluation (#58) * multi model evaluation * fix test * chained model evaluation * cache model output * v bump * support dict model output * rename * add get_model_output support * fix * 0.0.18-beta1 --- pyproject.toml | 2 +- .../test_pinning_business_logic.py | 1 + tests/scenarios/test_product_ranking.py | 9 +- wyvern/__init__.py | 9 +- wyvern/components/component.py | 17 +- .../models/model_chain_component.py | 54 +++++++ wyvern/components/models/model_component.py | 153 ++++++------------ .../components/models/modelbit_component.py | 7 +- wyvern/components/ranking_pipeline.py | 3 +- wyvern/entities/model_entities.py | 105 ++++++++++++ wyvern/exceptions.py | 4 + wyvern/wyvern_request.py | 54 ++++++- 12 files changed, 292 insertions(+), 126 deletions(-) create mode 100644 wyvern/components/models/model_chain_component.py create mode 100644 wyvern/entities/model_entities.py diff --git a/pyproject.toml b/pyproject.toml index 6924830..db33e04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wyvern-ai" -version = "0.0.17" +version = "0.0.18-beta1" description = "" authors = ["Wyvern AI "] readme = "README.md" diff --git a/tests/components/business_logic/test_pinning_business_logic.py b/tests/components/business_logic/test_pinning_business_logic.py index f0b945f..edc30a7 100644 --- a/tests/components/business_logic/test_pinning_business_logic.py +++ b/tests/components/business_logic/test_pinning_business_logic.py @@ -66,6 +66,7 @@ def __init__(self): entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), + model_output_map={}, ), ) return await pipeline.execute(request) diff --git a/tests/scenarios/test_product_ranking.py b/tests/scenarios/test_product_ranking.py index 68ff3fe..a5c233a 100644 --- a/tests/scenarios/test_product_ranking.py +++ b/tests/scenarios/test_product_ranking.py @@ -13,11 +13,7 @@ RealtimeFeatureComponent, RealtimeFeatureRequest, ) -from wyvern.components.models.model_component import ( - ModelComponent, - ModelInput, - ModelOutput, -) +from wyvern.components.models.model_component import ModelComponent from wyvern.components.pipeline_component import PipelineComponent from wyvern.config import settings from wyvern.core.compression import wyvern_encode @@ -26,6 +22,7 @@ from wyvern.entities.feature_entities import FeatureData, FeatureMap from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import ProductEntity, WyvernEntity +from wyvern.entities.model_entities import ModelInput, ModelOutput from wyvern.entities.request import BaseWyvernRequest from wyvern.service import WyvernService from wyvern.wyvern_request import WyvernRequest @@ -387,6 +384,7 @@ async def test_hydrate(mock_redis): json=json_input, headers={}, entity_store={}, + model_output_map={}, events=[], feature_map=FeatureMap(feature_map={}), ) @@ -450,6 +448,7 @@ async def test_hydrate__duplicate_brand(mock_redis__duplicate_brand): entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), + model_output_map={}, ) request_context.set(test_wyvern_request) diff --git a/wyvern/__init__.py b/wyvern/__init__.py index 08823ba..cc7906c 100644 --- a/wyvern/__init__.py +++ b/wyvern/__init__.py @@ -2,11 +2,8 @@ from wyvern.components.features.realtime_features_component import ( RealtimeFeatureComponent, ) -from wyvern.components.models.model_component import ( - ModelComponent, - ModelInput, - ModelOutput, -) +from wyvern.components.models.model_chain_component import ModelChainComponent +from wyvern.components.models.model_component import ModelComponent from wyvern.components.pipeline_component import PipelineComponent from wyvern.components.ranking_pipeline import ( RankingPipeline, @@ -23,6 +20,7 @@ WyvernDataModel, WyvernEntity, ) +from wyvern.entities.model_entities import ModelInput, ModelOutput from wyvern.feature_store.feature_server import generate_wyvern_store_app from wyvern.service import WyvernService from wyvern.wyvern_logging import setup_logging @@ -41,6 +39,7 @@ "FeatureMap", "Identifier", "IdentifierType", + "ModelChainComponent", "ModelComponent", "ModelInput", "ModelOutput", diff --git a/wyvern/components/component.py b/wyvern/components/component.py index b660097..ded81ca 100644 --- a/wyvern/components/component.py +++ b/wyvern/components/component.py @@ -5,7 +5,7 @@ import logging from enum import Enum from functools import cached_property -from typing import Dict, Generic, Optional, Set +from typing import Dict, Generic, List, Optional, Set, Union from uuid import uuid4 from wyvern import request_context @@ -177,3 +177,18 @@ def get_all_features( if not feature_data: return {} return feature_data.features + + def get_model_output( + self, + model_name: str, + identifier: Identifier, + ) -> Optional[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ] + ]: + current_request = request_context.ensure_current_request() + return current_request.get_model_output(model_name, identifier) diff --git a/wyvern/components/models/model_chain_component.py b/wyvern/components/models/model_chain_component.py new file mode 100644 index 0000000..a7d10e1 --- /dev/null +++ b/wyvern/components/models/model_chain_component.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +from functools import cached_property +from typing import Optional, Set + +from wyvern.components.models.model_component import ModelComponent +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT, ChainedModelInput +from wyvern.exceptions import MissingModelChainOutputError + + +class ModelChainComponent(ModelComponent[MODEL_INPUT, MODEL_OUTPUT]): + """ + Model chaining allows you to chain models together so that the output of one model can be the input to another model + + For all the models in the chain, all the request and entities in the model input are the same + """ + + def __init__(self, *upstreams: ModelComponent, name: Optional[str] = None): + super().__init__(*upstreams, name=name) + self.chain = upstreams + + @cached_property + def manifest_feature_names(self) -> Set[str]: + feature_names: Set[str] = set() + for model in self.chain: + feature_names = feature_names.union(model.manifest_feature_names) + return feature_names + + async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: + output = None + prev_model: Optional[ModelComponent] = None + for model in self.chain: + curr_input: ChainedModelInput + if prev_model is not None and output is not None: + curr_input = ChainedModelInput( + request=input.request, + entities=input.entities, + upstream_model_name=prev_model.name, + upstream_model_output=output.data, + ) + else: + curr_input = ChainedModelInput( + request=input.request, + entities=input.entities, + upstream_model_name=None, + upstream_model_output={}, + ) + output = await model.execute(curr_input, **kwargs) + prev_model = model + + if output is None: + raise MissingModelChainOutputError() + + # TODO: do type checking to make sure the output is of the correct type + return output diff --git a/wyvern/components/models/model_component.py b/wyvern/components/models/model_component.py index 6aac719..3eaa4f4 100644 --- a/wyvern/components/models/model_component.py +++ b/wyvern/components/models/model_component.py @@ -3,21 +3,9 @@ import logging from datetime import datetime from functools import cached_property -from typing import ( - Dict, - Generic, - List, - Optional, - Sequence, - Set, - Type, - TypeVar, - Union, - get_args, -) +from typing import Dict, List, Optional, Sequence, Set, Type, Union, get_args from pydantic import BaseModel -from pydantic.generics import GenericModel from wyvern import request_context from wyvern.components.component import Component @@ -25,19 +13,9 @@ from wyvern.config import settings from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import WyvernEntity +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT from wyvern.entities.request import BaseWyvernRequest from wyvern.event_logging import event_logger -from wyvern.exceptions import WyvernModelInputError -from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY - -MODEL_OUTPUT_DATA_TYPE = TypeVar( - "MODEL_OUTPUT_DATA_TYPE", - bound=Union[float, str, List[float]], -) -""" -MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats -(e.g. a list of probabilities, embeddings, etc.) -""" logger = logging.getLogger(__name__) @@ -58,6 +36,7 @@ class ModelEventData(BaseModel): model_output: str entity_identifier: Optional[str] = None entity_identifier_type: Optional[str] = None + target: Optional[str] = None class ModelEvent(LoggedEvent[ModelEventData]): @@ -71,74 +50,6 @@ class ModelEvent(LoggedEvent[ModelEventData]): event_type: EventType = EventType.MODEL -class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]): - """ - This class defines the output of a model. - - Args: - data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None. - model_name: The name of the model. This is optional. - """ - - data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]] - model_name: Optional[str] = None - - def get_entity_output( - self, - identifier: Identifier, - ) -> Optional[MODEL_OUTPUT_DATA_TYPE]: - """ - Get the model output for a given entity identifier. - - Args: - identifier: The identifier of the entity. - - Returns: - The model output for the given entity identifier. This can also be None if the model output is None. - """ - return self.data.get(identifier) - - -class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): - """ - This class defines the input to a model. - - Args: - request: The request that will be used to generate the model input. - entities: A list of entities that will be used to generate the model input. - """ - - request: REQUEST_ENTITY - entities: List[GENERALIZED_WYVERN_ENTITY] = [] - - @property - def first_entity(self) -> GENERALIZED_WYVERN_ENTITY: - """ - Get the first entity in the list of entities. This is useful when you know that there is only one entity. - - Returns: - The first entity in the list of entities. - """ - if not self.entities: - raise WyvernModelInputError(model_input=self) - return self.entities[0] - - @property - def first_identifier(self) -> Identifier: - """ - Get the identifier of the first entity in the list of entities. This is useful when you know that there is only - one entity. - - Returns: - The identifier of the first entity in the list of entities. - """ - return self.first_entity.identifier - - -MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) -MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) - - class ModelComponent( Component[ MODEL_INPUT, @@ -155,11 +66,14 @@ def __init__( self, *upstreams, name: Optional[str] = None, + cache_output: bool = True, ): super().__init__(*upstreams, name=name) self.model_input_type = self.get_type_args_simple(0) self.model_output_type = self.get_type_args_simple(1) + self.cache_output = cache_output + @classmethod def get_type_args_simple(cls, index: int) -> Type: """ @@ -185,26 +99,51 @@ async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT: """ The model_name and model_score will be automatically logged """ - api_source = request_context.ensure_current_request().url_path + wyvern_request = request_context.ensure_current_request() + api_source = wyvern_request.url_path request_id = input.request.request_id model_output = await self.inference(input, **kwargs) + if self.cache_output: + wyvern_request.cache_model_output(self.name, model_output.data) + def events_generator() -> List[ModelEvent]: timestamp = datetime.utcnow() - return [ - ModelEvent( - request_id=request_id, - api_source=api_source, - event_timestamp=timestamp, - event_data=ModelEventData( - model_name=model_output.model_name or self.__class__.__name__, - model_output=str(output), - entity_identifier=identifier.identifier, - entity_identifier_type=identifier.identifier_type, - ), - ) - for identifier, output in model_output.data.items() - ] + all_events: List[ModelEvent] = [] + for identifier, output in model_output.data.items(): + if isinstance(output, dict): + for key, value in output.items(): + all_events.append( + ModelEvent( + request_id=request_id, + api_source=api_source, + event_timestamp=timestamp, + event_data=ModelEventData( + model_name=model_output.model_name + or self.__class__.__name__, + model_output=str(value), + target=key, + entity_identifier=identifier.identifier, + entity_identifier_type=identifier.identifier_type, + ), + ), + ) + else: + all_events.append( + ModelEvent( + request_id=request_id, + api_source=api_source, + event_timestamp=timestamp, + event_data=ModelEventData( + model_name=model_output.model_name + or self.__class__.__name__, + model_output=str(output), + entity_identifier=identifier.identifier, + entity_identifier_type=identifier.identifier_type, + ), + ), + ) + return all_events event_logger.log_events(events_generator) # type: ignore diff --git a/wyvern/components/models/modelbit_component.py b/wyvern/components/models/modelbit_component.py index 92d60c5..9b1ada1 100644 --- a/wyvern/components/models/modelbit_component.py +++ b/wyvern/components/models/modelbit_component.py @@ -4,15 +4,12 @@ from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union -from wyvern.components.models.model_component import ( - MODEL_INPUT, - MODEL_OUTPUT, - ModelComponent, -) +from wyvern.components.models.model_component import ModelComponent from wyvern.config import settings from wyvern.core.http import aiohttp_client from wyvern.entities.identifier import Identifier from wyvern.entities.identifier_entities import WyvernEntity +from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT from wyvern.entities.request import BaseWyvernRequest from wyvern.exceptions import ( WyvernModelbitTokenMissingError, diff --git a/wyvern/components/ranking_pipeline.py b/wyvern/components/ranking_pipeline.py index bdbbafd..3625b40 100644 --- a/wyvern/components/ranking_pipeline.py +++ b/wyvern/components/ranking_pipeline.py @@ -13,7 +13,7 @@ ImpressionEventLoggingComponent, ImpressionEventLoggingRequest, ) -from wyvern.components.models.model_component import ModelComponent, ModelInput +from wyvern.components.models.model_component import ModelComponent from wyvern.components.pagination.pagination_component import ( PaginationComponent, PaginationRequest, @@ -22,6 +22,7 @@ from wyvern.components.pipeline_component import PipelineComponent from wyvern.entities.candidate_entities import ScoredCandidate from wyvern.entities.identifier_entities import QueryEntity +from wyvern.entities.model_entities import ModelInput from wyvern.entities.request import BaseWyvernRequest from wyvern.event_logging import event_logger from wyvern.wyvern_typing import WYVERN_ENTITY diff --git a/wyvern/entities/model_entities.py b/wyvern/entities/model_entities.py new file mode 100644 index 0000000..71f5e8b --- /dev/null +++ b/wyvern/entities/model_entities.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +from typing import Dict, Generic, List, Optional, TypeVar, Union + +from pydantic.generics import GenericModel + +from wyvern.entities.identifier import Identifier +from wyvern.exceptions import WyvernModelInputError +from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY + +MODEL_OUTPUT_DATA_TYPE = TypeVar( + "MODEL_OUTPUT_DATA_TYPE", + bound=Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ], +) +""" +MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats +(e.g. a list of probabilities, embeddings, etc.) +""" + + +class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]): + """ + This class defines the output of a model. + + Args: + data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None. + model_name: The name of the model. This is optional. + """ + + data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]] + model_name: Optional[str] = None + + def get_entity_output( + self, + identifier: Identifier, + ) -> Optional[MODEL_OUTPUT_DATA_TYPE]: + """ + Get the model output for a given entity identifier. + + Args: + identifier: The identifier of the entity. + + Returns: + The model output for the given entity identifier. This can also be None if the model output is None. + """ + return self.data.get(identifier) + + +class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): + """ + This class defines the input to a model. + + Args: + request: The request that will be used to generate the model input. + entities: A list of entities that will be used to generate the model input. + """ + + request: REQUEST_ENTITY + entities: List[GENERALIZED_WYVERN_ENTITY] = [] + + @property + def first_entity(self) -> GENERALIZED_WYVERN_ENTITY: + """ + Get the first entity in the list of entities. This is useful when you know that there is only one entity. + + Returns: + The first entity in the list of entities. + """ + if not self.entities: + raise WyvernModelInputError(model_input=self) + return self.entities[0] + + @property + def first_identifier(self) -> Identifier: + """ + Get the identifier of the first entity in the list of entities. This is useful when you know that there is only + one entity. + + Returns: + The identifier of the first entity in the list of entities. + """ + return self.first_entity.identifier + + +MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput) +MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput) + + +class ChainedModelInput(ModelInput, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]): + upstream_model_output: Dict[ + Identifier, + Optional[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + ] + ], + ] + upstream_model_name: Optional[str] = None diff --git a/wyvern/exceptions.py b/wyvern/exceptions.py index 2007d88..68cee4e 100644 --- a/wyvern/exceptions.py +++ b/wyvern/exceptions.py @@ -154,3 +154,7 @@ class ExperimentationClientInitializationError(WyvernError): class EntityColumnMissingError(WyvernError): message = "Entity column {entity} is missing in the entity data" + + +class MissingModelChainOutputError(WyvernError): + message = "Model chain output is missing" diff --git a/wyvern/wyvern_request.py b/wyvern/wyvern_request.py index f80342e..690d84a 100644 --- a/wyvern/wyvern_request.py +++ b/wyvern/wyvern_request.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urlparse import fastapi @@ -10,6 +10,7 @@ from wyvern.components.events.events import LoggedEvent from wyvern.entities.feature_entities import FeatureMap +from wyvern.entities.identifier import Identifier @dataclass @@ -44,6 +45,21 @@ class WyvernRequest: feature_map: FeatureMap + # the key is the name of the model and the value is a map of the identifier to the model score + model_output_map: Dict[ + str, + Dict[ + Identifier, + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + None, + ], + ], + ] + request_id: Optional[str] = None # TODO: params @@ -75,5 +91,41 @@ def parse_fastapi_request( entity_store={}, events=[], feature_map=FeatureMap(feature_map={}), + model_output_map={}, request_id=request_id, ) + + def cache_model_output( + self, + model_name: str, + data: Dict[ + Identifier, + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + None, + ], + ], + ) -> None: + if model_name not in self.model_output_map: + self.model_output_map[model_name] = {} + self.model_output_map[model_name].update(data) + + def get_model_output( + self, + model_name: str, + identifier: Identifier, + ) -> Optional[ + Union[ + float, + str, + List[float], + Dict[str, Optional[Union[float, str, list[float]]]], + None, + ] + ]: + if model_name not in self.model_output_map: + return None + return self.model_output_map[model_name].get(identifier)