Skip to content

Commit

Permalink
Merge 3f5d4a5 into fd8f1b3
Browse files Browse the repository at this point in the history
  • Loading branch information
mreso authored Mar 13, 2023
2 parents fd8f1b3 + 3f5d4a5 commit 710ead3
Show file tree
Hide file tree
Showing 13 changed files with 364 additions and 237 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dist/
*.egg-info/
.idea
*htmlcov*
.cache
.coverage
.github/actions/
.github/.DS_Store
Expand Down
1 change: 1 addition & 0 deletions examples/image_classifier/resnet_18/index_to_name.json

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions ts/torch_handler/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import shutil
import sys
from pathlib import Path

import pytest

from .models.base_model import save_pt_file
from .test_utils.mock_context import MockContext


@pytest.fixture()
def base_model_dir(tmp_path_factory):
model_dir = tmp_path_factory.mktemp("base_model_dir")

shutil.copyfile(
Path(__file__).parents[0] / "models" / "base_model.py", model_dir / "model.py"
)

save_pt_file(model_dir.joinpath("model.pt").as_posix())

sys.path.append(model_dir.as_posix())
yield model_dir
sys.path.pop()


@pytest.fixture()
def base_model_context(base_model_dir):

context = MockContext(
model_name="mnist",
model_dir=base_model_dir.as_posix(),
model_file="model.py",
)
yield context
10 changes: 8 additions & 2 deletions ts/torch_handler/unit_tests/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@

import torch


class ArgmaxModel(torch.nn.Module):
def forward(self, *input):
return torch.argmax(input[0], 1)

if __name__ == '__main__':

def save_pt_file(filepath="base_model.pt"):
model = ArgmaxModel()
torch.save(model.state_dict(), 'base_model.pt')
torch.save(model.state_dict(), filepath)


if __name__ == "__main__":
save_pt_file()
76 changes: 0 additions & 76 deletions ts/torch_handler/unit_tests/run_unit_tests.sh

This file was deleted.

28 changes: 11 additions & 17 deletions ts/torch_handler/unit_tests/test_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,28 @@
Ensures it can load and execute an example model
"""

import sys
import pytest

from ts.torch_handler.base_handler import BaseHandler
from .test_utils.mock_context import MockContext

sys.path.append('ts/torch_handler/unit_tests/models/tmp')

@pytest.fixture()
def model_context():
return MockContext()

def test_initialize(model_context):
def handler(base_model_context):
handler = BaseHandler()
handler.initialize(model_context)
handler.initialize(base_model_context)

assert(True)
return handler

def test_single_handle(model_context):
handler = test_initialize(model_context)

def test_single_handle(handler, base_model_context):
list_data = [[1.0, 2.0]]
processed = handler.handle(list_data, model_context)
processed = handler.handle(list_data, base_model_context)

assert processed == [1]

assert(processed == [1])

def test_batch_handle(model_context):
handler = test_initialize(model_context)
def test_batch_handle(handler, base_model_context):
list_data = [[1.0, 2.0], [4.0, 3.0]]
processed = handler.handle(list_data, model_context)
processed = handler.handle(list_data, base_model_context)

assert(processed == [1, 0])
assert processed == [1, 0]
71 changes: 30 additions & 41 deletions ts/torch_handler/unit_tests/test_envelopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,79 +5,68 @@
Ensures it can load and execute an example model
"""

import sys
import pytest

from ts.torch_handler.base_handler import BaseHandler
from ts.torch_handler.request_envelope.body import BodyEnvelope
from ts.torch_handler.request_envelope.json import JSONEnvelope
from .test_utils.mock_context import MockContext

sys.path.append('ts/torch_handler/unit_tests/models/tmp')

@pytest.fixture()
def model_context():
return MockContext()

@pytest.fixture()
def handle_fn():
ctx = MockContext()
def handle_fn(base_model_context):
handler = BaseHandler()
handler.initialize(ctx)
handler.initialize(base_model_context)

return handler.handle

def test_json(handle_fn, model_context):
test_data = [{'body':{
'instances': [[1.0, 2.0]]
}}]

def test_json(handle_fn, base_model_context):
test_data = [{"body": {"instances": [[1.0, 2.0]]}}]
expected_result = ['{"predictions": [1]}']

envelope = JSONEnvelope(handle_fn)
results = envelope.handle(test_data, model_context)
assert(results == expected_result)
results = envelope.handle(test_data, base_model_context)
assert results == expected_result


def test_json_batch(handle_fn, model_context):
test_data = [{'body':{
'instances': [[1.0, 2.0], [4.0, 3.0]]
}}]
def test_json_batch(handle_fn, base_model_context):
test_data = [{"body": {"instances": [[1.0, 2.0], [4.0, 3.0]]}}]
expected_result = ['{"predictions": [1, 0]}']

envelope = JSONEnvelope(handle_fn)
results = envelope.handle(test_data, model_context)
assert(results == expected_result)
results = envelope.handle(test_data, base_model_context)
assert results == expected_result

def test_json_double_batch(handle_fn, model_context):

def test_json_double_batch(handle_fn, base_model_context):
"""
More complex test case. Makes sure the model can
mux several batches and return the demuxed results
"""
test_data = [
{'body':{'instances': [[1.0, 2.0]]}},
{'body':{'instances': [[4.0, 3.0], [5.0, 6.0]]}}

{"body": {"instances": [[1.0, 2.0]]}},
{"body": {"instances": [[4.0, 3.0], [5.0, 6.0]]}},
]
expected_result = ['{"predictions": [1]}', '{"predictions": [0, 1]}']

envelope = JSONEnvelope(handle_fn)
results = envelope.handle(test_data, model_context)
results = envelope.handle(test_data, base_model_context)
print(results)
assert(results == expected_result)
assert results == expected_result

def test_body(handle_fn, model_context):
test_data = [{
'body':[1.0, 2.0]
}]

def test_body(handle_fn, base_model_context):
test_data = [{"body": [1.0, 2.0]}]
expected_result = [1]

envelope = BodyEnvelope(handle_fn)
results = envelope.handle(test_data, model_context)
assert(results == expected_result)
results = envelope.handle(test_data, base_model_context)
assert results == expected_result


def test_binary(model_context):
test_data = [{
'instances': [{'b64': 'YQ=='}]
}]
def test_binary(base_model_context):
test_data = [{"instances": [{"b64": "YQ=="}]}]

envelope = JSONEnvelope(lambda x, y: [row.decode('utf-8') for row in x])
results = envelope.handle(test_data, model_context)
assert(results == ['{"predictions": ["a"]}'])
envelope = JSONEnvelope(lambda x, y: [row.decode("utf-8") for row in x])
results = envelope.handle(test_data, base_model_context)
assert results == ['{"predictions": ["a"]}']
Loading

0 comments on commit 710ead3

Please sign in to comment.