From 27bbb99b5c02c45d60c0818b471a576f890151eb Mon Sep 17 00:00:00 2001 From: Anhtu Nguyen Date: Thu, 17 Nov 2022 15:37:52 +0100 Subject: [PATCH] change vb_fodler to vb_folder --- alonet/common/pl_helpers.py | 11 +++++++---- alonet/common/weights.py | 11 ++--------- alonet/deformable_detr/trt_exporter.py | 4 ++-- alonet/deformable_detr_panoptic/trt_exporter.py | 4 ++-- alonet/detr/trt_exporter.py | 4 ++-- alonet/detr_panoptic/trt_exporter.py | 4 ++-- 6 files changed, 17 insertions(+), 21 deletions(-) diff --git a/alonet/common/pl_helpers.py b/alonet/common/pl_helpers.py index dee86cb8..78551a7f 100644 --- a/alonet/common/pl_helpers.py +++ b/alonet/common/pl_helpers.py @@ -10,13 +10,16 @@ parser = ArgumentParser() -def vb_folder(): +def vb_folder(create_if_not_found=False): home = os.getenv("HOME") alofolder = os.path.join(home, ".aloception") if not os.path.exists(alofolder): - raise Exception( - f"{alofolder} do not exist. Please, create the folder with the appropriate files. (Checkout documentation)" - ) + if create_if_not_found: + os.mkdir(alofolder) + else: + raise Exception( + f"{alofolder} do not exist. Please, create the folder with the appropriate files. (Checkout documentation)" + ) return alofolder diff --git a/alonet/common/weights.py b/alonet/common/weights.py index fa88f60b..085d7d3b 100644 --- a/alonet/common/weights.py +++ b/alonet/common/weights.py @@ -1,6 +1,7 @@ import torch import requests import os +from alonet.common.pl_helpers import vb_folder WEIGHT_NAME_TO_FILES = { "detr-r50": ["https://storage.googleapis.com/visualbehavior-publicweights/detr-r50/detr-r50.pth"], @@ -30,14 +31,6 @@ } -def vb_fodler(): - home = os.getenv("HOME") - alofolder = os.path.join(home, ".aloception") - if not os.path.exists(alofolder): - os.mkdir(alofolder) - return alofolder - - def load_weights(model, weights, device, strict_load_weights=True): """Load and/or download weights from public cloud @@ -50,7 +43,7 @@ def load_weights(model, weights, device, strict_load_weights=True): device: torch.device Device to load the weights into """ - weights_dir = os.path.join(vb_fodler(), "weights") + weights_dir = os.path.join(vb_folder(), "weights") if not os.path.exists(weights_dir): os.makedirs(weights_dir) diff --git a/alonet/deformable_detr/trt_exporter.py b/alonet/deformable_detr/trt_exporter.py index 45f344d5..acccacdb 100644 --- a/alonet/deformable_detr/trt_exporter.py +++ b/alonet/deformable_detr/trt_exporter.py @@ -121,7 +121,7 @@ def prepare_sample_inputs(self): if __name__ == "__main__": # test script - from alonet.common.weights import vb_fodler + from alonet.common.pl_helpers import vb_folder load_trt_plugins_for_deformable_detr() device = torch.device("cuda") @@ -143,7 +143,7 @@ def prepare_sample_inputs(self): model = DeformableDetrR50(weights=model_name, tracing=True, aux_loss=False).eval() if args.onnx_path is None: - args.onnx_path = os.path.join(vb_fodler(), "weights", model_name, model_name + ".onnx") + args.onnx_path = os.path.join(vb_folder(), "weights", model_name, model_name + ".onnx") input_shape = [3] + list(args.HW) diff --git a/alonet/deformable_detr_panoptic/trt_exporter.py b/alonet/deformable_detr_panoptic/trt_exporter.py index 44eeba47..fa899ae6 100644 --- a/alonet/deformable_detr_panoptic/trt_exporter.py +++ b/alonet/deformable_detr_panoptic/trt_exporter.py @@ -83,7 +83,7 @@ def prepare_sample_inputs(self): if __name__ == "__main__": raise NotImplementedError("Not implemented yet") - from alonet.common.weights import vb_fodler + from alonet.common.pl_helpers import vb_folder from alonet.detr_panoptic import PanopticHead from alonet.detr import DetrR50 from alonet.detr.trt_exporter import DetrTRTExporter @@ -104,7 +104,7 @@ def prepare_sample_inputs(self): args = parser.parse_args() device = torch.device("cpu") if args.cpu else torch.device("cuda") input_shape = [3] + list(args.HW) - pan_onnx_path = os.path.join(vb_fodler(), "weights", "detr-r50-panoptic") + pan_onnx_path = os.path.join(vb_folder(), "weights", "detr-r50-panoptic") if args.split_engines: pan_onnx_path = os.path.join(pan_onnx_path, "panoptic-head.onnx") else: diff --git a/alonet/detr/trt_exporter.py b/alonet/detr/trt_exporter.py index a622a096..18cfd833 100644 --- a/alonet/detr/trt_exporter.py +++ b/alonet/detr/trt_exporter.py @@ -50,7 +50,7 @@ def prepare_sample_inputs(self): if __name__ == "__main__": - from alonet.common.weights import vb_fodler + from alonet.common.pl_helpers import vb_folder # test script parser = argparse.ArgumentParser() @@ -62,7 +62,7 @@ def prepare_sample_inputs(self): # parser.add_argument("--image_chw") args = parser.parse_args() if args.onnx_path is None: - args.onnx_path = os.path.join(vb_fodler(), "weights", "detr-r50", "detr-r50.onnx") + args.onnx_path = os.path.join(vb_folder(), "weights", "detr-r50", "detr-r50.onnx") device = torch.device("cpu") if args.cpu else torch.device("cuda") input_shape = [3] + list(args.HW) diff --git a/alonet/detr_panoptic/trt_exporter.py b/alonet/detr_panoptic/trt_exporter.py index e7336142..c4e10580 100644 --- a/alonet/detr_panoptic/trt_exporter.py +++ b/alonet/detr_panoptic/trt_exporter.py @@ -82,7 +82,7 @@ def prepare_sample_inputs(self): if __name__ == "__main__": - from alonet.common.weights import vb_fodler + from alonet.common.pl_helpers import vb_folder from alonet.detr_panoptic import PanopticHead from alonet.detr import DetrR50 from alonet.detr.trt_exporter import DetrTRTExporter @@ -103,7 +103,7 @@ def prepare_sample_inputs(self): args = parser.parse_args() device = torch.device("cpu") if args.cpu else torch.device("cuda") input_shape = [3] + list(args.HW) - pan_onnx_path = os.path.join(vb_fodler(), "weights", "detr-r50-panoptic") + pan_onnx_path = os.path.join(vb_folder(), "weights", "detr-r50-panoptic") if args.split_engines: pan_onnx_path = os.path.join(pan_onnx_path, "panoptic-head.onnx") else: