Skip to content

Commit

Permalink
[DLMED] add more features
Browse files Browse the repository at this point in the history
Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma committed Jan 17, 2022
1 parent 88b13ec commit 1b01dad
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python ../scripts/inference.py
--base_config ../configs/inference.json
--config ../configs/inference_v2.json
--meta ../configs/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
python ../scripts/inference.py
--config ../configs/trtinfer.json
--meta ../configs/metadata.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"multi_gpu": false,
"amp": true,
"model": "monai.data.load_net_with_metadata('../models/model.ts')[0]",
"model": "#monai.data.load_net_with_metadata('../models/model.ts')[0]",
"network": {
"name": "UNet",
"args": {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"amp": false,
"network": {
"name": "UNet",
"args": {
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"channels": [32, 64, 128, 256, 512],
"strides": [2, 2, 2, 2],
"num_res_units": 2,
"norm": "group"
}
},
"inferer": {
"name": "SlidingWindowInferer",
"args": {
"roi_size": [96, 96, 96],
"sw_batch_size": 4,
"overlap": 0.6
}
}
}
19 changes: 19 additions & 0 deletions modules/model_package/spleen_segmentation/configs/trtinfer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"preprocessing": {
"ref": {
"path": "../inference.json/preprocessing"
}
},
"dataset": {
"ref": {
"path": "../inference.json/dataset"
}
},
"model": "#load_trt_model(...)",
"dataloader": {
"name": "DALIpipeline"
},
"inferer": {
"name": "TensorRTInferer"
}
}
9 changes: 5 additions & 4 deletions modules/model_package/spleen_segmentation/scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json

import torch
from monai.apps import ConfigParser
from ignite.handlers import Checkpoint
from monai.data import save_net_with_metadata
from monai.networks import convert_to_torchscript
Expand All @@ -28,15 +29,15 @@ def main():

# load config file
with open(args.config, "r") as f:
cofnig_dict = json.load(f)
config_dict = json.load(f)
# load meta data
with open(args.meta, "r") as f:
meta_dict = json.load(f)

net: torch.nn.Module = None
# TODO: parse network definiftion from config file and construct network instance
# config_parser = ConfigParser(config_dict, meta_dict)
# net = config_parser.get_component("network")
config_parser = ConfigParser(config_dict)
net = config_parser.get_instance("network")

checkpoint = torch.load(args.weights)
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
Expand All @@ -51,7 +52,7 @@ def main():
include_config_vals=False,
append_timestamp=False,
meta_values=meta_dict,
more_extra_files={args.config: json.dumps(cofnig_dict).encode()},
more_extra_files={args.config: json.dumps(config_dict).encode()},
)


Expand Down
36 changes: 22 additions & 14 deletions modules/model_package/spleen_segmentation/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json

import torch
from monai.apps import ConfigParser
from monai.data import decollate_batch
from monai.inferers import Inferer
from monai.transforms import Transform
Expand All @@ -22,38 +23,45 @@

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str, help='file path of config file that defines network', required=True)
parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True)
parser.add_argument('--meta', '-e', type=str, help='file path of the meta data')
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
configs = {}

# load config file
with open(args.config, "r") as f:
cofnig_dict = json.load(f)
# load meta data
with open(args.meta, "r") as f:
meta_dict = json.load(f)
configs.update(json.load(f))
# load config file, can override meta data in config
with open(args.config, "r") as f:
configs.update(json.load(f))

net: torch.nn.Module = None
model: torch.nn.Module = None
dataloader: torch.utils.data.DataLoader = None
inferer: Inferer = None
postprocessing: Transform = None
# TODO: parse inference config file and construct instances
# config_parser = ConfigParser(config_dict, meta_dict)
# net = config_parser.get_component("model").to(device)
# dataloader = config_parser.get_component("dataloader")
# inferer = config_parser.get_component("inferer")
# postprocessing = config_parser.get_component("postprocessing")
config_parser = ConfigParser(configs)

# change JSON config content in python code, lazy instantiation
model_conf = config_parser.get_config("model")
model_conf["disabled"] = False
model = config_parser.build(model_conf).to(device)

# instantialize the components immediately
dataloader = config_parser.get_instance("dataloader")
inferer = config_parser.get_instance("inferer")
postprocessing = config_parser.get_instance("postprocessing")

net.eval()
model.eval()
with torch.no_grad():
for d in dataloader:
images = d[CommonKeys.IMAGE].to(device)
# define sliding window size and batch size for windows inference
d[CommonKeys.PRED] = inferer(inputs=images, predictor=net)
d[CommonKeys.PRED] = inferer(inputs=images, predictor=model)
# decollate the batch data into a list of dictionaries, then execute postprocessing transforms
d = [postprocessing(i) for i in decollate_batch(d)]
[postprocessing(i) for i in decollate_batch(d)]


if __name__ == '__main__':
Expand Down
67 changes: 67 additions & 0 deletions modules/model_package/spleen_segmentation/scripts/inference_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json

import torch
from monai.apps import ConfigParser
from monai.data import decollate_batch
from monai.inferers import Inferer
from monai.transforms import Transform
from monai.utils.enums import CommonKeys


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--base_config', '-c', type=str, help='file path of base config', required=False)
parser.add_argument('--config', '-c', type=str, help='config file to override base config', required=True)
parser.add_argument('--meta', '-e', type=str, help='file path of the meta data')
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
configs = {}

# load meta data
with open(args.meta, "r") as f:
configs.update(json.load(f))
# load base config file, can override meta data in config
with open(args.base_config, "r") as f:
configs.update(json.load(f))
# load config file, add or override the content of base config
with open(args.config, "r") as f:
configs.update(json.load(f))

model: torch.nn.Module = None
dataloader: torch.utils.data.DataLoader = None
inferer: Inferer = None
postprocessing: Transform = None
# TODO: parse inference config file and construct instances
config_parser = ConfigParser(configs)
# instantialize the components immediately
model = config_parser.get_instance("model").to(device)
dataloader = config_parser.get_instance("dataloader")
inferer = config_parser.get_instance("inferer")
postprocessing = config_parser.get_instance("postprocessing")

model.eval()
with torch.no_grad():
for d in dataloader:
images = d[CommonKeys.IMAGE].to(device)
# define sliding window size and batch size for windows inference
d[CommonKeys.PRED] = inferer(inputs=images, predictor=model)
# decollate the batch data into a list of dictionaries, then execute postprocessing transforms
[postprocessing(i) for i in decollate_batch(d)]


if __name__ == '__main__':
main()
59 changes: 59 additions & 0 deletions modules/model_package/spleen_segmentation/scripts/trtinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json

import torch
from monai.apps import ConfigParser
from monai.data import decollate_batch
from monai.transforms import Transform


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True)
parser.add_argument('--meta', '-e', type=str, help='file path of the meta data')
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
configs = {}

# load meta data
with open(args.meta, "r") as f:
configs.update(json.load(f))
# load config file, can override meta data in config
with open(args.config, "r") as f:
configs.update(json.load(f))

# fake code to simulate TensorRT and DALI logic
model: TRTModel = None
dataloader: DALIpipeline = None
inferer: TRTInfer = None
postprocessing: Transform = None
# TODO: parse inference config file and construct instances
config_parser = ConfigParser(configs)

# instantialize the components immediately
model = config_parser.get_instance("model").to(device)
dataloader = config_parser.get_instance("dataloader")
inferer = config_parser.get_instance("inferer")
postprocessing = config_parser.get_instance("postprocessing")

# simuluate TensorRT and DALI logic
for d in dataloader:
r = inferer(inputs=d, predictor=model)
[postprocessing(i) for i in decollate_batch(r)]


if __name__ == '__main__':
main()

0 comments on commit 1b01dad

Please sign in to comment.