-
Notifications
You must be signed in to change notification settings - Fork 886
Cannot load model #1813
Description
🐛 Describe the bug
I am trying to deploy locally pretrained model via sagemaker to make a endpoint and use it.
I deployed a model
from sagemaker.pytorch import PyTorchModel
pytorch_model = PyTorchModel(model_data='model.tar.gz',
role=role,
entry_point='inference.py',
framework_version="1.9.0",
py_version="py38")
predictor = pytorch_model.deploy(instance_type='ml.g4dn.xlarge', initial_instance_count=1)
and predict data
from PIL import Image
data = Image.open('./samples/inputs/1.jpg')
result = predictor.predict(data)
img = Image.open(result)
img.show()
as a result I got an error
ModelError Traceback (most recent call last)
/tmp/ipykernel_4268/3704626012.py in <cell line: 4>()
2
3 data = Image.open('./samples/inputs/1.jpg')
----> 4 result = predictor.predict(data)
5
6 img = Image.open(result)
~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/sagemaker/predictor.py in predict(self, data, initial_args, target_model, target_variant, inference_id)
159 data, initial_args, target_model, target_variant, inference_id
160 )
--> 161 response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
162 return self._handle_response(response)
163
~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/botocore/client.py in _api_call(self, *args, **kwargs)
506 )
507 # The "self" in this scope is referring to the BaseClient.
--> 508 return self._make_api_call(operation_name, kwargs)
509
510 _api_call.name = str(py_operation_name)
~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/botocore/client.py in _make_api_call(self, operation_name, api_params)
913 error_code = parsed_response.get("Error", {}).get("Code")
914 error_class = self.exceptions.from_code(error_code)
--> 915 raise error_class(parsed_response, operation_name)
916 else:
917 return parsed_response
ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received server error (0) from primary with message "Your invocation timed out while waiting for a response from container primary. Review the latency metrics for each container in Amazon CloudWatch, resolve the issue, and try again.".
I skim through logs via CloudWatch, and still struggling with this. need a help.
Error logs
timestamp | message | logStreamName |
---|---|---|
1661327528194 | 2022-08-24 07:52:07,987 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Initializing plugins manager... | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | 2022-08-24 07:52:08,112 [INFO ] main org.pytorch.serve.ModelServer - | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Torchserve version: 0.4.2 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | TS Home: /opt/conda/lib/python3.8/site-packages | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Current directory: / | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Temp directory: /home/model-server/tmp | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Number of GPUs: 1 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Number of CPUs: 1 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Max heap size: 3234 M | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Python executable: /opt/conda/bin/python3.8 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Config file: /etc/sagemaker-ts.properties | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Inference address: http://0.0.0.0:8080 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Management address: http://0.0.0.0:8080 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Metrics address: http://127.0.0.1:8082 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Model Store: /.sagemaker/ts/models | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Initial Models: model.mar | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Log dir: /logs | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Metrics dir: /logs | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Netty threads: 0 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Netty client threads: 0 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Default workers per model: 1 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Blacklist Regex: N/A | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Maximum Response Size: 6553500 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Maximum Request Size: 6553500 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Prefer direct buffer: false | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Allowed Urls: [file://.* | http(s)?://.*] |
1661327528194 | Custom python dependency for model allowed: false | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Metrics report format: prometheus | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Enable metrics API: true | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Workflow Store: /.sagemaker/ts/models | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | Model config: N/A | AllTraffic/i-0b6f78248b097b6c7 |
1661327528194 | 2022-08-24 07:52:08,120 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Loading snapshot serializer plugin... | AllTraffic/i-0b6f78248b097b6c7 |
1661327528444 | 2022-08-24 07:52:08,149 [INFO ] main org.pytorch.serve.ModelServer - Loading initial models: model.mar | AllTraffic/i-0b6f78248b097b6c7 |
1661327528444 | 2022-08-24 07:52:08,353 [INFO ] main org.pytorch.serve.wlm.ModelManager - Model model loaded. | AllTraffic/i-0b6f78248b097b6c7 |
1661327528694 | 2022-08-24 07:52:08,370 [INFO ] main org.pytorch.serve.ModelServer - Initialize Inference server with: EpollServerSocketChannel. | AllTraffic/i-0b6f78248b097b6c7 |
1661327528694 | 2022-08-24 07:52:08,472 [INFO ] main org.pytorch.serve.ModelServer - Inference API bind to: http://0.0.0.0:8080 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528694 | 2022-08-24 07:52:08,473 [INFO ] main org.pytorch.serve.ModelServer - Initialize Metrics server with: EpollServerSocketChannel. | AllTraffic/i-0b6f78248b097b6c7 |
1661327528944 | 2022-08-24 07:52:08,474 [INFO ] main org.pytorch.serve.ModelServer - Metrics API bind to: http://127.0.0.1:8082 | AllTraffic/i-0b6f78248b097b6c7 |
1661327528944 | Model server started. | AllTraffic/i-0b6f78248b097b6c7 |
1661327528944 | 2022-08-24 07:52:08,738 [WARN ] pool-2-thread-1 org.pytorch.serve.metrics.MetricCollector - worker pid is not available yet. | AllTraffic/i-0b6f78248b097b6c7 |
1661327528944 | 2022-08-24 07:52:08,786 [INFO ] pool-2-thread-1 TS_METRICS - CPUUtilization.Percent:0.0 | #Level:Host |
1661327528944 | 2022-08-24 07:52:08,787 [INFO ] pool-2-thread-1 TS_METRICS - DiskAvailable.Gigabytes:24.598094940185547 | #Level:Host |
1661327528944 | 2022-08-24 07:52:08,788 [INFO ] pool-2-thread-1 TS_METRICS - DiskUsage.Gigabytes:27.390167236328125 | #Level:Host |
1661327528944 | 2022-08-24 07:52:08,788 [INFO ] pool-2-thread-1 TS_METRICS - DiskUtilization.Percent:52.7 | #Level:Host |
1661327528944 | 2022-08-24 07:52:08,788 [INFO ] pool-2-thread-1 TS_METRICS - MemoryAvailable.Megabytes:14186.97265625 | #Level:Host |
1661327528944 | 2022-08-24 07:52:08,789 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUsed.Megabytes:1227.640625 | #Level:Host |
1661327529195 | 2022-08-24 07:52:08,789 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUtilization.Percent:9.9 | #Level:Host |
1661327529195 | 2022-08-24 07:52:09,004 [INFO ] W-9000-model_1-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000 | AllTraffic/i-0b6f78248b097b6c7 |
1661327529195 | 2022-08-24 07:52:09,004 [INFO ] W-9000-model_1-stdout MODEL_LOG - [PID]32 | AllTraffic/i-0b6f78248b097b6c7 |
1661327529195 | 2022-08-24 07:52:09,004 [INFO ] W-9000-model_1-stdout MODEL_LOG - Torch worker started. | AllTraffic/i-0b6f78248b097b6c7 |
1661327529195 | 2022-08-24 07:52:09,004 [INFO ] W-9000-model_1-stdout MODEL_LOG - Python runtime: 3.8.10 | AllTraffic/i-0b6f78248b097b6c7 |
1661327529195 | 2022-08-24 07:52:09,011 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Connecting to: /home/model-server/tmp/.ts.sock.9000 | AllTraffic/i-0b6f78248b097b6c7 |
1661327529195 | 2022-08-24 07:52:09,021 [INFO ] W-9000-model_1-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000. | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,064 [INFO ] W-9000-model_1-stdout MODEL_LOG - model_name: model, batchSize: 1 | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,605 [INFO ] W-9000-model_1-stdout MODEL_LOG - Backend worker process died. | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,605 [INFO ] W-9000-model_1-stdout MODEL_LOG - Traceback (most recent call last): | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,606 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 183, in | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,606 [INFO ] W-9000-model_1-stdout MODEL_LOG - worker.run_server() | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,606 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 155, in run_server | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,607 [INFO ] epollEventLoopGroup-5-1 org.pytorch.serve.wlm.WorkerThread - 9000 Worker disconnected. WORKER_STARTED | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,607 [INFO ] W-9000-model_1-stdout MODEL_LOG - self.handle_connection(cl_socket) | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,608 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 117, in handle_connection | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,608 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.BatchAggregator - Load model failed: model, error: Worker died. | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,608 [INFO ] W-9000-model_1-stdout MODEL_LOG - service, result, code = self.load_model(msg) | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,609 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stderr | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,609 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stdout | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,610 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 90, in load_model | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,610 [INFO ] W-9000-model_1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stdout | AllTraffic/i-0b6f78248b097b6c7 |
1661327529695 | 2022-08-24 07:52:09,610 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Retry worker: 9000 in 1 seconds. | AllTraffic/i-0b6f78248b097b6c7 |
1661327531196 | 2022-08-24 07:52:09,628 [INFO ] W-9000-model_1-stderr org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stderr | AllTraffic/i-0b6f78248b097b6c7 |
1661327531196 | 2022-08-24 07:52:11,192 [INFO ] W-9000-model_1-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000 | AllTraffic/i-0b6f78248b097b6c7 |
1661327531196 | 2022-08-24 07:52:11,193 [INFO ] W-9000-model_1-stdout MODEL_LOG - [PID]52 | AllTraffic/i-0b6f78248b097b6c7 |
1661327531196 | 2022-08-24 07:52:11,193 [INFO ] W-9000-model_1-stdout MODEL_LOG - Torch worker started. | AllTraffic/i-0b6f78248b097b6c7 |
1661327531196 | 2022-08-24 07:52:11,193 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Connecting to: /home/model-server/tmp/.ts.sock.9000 | AllTraffic/i-0b6f78248b097b6c7 |
1661327531196 | 2022-08-24 07:52:11,194 [INFO ] W-9000-model_1-stdout MODEL_LOG - Python runtime: 3.8.10 | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,195 [INFO ] W-9000-model_1-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000. | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,212 [INFO ] W-9000-model_1-stdout MODEL_LOG - model_name: model, batchSize: 1 | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,368 [INFO ] W-9000-model_1-stdout MODEL_LOG - Backend worker process died. | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,368 [INFO ] epollEventLoopGroup-5-2 org.pytorch.serve.wlm.WorkerThread - 9000 Worker disconnected. WORKER_STARTED | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,368 [INFO ] W-9000-model_1-stdout MODEL_LOG - Traceback (most recent call last): | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,369 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.BatchAggregator - Load model failed: model, error: Worker died. | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,371 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stderr | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,371 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stdout | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,371 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Retry worker: 9000 in 1 seconds. | AllTraffic/i-0b6f78248b097b6c7 |
1661327531446 | 2022-08-24 07:52:11,371 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 183, in | AllTraffic/i-0b6f78248b097b6c7 |
1661327531696 | 2022-08-24 07:52:11,372 [INFO ] W-9000-model_1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stdout | AllTraffic/i-0b6f78248b097b6c7 |
1661327531696 | 2022-08-24 07:52:11,665 [INFO ] W-9000-model_1 ACCESS_LOG - /169.254.178.2:35288 "GET /ping HTTP/1.1" 200 15 | AllTraffic/i-0b6f78248b097b6c7 |
1661327531696 | 2022-08-24 07:52:11,666 [INFO ] W-9000-model_1 TS_METRICS - Requests2XX.Count:1 | #Level:Host |
1661327532947 | 2022-08-24 07:52:11,673 [INFO ] W-9000-model_1-stderr org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stderr | AllTraffic/i-0b6f78248b097b6c7 |
1661327532947 | 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000 | AllTraffic/i-0b6f78248b097b6c7 |
1661327532947 | 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1-stdout MODEL_LOG - [PID]65 | AllTraffic/i-0b6f78248b097b6c7 |
1661327532947 | 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1-stdout MODEL_LOG - Torch worker started. | AllTraffic/i-0b6f78248b097b6c7 |
1661327532947 | 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Connecting to: /home/model-server/tmp/.ts.sock.9000 | AllTraffic/i-0b6f78248b097b6c7 |
1661327532947 | 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1-stdout MODEL_LOG - Python runtime: 3.8.10 | AllTraffic/i-0b6f78248b097b6c7 |
1661327532947 | 2022-08-24 07:52:12,893 [INFO ] W-9000-model_1-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000. | AllTraffic/i-0b6f78248b097b6c7 |
1661327533197 | 2022-08-24 07:52:12,894 [INFO ] W-9000-model_1-stdout MODEL_LOG - model_name: model, batchSize: 1 | AllTraffic/i-0b6f78248b097b6c7 |
1661327533197 | 2022-08-24 07:52:13,026 [INFO ] W-9000-model_1-stdout MODEL_LOG - Backend worker process died. | AllTraffic/i-0b6f78248b097b6c7 |
1661327533197 | 2022-08-24 07:52:13,026 [INFO ] W-9000-model_1-stdout MODEL_LOG - Traceback (most recent call last): | AllTraffic/i-0b6f78248b097b6c7 |
1661327533197 | 2022-08-24 07:52:13,027 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 183, in | AllTraffic/i-0b6f78248b097b6c7 |
1661327533197 | 2022-08-24 07:52:13,027 [INFO ] W-9000-model_1-stdout MODEL_LOG - worker.run_server() | AllTraffic/i-0b6f78248b097b6c7 |
Installation instructions
I am using sagemaker
Model Packaing
`from sagemaker.pytorch import PyTorchModel
pytorch_model = PyTorchModel(model_data='model.tar.gz',
role=role,
entry_point='inference.py',
framework_version="1.9.0",
py_version="py38")`
config.properties
No response
Versions
framework_version="1.9.0",
py_version="py38"
Torchserve version: 0.4.2
working on conda_pytorch_p38 sagemaker notebook instance
Repro instructions
inference file that I wrote
class ConvNormLReLU(nn.Sequential):
def init(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False):
pad_layer = {
"zero": nn.ZeroPad2d,
"same": nn.ReplicationPad2d,
"reflect": nn.ReflectionPad2d,
}
if pad_mode not in pad_layer:
raise NotImplementedError
super(ConvNormLReLU, self).__init__(
pad_layer[pad_mode](padding),
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias),
nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True),
nn.LeakyReLU(0.2, inplace=True)
)
class InvertedResBlock(nn.Module):
def init(self, in_ch, out_ch, expansion_ratio=2):
super(InvertedResBlock, self).init()
self.use_res_connect = in_ch == out_ch
bottleneck = int(round(in_ch*expansion_ratio))
layers = []
if expansion_ratio != 1:
layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0))
# dw
layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True))
# pw
layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False))
layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True))
self.layers = nn.Sequential(*layers)
def forward(self, input):
out = self.layers(input)
if self.use_res_connect:
out = input + out
return out
class Generator(nn.Module):
def init(self, ):
super().init()
self.block_a = nn.Sequential(
ConvNormLReLU(3, 32, kernel_size=7, padding=3),
ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)),
ConvNormLReLU(64, 64)
)
self.block_b = nn.Sequential(
ConvNormLReLU(64, 128, stride=2, padding=(0,1,0,1)),
ConvNormLReLU(128, 128)
)
self.block_c = nn.Sequential(
ConvNormLReLU(128, 128),
InvertedResBlock(128, 256, 2),
InvertedResBlock(256, 256, 2),
InvertedResBlock(256, 256, 2),
InvertedResBlock(256, 256, 2),
ConvNormLReLU(256, 128),
)
self.block_d = nn.Sequential(
ConvNormLReLU(128, 128),
ConvNormLReLU(128, 128)
)
self.block_e = nn.Sequential(
ConvNormLReLU(128, 64),
ConvNormLReLU(64, 64),
ConvNormLReLU(64, 32, kernel_size=7, padding=3)
)
self.out_layer = nn.Sequential(
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False),
nn.Tanh()
)
def forward(self, input, align_corners=True):
out = self.block_a(input)
half_size = out.size()[-2:]
out = self.block_b(out)
out = self.block_c(out)
if align_corners:
out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
else:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
out = self.block_d(out)
if align_corners:
out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
else:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
out = self.block_e(out)
out = self.out_layer(out)
return out
def model_fn(model_dir):
"""Load the model and return it.
Providing this function is optional.
There is a default_model_fn available, which will load the model
compiled using SageMaker Neo. You can override the default here.
The model_fn only needs to be defined if your model needs extra
steps to load, and can otherwise be left undefined.
Keyword arguments:
model_dir -- the directory path where the model artifacts are present
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# The compiled model is saved as "model.pt"
model = Generator()
model_path = os.path.join(model_dir, 'model.pt')
with open(os.path.join(model_path, 'model.pt'), 'rb') as f:
model.load_state_dict(torch.load(f))
model.to(device).eval()
return model
def transform_fn(model, request_body, request_content_type='image/', response_content_type='image/'):
image_format = "png" #@param ["jpeg", "png"]
"""Run prediction and return the output.
The function
1. Pre-processes the input request
2. Runs prediction
3. Post-processes the prediction output.
"""
# preprocess
img_in = Image.open(io.BytesIO(request_body)).convert("RGB")
# predict
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
im_out = model(img_in)
buffer_out = BytesIO()
im_out.save(buffer_out, format=image_format)
out = buffer_out.getvalue()
return out, response_content_type
Possible Solution
No response