Skip to content

Commit

Permalink
Revert "Simplify size op impl (#45808)" (#46123)
Browse files Browse the repository at this point in the history
This reverts commit c252b1d.
  • Loading branch information
chenwhql authored Sep 19, 2022
1 parent b273bb4 commit d963e2e
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 23 deletions.
32 changes: 32 additions & 0 deletions paddle/phi/kernels/cpu/size_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/size_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/size_kernel_impl.h"

PD_REGISTER_KERNEL(size,
CPU,
ALL_LAYOUT,
phi::SizeKernel,
uint8_t,
int16_t,
int,
int64_t,
phi::dtype::float16,
float,
double,
bool) {}
31 changes: 31 additions & 0 deletions paddle/phi/kernels/gpu/size_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/size_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/size_kernel_impl.h"

PD_REGISTER_KERNEL(size,
GPU,
ALL_LAYOUT,
phi::SizeKernel,
int16_t,
int,
int64_t,
phi::dtype::float16,
float,
double,
bool) {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/size_kernel.h"
#pragma once

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"

namespace phi {

template <typename Context>
template <typename T, typename Context>
void SizeKernel(const Context& ctx,
const DenseTensor& input,
DenseTensor* out) {
auto* out_data = ctx.template HostAlloc<int64_t>(out);
out_data[0] = input.numel();
auto place = ctx.GetPlace();
auto out_data = ctx.template Alloc<int64_t>(out);
auto cpu_place = phi::CPUPlace();
if (place == cpu_place) {
out_data[0] = input.numel();
} else {
DenseTensor cpu_tensor;
cpu_tensor.Resize(out->dims());
auto cpu_data = ctx.template HostAlloc<int64_t>(&cpu_tensor);
cpu_data[0] = input.numel();
phi::Copy(ctx, cpu_tensor, place, false, out);
}
}

} // namespace phi

PD_REGISTER_GENERAL_KERNEL(
size, CPU, ALL_LAYOUT, phi::SizeKernel<phi::CPUContext>, ALL_DTYPE) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
size, GPU, ALL_LAYOUT, phi::SizeKernel<phi::GPUContext>, ALL_DTYPE) {
kernel->OutputAt(0)
.SetBackend(phi::Backend::CPU)
.SetDataType(phi::DataType::INT64);
}
#endif
2 changes: 1 addition & 1 deletion paddle/phi/kernels/size_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

namespace phi {

template <typename Context>
template <typename T, typename Context>
void SizeKernel(const Context& ctx, const DenseTensor& input, DenseTensor* out);

} // namespace phi
3 changes: 0 additions & 3 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,9 +1117,6 @@ def all_gather_object(object_list, obj, group=None):
), "all_gather_object doesn't support static graph mode."

tensor, len_of_tensor = _convert_object_to_tensor(obj)
if paddle.get_device() != "cpu":
len_of_tensor = len_of_tensor._copy_to(
paddle.framework._current_expected_place(), False)

# gather len_of_tensor from all ranks
list_len_of_tensor = []
Expand Down

0 comments on commit d963e2e

Please sign in to comment.