Skip to content

Commit 9790d34

Browse files
add cuda implmentation of roll
change roll_op.cc to the original implemantation
1 parent d1db2e6 commit 9790d34

File tree

4 files changed

+319
-97
lines changed

4 files changed

+319
-97
lines changed

tensorflow/core/kernels/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3234,7 +3234,12 @@ cc_library(
32343234

32353235
tf_kernel_library(
32363236
name = "roll_op",
3237-
prefix = "roll_op",
3237+
srcs = ["roll_op.cc"],
3238+
hdrs = ["roll_op.h"],
3239+
gpu_srcs = [
3240+
"roll_op_gpu.cu.cc",
3241+
"roll_op.h",
3242+
],
32383243
deps = [
32393244
":bounds_check",
32403245
"//tensorflow/core:framework",

tensorflow/core/kernels/roll_op.cc

Lines changed: 160 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include "tensorflow/core/kernels/roll_op.h"
1617
#include "tensorflow/core/framework/bounds_check.h"
1718
#include "tensorflow/core/framework/common_shape_fns.h"
1819
#include "tensorflow/core/framework/op.h"
@@ -26,8 +27,87 @@ limitations under the License.
2627

2728
namespace tensorflow {
2829

29-
#define EIGEN_USE_THREADS
30-
using CPUDevice = Eigen::ThreadPoolDevice;
30+
typedef Eigen::ThreadPoolDevice CPUDevice;
31+
typedef Eigen::GpuDevice GPUDevice;
32+
33+
template <typename Device, typename T, typename Tshift, typename Taxis>
34+
class RollOp : public OpKernel {
35+
public:
36+
explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
37+
38+
void Compute(OpKernelContext* context) override {
39+
// Grab the input tensor
40+
const Tensor& input = context->input(0);
41+
const Tensor& shift = context->input(1);
42+
const Tensor& axis = context->input(2);
43+
44+
auto shift_flat = shift.flat<Tshift>();
45+
auto axis_flat = axis.flat<Taxis>();
46+
47+
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
48+
errors::InvalidArgument("input must be 1-D or higher"));
49+
OP_REQUIRES(context, shift.shape().dims() <= 1,
50+
errors::InvalidArgument(
51+
"shift must be a scalar or a 1-D vector. Found: ",
52+
shift.shape().DebugString()));
53+
OP_REQUIRES(context, axis.shape().dims() <= 1,
54+
errors::InvalidArgument(
55+
"axis must be a scalar or a 1-D vector. Found: ",
56+
axis.shape().DebugString()));
57+
OP_REQUIRES(
58+
context, shift.shape() == axis.shape(),
59+
errors::InvalidArgument("shift and axis must have the same size"));
60+
const int64 num_elements = input.NumElements();
61+
const int num_shifts = static_cast<int>(shift_flat.size());
62+
const int num_dims = input.dims();
63+
64+
// if there are any duplicate axes, shift_mod_sum will have the
65+
// total modulo sum of shifts for each dimension
66+
gtl::InlinedVector<int32, 4> shift_mod_sum(num_dims, 0);
67+
for (int i = 0; i < num_shifts; i++) {
68+
int axis = axis_flat(i);
69+
if (axis < 0) {
70+
axis += num_dims;
71+
}
72+
OP_REQUIRES(context, FastBoundsCheck(axis, num_dims),
73+
errors::InvalidArgument("axis ", axis, " is out of range"));
74+
const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
75+
const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
76+
// modulo that works with negatives: ((x % y) + y) % y
77+
shift_mod_sum[axis] = (sum % ds + ds) % ds;
78+
}
79+
// the size of each dimension
80+
gtl::InlinedVector<int32, 4> dim_size(num_dims);
81+
// threshold[i] is the index that the roll starts to wrap back to the front
82+
gtl::InlinedVector<int32, 4> threshold(num_dims);
83+
// dim_range is the number of indices over in the flattened tensor
84+
// you need to skip in order to make it over from one side of a dimension
85+
// to the other. Used to make the shifts wrap around after a threshold.
86+
gtl::InlinedVector<int64, 4> dim_range(num_dims);
87+
int64 dim_size_prod = 1; // dimension size product
88+
// inner shift dimension (inner most shifted dimension)
89+
int64 isd = 0;
90+
for (int i = num_dims - 1; i >= 0; i--) {
91+
if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
92+
const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
93+
dim_size[i] = ds;
94+
threshold[i] = (ds - shift_mod_sum[i]) % ds;
95+
dim_size_prod *= static_cast<int64>(input.dim_size(i));
96+
dim_range[i] = dim_size_prod;
97+
}
98+
99+
Tensor* output = nullptr;
100+
OP_REQUIRES_OK(context,
101+
context->allocate_output(0, input.shape(), &output));
102+
auto input_flat = input.flat<T>().data();
103+
auto output_flat = output->flat<T>().data();
104+
105+
functor::Roll<Device, T>()(context, num_elements, num_dims, dim_size,
106+
input_flat, output_flat, threshold, dim_range, isd);
107+
}
108+
};
109+
110+
namespace functor {
31111

32112
// dim_size - the size of each dimension
33113
// dim_range - the number of indices over in the flattened tensor
@@ -36,9 +116,9 @@ using CPUDevice = Eigen::ThreadPoolDevice;
36116
// threshold - the index for each dimension that the roll starts to wrap
37117
// back to the front
38118
template <typename T>
39-
void DoRoll(OpKernelContext* context, const int64 num_elements,
40-
const int num_dims, const gtl::ArraySlice<int>& dim_size,
41-
const T* input, T* output, const gtl::ArraySlice<int>& threshold,
119+
void DoRoll(const OpKernelContext* context, const int64 num_elements,
120+
const int num_dims, const gtl::ArraySlice<int32>& dim_size,
121+
const T* input, T* output, const gtl::ArraySlice<int32>& threshold,
42122
const gtl::ArraySlice<int64>& dim_range) {
43123
auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
44124
int64 start, int64 end) {
@@ -99,10 +179,10 @@ void DoRoll(OpKernelContext* context, const int64 num_elements,
99179
// isd - inner shift dimension
100180
template <typename T>
101181
// Use memcpy to copy memory in groups when the data type supports memcpy
102-
void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
103-
const int num_dims, const gtl::ArraySlice<int>& dim_size,
182+
void DoRollWithMemcpy(const OpKernelContext* context, const int64 num_elements,
183+
const int num_dims, const gtl::ArraySlice<int32>& dim_size,
104184
const T* input, T* output,
105-
const gtl::ArraySlice<int>& threshold,
185+
const gtl::ArraySlice<int32>& threshold,
106186
const gtl::ArraySlice<int64>& dim_range,
107187
const int64 isd) {
108188
auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
@@ -220,119 +300,103 @@ void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
220300
cost_per_group, std::move(work));
221301
}
222302

223-
template <typename Device, typename T, typename Tshift, typename Taxis>
224-
class RollOp : public OpKernel {
225-
public:
226-
explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
227-
228-
void Compute(OpKernelContext* context) override {
229-
// Grab the input tensor
230-
const Tensor& input = context->input(0);
231-
const Tensor& shift = context->input(1);
232-
const Tensor& axis = context->input(2);
233-
234-
auto shift_flat = shift.flat<Tshift>();
235-
auto axis_flat = axis.flat<Taxis>();
236-
237-
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
238-
errors::InvalidArgument("input must be 1-D or higher"));
239-
OP_REQUIRES(context, shift.shape().dims() <= 1,
240-
errors::InvalidArgument(
241-
"shift must be a scalar or a 1-D vector. Found: ",
242-
shift.shape().DebugString()));
243-
OP_REQUIRES(context, axis.shape().dims() <= 1,
244-
errors::InvalidArgument(
245-
"axis must be a scalar or a 1-D vector. Found: ",
246-
axis.shape().DebugString()));
247-
OP_REQUIRES(
248-
context, shift.shape() == axis.shape(),
249-
errors::InvalidArgument("shift and axis must have the same size"));
250-
const int64 num_elements = input.NumElements();
251-
const int num_shifts = static_cast<int>(shift_flat.size());
252-
const int num_dims = input.dims();
253-
254-
// if there are any duplicate axes, shift_mod_sum will have the
255-
// total modulo sum of shifts for each dimension
256-
gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0);
257-
for (int i = 0; i < num_shifts; i++) {
258-
int axis = axis_flat(i);
259-
if (axis < 0) {
260-
axis += num_dims;
261-
}
262-
OP_REQUIRES(context, FastBoundsCheck(axis, num_dims),
263-
errors::InvalidArgument("axis ", axis, " is out of range"));
264-
const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
265-
const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
266-
// modulo that works with negatives: ((x % y) + y) % y
267-
shift_mod_sum[axis] = (sum % ds + ds) % ds;
268-
}
269-
// the size of each dimension
270-
gtl::InlinedVector<int, 4> dim_size(num_dims);
271-
// threshold[i] is the index that the roll starts to wrap back to the front
272-
gtl::InlinedVector<int, 4> threshold(num_dims);
273-
// dim_range is the number of indices over in the flattened tensor
274-
// you need to skip in order to make it over from one side of a dimension
275-
// to the other. Used to make the shifts wrap around after a threshold.
276-
gtl::InlinedVector<int64, 4> dim_range(num_dims);
277-
int64 dim_size_prod = 1; // dimension size product
278-
// inner shift dimension (inner most shifted dimension)
279-
int64 isd = 0;
280-
for (int i = num_dims - 1; i >= 0; i--) {
281-
if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
282-
const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
283-
dim_size[i] = ds;
284-
threshold[i] = (ds - shift_mod_sum[i]) % ds;
285-
dim_size_prod *= static_cast<int64>(input.dim_size(i));
286-
dim_range[i] = dim_size_prod;
287-
}
288-
289-
Tensor* output = nullptr;
290-
OP_REQUIRES_OK(context,
291-
context->allocate_output(0, input.shape(), &output));
292-
auto input_flat = input.flat<T>().data();
293-
auto output_flat = output->flat<T>().data();
294-
295-
if (std::is_same<Device, CPUDevice>::value) {
296-
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
297-
// V2 copies memory in groups instead of element by element
298-
DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
299-
input_flat, output_flat, threshold, dim_range, isd);
300-
} else {
301-
// incase memcpy does not work for current data type
302-
DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat,
303-
output_flat, threshold, dim_range);
304-
}
305-
}
306-
}
303+
template<typename T>
304+
struct Roll<CPUDevice, T> {
305+
void operator()(const OpKernelContext *context,
306+
const int64 num_elements,
307+
const int num_dims,
308+
const gtl::ArraySlice<int32> dim_size,
309+
const T *input, T *output,
310+
const gtl::ArraySlice<int32> threshold,
311+
const gtl::ArraySlice<int64> dim_range,
312+
const int64 isd) {
313+
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
314+
// V2 copies memory in groups instead of element by element
315+
DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
316+
input, output, threshold, dim_range, isd);
317+
} else {
318+
// incase memcpy does not work for current data type
319+
DoRoll<T>(context, num_elements, num_dims, dim_size, input,
320+
output, threshold, dim_range);
321+
}
322+
};
307323
};
324+
}
308325

309326
// Register the CPU kernels.
310327
#define REGISTER_CPU(type) \
311328
REGISTER_KERNEL_BUILDER(Name("Roll") \
312329
.Device(DEVICE_CPU) \
313330
.TypeConstraint<type>("T") \
314331
.TypeConstraint<int32>("Tshift") \
315-
.TypeConstraint<int32>("Taxis"), \
332+
.TypeConstraint<int32>("Taxis") \
333+
.HostMemory("shift") \
334+
.HostMemory("axis"), \
316335
RollOp<CPUDevice, type, int32, int32>) \
317336
REGISTER_KERNEL_BUILDER(Name("Roll") \
318337
.Device(DEVICE_CPU) \
319338
.TypeConstraint<type>("T") \
320339
.TypeConstraint<int64>("Tshift") \
321-
.TypeConstraint<int32>("Taxis"), \
340+
.TypeConstraint<int32>("Taxis") \
341+
.HostMemory("shift") \
342+
.HostMemory("axis"), \
322343
RollOp<CPUDevice, type, int64, int32>) \
323344
REGISTER_KERNEL_BUILDER(Name("Roll") \
324345
.Device(DEVICE_CPU) \
325346
.TypeConstraint<type>("T") \
326347
.TypeConstraint<int32>("Tshift") \
327-
.TypeConstraint<int64>("Taxis"), \
348+
.TypeConstraint<int64>("Taxis") \
349+
.HostMemory("shift") \
350+
.HostMemory("axis"), \
328351
RollOp<CPUDevice, type, int32, int64>) \
329352
REGISTER_KERNEL_BUILDER(Name("Roll") \
330353
.Device(DEVICE_CPU) \
331354
.TypeConstraint<type>("T") \
332355
.TypeConstraint<int64>("Tshift") \
333-
.TypeConstraint<int64>("Taxis"), \
356+
.TypeConstraint<int64>("Taxis") \
357+
.HostMemory("shift") \
358+
.HostMemory("axis"), \
334359
RollOp<CPUDevice, type, int64, int64>)
335360

336361
TF_CALL_ALL_TYPES(REGISTER_CPU);
337362
#undef REGISTER_CPU
363+
364+
#if GOOGLE_CUDA
365+
#define REGISTER_KERNEL(type) \
366+
REGISTER_KERNEL_BUILDER(Name("Roll") \
367+
.Device(DEVICE_GPU) \
368+
.TypeConstraint<type>("T") \
369+
.TypeConstraint<int32>("Tshift") \
370+
.TypeConstraint<int32>("Taxis") \
371+
.HostMemory("shift") \
372+
.HostMemory("axis"), \
373+
RollOp<GPUDevice, type, int32, int32>) \
374+
REGISTER_KERNEL_BUILDER(Name("Roll") \
375+
.Device(DEVICE_GPU) \
376+
.TypeConstraint<type>("T") \
377+
.TypeConstraint<int64>("Tshift") \
378+
.TypeConstraint<int32>("Taxis") \
379+
.HostMemory("shift") \
380+
.HostMemory("axis"), \
381+
RollOp<GPUDevice, type, int64, int32>) \
382+
REGISTER_KERNEL_BUILDER(Name("Roll") \
383+
.Device(DEVICE_GPU) \
384+
.TypeConstraint<type>("T") \
385+
.TypeConstraint<int32>("Tshift") \
386+
.TypeConstraint<int64>("Taxis") \
387+
.HostMemory("shift") \
388+
.HostMemory("axis"), \
389+
RollOp<GPUDevice, type, int32, int64>) \
390+
REGISTER_KERNEL_BUILDER(Name("Roll") \
391+
.Device(DEVICE_GPU) \
392+
.TypeConstraint<type>("T") \
393+
.TypeConstraint<int64>("Tshift") \
394+
.TypeConstraint<int64>("Taxis") \
395+
.HostMemory("shift") \
396+
.HostMemory("axis"), \
397+
RollOp<GPUDevice, type, int64, int64>)
398+
399+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
400+
#undef REGISTER_KERNEL
401+
#endif // GOOGLE_CUDA
338402
} // namespace tensorflow

tensorflow/core/kernels/roll_op.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_ROLL_H
17+
#define TENSORFLOW_ROLL_H
18+
19+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20+
#include "tensorflow/core/framework/numeric_types.h"
21+
#include "tensorflow/core/framework/op_kernel.h"
22+
#include "tensorflow/core/framework/tensor_types.h"
23+
24+
namespace tensorflow {
25+
namespace functor {
26+
27+
template <typename Device, typename T>
28+
struct Roll {
29+
// dim_size - the size of each dimension
30+
// dim_range - the number of indices over in the flattened tensor
31+
// you need to skip in order to make it over from one side of a dimension
32+
// to the other. Used to make the shifts wrap around after a threshold.
33+
// threshold - the index for each dimension that the roll starts to wrap
34+
// back to the front
35+
// isd - inner shift dimension
36+
void operator()(const OpKernelContext* context,
37+
const int64 num_elements,
38+
const int num_dims,
39+
const gtl::ArraySlice<int32> dim_size,
40+
const T* input, T* output,
41+
const gtl::ArraySlice<int32> threshold,
42+
const gtl::ArraySlice<int64> dim_range,
43+
const int64 isd);
44+
};
45+
46+
}
47+
}
48+
49+
50+
#endif //TENSORFLOW_ROLLL_H

0 commit comments

Comments
 (0)