Skip to content

Commit 9f677db

Browse files
authored
Merge branch 'master' into olruwase/docs
2 parents 01cc36d + f4caed6 commit 9f677db

File tree

7 files changed

+11
-11
lines changed

7 files changed

+11
-11
lines changed

op_builder/builder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,8 @@ def compute_capability_args(self, cross_compile_archs=None):
612612
613613
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
614614
615-
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
616-
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
615+
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ...
616+
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ...
617617
618618
- `cross_compile_archs` uses ; separator.
619619
@@ -651,9 +651,9 @@ def compute_capability_args(self, cross_compile_archs=None):
651651
args = []
652652
self.enable_bf16 = True
653653
for cc in ccs:
654-
num = cc[0] + cc[2]
654+
num = cc[0] + cc[1].split('+')[0]
655655
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
656-
if cc.endswith('+PTX'):
656+
if cc[1].endswith('+PTX'):
657657
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
658658

659659
if int(cc[0]) <= 7:
@@ -666,7 +666,7 @@ def filter_ccs(self, ccs: List[str]):
666666
Prune any compute capabilities that are not compatible with the builder. Should log
667667
which CCs have been pruned.
668668
"""
669-
return ccs
669+
return [cc.split('.') for cc in ccs]
670670

671671
def version_dependent_macros(self):
672672
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456

op_builder/fp_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def is_compatible(self, verbose=False):
7878
def filter_ccs(self, ccs):
7979
ccs_retained = []
8080
ccs_pruned = []
81-
for cc in ccs:
81+
for cc in [cc.split('.') for cc in ccs]:
8282
if int(cc[0]) >= 8:
8383
ccs_retained.append(cc)
8484
else:

op_builder/inference_core_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def is_compatible(self, verbose=False):
4646
def filter_ccs(self, ccs):
4747
ccs_retained = []
4848
ccs_pruned = []
49-
for cc in ccs:
49+
for cc in [cc.split('.') for cc in ccs]:
5050
if int(cc[0]) >= 6:
5151
ccs_retained.append(cc)
5252
else:

op_builder/inference_cutlass_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def is_compatible(self, verbose=False):
4545
def filter_ccs(self, ccs):
4646
ccs_retained = []
4747
ccs_pruned = []
48-
for cc in ccs:
48+
for cc in [cc.split('.') for cc in ccs]:
4949
if int(cc[0]) >= 8:
5050
# Only support Ampere and newer
5151
ccs_retained.append(cc)

op_builder/ragged_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def is_compatible(self, verbose=False):
4646
def filter_ccs(self, ccs):
4747
ccs_retained = []
4848
ccs_pruned = []
49-
for cc in ccs:
49+
for cc in [cc.split('.') for cc in ccs]:
5050
if int(cc[0]) >= 8:
5151
# Blocked flash has a dependency on Ampere + newer
5252
ccs_retained.append(cc)

op_builder/ragged_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def is_compatible(self, verbose=False):
4646
def filter_ccs(self, ccs):
4747
ccs_retained = []
4848
ccs_pruned = []
49-
for cc in ccs:
49+
for cc in [cc.split('.') for cc in ccs]:
5050
if int(cc[0]) >= 6:
5151
ccs_retained.append(cc)
5252
else:

op_builder/transformer_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def is_compatible(self, verbose=False):
4444
def filter_ccs(self, ccs):
4545
ccs_retained = []
4646
ccs_pruned = []
47-
for cc in ccs:
47+
for cc in [cc.split('.') for cc in ccs]:
4848
if int(cc[0]) >= 6:
4949
ccs_retained.append(cc)
5050
else:

0 commit comments

Comments
 (0)