Skip to content

Convert Modeldata to Dict #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def load(modelData):
# during runtime.

# Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
modelData.user_data['payload'] = "Loading has been completed."
modelData['payload'] = "Loading has been completed."
return modelData


Expand Down Expand Up @@ -176,9 +176,9 @@ def infer_image(image_url, n, globals):

def load(modelData):

modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
modelData.user_data["model"] = load_model(modelData.get_model("squeezenet"))
modelData.user_data["labels"] = load_labels(modelData.get_model("labels"))
modelData["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
modelData["model"] = load_model(modelData.get_model("squeezenet"))
modelData["labels"] = load_labels(modelData.get_model("labels"))
return modelData


Expand All @@ -190,10 +190,10 @@ def apply(input, modelData):
n = 3
if "data" in input:
if isinstance(input["data"], str):
output = infer_image(input["data"], n, modelData.user_data)
output = infer_image(input["data"], n, modelData)
elif isinstance(input["data"], list):
for row in input["data"]:
row["predictions"] = infer_image(row["image_url"], n, modelData.user_data)
row["predictions"] = infer_image(row["image_url"], n, modelData)
output = input["data"]
else:
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
Expand Down Expand Up @@ -257,4 +257,4 @@ Verify that it works on pytest, then:
```commandline
python -m twine upload -r pypi dist/*
```
and you're done :)
and you're done :)
19 changes: 17 additions & 2 deletions adk/modeldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,23 @@ def __init__(self, client, model_manifest_path):
self.manifest_data = get_manifest(self.manifest_freeze_path)
self.client = client
self.models = {}
self.user_data = {}
self.system_data = {}
self.usr_key = "__user__"

def __getitem__(self, key):
return getattr(self, self.usr_key + key)

def __setitem__(self, key, value):
setattr(self, self.usr_key + key, value)

def data(self):
__dict = self.__dict__
output = {}
for key in __dict.keys():
if self.usr_key in key:
without_usr_key = key.split(self.usr_key)[1]
output[without_usr_key] = __dict[key]
return output


def available(self):
if self.manifest_data:
Expand Down
2 changes: 1 addition & 1 deletion examples/loaded_state_hello_world/src/Algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def load(modelData):
# during runtime.

# Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
modelData.user_data['payload'] = "Loading has been completed."
modelData['payload'] = "Loading has been completed."
return modelData


Expand Down
10 changes: 5 additions & 5 deletions examples/pytorch_image_classification/src/Algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def infer_image(image_url, n, globals):

def load(modelData):

modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
modelData.user_data["model"] = load_model(modelData.get_model("squeezenet"))
modelData.user_data["labels"] = load_labels(modelData.get_model("labels"))
modelData["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
modelData["model"] = load_model(modelData.get_model("squeezenet"))
modelData["labels"] = load_labels(modelData.get_model("labels"))
return modelData


Expand All @@ -67,10 +67,10 @@ def apply(input, modelData):
n = 3
if "data" in input:
if isinstance(input["data"], str):
output = infer_image(input["data"], n, modelData.user_data)
output = infer_image(input["data"], n, modelData)
elif isinstance(input["data"], list):
for row in input["data"]:
row["predictions"] = infer_image(row["image_url"], n, modelData.user_data)
row["predictions"] = infer_image(row["image_url"], n, modelData)
output = input["data"]
else:
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
Expand Down
14 changes: 7 additions & 7 deletions tests/adk_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def apply_binary(input):

def apply_input_or_context(input, model_data=None):
if model_data:
return model_data.user_data
return model_data.data()
else:
return "hello " + input

Expand All @@ -30,7 +30,7 @@ def apply_successful_manifest_parsing(input, model_data):

# -- Loading functions --- #
def loading_text(modelData):
modelData.user_data['message'] = 'This message was loaded prior to runtime'
modelData['message'] = 'This message was loaded prior to runtime'
return modelData


Expand All @@ -39,14 +39,14 @@ def loading_exception(modelData):


def loading_file_from_algorithmia(modelData):
modelData.user_data['data_url'] = 'data://demo/collection/somefile.json'
modelData.user_data['data'] = modelData.client.file(modelData.user_data['data_url']).getJson()
modelData['data_url'] = 'data://demo/collection/somefile.json'
modelData['data'] = modelData.client.file(modelData['data_url']).getJson()
return modelData


def loading_with_manifest(modelData):
modelData.user_data["squeezenet"] = modelData.get_model("squeezenet")
modelData.user_data['labels'] = modelData.get_model("labels")
modelData["squeezenet"] = modelData.get_model("squeezenet")
modelData['labels'] = modelData.get_model("labels")
# optional model
modelData.user_data['mobilenet'] = modelData.get_model("mobilenet")
modelData['mobilenet'] = modelData.get_model("mobilenet")
return modelData