Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Cannot load model #1813

@LiJell

Description

@LiJell

🐛 Describe the bug

I am trying to deploy locally pretrained model via sagemaker to make a endpoint and use it.

I deployed a model

from sagemaker.pytorch import PyTorchModel

pytorch_model = PyTorchModel(model_data='model.tar.gz',
role=role,
entry_point='inference.py',
framework_version="1.9.0",
py_version="py38")

predictor = pytorch_model.deploy(instance_type='ml.g4dn.xlarge', initial_instance_count=1)

and predict data

from PIL import Image

data = Image.open('./samples/inputs/1.jpg')
result = predictor.predict(data)
img = Image.open(result)
img.show()

as a result I got an error

ModelError Traceback (most recent call last)
/tmp/ipykernel_4268/3704626012.py in <cell line: 4>()
2
3 data = Image.open('./samples/inputs/1.jpg')
----> 4 result = predictor.predict(data)
5
6 img = Image.open(result)

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/sagemaker/predictor.py in predict(self, data, initial_args, target_model, target_variant, inference_id)
159 data, initial_args, target_model, target_variant, inference_id
160 )
--> 161 response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
162 return self._handle_response(response)
163

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/botocore/client.py in _api_call(self, *args, **kwargs)
506 )
507 # The "self" in this scope is referring to the BaseClient.
--> 508 return self._make_api_call(operation_name, kwargs)
509
510 _api_call.name = str(py_operation_name)

~/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/botocore/client.py in _make_api_call(self, operation_name, api_params)
913 error_code = parsed_response.get("Error", {}).get("Code")
914 error_class = self.exceptions.from_code(error_code)
--> 915 raise error_class(parsed_response, operation_name)
916 else:
917 return parsed_response

ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received server error (0) from primary with message "Your invocation timed out while waiting for a response from container primary. Review the latency metrics for each container in Amazon CloudWatch, resolve the issue, and try again.".

I skim through logs via CloudWatch, and still struggling with this. need a help.

Error logs


timestamp message logStreamName
1661327528194 2022-08-24 07:52:07,987 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Initializing plugins manager... AllTraffic/i-0b6f78248b097b6c7
1661327528194 2022-08-24 07:52:08,112 [INFO ] main org.pytorch.serve.ModelServer - AllTraffic/i-0b6f78248b097b6c7
1661327528194 Torchserve version: 0.4.2 AllTraffic/i-0b6f78248b097b6c7
1661327528194 TS Home: /opt/conda/lib/python3.8/site-packages AllTraffic/i-0b6f78248b097b6c7
1661327528194 Current directory: / AllTraffic/i-0b6f78248b097b6c7
1661327528194 Temp directory: /home/model-server/tmp AllTraffic/i-0b6f78248b097b6c7
1661327528194 Number of GPUs: 1 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Number of CPUs: 1 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Max heap size: 3234 M AllTraffic/i-0b6f78248b097b6c7
1661327528194 Python executable: /opt/conda/bin/python3.8 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Config file: /etc/sagemaker-ts.properties AllTraffic/i-0b6f78248b097b6c7
1661327528194 Inference address: http://0.0.0.0:8080 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Management address: http://0.0.0.0:8080 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Metrics address: http://127.0.0.1:8082 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Model Store: /.sagemaker/ts/models AllTraffic/i-0b6f78248b097b6c7
1661327528194 Initial Models: model.mar AllTraffic/i-0b6f78248b097b6c7
1661327528194 Log dir: /logs AllTraffic/i-0b6f78248b097b6c7
1661327528194 Metrics dir: /logs AllTraffic/i-0b6f78248b097b6c7
1661327528194 Netty threads: 0 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Netty client threads: 0 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Default workers per model: 1 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Blacklist Regex: N/A AllTraffic/i-0b6f78248b097b6c7
1661327528194 Maximum Response Size: 6553500 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Maximum Request Size: 6553500 AllTraffic/i-0b6f78248b097b6c7
1661327528194 Prefer direct buffer: false AllTraffic/i-0b6f78248b097b6c7
1661327528194 Allowed Urls: [file://.* http(s)?://.*]
1661327528194 Custom python dependency for model allowed: false AllTraffic/i-0b6f78248b097b6c7
1661327528194 Metrics report format: prometheus AllTraffic/i-0b6f78248b097b6c7
1661327528194 Enable metrics API: true AllTraffic/i-0b6f78248b097b6c7
1661327528194 Workflow Store: /.sagemaker/ts/models AllTraffic/i-0b6f78248b097b6c7
1661327528194 Model config: N/A AllTraffic/i-0b6f78248b097b6c7
1661327528194 2022-08-24 07:52:08,120 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Loading snapshot serializer plugin... AllTraffic/i-0b6f78248b097b6c7
1661327528444 2022-08-24 07:52:08,149 [INFO ] main org.pytorch.serve.ModelServer - Loading initial models: model.mar AllTraffic/i-0b6f78248b097b6c7
1661327528444 2022-08-24 07:52:08,353 [INFO ] main org.pytorch.serve.wlm.ModelManager - Model model loaded. AllTraffic/i-0b6f78248b097b6c7
1661327528694 2022-08-24 07:52:08,370 [INFO ] main org.pytorch.serve.ModelServer - Initialize Inference server with: EpollServerSocketChannel. AllTraffic/i-0b6f78248b097b6c7
1661327528694 2022-08-24 07:52:08,472 [INFO ] main org.pytorch.serve.ModelServer - Inference API bind to: http://0.0.0.0:8080 AllTraffic/i-0b6f78248b097b6c7
1661327528694 2022-08-24 07:52:08,473 [INFO ] main org.pytorch.serve.ModelServer - Initialize Metrics server with: EpollServerSocketChannel. AllTraffic/i-0b6f78248b097b6c7
1661327528944 2022-08-24 07:52:08,474 [INFO ] main org.pytorch.serve.ModelServer - Metrics API bind to: http://127.0.0.1:8082 AllTraffic/i-0b6f78248b097b6c7
1661327528944 Model server started. AllTraffic/i-0b6f78248b097b6c7
1661327528944 2022-08-24 07:52:08,738 [WARN ] pool-2-thread-1 org.pytorch.serve.metrics.MetricCollector - worker pid is not available yet. AllTraffic/i-0b6f78248b097b6c7
1661327528944 2022-08-24 07:52:08,786 [INFO ] pool-2-thread-1 TS_METRICS - CPUUtilization.Percent:0.0 #Level:Host
1661327528944 2022-08-24 07:52:08,787 [INFO ] pool-2-thread-1 TS_METRICS - DiskAvailable.Gigabytes:24.598094940185547 #Level:Host
1661327528944 2022-08-24 07:52:08,788 [INFO ] pool-2-thread-1 TS_METRICS - DiskUsage.Gigabytes:27.390167236328125 #Level:Host
1661327528944 2022-08-24 07:52:08,788 [INFO ] pool-2-thread-1 TS_METRICS - DiskUtilization.Percent:52.7 #Level:Host
1661327528944 2022-08-24 07:52:08,788 [INFO ] pool-2-thread-1 TS_METRICS - MemoryAvailable.Megabytes:14186.97265625 #Level:Host
1661327528944 2022-08-24 07:52:08,789 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUsed.Megabytes:1227.640625 #Level:Host
1661327529195 2022-08-24 07:52:08,789 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUtilization.Percent:9.9 #Level:Host
1661327529195 2022-08-24 07:52:09,004 [INFO ] W-9000-model_1-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000 AllTraffic/i-0b6f78248b097b6c7
1661327529195 2022-08-24 07:52:09,004 [INFO ] W-9000-model_1-stdout MODEL_LOG - [PID]32 AllTraffic/i-0b6f78248b097b6c7
1661327529195 2022-08-24 07:52:09,004 [INFO ] W-9000-model_1-stdout MODEL_LOG - Torch worker started. AllTraffic/i-0b6f78248b097b6c7
1661327529195 2022-08-24 07:52:09,004 [INFO ] W-9000-model_1-stdout MODEL_LOG - Python runtime: 3.8.10 AllTraffic/i-0b6f78248b097b6c7
1661327529195 2022-08-24 07:52:09,011 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Connecting to: /home/model-server/tmp/.ts.sock.9000 AllTraffic/i-0b6f78248b097b6c7
1661327529195 2022-08-24 07:52:09,021 [INFO ] W-9000-model_1-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000. AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,064 [INFO ] W-9000-model_1-stdout MODEL_LOG - model_name: model, batchSize: 1 AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,605 [INFO ] W-9000-model_1-stdout MODEL_LOG - Backend worker process died. AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,605 [INFO ] W-9000-model_1-stdout MODEL_LOG - Traceback (most recent call last): AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,606 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 183, in AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,606 [INFO ] W-9000-model_1-stdout MODEL_LOG - worker.run_server() AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,606 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 155, in run_server AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,607 [INFO ] epollEventLoopGroup-5-1 org.pytorch.serve.wlm.WorkerThread - 9000 Worker disconnected. WORKER_STARTED AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,607 [INFO ] W-9000-model_1-stdout MODEL_LOG - self.handle_connection(cl_socket) AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,608 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 117, in handle_connection AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,608 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.BatchAggregator - Load model failed: model, error: Worker died. AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,608 [INFO ] W-9000-model_1-stdout MODEL_LOG - service, result, code = self.load_model(msg) AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,609 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stderr AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,609 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stdout AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,610 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 90, in load_model AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,610 [INFO ] W-9000-model_1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stdout AllTraffic/i-0b6f78248b097b6c7
1661327529695 2022-08-24 07:52:09,610 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Retry worker: 9000 in 1 seconds. AllTraffic/i-0b6f78248b097b6c7
1661327531196 2022-08-24 07:52:09,628 [INFO ] W-9000-model_1-stderr org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stderr AllTraffic/i-0b6f78248b097b6c7
1661327531196 2022-08-24 07:52:11,192 [INFO ] W-9000-model_1-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000 AllTraffic/i-0b6f78248b097b6c7
1661327531196 2022-08-24 07:52:11,193 [INFO ] W-9000-model_1-stdout MODEL_LOG - [PID]52 AllTraffic/i-0b6f78248b097b6c7
1661327531196 2022-08-24 07:52:11,193 [INFO ] W-9000-model_1-stdout MODEL_LOG - Torch worker started. AllTraffic/i-0b6f78248b097b6c7
1661327531196 2022-08-24 07:52:11,193 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Connecting to: /home/model-server/tmp/.ts.sock.9000 AllTraffic/i-0b6f78248b097b6c7
1661327531196 2022-08-24 07:52:11,194 [INFO ] W-9000-model_1-stdout MODEL_LOG - Python runtime: 3.8.10 AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,195 [INFO ] W-9000-model_1-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000. AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,212 [INFO ] W-9000-model_1-stdout MODEL_LOG - model_name: model, batchSize: 1 AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,368 [INFO ] W-9000-model_1-stdout MODEL_LOG - Backend worker process died. AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,368 [INFO ] epollEventLoopGroup-5-2 org.pytorch.serve.wlm.WorkerThread - 9000 Worker disconnected. WORKER_STARTED AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,368 [INFO ] W-9000-model_1-stdout MODEL_LOG - Traceback (most recent call last): AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,369 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.BatchAggregator - Load model failed: model, error: Worker died. AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,371 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stderr AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,371 [WARN ] W-9000-model_1 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-model_1-stdout AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,371 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Retry worker: 9000 in 1 seconds. AllTraffic/i-0b6f78248b097b6c7
1661327531446 2022-08-24 07:52:11,371 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 183, in AllTraffic/i-0b6f78248b097b6c7
1661327531696 2022-08-24 07:52:11,372 [INFO ] W-9000-model_1-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stdout AllTraffic/i-0b6f78248b097b6c7
1661327531696 2022-08-24 07:52:11,665 [INFO ] W-9000-model_1 ACCESS_LOG - /169.254.178.2:35288 "GET /ping HTTP/1.1" 200 15 AllTraffic/i-0b6f78248b097b6c7
1661327531696 2022-08-24 07:52:11,666 [INFO ] W-9000-model_1 TS_METRICS - Requests2XX.Count:1 #Level:Host
1661327532947 2022-08-24 07:52:11,673 [INFO ] W-9000-model_1-stderr org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9000-model_1-stderr AllTraffic/i-0b6f78248b097b6c7
1661327532947 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000 AllTraffic/i-0b6f78248b097b6c7
1661327532947 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1-stdout MODEL_LOG - [PID]65 AllTraffic/i-0b6f78248b097b6c7
1661327532947 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1-stdout MODEL_LOG - Torch worker started. AllTraffic/i-0b6f78248b097b6c7
1661327532947 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1 org.pytorch.serve.wlm.WorkerThread - Connecting to: /home/model-server/tmp/.ts.sock.9000 AllTraffic/i-0b6f78248b097b6c7
1661327532947 2022-08-24 07:52:12,892 [INFO ] W-9000-model_1-stdout MODEL_LOG - Python runtime: 3.8.10 AllTraffic/i-0b6f78248b097b6c7
1661327532947 2022-08-24 07:52:12,893 [INFO ] W-9000-model_1-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000. AllTraffic/i-0b6f78248b097b6c7
1661327533197 2022-08-24 07:52:12,894 [INFO ] W-9000-model_1-stdout MODEL_LOG - model_name: model, batchSize: 1 AllTraffic/i-0b6f78248b097b6c7
1661327533197 2022-08-24 07:52:13,026 [INFO ] W-9000-model_1-stdout MODEL_LOG - Backend worker process died. AllTraffic/i-0b6f78248b097b6c7
1661327533197 2022-08-24 07:52:13,026 [INFO ] W-9000-model_1-stdout MODEL_LOG - Traceback (most recent call last): AllTraffic/i-0b6f78248b097b6c7
1661327533197 2022-08-24 07:52:13,027 [INFO ] W-9000-model_1-stdout MODEL_LOG - File "/opt/conda/lib/python3.8/site-packages/ts/model_service_worker.py", line 183, in AllTraffic/i-0b6f78248b097b6c7
1661327533197 2022-08-24 07:52:13,027 [INFO ] W-9000-model_1-stdout MODEL_LOG - worker.run_server() AllTraffic/i-0b6f78248b097b6c7

Installation instructions

I am using sagemaker

Model Packaing

`from sagemaker.pytorch import PyTorchModel

pytorch_model = PyTorchModel(model_data='model.tar.gz',
role=role,
entry_point='inference.py',
framework_version="1.9.0",
py_version="py38")`

config.properties

No response

Versions

framework_version="1.9.0",
py_version="py38"
Torchserve version: 0.4.2
working on conda_pytorch_p38 sagemaker notebook instance

Repro instructions

inference file that I wrote
class ConvNormLReLU(nn.Sequential):
def init(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False):

    pad_layer = {
        "zero":    nn.ZeroPad2d,
        "same":    nn.ReplicationPad2d,
        "reflect": nn.ReflectionPad2d,
    }
    if pad_mode not in pad_layer:
        raise NotImplementedError
        
    super(ConvNormLReLU, self).__init__(
        pad_layer[pad_mode](padding),
        nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias),
        nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True),
        nn.LeakyReLU(0.2, inplace=True)
    )

class InvertedResBlock(nn.Module):
def init(self, in_ch, out_ch, expansion_ratio=2):
super(InvertedResBlock, self).init()

    self.use_res_connect = in_ch == out_ch
    bottleneck = int(round(in_ch*expansion_ratio))
    layers = []
    if expansion_ratio != 1:
        layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0))
    
    # dw
    layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True))
    # pw
    layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False))
    layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True))

    self.layers = nn.Sequential(*layers)
    
def forward(self, input):
    out = self.layers(input)
    if self.use_res_connect:
        out = input + out
    return out

class Generator(nn.Module):
def init(self, ):
super().init()

    self.block_a = nn.Sequential(
        ConvNormLReLU(3,  32, kernel_size=7, padding=3),
        ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)),
        ConvNormLReLU(64, 64)
    )
    
    self.block_b = nn.Sequential(
        ConvNormLReLU(64,  128, stride=2, padding=(0,1,0,1)),            
        ConvNormLReLU(128, 128)
    )
    
    self.block_c = nn.Sequential(
        ConvNormLReLU(128, 128),
        InvertedResBlock(128, 256, 2),
        InvertedResBlock(256, 256, 2),
        InvertedResBlock(256, 256, 2),
        InvertedResBlock(256, 256, 2),
        ConvNormLReLU(256, 128),
    )    
    
    self.block_d = nn.Sequential(
        ConvNormLReLU(128, 128),
        ConvNormLReLU(128, 128)
    )

    self.block_e = nn.Sequential(
        ConvNormLReLU(128, 64),
        ConvNormLReLU(64,  64),
        ConvNormLReLU(64,  32, kernel_size=7, padding=3)
    )

    self.out_layer = nn.Sequential(
        nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False),
        nn.Tanh()
    )
    
def forward(self, input, align_corners=True):
    out = self.block_a(input)
    half_size = out.size()[-2:]
    out = self.block_b(out)
    out = self.block_c(out)
    
    if align_corners:
        out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
    else:
        out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
    out = self.block_d(out)

    if align_corners:
        out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
    else:
        out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
    out = self.block_e(out)

    out = self.out_layer(out)
    return out

def model_fn(model_dir):
"""Load the model and return it.
Providing this function is optional.
There is a default_model_fn available, which will load the model
compiled using SageMaker Neo. You can override the default here.
The model_fn only needs to be defined if your model needs extra
steps to load, and can otherwise be left undefined.

Keyword arguments:
model_dir -- the directory path where the model artifacts are present
"""        

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# The compiled model is saved as "model.pt"
model = Generator()
model_path = os.path.join(model_dir, 'model.pt')
with open(os.path.join(model_path, 'model.pt'), 'rb') as f:
    model.load_state_dict(torch.load(f))
    
model.to(device).eval()


return model

def transform_fn(model, request_body, request_content_type='image/', response_content_type='image/'):
image_format = "png" #@param ["jpeg", "png"]
"""Run prediction and return the output.
The function
1. Pre-processes the input request
2. Runs prediction
3. Post-processes the prediction output.
"""
# preprocess
img_in = Image.open(io.BytesIO(request_body)).convert("RGB")

# predict
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
im_out = model(img_in)
buffer_out = BytesIO()
im_out.save(buffer_out, format=image_format)
out = buffer_out.getvalue()

return out, response_content_type

Possible Solution

No response

Metadata

Metadata

Assignees

Labels

triaged_waitWaiting for the Reporter's resp

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions