Skip to content

Commit

Permalink
fix include headers
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 committed Mar 10, 2022
1 parent 67537d6 commit 0de44a0
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
65 changes: 65 additions & 0 deletions paddle/fluid/operators/tile_op_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <vector>

#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {

inline std::vector<int> get_repeat_times(
const framework::ExecutionContext& ctx) {
if (ctx.HasInput("RepeatTimes")) {
auto* repeat_tensor = ctx.Input<framework::LoDTensor>("RepeatTimes");
auto* repeat_data = repeat_tensor->data<int>();
framework::Tensor cpu_repeat_tensor;
if (platform::is_gpu_place(repeat_tensor->place()) ||
platform::is_xpu_place(repeat_tensor->place()) ||
platform::is_npu_place(repeat_tensor->place())) {
paddle::framework::TensorCopySync(*repeat_tensor, platform::CPUPlace(),
&cpu_repeat_tensor);
repeat_data = cpu_repeat_tensor.data<int>();
}
auto vec_repeat_times =
std::vector<int>(repeat_data, repeat_data + repeat_tensor->numel());
return vec_repeat_times;
}

auto list_repeat_times_tensor =
ctx.MultiInput<framework::Tensor>("repeat_times_tensor");
if (list_repeat_times_tensor.size() > 0) {
// get tensor from
std::vector<int> vec_repeat_times;
for (size_t i = 0; i < list_repeat_times_tensor.size(); ++i) {
auto tensor = list_repeat_times_tensor[i];
if (platform::is_gpu_place(tensor->place()) ||
platform::is_xpu_place(tensor->place()) ||
platform::is_npu_place(tensor->place())) {
framework::Tensor temp;
paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_repeat_times.push_back(*temp.data<int32_t>());
} else {
vec_repeat_times.push_back(*tensor->data<int32_t>());
}
}
return vec_repeat_times;
} else {
return ctx.Attr<std::vector<int>>("repeat_times");
}
}

} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/operators/tile_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/tile_op_functor.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"

namespace paddle {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/tile_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/tile_op_functor.h"

namespace paddle {
namespace operators {
Expand Down

0 comments on commit 0de44a0

Please sign in to comment.