Skip to content

Commit 0e768bf

Browse files
committed
Initial commit to enable XLA on TensorFlow 1.8+
The implementation is ported from TensorFlow 1.3
1 parent 1e632e2 commit 0e768bf

38 files changed

+1309
-146
lines changed

configure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1493,7 +1493,7 @@ def main():
14931493
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
14941494
'with_kafka_support', False, 'kafka')
14951495
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
1496-
False, 'xla')
1496+
True, 'xla')
14971497
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
14981498
False, 'gdr')
14991499
set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',

tensorflow/compiler/jit/BUILD

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
2525
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
2626
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
2727
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
28+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
29+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
2830

2931
# Target that bundles up the XLA CPU and GPU JIT devices.
3032
cc_library(
@@ -40,6 +42,9 @@ cc_library(
4042
] + if_cuda_is_configured([
4143
":xla_gpu_device",
4244
":xla_gpu_jit",
45+
]) + if_rocm_is_configured([
46+
":xla_gpu_device",
47+
":xla_gpu_jit",
4348
]),
4449
alwayslink = 1,
4550
)
@@ -59,12 +64,17 @@ cc_library(
5964
cc_library(
6065
name = "xla_gpu_jit",
6166
visibility = ["//visibility:public"],
62-
deps = if_cuda([
67+
deps = if_cuda_is_configured(if_cuda([
6368
":jit_compilation_passes",
6469
"//tensorflow/compiler/jit/kernels:xla_launch_op",
6570
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
6671
"//tensorflow/compiler/xla/service:gpu_plugin",
67-
]),
72+
])) + if_rocm_is_configured(if_rocm([
73+
":jit_compilation_passes",
74+
"//tensorflow/compiler/jit/kernels:xla_launch_op",
75+
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
76+
"//tensorflow/compiler/xla/service:gpu_plugin",
77+
])),
6878
alwayslink = 1,
6979
)
7080

tensorflow/compiler/jit/kernels/xla_launch_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
5151
if (device_type_ == DeviceType(DEVICE_CPU)) {
5252
platform_id_ = se::host::kHostPlatformId;
5353
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
54-
platform_id_ = se::cuda::kCudaPlatformId;
54+
// XXX FIXME devise a way to cope with multiple platforms
55+
//platform_id_ = se::cuda::kCudaPlatformId;
56+
platform_id_ = se::rocm::kROCmPlatformId;
5557
} else {
5658
platform_id_ = nullptr;
5759
}

tensorflow/compiler/jit/xla_gpu_device.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
17-
// operators using XLA via the XLA "CUDA" (GPU) backend.
17+
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
1818

1919
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
2020
#include "tensorflow/compiler/jit/xla_device.h"
@@ -46,6 +46,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
4646

4747
std::unique_ptr<XlaDevice> device;
4848
Status status =
49+
// XXX FIXME devise a way to cope with multiple platforms
50+
//XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
4951
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
5052
name_prefix, registration,
5153
/*transfer_as_literal=*/false, &device);

tensorflow/compiler/xla/service/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ cc_library(
711711
"//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager",
712712
"//tensorflow/core:stream_executor_no_cuda",
713713
"//tensorflow/core/platform/default/build_config:stream_executor_cuda",
714+
"//tensorflow/core/platform/default/build_config:stream_executor_rocm",
714715
],
715716
)
716717

tensorflow/compiler/xla/service/computation_placer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ static bool InitModule() {
147147
stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
148148
xla::ComputationPlacer::RegisterComputationPlacer(
149149
stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
150+
xla::ComputationPlacer::RegisterComputationPlacer(
151+
stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
150152
return true;
151153
}
152154
static bool module_initialized = InitModule();

tensorflow/compiler/xla/service/gpu/BUILD

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ filegroup(
2222
)
2323

2424
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
25+
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
26+
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
27+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
28+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
2529

2630
cc_library(
2731
name = "gpu_constants",
@@ -296,6 +300,7 @@ cc_library(
296300
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
297301
"//tensorflow/core/platform/default/build_config:cufft_plugin",
298302
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
303+
"//tensorflow/core/platform/default/build_config:stream_executor_rocm",
299304
"//tensorflow/stream_executor",
300305
],
301306
)
@@ -490,8 +495,10 @@ cc_library(
490495

491496
cc_library(
492497
name = "gpu_compiler",
493-
srcs = ["gpu_compiler.cc"],
494-
hdrs = ["gpu_compiler.h"],
498+
srcs = if_cuda_is_configured(if_cuda(["nvptx_compiler.cc"])) +
499+
if_rocm_is_configured(if_rocm(["amdgpu_compiler.cc"])),
500+
hdrs = if_cuda_is_configured(if_cuda(["nvptx_compiler.h"])) +
501+
if_rocm_is_configured(if_rocm(["amdgpu_compiler.h"])),
495502
deps = [
496503
":cudnn_convolution_algorithm_picker",
497504
":cudnn_convolution_rewriter",
@@ -545,6 +552,7 @@ cc_library(
545552
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
546553
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
547554
"//tensorflow/core:cuda_libdevice_path",
555+
"//tensorflow/core:rocm_rocdl_path",
548556
"//tensorflow/core:lib",
549557
"//tensorflow/core:lib_internal",
550558
"//tensorflow/core:regexp_internal",

0 commit comments

Comments
 (0)