Skip to content

Commit a4652d5

Browse files
committed
Initial commit to enable XLA on TensorFlow 1.8+
The implementation is ported from TensorFlow 1.3
1 parent 1e632e2 commit a4652d5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1899
-210
lines changed

configure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1493,7 +1493,7 @@ def main():
14931493
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
14941494
'with_kafka_support', False, 'kafka')
14951495
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
1496-
False, 'xla')
1496+
True, 'xla')
14971497
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
14981498
False, 'gdr')
14991499
set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',

tensorflow/compiler/jit/BUILD

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
2525
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
2626
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
2727
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
28+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
29+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
2830

2931
# Target that bundles up the XLA CPU and GPU JIT devices.
3032
cc_library(
@@ -40,6 +42,9 @@ cc_library(
4042
] + if_cuda_is_configured([
4143
":xla_gpu_device",
4244
":xla_gpu_jit",
45+
]) + if_rocm_is_configured([
46+
":xla_gpu_device",
47+
":xla_gpu_jit",
4348
]),
4449
alwayslink = 1,
4550
)
@@ -59,12 +64,17 @@ cc_library(
5964
cc_library(
6065
name = "xla_gpu_jit",
6166
visibility = ["//visibility:public"],
62-
deps = if_cuda([
67+
deps = if_cuda_is_configured(if_cuda([
6368
":jit_compilation_passes",
6469
"//tensorflow/compiler/jit/kernels:xla_launch_op",
6570
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
6671
"//tensorflow/compiler/xla/service:gpu_plugin",
67-
]),
72+
])) + if_rocm_is_configured(if_rocm([
73+
":jit_compilation_passes",
74+
"//tensorflow/compiler/jit/kernels:xla_launch_op",
75+
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
76+
"//tensorflow/compiler/xla/service:gpu_plugin",
77+
])),
6878
alwayslink = 1,
6979
)
7080

tensorflow/compiler/jit/kernels/xla_launch_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
5151
if (device_type_ == DeviceType(DEVICE_CPU)) {
5252
platform_id_ = se::host::kHostPlatformId;
5353
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
54-
platform_id_ = se::cuda::kCudaPlatformId;
54+
// XXX FIXME devise a way to cope with multiple platforms
55+
//platform_id_ = se::cuda::kCudaPlatformId;
56+
platform_id_ = se::rocm::kROCmPlatformId;
5557
} else {
5658
platform_id_ = nullptr;
5759
}

tensorflow/compiler/jit/xla_gpu_device.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
17-
// operators using XLA via the XLA "CUDA" (GPU) backend.
17+
// operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
1818

1919
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
2020
#include "tensorflow/compiler/jit/xla_device.h"
@@ -46,6 +46,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
4646

4747
std::unique_ptr<XlaDevice> device;
4848
Status status =
49+
// XXX FIXME devise a way to cope with multiple platforms
50+
//XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
4951
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
5052
name_prefix, registration,
5153
/*transfer_as_literal=*/false, &device);

tensorflow/compiler/tf2xla/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package(
2323
)
2424

2525
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
26+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
2627
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
2728

2829
cc_library(
@@ -141,6 +142,8 @@ cc_library(
141142
"xla_cpu_backend.cc",
142143
] + if_cuda_is_configured([
143144
"xla_gpu_backend.cc",
145+
]) + if_rocm_is_configured([
146+
"xla_gpu_backend.cc",
144147
]),
145148
hdrs = [
146149
"const_analysis.h",

tensorflow/compiler/xla/service/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ cc_library(
711711
"//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager",
712712
"//tensorflow/core:stream_executor_no_cuda",
713713
"//tensorflow/core/platform/default/build_config:stream_executor_cuda",
714+
"//tensorflow/core/platform/default/build_config:stream_executor_rocm",
714715
],
715716
)
716717

tensorflow/compiler/xla/service/computation_placer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ static bool InitModule() {
147147
stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
148148
xla::ComputationPlacer::RegisterComputationPlacer(
149149
stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
150+
xla::ComputationPlacer::RegisterComputationPlacer(
151+
stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
150152
return true;
151153
}
152154
static bool module_initialized = InitModule();

tensorflow/compiler/xla/service/gpu/BUILD

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ filegroup(
2222
)
2323

2424
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
25+
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
26+
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
27+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
28+
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
2529

2630
cc_library(
2731
name = "gpu_constants",
@@ -190,6 +194,7 @@ cc_library(
190194
srcs = ["elemental_ir_emitter.cc"],
191195
hdrs = ["elemental_ir_emitter.h"],
192196
deps = [
197+
":ir_emission_utils",
193198
"//tensorflow/compiler/xla:literal_util",
194199
"//tensorflow/compiler/xla:shape_util",
195200
"//tensorflow/compiler/xla:status_macros",
@@ -246,7 +251,8 @@ cc_library(
246251
"thunk_schedule.cc",
247252
"tuple_thunk.cc",
248253
"while_thunk.cc",
249-
],
254+
] + if_cuda_is_configured(if_cuda(["nvptx_executable.cc"])) +
255+
if_rocm_is_configured(if_rocm(["amdgpu_executable.cc"])),
250256
hdrs = [
251257
"conditional_thunk.h",
252258
"convolution_thunk.h",
@@ -264,7 +270,8 @@ cc_library(
264270
"thunk_schedule.h",
265271
"tuple_thunk.h",
266272
"while_thunk.h",
267-
],
273+
] + if_cuda_is_configured(if_cuda(["nvptx_executable.h"])) +
274+
if_rocm_is_configured(if_rocm(["amdgpu_executable.h"])),
268275
deps = [
269276
":buffer_allocations",
270277
":cudnn_convolution_runner",
@@ -296,6 +303,7 @@ cc_library(
296303
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
297304
"//tensorflow/core/platform/default/build_config:cufft_plugin",
298305
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
306+
"//tensorflow/core/platform/default/build_config:stream_executor_rocm",
299307
"//tensorflow/stream_executor",
300308
],
301309
)
@@ -490,8 +498,10 @@ cc_library(
490498

491499
cc_library(
492500
name = "gpu_compiler",
493-
srcs = ["gpu_compiler.cc"],
494-
hdrs = ["gpu_compiler.h"],
501+
srcs = if_cuda_is_configured(if_cuda(["nvptx_compiler.cc"])) +
502+
if_rocm_is_configured(if_rocm(["amdgpu_compiler.cc"])),
503+
hdrs = if_cuda_is_configured(if_cuda(["nvptx_compiler.h"])) +
504+
if_rocm_is_configured(if_rocm(["amdgpu_compiler.h"])),
495505
deps = [
496506
":cudnn_convolution_algorithm_picker",
497507
":cudnn_convolution_rewriter",
@@ -545,6 +555,7 @@ cc_library(
545555
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
546556
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
547557
"//tensorflow/core:cuda_libdevice_path",
558+
"//tensorflow/core:rocm_rocdl_path",
548559
"//tensorflow/core:lib",
549560
"//tensorflow/core:lib_internal",
550561
"//tensorflow/core:regexp_internal",

0 commit comments

Comments
 (0)