Skip to content

Commit

Permalink
Merge pull request #273 from Visual-Behavior/issue271-vb_folder
Browse files Browse the repository at this point in the history
change vb_fodler to vb_folder
  • Loading branch information
Data-Iab authored Nov 17, 2022
2 parents f16022e + 27bbb99 commit 16d318f
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 21 deletions.
11 changes: 7 additions & 4 deletions alonet/common/pl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
parser = ArgumentParser()


def vb_folder():
def vb_folder(create_if_not_found=False):
home = os.getenv("HOME")
alofolder = os.path.join(home, ".aloception")
if not os.path.exists(alofolder):
raise Exception(
f"{alofolder} do not exist. Please, create the folder with the appropriate files. (Checkout documentation)"
)
if create_if_not_found:
os.mkdir(alofolder)
else:
raise Exception(
f"{alofolder} do not exist. Please, create the folder with the appropriate files. (Checkout documentation)"
)
return alofolder


Expand Down
11 changes: 2 additions & 9 deletions alonet/common/weights.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import requests
import os
from alonet.common.pl_helpers import vb_folder

WEIGHT_NAME_TO_FILES = {
"detr-r50": ["https://storage.googleapis.com/visualbehavior-publicweights/detr-r50/detr-r50.pth"],
Expand Down Expand Up @@ -30,14 +31,6 @@
}


def vb_fodler():
home = os.getenv("HOME")
alofolder = os.path.join(home, ".aloception")
if not os.path.exists(alofolder):
os.mkdir(alofolder)
return alofolder


def load_weights(model, weights, device, strict_load_weights=True):
"""Load and/or download weights from public cloud
Expand All @@ -50,7 +43,7 @@ def load_weights(model, weights, device, strict_load_weights=True):
device: torch.device
Device to load the weights into
"""
weights_dir = os.path.join(vb_fodler(), "weights")
weights_dir = os.path.join(vb_folder(), "weights")

if not os.path.exists(weights_dir):
os.makedirs(weights_dir)
Expand Down
4 changes: 2 additions & 2 deletions alonet/deformable_detr/trt_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def prepare_sample_inputs(self):

if __name__ == "__main__":
# test script
from alonet.common.weights import vb_fodler
from alonet.common.pl_helpers import vb_folder

load_trt_plugins_for_deformable_detr()
device = torch.device("cuda")
Expand All @@ -143,7 +143,7 @@ def prepare_sample_inputs(self):
model = DeformableDetrR50(weights=model_name, tracing=True, aux_loss=False).eval()

if args.onnx_path is None:
args.onnx_path = os.path.join(vb_fodler(), "weights", model_name, model_name + ".onnx")
args.onnx_path = os.path.join(vb_folder(), "weights", model_name, model_name + ".onnx")

input_shape = [3] + list(args.HW)

Expand Down
4 changes: 2 additions & 2 deletions alonet/deformable_detr_panoptic/trt_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def prepare_sample_inputs(self):

if __name__ == "__main__":
raise NotImplementedError("Not implemented yet")
from alonet.common.weights import vb_fodler
from alonet.common.pl_helpers import vb_folder
from alonet.detr_panoptic import PanopticHead
from alonet.detr import DetrR50
from alonet.detr.trt_exporter import DetrTRTExporter
Expand All @@ -104,7 +104,7 @@ def prepare_sample_inputs(self):
args = parser.parse_args()
device = torch.device("cpu") if args.cpu else torch.device("cuda")
input_shape = [3] + list(args.HW)
pan_onnx_path = os.path.join(vb_fodler(), "weights", "detr-r50-panoptic")
pan_onnx_path = os.path.join(vb_folder(), "weights", "detr-r50-panoptic")
if args.split_engines:
pan_onnx_path = os.path.join(pan_onnx_path, "panoptic-head.onnx")
else:
Expand Down
4 changes: 2 additions & 2 deletions alonet/detr/trt_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def prepare_sample_inputs(self):


if __name__ == "__main__":
from alonet.common.weights import vb_fodler
from alonet.common.pl_helpers import vb_folder

# test script
parser = argparse.ArgumentParser()
Expand All @@ -62,7 +62,7 @@ def prepare_sample_inputs(self):
# parser.add_argument("--image_chw")
args = parser.parse_args()
if args.onnx_path is None:
args.onnx_path = os.path.join(vb_fodler(), "weights", "detr-r50", "detr-r50.onnx")
args.onnx_path = os.path.join(vb_folder(), "weights", "detr-r50", "detr-r50.onnx")
device = torch.device("cpu") if args.cpu else torch.device("cuda")

input_shape = [3] + list(args.HW)
Expand Down
4 changes: 2 additions & 2 deletions alonet/detr_panoptic/trt_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def prepare_sample_inputs(self):


if __name__ == "__main__":
from alonet.common.weights import vb_fodler
from alonet.common.pl_helpers import vb_folder
from alonet.detr_panoptic import PanopticHead
from alonet.detr import DetrR50
from alonet.detr.trt_exporter import DetrTRTExporter
Expand All @@ -103,7 +103,7 @@ def prepare_sample_inputs(self):
args = parser.parse_args()
device = torch.device("cpu") if args.cpu else torch.device("cuda")
input_shape = [3] + list(args.HW)
pan_onnx_path = os.path.join(vb_fodler(), "weights", "detr-r50-panoptic")
pan_onnx_path = os.path.join(vb_folder(), "weights", "detr-r50-panoptic")
if args.split_engines:
pan_onnx_path = os.path.join(pan_onnx_path, "panoptic-head.onnx")
else:
Expand Down

0 comments on commit 16d318f

Please sign in to comment.