Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions paddle/common/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,27 @@ DDim DDim::reshape(std::vector<int>& shape) const {
return common::make_ddim(shape);
}

DDim DDim::reshape(std::vector<int64_t>& shape) const {
const DDim& in_dims = *this;

for (int i = 0; i < static_cast<int>(shape.size()); ++i) {
if (shape[i] == 0) {
shape[i] = static_cast<int64_t>(in_dims.at(i));
}
}

// Dim marked as "-1" must be inferred
auto it = std::find(shape.begin(), shape.end(), -1);
if (it != shape.end()) {
int index = static_cast<int>(std::distance(shape.begin(), it));
int64_t reshape_out_product =
std::accumulate(shape.begin(), shape.end(), -1, std::multiplies<>());
shape[index] = static_cast<int64_t>(product(in_dims)) / reshape_out_product;
}

return common::make_ddim(shape);
}

DDim DDim::transpose(const std::vector<int>& axis) const {
const DDim& in_dims = *this;

Expand Down
2 changes: 2 additions & 0 deletions paddle/common/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class TEST_API DDim {

DDim reshape(std::vector<int>& shape) const; // NOLINT

DDim reshape(std::vector<int64_t>& shape) const; // NOLINT

DDim transpose(const std::vector<int>& axis) const;

private:
Expand Down
21 changes: 12 additions & 9 deletions paddle/phi/kernels/funcs/matrix_inverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ namespace funcs {

template <typename Context, typename T>
struct MapMatrixInverseFunctor {
void operator()(
const Context& dev_ctx, const T* a_ptr, T* a_inv_ptr, int offset, int n) {
void operator()(const Context& dev_ctx,
const T* a_ptr,
T* a_inv_ptr,
int64_t offset,
int64_t n) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
Expand All @@ -52,8 +55,8 @@ struct MapMatrixInverseFunctor<Context, phi::dtype::complex<T>> {
void operator()(const Context& dev_ctx,
const phi::dtype::complex<T>* a_ptr,
phi::dtype::complex<T>* a_inv_ptr,
int offset,
int n) {
int64_t offset,
int64_t n) {
using Matrix = Eigen::Matrix<std::complex<T>,
Eigen::Dynamic,
Eigen::Dynamic,
Expand All @@ -62,7 +65,7 @@ struct MapMatrixInverseFunctor<Context, phi::dtype::complex<T>> {
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
std::complex<T>* std_ptr = new std::complex<T>[n * n];
std::complex<T>* std_inv_ptr = new std::complex<T>[n * n];
for (int i = 0; i < n * n; i++) {
for (int64_t i = 0; i < n * n; i++) {
*(std_ptr + i) = static_cast<std::complex<T>>(*(a_ptr + offset + i));
}
ConstEigenMatrixMap mat(std_ptr, n, n);
Expand All @@ -75,7 +78,7 @@ struct MapMatrixInverseFunctor<Context, phi::dtype::complex<T>> {
static_cast<std::complex<T>>(0),
errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
for (int i = 0; i < n * n; i++) {
for (int64_t i = 0; i < n * n; i++) {
*(a_inv_ptr + offset + i) =
static_cast<phi::dtype::complex<T>>(*(std_inv_ptr + i));
}
Expand All @@ -90,8 +93,8 @@ void ComputeInverseEigen(const Context& dev_ctx,
DenseTensor* a_inv) {
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
int64_t n = mat_dims[rank - 1];
int64_t batch_size = rank > 2 ? a.numel() / (n * n) : 1;

const T* a_ptr = a.data<T>();
T* a_inv_ptr = dev_ctx.template Alloc<T>(a_inv);
Expand All @@ -100,7 +103,7 @@ void ComputeInverseEigen(const Context& dev_ctx,
// it's not going to get the right result,
// so we're going to convert it to std::complex and
// then we're going to put it into eigen::matrix.
for (int i = 0; i < batch_size; ++i) {
for (int64_t i = 0; i < batch_size; ++i) {
MapMatrixInverseFunctor<Context, T> functor;
functor(dev_ctx, a_ptr, a_inv_ptr, i * n * n, n);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/unsqueeze.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
inline const DenseTensor Unsqueeze(const DenseTensor& x, int axis = 0) {
// don't copy data, only change the dims
DenseTensor out(x);
std::vector<int> out_shape = common::vectorize<int>(x.dims());
std::vector<int64_t> out_shape = common::vectorize<int64_t>(x.dims());
if (axis >= 0) {
auto index = (out_shape.begin() + axis);
out_shape.insert(index, 1);
Expand Down
20 changes: 10 additions & 10 deletions paddle/phi/kernels/gpu/slogdeterminant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ __global__ void GetSlogDetFromLUComplex(const T* lu_data,
int64_t n,
int64_t batch_size,
T* out_data) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int64_t idx = threadIdx.x + static_cast<int64_t>(blockIdx.x) * blockDim.x;
if (idx < batch_size) {
int offset_lu = idx * n * n;
int offset_ipiv = idx * n;
int64_t offset_lu = idx * n * n;
int64_t offset_ipiv = idx * n;
T det_val = T(1.0, 0.0);
T negative = T(-1.0, 0.0);
for (int i = 0; i < n; ++i) {
for (int64_t i = 0; i < n; ++i) {
det_val *= lu_data[offset_lu + i * n + i];
if (ipiv[offset_ipiv + i] != i + 1) {
det_val *= negative;
Expand Down Expand Up @@ -135,12 +135,12 @@ struct SlogDeterminantFunctor<phi::dtype::complex<T>, Context> {
tmp_gpu_mat_data->ptr());

std::vector<const phi::dtype::complex<T>*> cpu_ptrs(batch_count);
for (int i = 0; i < batch_count; ++i) {
for (int64_t i = 0; i < batch_count; ++i) {
cpu_ptrs[i] = gpu_mat + i * rank * rank;
}

// num_ints is for pivot (rank * batch_count) and info (batch_count)
int num_ints = batch_count * (rank + 1);
int64_t num_ints = batch_count * (rank + 1);
size_t total_bytes =
batch_count * sizeof(phi::dtype::complex<T>*) + num_ints * sizeof(int);
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc(
Expand Down Expand Up @@ -218,7 +218,7 @@ void SlogDeterminantKernel(const Context& dev_ctx,
// shape [*, M, M], check whether it contains 0 in '*'.
if (input_dim.size() > 2) {
bool size_0 = false;
std::vector<int> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2);
std::vector<int64_t> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2);
for (size_t i = 0; i < tmp_dim_vec.size(); ++i) {
if (tmp_dim_vec[i] == 0) {
size_0 = true;
Expand All @@ -234,7 +234,7 @@ void SlogDeterminantKernel(const Context& dev_ctx,
}
}

auto batch_count = detail::GetBatchCount(x.dims());
int64_t batch_count = detail::GetBatchCount(x.dims());
VLOG(2) << "input dim:" << x.dims();
PADDLE_ENFORCE_GE(
input_dim_size,
Expand All @@ -245,9 +245,9 @@ void SlogDeterminantKernel(const Context& dev_ctx,
input_dim[input_dim_size - 1],
input_dim[input_dim_size - 2],
errors::InvalidArgument("the input matrix should be square matrix."));
auto rank = input_dim[input_dim_size - 1]; // square matrix length
int64_t rank = input_dim[input_dim_size - 1]; // square matrix length
SlogDeterminantFunctor<T, Context>()(dev_ctx, x, rank, batch_count, out);
std::vector<int> output_dim_vec(input_dim.begin(), input_dim.end() - 2);
std::vector<int64_t> output_dim_vec(input_dim.begin(), input_dim.end() - 2);
if (input_dim.size() == static_cast<size_t>(2)) {
// when input is a two-dimension matrix, The det value is a number.
output_dim_vec = {};
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ void SlogDeterminantGradKernel(const Context& dev_ctx,

// remove useless first dimension
int det_grad_size = det_grad.dims().size();
std::vector<int> det_grad_vec;
for (int i = 1; i < det_grad_size; ++i) {
std::vector<int64_t> det_grad_vec;
for (int64_t i = 1; i < det_grad_size; ++i) {
det_grad_vec.emplace_back(det_grad.dims()[i]);
}
det_grad.Resize(det_grad.dims().reshape(det_grad_vec));
Expand Down