Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
transform modelbit response
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Sep 13, 2023
1 parent ee85723 commit e201dac
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions wyvern/components/models/modelbit_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import logging
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union
from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union, cast

from wyvern.components.models.model_component import ModelComponent
from wyvern.config import settings
Expand Down Expand Up @@ -146,9 +146,33 @@ async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
# individual_output[1] is the actual output
output_data[
target_identifiers[batch_idx * settings.MODELBIT_BATCH_SIZE + idx]
] = individual_output[1]
] = self.transform_response(individual_output[1])

return self.model_output_type(
data=output_data,
model_name=self.name,
)

def transform_response(
self,
modelbit_resp: Any,
) -> Optional[Union[float, str, List[float]]]:
"""
This method parses the response from Modelbit.
"""
if isinstance(modelbit_resp, list):
return cast(List[float], modelbit_resp)
if isinstance(modelbit_resp, bool):
return float(modelbit_resp)
if isinstance(modelbit_resp, dict):
return self.transform_dict_response(modelbit_resp)
return modelbit_resp

def transform_dict_response(
self,
modelbit_resp: Dict[str, Any],
) -> Optional[Union[float, str, List[float]]]:
"""
This method parses the response from Modelbit and return the data format that's supported by wyvern.
"""
raise NotImplementedError

0 comments on commit e201dac

Please sign in to comment.