Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
Merge pull request #6 from Wyvern-AI/shu/fix_modelbit_request_and_Wyv…
Browse files Browse the repository at this point in the history
…ernFeature

Fix modelbit component and support boolean WyvernFeature
  • Loading branch information
wintonzheng committed Aug 9, 2023
2 parents 5f5e63a + fa28cdc commit ccd2967
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "wyvern-ai"
version = "0.0.3"
version = "0.0.4"
description = ""
authors = ["Wyvern AI <info@wyvern.ai>"]
readme = "README.md"
Expand Down
16 changes: 7 additions & 9 deletions wyvern/components/models/modelbit_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def __init__(
self._auth_token = auth_token or self.AUTH_TOKEN
self._modelbit_url = url or self.URL
self.headers = {
"Authorization": f"Bearer {self._auth_token}",
"Authorization": self._auth_token,
"Content-Type": "application/json",
}

# TODO shu: test out the model_input_type
self.model_input_type = self.get_type_args_simple(0)
self.model_ouput_type = self.get_type_args_simple(1)
self.model_output_type = self.get_type_args_simple(1)

if not self._auth_token:
raise WyvernModelbitTokenMissingError()
Expand Down Expand Up @@ -87,12 +87,10 @@ async def build_requests(
all_requests = [
[
idx + 1,
{
"features": [
self.get_feature(identifier, feature_name)
for feature_name in self.modelbit_features
],
},
[
self.get_feature(identifier, feature_name)
for feature_name in self.modelbit_features
],
]
for idx, identifier in enumerate(target_identifiers)
]
Expand Down Expand Up @@ -139,7 +137,7 @@ async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:
target_identifiers[batch_idx * settings.MODELBIT_BATCH_SIZE + idx]
] = individual_output[1]

return self.model_ouput_type(
return self.model_output_type(
data=output_data,
model_name=self.name,
)
11 changes: 11 additions & 0 deletions wyvern/entities/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
from typing import Optional

from pydantic import PrivateAttr

from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernDataModel

Expand All @@ -9,8 +11,17 @@ class BaseWyvernRequest(WyvernDataModel):
request_id: str
include_events: Optional[bool] = False

_identifier: Identifier = PrivateAttr()

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._identifier = self.generate_identifier()

@property
def identifier(self) -> Identifier:
return self._identifier

def generate_identifier(self) -> Identifier:
return Identifier(
identifier=self.request_id,
identifier_type="request",
Expand Down
2 changes: 1 addition & 1 deletion wyvern/wyvern_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
REQUEST_SCHEMA = TypeVar("REQUEST_SCHEMA", bound=BaseModel)
RESPONSE_SCHEMA = TypeVar("RESPONSE_SCHEMA", bound=BaseModel)

WyvernFeature = Union[float, str, List[float], None]
WyvernFeature = Union[bool, float, str, List[float], None]

0 comments on commit ccd2967

Please sign in to comment.