Skip to content

Commit

Permalink
prepare support for selecting device_type
Browse files Browse the repository at this point in the history
currently only "gpu" is supported.
In future, cpu,tpu.
Use `net = Yolact(device_type="gpu")`
`
  • Loading branch information
breznak committed Feb 12, 2020
1 parent eb61db9 commit 236352d
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions yolact.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ class Yolact(nn.Module):

def __init__(self,
config_name="yolact_base_config",
device_type="gpu"
):
"""
@param config_name: string name of used config, choose from ./data/config.py, default "yolact_base"
@param device_type: string, type of devices used, choose from "gpu","cpu","tpu". Default "gpu".
"""
super().__init__()

Expand Down Expand Up @@ -135,8 +137,12 @@ def __init__(self,

# GPU
#TODO try half: net = net.half()
self.cuda()
torch.set_default_tensor_type('torch.cuda.FloatTensor')
assert(device_type == "gpu" or device_type == "cpu" or device_type == "tpu")
assert(device_type != "tpu"), "TPU not yet supported!"
self.device_type = device_type
if self.device_type == "gpu":
self.cuda()
torch.set_default_tensor_type('torch.cuda.FloatTensor')


def save_weights(self, path):
Expand Down

0 comments on commit 236352d

Please sign in to comment.