@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16+ #include " tensorflow/core/kernels/roll_op.h"
1617#include " tensorflow/core/framework/bounds_check.h"
1718#include " tensorflow/core/framework/common_shape_fns.h"
1819#include " tensorflow/core/framework/op.h"
@@ -26,8 +27,87 @@ limitations under the License.
2627
2728namespace tensorflow {
2829
29- #define EIGEN_USE_THREADS
30- using CPUDevice = Eigen::ThreadPoolDevice;
30+ typedef Eigen::ThreadPoolDevice CPUDevice;
31+ typedef Eigen::GpuDevice GPUDevice;
32+
33+ template <typename Device, typename T, typename Tshift, typename Taxis>
34+ class RollOp : public OpKernel {
35+ public:
36+ explicit RollOp (OpKernelConstruction* context) : OpKernel(context) {}
37+
38+ void Compute (OpKernelContext* context) override {
39+ // Grab the input tensor
40+ const Tensor& input = context->input (0 );
41+ const Tensor& shift = context->input (1 );
42+ const Tensor& axis = context->input (2 );
43+
44+ auto shift_flat = shift.flat <Tshift>();
45+ auto axis_flat = axis.flat <Taxis>();
46+
47+ OP_REQUIRES (context, TensorShapeUtils::IsVectorOrHigher (input.shape ()),
48+ errors::InvalidArgument (" input must be 1-D or higher" ));
49+ OP_REQUIRES (context, shift.shape ().dims () <= 1 ,
50+ errors::InvalidArgument (
51+ " shift must be a scalar or a 1-D vector. Found: " ,
52+ shift.shape ().DebugString ()));
53+ OP_REQUIRES (context, axis.shape ().dims () <= 1 ,
54+ errors::InvalidArgument (
55+ " axis must be a scalar or a 1-D vector. Found: " ,
56+ axis.shape ().DebugString ()));
57+ OP_REQUIRES (
58+ context, shift.shape () == axis.shape (),
59+ errors::InvalidArgument (" shift and axis must have the same size" ));
60+ const int64 num_elements = input.NumElements ();
61+ const int num_shifts = static_cast <int >(shift_flat.size ());
62+ const int num_dims = input.dims ();
63+
64+ // if there are any duplicate axes, shift_mod_sum will have the
65+ // total modulo sum of shifts for each dimension
66+ gtl::InlinedVector<int32, 4 > shift_mod_sum (num_dims, 0 );
67+ for (int i = 0 ; i < num_shifts; i++) {
68+ int axis = axis_flat (i);
69+ if (axis < 0 ) {
70+ axis += num_dims;
71+ }
72+ OP_REQUIRES (context, FastBoundsCheck (axis, num_dims),
73+ errors::InvalidArgument (" axis " , axis, " is out of range" ));
74+ const int ds = std::max<int >(static_cast <int >(input.dim_size (axis)), 1 );
75+ const int sum = shift_mod_sum[axis] + static_cast <int >(shift_flat (i));
76+ // modulo that works with negatives: ((x % y) + y) % y
77+ shift_mod_sum[axis] = (sum % ds + ds) % ds;
78+ }
79+ // the size of each dimension
80+ gtl::InlinedVector<int32, 4 > dim_size (num_dims);
81+ // threshold[i] is the index that the roll starts to wrap back to the front
82+ gtl::InlinedVector<int32, 4 > threshold (num_dims);
83+ // dim_range is the number of indices over in the flattened tensor
84+ // you need to skip in order to make it over from one side of a dimension
85+ // to the other. Used to make the shifts wrap around after a threshold.
86+ gtl::InlinedVector<int64, 4 > dim_range (num_dims);
87+ int64 dim_size_prod = 1 ; // dimension size product
88+ // inner shift dimension (inner most shifted dimension)
89+ int64 isd = 0 ;
90+ for (int i = num_dims - 1 ; i >= 0 ; i--) {
91+ if (isd == 0 && shift_mod_sum[i] != 0 ) isd = i;
92+ const int ds = std::max<int >(static_cast <int >(input.dim_size (i)), 1 );
93+ dim_size[i] = ds;
94+ threshold[i] = (ds - shift_mod_sum[i]) % ds;
95+ dim_size_prod *= static_cast <int64>(input.dim_size (i));
96+ dim_range[i] = dim_size_prod;
97+ }
98+
99+ Tensor* output = nullptr ;
100+ OP_REQUIRES_OK (context,
101+ context->allocate_output (0 , input.shape (), &output));
102+ auto input_flat = input.flat <T>().data ();
103+ auto output_flat = output->flat <T>().data ();
104+
105+ functor::Roll<Device, T>()(context, num_elements, num_dims, dim_size,
106+ input_flat, output_flat, threshold, dim_range, isd);
107+ }
108+ };
109+
110+ namespace functor {
31111
32112// dim_size - the size of each dimension
33113// dim_range - the number of indices over in the flattened tensor
@@ -36,9 +116,9 @@ using CPUDevice = Eigen::ThreadPoolDevice;
36116// threshold - the index for each dimension that the roll starts to wrap
37117// back to the front
38118template <typename T>
39- void DoRoll (OpKernelContext* context, const int64 num_elements,
40- const int num_dims, const gtl::ArraySlice<int >& dim_size,
41- const T* input, T* output, const gtl::ArraySlice<int >& threshold,
119+ void DoRoll (const OpKernelContext* context, const int64 num_elements,
120+ const int num_dims, const gtl::ArraySlice<int32 >& dim_size,
121+ const T* input, T* output, const gtl::ArraySlice<int32 >& threshold,
42122 const gtl::ArraySlice<int64>& dim_range) {
43123 auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
44124 int64 start, int64 end) {
@@ -99,10 +179,10 @@ void DoRoll(OpKernelContext* context, const int64 num_elements,
99179// isd - inner shift dimension
100180template <typename T>
101181// Use memcpy to copy memory in groups when the data type supports memcpy
102- void DoRollWithMemcpy (OpKernelContext* context, const int64 num_elements,
103- const int num_dims, const gtl::ArraySlice<int >& dim_size,
182+ void DoRollWithMemcpy (const OpKernelContext* context, const int64 num_elements,
183+ const int num_dims, const gtl::ArraySlice<int32 >& dim_size,
104184 const T* input, T* output,
105- const gtl::ArraySlice<int >& threshold,
185+ const gtl::ArraySlice<int32 >& threshold,
106186 const gtl::ArraySlice<int64>& dim_range,
107187 const int64 isd) {
108188 auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
@@ -220,119 +300,103 @@ void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
220300 cost_per_group, std::move (work));
221301}
222302
223- template <typename Device, typename T, typename Tshift, typename Taxis>
224- class RollOp : public OpKernel {
225- public:
226- explicit RollOp (OpKernelConstruction* context) : OpKernel(context) {}
227-
228- void Compute (OpKernelContext* context) override {
229- // Grab the input tensor
230- const Tensor& input = context->input (0 );
231- const Tensor& shift = context->input (1 );
232- const Tensor& axis = context->input (2 );
233-
234- auto shift_flat = shift.flat <Tshift>();
235- auto axis_flat = axis.flat <Taxis>();
236-
237- OP_REQUIRES (context, TensorShapeUtils::IsVectorOrHigher (input.shape ()),
238- errors::InvalidArgument (" input must be 1-D or higher" ));
239- OP_REQUIRES (context, shift.shape ().dims () <= 1 ,
240- errors::InvalidArgument (
241- " shift must be a scalar or a 1-D vector. Found: " ,
242- shift.shape ().DebugString ()));
243- OP_REQUIRES (context, axis.shape ().dims () <= 1 ,
244- errors::InvalidArgument (
245- " axis must be a scalar or a 1-D vector. Found: " ,
246- axis.shape ().DebugString ()));
247- OP_REQUIRES (
248- context, shift.shape () == axis.shape (),
249- errors::InvalidArgument (" shift and axis must have the same size" ));
250- const int64 num_elements = input.NumElements ();
251- const int num_shifts = static_cast <int >(shift_flat.size ());
252- const int num_dims = input.dims ();
253-
254- // if there are any duplicate axes, shift_mod_sum will have the
255- // total modulo sum of shifts for each dimension
256- gtl::InlinedVector<int , 4 > shift_mod_sum (num_dims, 0 );
257- for (int i = 0 ; i < num_shifts; i++) {
258- int axis = axis_flat (i);
259- if (axis < 0 ) {
260- axis += num_dims;
261- }
262- OP_REQUIRES (context, FastBoundsCheck (axis, num_dims),
263- errors::InvalidArgument (" axis " , axis, " is out of range" ));
264- const int ds = std::max<int >(static_cast <int >(input.dim_size (axis)), 1 );
265- const int sum = shift_mod_sum[axis] + static_cast <int >(shift_flat (i));
266- // modulo that works with negatives: ((x % y) + y) % y
267- shift_mod_sum[axis] = (sum % ds + ds) % ds;
268- }
269- // the size of each dimension
270- gtl::InlinedVector<int , 4 > dim_size (num_dims);
271- // threshold[i] is the index that the roll starts to wrap back to the front
272- gtl::InlinedVector<int , 4 > threshold (num_dims);
273- // dim_range is the number of indices over in the flattened tensor
274- // you need to skip in order to make it over from one side of a dimension
275- // to the other. Used to make the shifts wrap around after a threshold.
276- gtl::InlinedVector<int64, 4 > dim_range (num_dims);
277- int64 dim_size_prod = 1 ; // dimension size product
278- // inner shift dimension (inner most shifted dimension)
279- int64 isd = 0 ;
280- for (int i = num_dims - 1 ; i >= 0 ; i--) {
281- if (isd == 0 && shift_mod_sum[i] != 0 ) isd = i;
282- const int ds = std::max<int >(static_cast <int >(input.dim_size (i)), 1 );
283- dim_size[i] = ds;
284- threshold[i] = (ds - shift_mod_sum[i]) % ds;
285- dim_size_prod *= static_cast <int64>(input.dim_size (i));
286- dim_range[i] = dim_size_prod;
287- }
288-
289- Tensor* output = nullptr ;
290- OP_REQUIRES_OK (context,
291- context->allocate_output (0 , input.shape (), &output));
292- auto input_flat = input.flat <T>().data ();
293- auto output_flat = output->flat <T>().data ();
294-
295- if (std::is_same<Device, CPUDevice>::value) {
296- if (DataTypeCanUseMemcpy (DataTypeToEnum<T>::v ())) {
297- // V2 copies memory in groups instead of element by element
298- DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
299- input_flat, output_flat, threshold, dim_range, isd);
300- } else {
301- // incase memcpy does not work for current data type
302- DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat,
303- output_flat, threshold, dim_range);
304- }
305- }
306- }
303+ template <typename T>
304+ struct Roll <CPUDevice, T> {
305+ void operator ()(const OpKernelContext *context,
306+ const int64 num_elements,
307+ const int num_dims,
308+ const gtl::ArraySlice<int32> dim_size,
309+ const T *input, T *output,
310+ const gtl::ArraySlice<int32> threshold,
311+ const gtl::ArraySlice<int64> dim_range,
312+ const int64 isd) {
313+ if (DataTypeCanUseMemcpy (DataTypeToEnum<T>::v ())) {
314+ // V2 copies memory in groups instead of element by element
315+ DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
316+ input, output, threshold, dim_range, isd);
317+ } else {
318+ // incase memcpy does not work for current data type
319+ DoRoll<T>(context, num_elements, num_dims, dim_size, input,
320+ output, threshold, dim_range);
321+ }
322+ };
307323};
324+ }
308325
309326// Register the CPU kernels.
310327#define REGISTER_CPU (type ) \
311328 REGISTER_KERNEL_BUILDER (Name(" Roll" ) \
312329 .Device(DEVICE_CPU) \
313330 .TypeConstraint<type>(" T" ) \
314331 .TypeConstraint<int32>(" Tshift" ) \
315- .TypeConstraint<int32>(" Taxis" ), \
332+ .TypeConstraint<int32>(" Taxis" ) \
333+ .HostMemory(" shift" ) \
334+ .HostMemory(" axis" ), \
316335 RollOp<CPUDevice, type, int32, int32>) \
317336 REGISTER_KERNEL_BUILDER (Name(" Roll" ) \
318337 .Device(DEVICE_CPU) \
319338 .TypeConstraint<type>(" T" ) \
320339 .TypeConstraint<int64>(" Tshift" ) \
321- .TypeConstraint<int32>(" Taxis" ), \
340+ .TypeConstraint<int32>(" Taxis" ) \
341+ .HostMemory(" shift" ) \
342+ .HostMemory(" axis" ), \
322343 RollOp<CPUDevice, type, int64, int32>) \
323344 REGISTER_KERNEL_BUILDER (Name(" Roll" ) \
324345 .Device(DEVICE_CPU) \
325346 .TypeConstraint<type>(" T" ) \
326347 .TypeConstraint<int32>(" Tshift" ) \
327- .TypeConstraint<int64>(" Taxis" ), \
348+ .TypeConstraint<int64>(" Taxis" ) \
349+ .HostMemory(" shift" ) \
350+ .HostMemory(" axis" ), \
328351 RollOp<CPUDevice, type, int32, int64>) \
329352 REGISTER_KERNEL_BUILDER (Name(" Roll" ) \
330353 .Device(DEVICE_CPU) \
331354 .TypeConstraint<type>(" T" ) \
332355 .TypeConstraint<int64>(" Tshift" ) \
333- .TypeConstraint<int64>(" Taxis" ), \
356+ .TypeConstraint<int64>(" Taxis" ) \
357+ .HostMemory(" shift" ) \
358+ .HostMemory(" axis" ), \
334359 RollOp<CPUDevice, type, int64, int64>)
335360
336361TF_CALL_ALL_TYPES (REGISTER_CPU);
337362#undef REGISTER_CPU
363+
364+ #if GOOGLE_CUDA
365+ #define REGISTER_KERNEL (type ) \
366+ REGISTER_KERNEL_BUILDER (Name(" Roll" ) \
367+ .Device(DEVICE_GPU) \
368+ .TypeConstraint<type>(" T" ) \
369+ .TypeConstraint<int32>(" Tshift" ) \
370+ .TypeConstraint<int32>(" Taxis" ) \
371+ .HostMemory(" shift" ) \
372+ .HostMemory(" axis" ), \
373+ RollOp<GPUDevice, type, int32, int32>) \
374+ REGISTER_KERNEL_BUILDER(Name(" Roll" ) \
375+ .Device(DEVICE_GPU) \
376+ .TypeConstraint<type>(" T" ) \
377+ .TypeConstraint<int64>(" Tshift" ) \
378+ .TypeConstraint<int32>(" Taxis" ) \
379+ .HostMemory(" shift" ) \
380+ .HostMemory(" axis" ), \
381+ RollOp<GPUDevice, type, int64, int32>) \
382+ REGISTER_KERNEL_BUILDER(Name(" Roll" ) \
383+ .Device(DEVICE_GPU) \
384+ .TypeConstraint<type>(" T" ) \
385+ .TypeConstraint<int32>(" Tshift" ) \
386+ .TypeConstraint<int64>(" Taxis" ) \
387+ .HostMemory(" shift" ) \
388+ .HostMemory(" axis" ), \
389+ RollOp<GPUDevice, type, int32, int64>) \
390+ REGISTER_KERNEL_BUILDER(Name(" Roll" ) \
391+ .Device(DEVICE_GPU) \
392+ .TypeConstraint<type>(" T" ) \
393+ .TypeConstraint<int64>(" Tshift" ) \
394+ .TypeConstraint<int64>(" Taxis" ) \
395+ .HostMemory(" shift" ) \
396+ .HostMemory(" axis" ), \
397+ RollOp<GPUDevice, type, int64, int64>)
398+
399+ TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
400+ #undef REGISTER_KERNEL
401+ #endif // GOOGLE_CUDA
338402} // namespace tensorflow
0 commit comments