Skip to content
This repository has been archived by the owner on Dec 18, 2024. It is now read-only.

Commit

Permalink
make code ready for release
Browse files Browse the repository at this point in the history
  • Loading branch information
ranftlr committed Mar 22, 2021
1 parent 2c4ae0b commit b1528a6
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 60 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,5 @@ venv.bak/
*.pfm
*.jpg
*.jpeg
*.pt
*.pt
.DS_Store
27 changes: 5 additions & 22 deletions midas/midas_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, path=None, features=256, backbone="vitb_rn50_384", monodepth=
self.dropout_rate = 0.0

self.groups = 1
self.expand = False

self.bn = False
if "batch_norm" in self.blocks and self.blocks["batch_norm"] == True:
Expand All @@ -70,28 +71,10 @@ def __init__(self, path=None, features=256, backbone="vitb_rn50_384", monodepth=
if ('shift' in self.blocks):
self.shift = self.blocks['shift']

self.features1=features
self.features2=features
self.features3=features
self.features4=features

features1=features
features2=features
features3=features
features4=features
self.expand = False
if "expand" in self.blocks and self.blocks['expand'] == True:
self.expand = True
features1=features
features2=features*2
features3=features*4
features4=features*8

self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups,
expand=self.expand, exportable=exportable, hooks = self.hooks, use_readout=self.use_readout)

print(" features(1,2,3,4) = ", features1, features2, features3, features4)

if "activation" not in self.blocks:
blocks['activation'] = None

Expand All @@ -106,10 +89,10 @@ def __init__(self, path=None, features=256, backbone="vitb_rn50_384", monodepth=
else:
self.scratch.activation = nn.Identity()

self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=self.bn, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=self.bn, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=self.bn, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=self.bn, align_corners=align_corners)
self.scratch.refinenet4 = FeatureFusionBlock_custom(features, self.scratch.activation, deconv=False, bn=self.bn, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet3 = FeatureFusionBlock_custom(features, self.scratch.activation, deconv=False, bn=self.bn, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet2 = FeatureFusionBlock_custom(features, self.scratch.activation, deconv=False, bn=self.bn, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet1 = FeatureFusionBlock_custom(features, self.scratch.activation, deconv=False, bn=self.bn, align_corners=align_corners)

if self.monodepth == True:
self.scratch.output_conv = nn.Sequential(
Expand Down
63 changes: 26 additions & 37 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,21 @@ def run(input_path, output_path, model_path, model_type="large", optimize=True):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: %s" % device)

net_w = net_h = 384

# load network
if model_type == "vit_large":
if model_type == "dpt_large":
model = MidasNet(model_path, backbone="vitl16_384", blocks={'hooks': [5, 11, 17, 23], 'use_readout': 'project', 'activation': 'relu'}, non_negative=True)
net_w, net_h = 384, 384
elif model_type == "vit_hybrid":
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid":
model = MidasNet(model_path, backbone="vitb_rn50_384", blocks={'hooks': [0, 1, 8, 11], 'use_readout': 'project', 'activation': 'relu'}, non_negative=True)
net_w, net_h = 384, 384
elif model_type == "large":
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
model = MidasNet_large(model_path, non_negative=True)
net_w, net_h = 384, 384
elif model_type == "small":
model = MidasNet(model_path, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True, 'activation': 'relu'})
net_w, net_h = 256, 256
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
else:
print(f"model_type '{model_type}' not implemented, use: --model_type large")
assert False

assert False, f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid|midas_v21]"

transform = Compose(
[
Resize(
Expand All @@ -52,26 +50,19 @@ def run(input_path, output_path, model_path, model_type="large", optimize=True):
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="upper_bound",
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if model_type=="large" or model_type=="small" else
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
normalization,
PrepareForNet(),
]
)

model.eval()

if optimize==True:
# rand_example = torch.rand(1, 3, net_h, net_w)
# model(rand_example)
# traced_script_module = torch.jit.trace(model, rand_example)
# model = traced_script_module

if device == torch.device("cuda"):
model = model.to(memory_format=torch.channels_last)
model = model.half()

if optimize == True and device == torch.device("cuda"):
model = model.to(memory_format=torch.channels_last)
model = model.half()

model.to(device)

Expand All @@ -83,11 +74,9 @@ def run(input_path, output_path, model_path, model_type="large", optimize=True):
os.makedirs(output_path, exist_ok=True)

print("start processing")

for ind, img_name in enumerate(img_names):

print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))

# input

img = utils.read_image(img_name)
Expand All @@ -96,9 +85,11 @@ def run(input_path, output_path, model_path, model_type="large", optimize=True):
# compute
with torch.no_grad():
sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
if optimize==True and device == torch.device("cuda"):
sample = sample.to(memory_format=torch.channels_last)

if optimize == True and device == torch.device("cuda"):
sample = sample.to(memory_format=torch.channels_last)
sample = sample.half()

prediction = model.forward(sample)
prediction = (
torch.nn.functional.interpolate(
Expand Down Expand Up @@ -135,15 +126,14 @@ def run(input_path, output_path, model_path, model_type="large", optimize=True):
)

parser.add_argument('-m', '--model_weights',
#default='model-f6b98070.pt',
default=None,
help='path to the trained weights of model'
help='path to model weights'
)

# 'large', 'small', 'vit_large', 'vit_hybrid'
parser.add_argument('-t', '--model_type',
default='vit_hybrid',
help='model type: large or small'
default='dpt_hybrid',
help='model type'
)

parser.add_argument('--optimize', dest='optimize', action='store_true')
Expand All @@ -153,10 +143,9 @@ def run(input_path, output_path, model_path, model_type="large", optimize=True):
args = parser.parse_args()

default_models = {
'large': 'model-f6b98070.pt',
'small': 'model-small-70d6b9c8.pt',
'vit_large': 'vit_large-2f21e586.pt',
'vit_hybrid': 'vit_hybrid-501f0c75.pt',
"midas_v21": "weights/midas_v21-f6b98070.pt",
"dpt_large": "weights/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "weights/dpt_hybrid-midas-501f0c75.pt",
}

if args.model_weights is None:
Expand Down
Empty file added weights/.placeholder
Empty file.

0 comments on commit b1528a6

Please sign in to comment.