Skip to content

Commit

Permalink
optimized expand_as_kernel (#57509)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored Sep 22, 2023
1 parent f10dede commit acad271
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions paddle/phi/kernels/gpu/expand_as_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"

namespace phi {
Expand Down Expand Up @@ -70,11 +71,7 @@ void ExpandAsKernel(const Context& ctx,
}
}

out->Resize(phi::make_ddim(target_shape));
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<T>(ctx, ins, &outs, kps::IdentityFunctor<T>());
ExpandKernel<T, Context>(ctx, x, target_shape, out);
}

} // namespace phi
Expand Down

0 comments on commit acad271

Please sign in to comment.