Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ limitations under the License.
#include "llvm/IR/DataLayout.h"
#include "tensorflow/compiler/xla/literal_util.h"
// XXX figure out how to cope with both platforms
#if GOOGLE_CUDA
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
#elif TENSORFLOW_USE_ROCM
#if TENSORFLOW_USE_ROCM
#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h"
#else
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
#endif
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
Expand All @@ -46,13 +46,11 @@ namespace xla {
GpuTransferManager::GpuTransferManager(se::Platform::Id id)
: GenericTransferManager(
id,
// XXX figure out how to cope with both platforms
#if GOOGLE_CUDA
/*pointer_size=*/llvm::DataLayout(gpu::NVPTXCompiler::kDataLayout)
#elif TENSORFLOW_USE_ROCM
/*pointer_size=*/llvm::DataLayout(gpu::AMDGPUCompiler::kDataLayout)
#if TENSORFLOW_USE_ROCM
llvm::DataLayout(gpu::AMDGPUCompiler::kDataLayout).getPointerSize(0)){}
#else
llvm::DataLayout(gpu::NVPTXCompiler::kDataLayout).getPointerSize(0)){}
#endif
.getPointerSize(0 /* default address space */)) {}

Status GpuTransferManager::TransferLiteralToInfeed(
se::StreamExecutor* executor, const LiteralSlice& literal) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h"

#include <map>
#include <memory>
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/gpu/nvptx_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/nvptx_executable.h"

#include <set>
#include <utility>
Expand Down Expand Up @@ -45,7 +45,7 @@ NVPTXExecutable::NVPTXExecutable(
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: GpuExecutable(std::move(text), std::move(think_schedule),
: GpuExecutable(std::move(text), std::move(thunk_schedule),
std::move(hlo_module), std::move(assignment),
std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
auto bias_ptr = AsDeviceMemory(bias.template flat<BiasType>().data(),
bias.template flat<BiasType>().size());

static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
static int64 ConvolveScratchSize = GetDnnWorkspaceLimit(
// default value is in bytes despite the name of the environment variable
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
);
Expand Down Expand Up @@ -551,7 +551,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
for (auto profile_algorithm : algorithms) {
// TODO(zhengxq): profile each algorithm multiple times to better
// accuracy.
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
dnn::ProfileResult profile_result;
bool cudnn_launch_status =
stream
Expand Down Expand Up @@ -591,7 +591,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
algorithm_config);
}

CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
bool cudnn_launch_status =
stream
->ThenFusedConvolveWithAlgorithm(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h"
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {

Expand All @@ -32,11 +32,11 @@ namespace functor {
#define GPUReduceSliceFunctorReduceop(reduceop, beginning) \
template <typename T, typename Index> \
__global__ void ReduceSliceDeviceKernel##reduceop( \
Cuda3DLaunchConfig config, Index indices_width, Index bound, \
Gpu3DLaunchConfig config, Index indices_width, Index bound, \
const T begin, const Index *indices, const T *input, T *out) { \
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { \
GPU_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \
GPU_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \
GPU_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { \
Index outidx = x * config.virtual_thread_count.y * \
config.virtual_thread_count.z + \
y * config.virtual_thread_count.z + z; \
Expand Down Expand Up @@ -68,7 +68,7 @@ namespace functor {
if (sizex * sizey * sizez == 0) { \
return; \
} \
Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \
Gpu3DLaunchConfig config = GetGpu3DLaunchConfig( \
sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \
0, 0); \
\
Expand Down
18 changes: 9 additions & 9 deletions tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <cmath>

#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {

Expand All @@ -43,7 +43,7 @@ __global__ void Resampler2DKernel(const T* __restrict__ data,
const int data_channels,
const int num_sampling_points) {
const int output_data_size = batch_size * num_sampling_points * data_channels;
CUDA_1D_KERNEL_LOOP(index, output_data_size) {
GPU_1D_KERNEL_LOOP(index, output_data_size) {
const int out_index = index;

// Get (idxSample, channel, point) from the index.
Expand Down Expand Up @@ -117,8 +117,8 @@ struct Resampler2DFunctor<GPUDevice, T> {
const int data_channels, const int num_sampling_points) {
const int output_data_size =
batch_size * num_sampling_points * data_channels;
::tensorflow::CudaLaunchConfig config =
::tensorflow::GetCudaLaunchConfig(output_data_size, d);
::tensorflow::GpuLaunchConfig config =
::tensorflow::GetGpuLaunchConfig(output_data_size, d);
Resampler2DKernel<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
data, warp, output, batch_size, data_height, data_width,
Expand Down Expand Up @@ -149,7 +149,7 @@ __global__ void ResamplerGrad2DKernel(
const int num_sampling_points) {
const int resampler_output_size =
batch_size * num_sampling_points * data_channels;
CUDA_1D_KERNEL_LOOP(index, resampler_output_size) {
GPU_1D_KERNEL_LOOP(index, resampler_output_size) {
const int out_index = index;

// Get (idxSample, channel, point) from the index.
Expand Down Expand Up @@ -252,20 +252,20 @@ struct ResamplerGrad2DFunctor<GPUDevice, T> {
const int grad_data_size =
batch_size * data_height * data_width * data_channels;

::tensorflow::CudaLaunchConfig config =
::tensorflow::GetCudaLaunchConfig(grad_warp_size, d);
::tensorflow::GpuLaunchConfig config =
::tensorflow::GetGpuLaunchConfig(grad_warp_size, d);
::tensorflow::
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
grad_warp_size, grad_warp);

config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d);
config = ::tensorflow::GetGpuLaunchConfig(grad_data_size, d);
::tensorflow::
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
grad_data_size, grad_data);

const int resampler_output_size =
batch_size * num_sampling_points * data_channels;
config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d);
config = ::tensorflow::GetGpuLaunchConfig(resampler_output_size, d);
ResamplerGrad2DKernel<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
data, warp, grad_output, grad_data, grad_warp, batch_size,
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/eigen_activations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {
namespace functor {
Expand Down Expand Up @@ -186,7 +186,7 @@ void LSTMBlockCellFpropWithCUDA(
typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
typename TTypes<T>::Matrix h, int batch_size, int cell_size,
int input_size) {
const cudaStream_t& cu_stream = GetCudaStream(ctx);
const cudaStream_t& cu_stream = GetGpuStream(ctx);

// Concatenate xh = [x, h].
//
Expand Down Expand Up @@ -321,7 +321,7 @@ void LSTMBlockCellBpropWithCUDA(
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
typename TTypes<T>::Vec wco_grad, const int batch_size, const int cell_size,
const bool use_peephole) {
const cudaStream_t& cu_stream = GetCudaStream(ctx);
const cudaStream_t& cu_stream = GetGpuStream(ctx);

dim3 block_dim_2d(std::min(batch_size, 8), 32);
dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast<int>(block_dim_2d.x)),
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU

#include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {
namespace functor {
Expand All @@ -31,7 +31,7 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
const T* parent_ids,
const int32* max_sequence_lengths,
const T end_token, T* beams) {
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
GPU_1D_KERNEL_LOOP(i, batch_size * beam_width) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;

Expand Down Expand Up @@ -90,7 +90,7 @@ struct GatherTree<GPUDevice, T> {
// First kernel launch to "zero" things out
beams.device(d) = beams.constant(end_token);

CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d);
GpuLaunchConfig config = GetGpuLaunchConfig(batch_size * beam_width, d);
// clang-format off
GatherTreeOpKernel<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void AdjustHueGPU::operator()(GPUDevice* device, const int64 number_of_elements,
const int threads_per_block = config.thread_per_block;
const int block_count =
(number_of_elements + threads_per_block - 1) / threads_per_block;
GPU_LAUNCH_KERNEL(internal::adjust_hsv_nhwc<true, false, false>,
GPU_LAUNCH_KERNEL((internal::adjust_hsv_nhwc<true, false, false>),
dim3(block_count), dim3(threads_per_block), 0, stream,
number_of_elements, input, output, delta, nullptr, nullptr);
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/adjust_saturation_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void AdjustSaturationGPU::operator()(GPUDevice* device,
const int threads_per_block = config.thread_per_block;
const int block_count =
(number_of_elements + threads_per_block - 1) / threads_per_block;
GPU_LAUNCH_KERNEL(internal::adjust_hsv_nhwc<false, true, false>,
GPU_LAUNCH_KERNEL((internal::adjust_hsv_nhwc<false, true, false>),
dim3(block_count), dim3(threads_per_block), 0, stream,
number_of_elements, input, output, nullptr, scale, nullptr);
}
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/bucketize_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ struct BucketizeFunctor<GPUDevice, T> {
const int32 kMaxSharedMemBytes = 16384;
if (shared_mem_size < d.sharedMemPerBlock() &&
shared_mem_size < kMaxSharedMemBytes) {
GPU_LAUNCH_KERNEL(BucketizeCustomKernel<T, true>,
GPU_LAUNCH_KERNEL((BucketizeCustomKernel<T, true>),
dim3(config.block_count), dim3(config.thread_per_block),
shared_mem_size, d.stream(),
input.size(), input.data(), boundaries_vector.size(),
boundaries_array.data(), output.data());
} else {
GPU_LAUNCH_KERNEL(BucketizeCustomKernel<T, false>,
GPU_LAUNCH_KERNEL((BucketizeCustomKernel<T, false>),
dim3(config.block_count), dim3(config.thread_per_block), 0,
d.stream(),
input.size(), input.data(), boundaries_vector.size(),
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
output->dimension(0), gpu_device);

if (fixed_size) {
GPU_LAUNCH_KERNEL(concat_fixed_kernel<T, IntType>,
GPU_LAUNCH_KERNEL((concat_fixed_kernel<T, IntType>),
dim3(config.block_count), dim3(config.thread_per_block), 0,
gpu_device.stream(),
input_ptrs,
Expand All @@ -164,13 +164,13 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
// 4096 inputs is a lot, most code will take the smem path
const int32 kMaxSmemBytesPerformance = 16384;
if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance) {
GPU_LAUNCH_KERNEL(concat_variable_kernel<T, IntType, true>,
GPU_LAUNCH_KERNEL((concat_variable_kernel<T, IntType, true>),
dim3(config.block_count), dim3(config.thread_per_block), smem_usage,
gpu_device.stream(),
input_ptrs, output_scan, output->dimension(0), output->dimension(1),
output->data());
} else {
GPU_LAUNCH_KERNEL(concat_variable_kernel<T, IntType, false>,
GPU_LAUNCH_KERNEL((concat_variable_kernel<T, IntType, false>),
dim3(config.block_count), dim3(config.thread_per_block), 0,
gpu_device.stream(),
input_ptrs, output_scan, output->dimension(0), output->dimension(1),
Expand Down
20 changes: 10 additions & 10 deletions tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,12 +506,12 @@ struct PadInput<GPUDevice, T, int, NDIMS> {
const Dimension<NDIMS - 2> padding_left_dim(padding_left);

if (format == FORMAT_NHWC) {
GPU_LAUNCH_KERNEL(PadInputCustomKernelNHWC<T, NDIMS>,
GPU_LAUNCH_KERNEL((PadInputCustomKernelNHWC<T, NDIMS>),
dim3(config.block_count), dim3(config.thread_per_block), 0, d.stream(),
config.virtual_thread_count, in.data(), input_dims, out.data(),
output_dims, padding_left_dim);
} else if (format == FORMAT_NCHW) {
GPU_LAUNCH_KERNEL(PadInputCustomKernelNCHW<T, NDIMS>,
GPU_LAUNCH_KERNEL((PadInputCustomKernelNCHW<T, NDIMS>),
dim3(config.block_count), dim3(config.thread_per_block), 0, d.stream(),
config.virtual_thread_count, in.data(), input_dims, out.data(),
output_dims, padding_left_dim);
Expand Down Expand Up @@ -623,13 +623,13 @@ void LaunchBatchNarrowMatrixTransposeKernel(
const T* input, const Dimension<3>& input_dims, T* output) {
constexpr int NumThreads = TileLongSide;
if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) {
GPU_LAUNCH_KERNEL(SwapDimension1And2InTensor3UsingTiles<T, NumThreads,
TileLongSide, TileShortSide>,
GPU_LAUNCH_KERNEL((SwapDimension1And2InTensor3UsingTiles<T, NumThreads,
TileLongSide, TileShortSide>),
dim3(total_tiles_count), dim3(NumThreads), 0, d.stream(),
input, input_dims, output);
} else {
GPU_LAUNCH_KERNEL(SwapDimension1And2InTensor3UsingTiles<T, NumThreads,
TileShortSide, TileLongSide>,
GPU_LAUNCH_KERNEL((SwapDimension1And2InTensor3UsingTiles<T, NumThreads,
TileShortSide, TileLongSide>),
dim3(total_tiles_count), dim3(NumThreads), 0, d.stream(),
input, input_dims, output);
}
Expand Down Expand Up @@ -932,8 +932,8 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,

int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
input_dims_in_tiles[2];
GPU_LAUNCH_KERNEL(SwapDimension1And2InTensor3UsingTiles<T, kNumThreads,
kTileSize, kTileSize, conjugate>,
GPU_LAUNCH_KERNEL((SwapDimension1And2InTensor3UsingTiles<T, kNumThreads,
kTileSize, kTileSize, conjugate>),
dim3(total_tiles_count), dim3(kNumThreads), 0, d.stream(),
input, input_dims, output);

Expand All @@ -943,7 +943,7 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
GpuLaunchConfig config = GetGpuLaunchConfig(total_element_count, d);
GPU_LAUNCH_KERNEL(SwapDimension1And2InTensor3Simple<T, conjugate>,
GPU_LAUNCH_KERNEL((SwapDimension1And2InTensor3Simple<T, conjugate>),
dim3(config.block_count), dim3(config.thread_per_block), 0, d.stream(),
config.virtual_thread_count, input, input_dims, output);
}
Expand Down Expand Up @@ -975,7 +975,7 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
static_cast<int>(combined_dims[2])};
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d);
GPU_LAUNCH_KERNEL(SwapDimension0And2InTensor3Simple<T, conjugate>,
GPU_LAUNCH_KERNEL((SwapDimension0And2InTensor3Simple<T, conjugate>),
dim3(config.block_count), dim3(config.thread_per_block), 0, d.stream(),
config.virtual_thread_count, in, input_dims, out);
}
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ struct CropAndResizeBackpropImage<GPUDevice, T> {
total_count = batch * image_height * image_width * depth;
if (total_count > 0) {
config = GetGpuLaunchConfig(total_count, d);
GPU_LAUNCH_KERNEL(SetZero<T>,
GPU_LAUNCH_KERNEL(SetZero,
dim3(config.block_count), dim3(config.thread_per_block), 0,
d.stream(),
config.virtual_thread_count, grads_image.data());
Expand Down Expand Up @@ -458,7 +458,7 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
total_count = num_boxes * 4;
if (total_count > 0) {
config = GetGpuLaunchConfig(total_count, d);
GPU_LAUNCH_KERNEL(SetZero<T>,
GPU_LAUNCH_KERNEL(SetZero,
dim3(config.block_count), dim3(config.thread_per_block), 0,
d.stream(),
config.virtual_thread_count, grads_boxes.data());
Expand Down
Loading