Skip to content

Commit 9b7b8f1

Browse files
chsiggtensorflower-gardener
authored andcommitted
Support compiling for a separate set of virtual and real CUDA compute architectures.
We currently use the following setup to select which compute architectures to compile for: - ./configure allows specifying a set of CUDA compute architectures to compile for, e.g. '5.2,6.0'. - .tf_configure.bazelrc maps this to an environment variable (TF_CUDA_COMPUTE_CAPABILITIES=5.2,6.0) - cuda_configure.bzl turns this into compiler flags (copts) for clang, which the crosstool maps to nvcc if needed. - The kernels are always compiled to both the virtual (ptx) and the real (sass) architecture. This change adds support for specifying just real (sm_xy) or both virtual and real (compute_xy) compute architectures in TF_CUDA_COMPUTE_CAPABILITIES. ./configure is left unchanged, the old 'x.y' strings are mapped to 'compute_xy' in cuda_configure.bzl. PiperOrigin-RevId: 313359468 Change-Id: I96c5b8b0a02b2ce62df27df7cc5272ddd42217aa
1 parent f0ef163 commit 9b7b8f1

File tree

5 files changed

+60
-35
lines changed

5 files changed

+60
-35
lines changed

tensorflow/core/kernels/cubin_headers/build_defs.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def _gen_kernel_image_hdr_impl(ctx):
2222
cubins = []
2323
images = []
2424
for arch in ctx.attr.gpu_archs:
25+
# TODO(b/152737872): 'compute_' should generate both SASS and PTX.
26+
arch = arch.replace("compute_", "sm_")
2527
filename = "%s.%s.cubin" % (name, arch)
2628
cubin = ctx.actions.declare_file(filename)
2729
ctx.actions.run(

third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,12 @@ def InvokeNvcc(argv, log=False):
221221
nvccopts = '-D_FORCE_INLINES '
222222
for capability in GetOptionValue(argv, "--cuda-gpu-arch"):
223223
capability = capability[len('sm_'):]
224-
nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
225-
capability, capability, capability)
224+
nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s\" ' % (capability,
225+
capability)
226+
for capability in GetOptionValue(argv, '--cuda-include-ptx'):
227+
capability = capability[len('sm_'):]
228+
nvccopts += r'-gencode=arch=compute_%s,\"code=compute_%s\" ' % (capability,
229+
capability)
226230
nvccopts += nvcc_compiler_options
227231
nvccopts += undefines
228232
nvccopts += defines

third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,18 @@ def InvokeNvcc(argv, log=False):
138138
nvccopts = ['-D_FORCE_INLINES']
139139
compute_capabilities, argv = GetOptionValue(argv, "--cuda-gpu-arch")
140140
for capability in compute_capabilities:
141-
print(capability)
142141
capability = capability[len('sm_'):]
143-
nvccopts += [r'-gencode=arch=compute_%s,"code=sm_%s,compute_%s"' % (
144-
capability, capability, capability)]
142+
nvccopts += [
143+
r'-gencode=arch=compute_%s,"code=sm_%s"' % (capability, capability)
144+
]
145+
compute_capabilities, argv = GetOptionValue(argv, '--cuda-include-ptx')
146+
for capability in compute_capabilities:
147+
capability = capability[len('sm_'):]
148+
nvccopts += [
149+
r'-gencode=arch=compute_%s,"code=compute_%s"' % (capability, capability)
150+
]
151+
_, argv = GetOptionValue(argv, '--no-cuda-include-ptx')
152+
145153
nvccopts += nvcc_compiler_options
146154
nvccopts += undefines
147155
nvccopts += defines

third_party/gpus/cuda_configure.bzl

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ _TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO"
6666
_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
6767
_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
6868

69-
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"]
70-
7169
def to_list_of_strings(elements):
7270
"""Convert the list of ["a", "b", "c"] into '"a", "b", "c"'.
7371
@@ -410,18 +408,40 @@ _NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
410408
_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
411409

412410
def compute_capabilities(repository_ctx):
413-
"""Returns a list of strings representing cuda compute capabilities."""
414-
capabilities_str = get_host_environ(repository_ctx, _TF_CUDA_COMPUTE_CAPABILITIES)
415-
if capabilities_str == None:
416-
return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
417-
capabilities = capabilities_str.split(",")
418-
for capability in capabilities:
419-
# Workaround for Skylark's lack of support for regex. This check should
420-
# be equivalent to checking:
421-
# if re.match("[0-9]+.[0-9]+", capability) == None:
411+
"""Returns a list of strings representing cuda compute capabilities.
412+
413+
Args:
414+
repository_ctx: the repo rule's context.
415+
Returns: list of cuda architectures to compile for. 'compute_xy' refers to
416+
both PTX and SASS, 'sm_xy' refers to SASS only.
417+
"""
418+
capabilities = get_host_environ(
419+
repository_ctx,
420+
_TF_CUDA_COMPUTE_CAPABILITIES,
421+
"compute_35,compute_52",
422+
).split(",")
423+
424+
# Map old 'x.y' capabilities to 'compute_xy'.
425+
for i, capability in enumerate(capabilities):
422426
parts = capability.split(".")
423-
if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
427+
if len(parts) != 2:
428+
continue
429+
capabilities[i] = "compute_%s%s" % (parts[0], parts[1])
430+
431+
# Make list unique
432+
capabilities = dict(zip(capabilities, capabilities)).keys()
433+
434+
# Validate capabilities.
435+
for capability in capabilities:
436+
if not capability.startswith(("compute_", "sm_")):
424437
auto_configure_fail("Invalid compute capability: %s" % capability)
438+
for prefix in ["compute_", "sm_"]:
439+
if not capability.startswith(prefix):
440+
continue
441+
if len(capability) == len(prefix) + 2 and capability[-2:].isdigit():
442+
continue
443+
auto_configure_fail("Invalid compute capability: %s" % capability)
444+
425445
return capabilities
426446

427447
def lib_name(base_name, cpu_value, version = None, static = False):
@@ -849,21 +869,14 @@ def _tf_sysroot(repository_ctx):
849869
return get_host_environ(repository_ctx, _TF_SYSROOT, "")
850870

851871
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
852-
capability_flags = [
853-
"--cuda-gpu-arch=sm_" + cap.replace(".", "")
854-
for cap in compute_capabilities
855-
]
856-
return str(capability_flags)
857-
858-
def _compute_cuda_gpu_architectures(repository_ctx, compute_capabilities):
859-
gpu_architectures = [
860-
"sm_" + capability.replace(".", "")
861-
for capability in compute_capabilities
862-
]
872+
capability_flags = ["--no-cuda-include-ptx=all"]
873+
for capability in compute_capabilities:
874+
if capability.startswith("compute_"):
875+
capability = capability.replace("compute_", "sm_")
876+
capability_flags.append("--cuda-include-ptx=%s" % capability)
877+
capability_flags.append("--cuda-gpu-arch=%s" % capability)
863878

864-
# Make the list unique.
865-
gpu_architectures = dict(zip(gpu_architectures, gpu_architectures)).keys()
866-
return str(gpu_architectures)
879+
return str(capability_flags)
867880

868881
def _tpl_path(repository_ctx, filename):
869882
return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename))
@@ -996,10 +1009,7 @@ def _create_local_cuda_repository(repository_ctx):
9961009
repository_ctx,
9971010
cuda_config.compute_capabilities,
9981011
),
999-
"%{cuda_gpu_architectures}": _compute_cuda_gpu_architectures(
1000-
repository_ctx,
1001-
cuda_config.compute_capabilities,
1002-
),
1012+
"%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities),
10031013
},
10041014
)
10051015

third_party/nccl/build_defs.bzl.tpl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def _device_link_impl(ctx):
8484
cubins = []
8585
images = []
8686
for arch in ctx.attr.gpu_archs:
87+
arch = arch.replace("compute_", "sm_") # PTX is JIT-linked at runtime.
8788
cubin = ctx.actions.declare_file("%s_%s.cubin" % (name, arch))
8889
register_h = ctx.actions.declare_file("%s_register_%s.h" % (name, arch))
8990
ctx.actions.run(

0 commit comments

Comments
 (0)