Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FX: support pytorch 1.11 #16

Merged
merged 1 commit into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,12 @@ def __init__(self, inner_module : torch.nn.Module):
to_offset_end = offset_end + param_st - storage_st

# copy to buffer
self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end])
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
d_device = self._storage_params[kw_name].device
torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:]
# self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end])
del contiguous_param

# clear parameter data, but keep the dtype and device
Expand Down Expand Up @@ -472,7 +477,12 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
to_offset_end = offset_end + param_st - storage_st

# copy to buffer
self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end])
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
d_device = self._storage_params[kw_name].device
torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:]
# self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(contiguous_param.storage()[offset_st: offset_end])
del contiguous_param
elif strict:
missing_keys.append(key)
Expand Down Expand Up @@ -527,7 +537,12 @@ def init_parameters(self):
to_offset_end = offset_end + param_st - storage_st

# copy to buffer
self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(tmp_tensor.storage()[offset_st: offset_end])
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
d_device = self._storage_params[kw_name].device
torch.tensor([], dtype=d_dtype, device=d_device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,))[:] = \
torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:]
# self._storage_params[kw_name].storage()[to_offset_st: to_offset_end].copy_(tmp_tensor.storage()[offset_st: offset_end])
del tmp_tensor

def _named_members(self, get_members_fn, prefix='', recurse=True):
Expand Down
4 changes: 3 additions & 1 deletion bmtrain/param_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]):

param._init_method(tmp_tensor)

param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)])
# Pytorch 1.11 changed the API of storage.__getitem__
param[:] = torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_storage)[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]
# param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)])

def iterate_parameters(model : torch.nn.Module):
for kw, val in model._parameters.items():
Expand Down
37 changes: 30 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,39 @@ def get_avx_flags():
else:
return ["-march=native"]

def get_device_cc():
try:
CC_SET = set()
for i in range(torch.cuda.device_count()):
CC_SET.add(torch.cuda.get_device_capability(i))

if len(CC_SET) == 0:
return None

ret = ""
for it in CC_SET:
if len(ret) > 0:
ret = ret + " "
ret = ret + ("%d.%d" % it)
return ret
except RuntimeError:
return None

avx_flag = get_avx_flags()

if not torch.cuda.is_available():
os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0+PTX")
else:
if torch.version.cuda.startswith("10"):
os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5+PTX")
device_cc = get_device_cc()
if device_cc is None:
if not torch.cuda.is_available():
os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0+PTX")
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0 8.6+PTX")
if torch.version.cuda.startswith("10"):
os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5+PTX")
else:
if not torch.version.cuda.startswith("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0 8.6+PTX")
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0+PTX")
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", device_cc)

setup(
name='bmtrain',
Expand Down