Skip to content

Commit

Permalink
ScriptModuleWrapper in a sepatate file
Browse files Browse the repository at this point in the history
and imported where needed
  • Loading branch information
breznak committed Feb 12, 2020
1 parent cf2d4e6 commit fccd89b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 18 deletions.
11 changes: 1 addition & 10 deletions layers/modules/fast_mask_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,7 @@
#locals
from data.config import Config
from utils.functions import make_net


# As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules
use_jit = torch.cuda.device_count() <= 1
if not use_jit:
print('Multiple GPUs detected! Turning off JIT.')

ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module
script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn

from utils.script_module_wrapper import ScriptModuleWrapper, script_method_wrapper

class FastMaskIoUNet(ScriptModuleWrapper):

Expand Down
9 changes: 1 addition & 8 deletions layers/modules/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@

#local imports
from data.config import Config

# As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules
use_jit = torch.cuda.device_count() <= 1
if not use_jit:
print('Multiple GPUs detected! Turning off JIT.')

ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module
script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn
from utils.script_module_wrapper import ScriptModuleWrapper, script_method_wrapper



Expand Down
10 changes: 10 additions & 0 deletions utils/script_module_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
import torch.nn

# As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules
use_jit = torch.cuda.device_count() <= 1
if not use_jit:
print('Multiple GPUs detected! Turning off JIT.')

ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module #TODO remove once nn.Module supports JIT script modules
script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn

0 comments on commit fccd89b

Please sign in to comment.