Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
chained model evaluation (#58)
Browse files Browse the repository at this point in the history
* multi model evaluation

* fix test

* chained model evaluation

* cache model output

* v bump

* support dict model output

* rename

* add get_model_output support

* fix

* 0.0.18-beta1
  • Loading branch information
wintonzheng committed Sep 14, 2023
1 parent 8dc9cef commit 904a622
Show file tree
Hide file tree
Showing 12 changed files with 292 additions and 126 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "wyvern-ai"
version = "0.0.17"
version = "0.0.18-beta1"
description = ""
authors = ["Wyvern AI <info@wyvern.ai>"]
readme = "README.md"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self):
entity_store={},
events=[],
feature_map=FeatureMap(feature_map={}),
model_output_map={},
),
)
return await pipeline.execute(request)
Expand Down
9 changes: 4 additions & 5 deletions tests/scenarios/test_product_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
RealtimeFeatureComponent,
RealtimeFeatureRequest,
)
from wyvern.components.models.model_component import (
ModelComponent,
ModelInput,
ModelOutput,
)
from wyvern.components.models.model_component import ModelComponent
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.config import settings
from wyvern.core.compression import wyvern_encode
Expand All @@ -26,6 +22,7 @@
from wyvern.entities.feature_entities import FeatureData, FeatureMap
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import ProductEntity, WyvernEntity
from wyvern.entities.model_entities import ModelInput, ModelOutput
from wyvern.entities.request import BaseWyvernRequest
from wyvern.service import WyvernService
from wyvern.wyvern_request import WyvernRequest
Expand Down Expand Up @@ -387,6 +384,7 @@ async def test_hydrate(mock_redis):
json=json_input,
headers={},
entity_store={},
model_output_map={},
events=[],
feature_map=FeatureMap(feature_map={}),
)
Expand Down Expand Up @@ -450,6 +448,7 @@ async def test_hydrate__duplicate_brand(mock_redis__duplicate_brand):
entity_store={},
events=[],
feature_map=FeatureMap(feature_map={}),
model_output_map={},
)
request_context.set(test_wyvern_request)

Expand Down
9 changes: 4 additions & 5 deletions wyvern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
from wyvern.components.features.realtime_features_component import (
RealtimeFeatureComponent,
)
from wyvern.components.models.model_component import (
ModelComponent,
ModelInput,
ModelOutput,
)
from wyvern.components.models.model_chain_component import ModelChainComponent
from wyvern.components.models.model_component import ModelComponent
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.components.ranking_pipeline import (
RankingPipeline,
Expand All @@ -23,6 +20,7 @@
WyvernDataModel,
WyvernEntity,
)
from wyvern.entities.model_entities import ModelInput, ModelOutput
from wyvern.feature_store.feature_server import generate_wyvern_store_app
from wyvern.service import WyvernService
from wyvern.wyvern_logging import setup_logging
Expand All @@ -41,6 +39,7 @@
"FeatureMap",
"Identifier",
"IdentifierType",
"ModelChainComponent",
"ModelComponent",
"ModelInput",
"ModelOutput",
Expand Down
17 changes: 16 additions & 1 deletion wyvern/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from enum import Enum
from functools import cached_property
from typing import Dict, Generic, Optional, Set
from typing import Dict, Generic, List, Optional, Set, Union
from uuid import uuid4

from wyvern import request_context
Expand Down Expand Up @@ -177,3 +177,18 @@ def get_all_features(
if not feature_data:
return {}
return feature_data.features

def get_model_output(
self,
model_name: str,
identifier: Identifier,
) -> Optional[
Union[
float,
str,
List[float],
Dict[str, Optional[Union[float, str, list[float]]]],
]
]:
current_request = request_context.ensure_current_request()
return current_request.get_model_output(model_name, identifier)
54 changes: 54 additions & 0 deletions wyvern/components/models/model_chain_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
from functools import cached_property
from typing import Optional, Set

from wyvern.components.models.model_component import ModelComponent
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT, ChainedModelInput
from wyvern.exceptions import MissingModelChainOutputError


class ModelChainComponent(ModelComponent[MODEL_INPUT, MODEL_OUTPUT]):
"""
Model chaining allows you to chain models together so that the output of one model can be the input to another model
For all the models in the chain, all the request and entities in the model input are the same
"""

def __init__(self, *upstreams: ModelComponent, name: Optional[str] = None):
super().__init__(*upstreams, name=name)
self.chain = upstreams

@cached_property
def manifest_feature_names(self) -> Set[str]:
feature_names: Set[str] = set()
for model in self.chain:
feature_names = feature_names.union(model.manifest_feature_names)
return feature_names

async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
output = None
prev_model: Optional[ModelComponent] = None
for model in self.chain:
curr_input: ChainedModelInput
if prev_model is not None and output is not None:
curr_input = ChainedModelInput(
request=input.request,
entities=input.entities,
upstream_model_name=prev_model.name,
upstream_model_output=output.data,
)
else:
curr_input = ChainedModelInput(
request=input.request,
entities=input.entities,
upstream_model_name=None,
upstream_model_output={},
)
output = await model.execute(curr_input, **kwargs)
prev_model = model

if output is None:
raise MissingModelChainOutputError()

# TODO: do type checking to make sure the output is of the correct type
return output
153 changes: 46 additions & 107 deletions wyvern/components/models/model_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,19 @@
import logging
from datetime import datetime
from functools import cached_property
from typing import (
Dict,
Generic,
List,
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
get_args,
)
from typing import Dict, List, Optional, Sequence, Set, Type, Union, get_args

from pydantic import BaseModel
from pydantic.generics import GenericModel

from wyvern import request_context
from wyvern.components.component import Component
from wyvern.components.events.events import EventType, LoggedEvent
from wyvern.config import settings
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT
from wyvern.entities.request import BaseWyvernRequest
from wyvern.event_logging import event_logger
from wyvern.exceptions import WyvernModelInputError
from wyvern.wyvern_typing import GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY

MODEL_OUTPUT_DATA_TYPE = TypeVar(
"MODEL_OUTPUT_DATA_TYPE",
bound=Union[float, str, List[float]],
)
"""
MODEL_OUTPUT_DATA_TYPE is the type of the output of the model. It can be a float, a string, or a list of floats
(e.g. a list of probabilities, embeddings, etc.)
"""

logger = logging.getLogger(__name__)

Expand All @@ -58,6 +36,7 @@ class ModelEventData(BaseModel):
model_output: str
entity_identifier: Optional[str] = None
entity_identifier_type: Optional[str] = None
target: Optional[str] = None


class ModelEvent(LoggedEvent[ModelEventData]):
Expand All @@ -71,74 +50,6 @@ class ModelEvent(LoggedEvent[ModelEventData]):
event_type: EventType = EventType.MODEL


class ModelOutput(GenericModel, Generic[MODEL_OUTPUT_DATA_TYPE]):
"""
This class defines the output of a model.
Args:
data: A dictionary mapping entity identifiers to model outputs. The model outputs can also be None.
model_name: The name of the model. This is optional.
"""

data: Dict[Identifier, Optional[MODEL_OUTPUT_DATA_TYPE]]
model_name: Optional[str] = None

def get_entity_output(
self,
identifier: Identifier,
) -> Optional[MODEL_OUTPUT_DATA_TYPE]:
"""
Get the model output for a given entity identifier.
Args:
identifier: The identifier of the entity.
Returns:
The model output for the given entity identifier. This can also be None if the model output is None.
"""
return self.data.get(identifier)


class ModelInput(GenericModel, Generic[GENERALIZED_WYVERN_ENTITY, REQUEST_ENTITY]):
"""
This class defines the input to a model.
Args:
request: The request that will be used to generate the model input.
entities: A list of entities that will be used to generate the model input.
"""

request: REQUEST_ENTITY
entities: List[GENERALIZED_WYVERN_ENTITY] = []

@property
def first_entity(self) -> GENERALIZED_WYVERN_ENTITY:
"""
Get the first entity in the list of entities. This is useful when you know that there is only one entity.
Returns:
The first entity in the list of entities.
"""
if not self.entities:
raise WyvernModelInputError(model_input=self)
return self.entities[0]

@property
def first_identifier(self) -> Identifier:
"""
Get the identifier of the first entity in the list of entities. This is useful when you know that there is only
one entity.
Returns:
The identifier of the first entity in the list of entities.
"""
return self.first_entity.identifier


MODEL_INPUT = TypeVar("MODEL_INPUT", bound=ModelInput)
MODEL_OUTPUT = TypeVar("MODEL_OUTPUT", bound=ModelOutput)


class ModelComponent(
Component[
MODEL_INPUT,
Expand All @@ -155,11 +66,14 @@ def __init__(
self,
*upstreams,
name: Optional[str] = None,
cache_output: bool = True,
):
super().__init__(*upstreams, name=name)
self.model_input_type = self.get_type_args_simple(0)
self.model_output_type = self.get_type_args_simple(1)

self.cache_output = cache_output

@classmethod
def get_type_args_simple(cls, index: int) -> Type:
"""
Expand All @@ -185,26 +99,51 @@ async def execute(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
"""
The model_name and model_score will be automatically logged
"""
api_source = request_context.ensure_current_request().url_path
wyvern_request = request_context.ensure_current_request()
api_source = wyvern_request.url_path
request_id = input.request.request_id
model_output = await self.inference(input, **kwargs)

if self.cache_output:
wyvern_request.cache_model_output(self.name, model_output.data)

def events_generator() -> List[ModelEvent]:
timestamp = datetime.utcnow()
return [
ModelEvent(
request_id=request_id,
api_source=api_source,
event_timestamp=timestamp,
event_data=ModelEventData(
model_name=model_output.model_name or self.__class__.__name__,
model_output=str(output),
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
)
for identifier, output in model_output.data.items()
]
all_events: List[ModelEvent] = []
for identifier, output in model_output.data.items():
if isinstance(output, dict):
for key, value in output.items():
all_events.append(
ModelEvent(
request_id=request_id,
api_source=api_source,
event_timestamp=timestamp,
event_data=ModelEventData(
model_name=model_output.model_name
or self.__class__.__name__,
model_output=str(value),
target=key,
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
),
)
else:
all_events.append(
ModelEvent(
request_id=request_id,
api_source=api_source,
event_timestamp=timestamp,
event_data=ModelEventData(
model_name=model_output.model_name
or self.__class__.__name__,
model_output=str(output),
entity_identifier=identifier.identifier,
entity_identifier_type=identifier.identifier_type,
),
),
)
return all_events

event_logger.log_events(events_generator) # type: ignore

Expand Down
7 changes: 2 additions & 5 deletions wyvern/components/models/modelbit_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union

from wyvern.components.models.model_component import (
MODEL_INPUT,
MODEL_OUTPUT,
ModelComponent,
)
from wyvern.components.models.model_component import ModelComponent
from wyvern.config import settings
from wyvern.core.http import aiohttp_client
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.model_entities import MODEL_INPUT, MODEL_OUTPUT
from wyvern.entities.request import BaseWyvernRequest
from wyvern.exceptions import (
WyvernModelbitTokenMissingError,
Expand Down
Loading

0 comments on commit 904a622

Please sign in to comment.