Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pten] Refactor the implementation of custom operator #37122

Merged
merged 30 commits into from
Nov 15, 2021

Conversation

chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Nov 11, 2021

PR types

Function optimization

PR changes

OPs

Describe

[Pten] Refactor the implementation of custom operator

Paddle Tensor计算库(以下简称pten)需要为自定义算子提供更多的C++运算类API,从而使自定义算子的开发成本进一步降低,因此需要将自定义算子与pten实现打通,本PR改动如下:

  1. 将原自定义算子的Tensor实现整合至pten API Tensor中,相关方法替换为计算库的实现,移除原自定义算子Tensor,减少外部Tensor概念,降低长期维护成本,后续自定义算子、动态图、Python&C++ API将共用一个pten API Tensor;
  2. 将原自定义算子C++ API整合至计算库,移除原DataType API,复用pten的DataType;
  3. 更新原custom_operator.cc中对Tensor转换的适配逻辑,接入pten;
  4. 调整对外暴露的头文件,以pten api为主要对外接口;
  5. 整理相应的单元测试,以适配pten实现;
  6. 出于兼容考虑暂时保留原自定义算子的部分数据结构及方法,比如PlaceType以及使用PlaceType的方法,后续逐渐Deprecated。

本次自定义算子与pten整合后,自定义算子开发便利性会得到有效增强,复杂Kernel开发代码量显著减少,以开发一个linear算子的前向Kernel为例:

1. 原先实现方式(手写基础实现逻辑,较复杂)

说明: 此处仅为paddle内部matmul前向kernel的主体实现代码,不包含linear的add运算,add运算内部实现还有上千行代码(#37034),代码量太多,这里不贴了。此处示例代码未编译验证,仅用于对比代码量变化。并且,在外部自定义算子中写出高效的matmul和elementwise_add是很困难的,原因包括但不限于:

  • 一方面,这里调用了paddle内部对于eigen和blas等第三方库方法的封装,实际代码远超这里列出的量
  • 另一方面,目前在外部自定义算子中,不能使用我们内部封装的eigen及blas等方法,即使用户自己在外部实现一遍,由于全局内存和显存管理不统一,性能也会受影响
std::vector<paddle::Tensor> CustomLinearForward(const paddle::Tensor& x,
                                              const paddle::Tensor& weight,
                                              const paddle::Tensor& bias) {
  PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
                    0,
                    paddle::platform::errors::InvalidArgument(
                        "The Input(X) dims size must not be equal 0,"
                        " but reviced dims size is 0. "));
  PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
                    0,
                    paddle::platform::errors::InvalidArgument(
                        "The Input(Y) dims size must not be equal 0,"
                        " but reviced dims size is 0. "));
  const std::vector<std::int64_t> x_dims = vectorize(X.dims());
  const std::vector<std::int64_t> y_dims = vectorize(Y.dims());

  const int x_ndim = x_dims.size();
  const int y_ndim = y_dims.size();

  // Get data ptr
  const T* x_data = X.data<T>();
  const T* y_data = Y.data<T>();

  if (x_ndim == 1 && y_ndim == 1) {
    PADDLE_ENFORCE_EQ(
        X.numel(),
        Y.numel(),
        paddle::platform::errors::InvalidArgument(
            "X's numbers must be equal to Y's numbers,"
            "when X/Y's dims =1. But received X has [%d] elements,"
            "received Y has [%d] elements",
            X.numel(),
            Y.numel()));
    VLOG(3) << "MatMul's case 1";
    Out->Resize({1});
    Out->mutable_data<T>();
    auto out_eigen = EigenScalar<T>::From(*Out);
    auto x_eigen = EigenVector<T>::Flatten(X);
    auto y_eigen = EigenVector<T>::Flatten(Y);

    auto& dev = *dev_ctx.eigen_device();
    if (flag) {
      out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen;
    } else {
      out_eigen.device(dev) = (x_eigen * y_eigen).sum();
    }
    return;
  }

  auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);

  if (x_ndim == 1) {
    const int N = X.numel();
    if (trans_y) {
      PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1],
                        N,
                        paddle::platform::errors::InvalidArgument(
                            "Input(Y) has error dim."
                            "Y'dims[%d] must be equal to %d"
                            "But received Y'dims[%d] is %d",
                            y_ndim - 1,
                            N,
                            y_ndim - 1,
                            y_dims[y_ndim - 1]));
    } else {
      PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2],
                        N,
                        paddle::platform::errors::InvalidArgument(
                            "Input(Y) has error dim."
                            "Y'dims[%d] must be equal to %d"
                            "But received Y'dims[%d] is %d",
                            y_ndim - 2,
                            N,
                            y_ndim - 2,
                            y_dims[y_ndim - 2]));
    }
    std::vector<std::int64_t> out_dims(y_ndim - 1);
    if (trans_y) {
      std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin());
    } else {
      std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin());
      out_dims.back() = y_dims.back();
    }
    Out->Resize(paddle::framework::make_ddim(out_dims));
    Out->mutable_data<T>();
    if (trans_y) {
      const int M = Y.numel() / N;
      VLOG(3) << "MatMul's case 2";
      blas.GEMV(false,
                M,
                N,
                static_cast<T>(1),
                y_data,
                x_data,
                static_cast<T>(flag),
                Out->mutable_data<T>());
    } else {
      const int M = y_dims[y_ndim - 1];
      const int batch_size = Y.numel() / (M * N);
      if (batch_size == 1) {
        VLOG(3) << "MatMul's case 3";
        blas.GEMV(true,
                  N,
                  M,
                  static_cast<T>(1),
                  y_data,
                  x_data,
                  static_cast<T>(flag),
                  Out->mutable_data<T>());
      } else {
        VLOG(3) << "MatMul's case 4";
        blas.BatchedGEMM(CblasTrans,
                         CblasNoTrans,
                         M,
                         1,
                         N,
                         static_cast<T>(1),
                         y_data,
                         x_data,
                         static_cast<T>(flag),
                         Out->mutable_data<T>(),
                         batch_size,
                         M * N,
                         0);
      }
    }
    return;
  }

  if (y_ndim == 1) {
    const int N = Y.numel();
    if (trans_x) {
      PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2],
                        N,
                        paddle::platform::errors::InvalidArgument(
                            "Input(X) has error dim."
                            "X'dims[%d] must be equal to %d"
                            "But received X'dims[%d] is %d",
                            x_ndim - 2,
                            N,
                            x_ndim - 2,
                            x_dims[x_ndim - 2]));
    } else {
      PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1],
                        N,
                        paddle::platform::errors::InvalidArgument(
                            "Input(X) has error dim."
                            "X'dims[%d] must be equal to %d"
                            "But received X'dims[%d] is %d",
                            x_ndim - 1,
                            N,
                            x_ndim - 1,
                            x_dims[x_ndim - 1]));
    }
    std::vector<std::int64_t> out_dims(x_ndim - 1);
    if (trans_x) {
      std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin());
      out_dims.back() = x_dims.back();
    } else {
      std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin());
    }
    Out->Resize(paddle::framework::make_ddim(out_dims));
    Out->mutable_data<T>();

    if (trans_x) {
      const int M = x_dims[x_ndim - 1];
      const int batch_size = X.numel() / (M * N);
      if (batch_size == 1) {
        VLOG(3) << "MatMul's case 5";
        blas.GEMV(true,
                  N,
                  M,
                  static_cast<T>(1),
                  x_data,
                  y_data,
                  static_cast<T>(flag),
                  Out->mutable_data<T>());
      } else {
        VLOG(3) << "MatMul's case 6";
        blas.BatchedGEMM(CblasTrans,
                         CblasNoTrans,
                         M,
                         1,
                         N,
                         static_cast<T>(1),
                         x_data,
                         y_data,
                         static_cast<T>(flag),
                         Out->mutable_data<T>(),
                         batch_size,
                         M * N,
                         0);
      }
    } else {
      const int M = X.numel() / N;
      VLOG(3) << "MatMul's case 7";
      blas.GEMV(false,
                M,
                N,
                static_cast<T>(1),
                x_data,
                y_data,
                static_cast<T>(flag),
                Out->mutable_data<T>());
    }
    return;
  }

  const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2];
  const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
  if (trans_y) {
    PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1],
                      K,
                      paddle::platform::errors::InvalidArgument(
                          "Input(Y) has error dim."
                          "Y'dims[%d] must be equal to %d"
                          "But received Y'dims[%d] is %d",
                          y_ndim - 1,
                          K,
                          y_ndim - 1,
                          y_dims[y_ndim - 1]));
  } else {
    PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2],
                      K,
                      paddle::platform::errors::InvalidArgument(
                          "Input(Y) has error dim."
                          "Y'dims[%d] must be equal to %d"
                          "But received Y'dims[%d] is %d",
                          y_ndim - 2,
                          K,
                          y_ndim - 2,
                          y_dims[y_ndim - 2]));
  }
  const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
  const int ndim = (std::max)(x_ndim, y_ndim);
  std::vector<std::int64_t> x_broadcast_dims(ndim);
  std::vector<std::int64_t> y_broadcast_dims(ndim);
  std::vector<std::int64_t> out_broadcast_dims(ndim);

  GetBroadcastFromDims(x_ndim - 2,
                       x_dims.data(),
                       y_ndim - 2,
                       y_dims.data(),
                       x_broadcast_dims.data(),
                       y_broadcast_dims.data(),
                       out_broadcast_dims.data());
  out_broadcast_dims[ndim - 2] = M;
  out_broadcast_dims[ndim - 1] = N;

  Out->Resize(paddle::framework::make_ddim(out_broadcast_dims));
  Out->mutable_data<T>();

  const int batch_dim = ndim - 2;
  // broadcast message
  const bool is_broadcast_dims =
      !std::equal(x_broadcast_dims.cbegin(),
                  x_broadcast_dims.cbegin() + batch_dim,
                  y_broadcast_dims.cbegin());

  const std::int64_t x_batch_size =
      std::accumulate(x_broadcast_dims.cbegin(),
                      x_broadcast_dims.cbegin() + batch_dim,
                      1LL,
                      std::multiplies<std::int64_t>());
  const std::int64_t y_batch_size =
      std::accumulate(y_broadcast_dims.cbegin(),
                      y_broadcast_dims.cbegin() + batch_dim,
                      1LL,
                      std::multiplies<std::int64_t>());
  const std::int64_t out_batch_size =
      std::accumulate(out_broadcast_dims.cbegin(),
                      out_broadcast_dims.cbegin() + batch_dim,
                      1LL,
                      std::multiplies<std::int64_t>());
  if (out_batch_size == 0) return;
  if (x_batch_size == 1 && y_batch_size == 1) {
    VLOG(3) << "MatMul's case 8";
    blas.GEMM(trans_x ? CblasTrans : CblasNoTrans,
              trans_y ? CblasTrans : CblasNoTrans,
              M,
              N,
              K,
              static_cast<T>(1),
              x_data,
              y_data,
              static_cast<T>(flag),
              Out->mutable_data<T>());
  } else if (x_batch_size == 1) {
    if (M == 1 && trans_y) {
      VLOG(3) << "MatMul's case 9";
      blas.GEMV(false,
                y_batch_size * N,
                K,
                static_cast<T>(1),
                y_data,
                x_data,
                static_cast<T>(flag),
                Out->mutable_data<T>());
    } else {
      VLOG(3) << "MatMul's case 10";
      blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
                       trans_y ? CblasTrans : CblasNoTrans,
                       M,
                       N,
                       K,
                       static_cast<T>(1),
                       x_data,
                       y_data,
                       static_cast<T>(flag),
                       Out->mutable_data<T>(),
                       out_batch_size,
                       0,
                       K * N);
    }
  } else if (y_batch_size == 1) {
    if (!trans_x) {
      VLOG(3) << "MatMul's case 11";
      blas.GEMM(CblasNoTrans,
                trans_y ? CblasTrans : CblasNoTrans,
                x_batch_size * M,
                N,
                K,
                static_cast<T>(1),
                x_data,
                y_data,
                static_cast<T>(flag),
                Out->mutable_data<T>());
    } else {
      VLOG(3) << "MatMul's case 12";
      blas.BatchedGEMM(CblasTrans,
                       trans_y ? CblasTrans : CblasNoTrans,
                       M,
                       N,
                       K,
                       static_cast<T>(1),
                       x_data,
                       y_data,
                       static_cast<T>(flag),
                       Out->mutable_data<T>(),
                       out_batch_size,
                       M * K,
                       0);
    }
  } else if (!is_broadcast_dims) {
    VLOG(3) << "MatMul's case 13";
    blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
                     trans_y ? CblasTrans : CblasNoTrans,
                     M,
                     N,
                     K,
                     static_cast<T>(1),
                     x_data,
                     y_data,
                     static_cast<T>(flag),
                     Out->mutable_data<T>(),
                     out_batch_size,
                     M * K,
                     K * N);
  } else {
    // in the case, can't use stridedgemm
    std::vector<const T*> x_ptr(out_batch_size);
    std::vector<const T*> y_ptr(out_batch_size);
    std::vector<T*> out_ptr(out_batch_size);
    std::vector<std::int64_t> index(batch_dim, 0);
    for (std::int64_t i = 0; i < out_batch_size; ++i) {
      // using the index to get offset
      const std::int64_t x_index =
          GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data());
      const std::int64_t y_index =
          GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data());

      x_ptr[i] = x_data + x_index * M * K;
      y_ptr[i] = y_data + y_index * K * N;
      out_ptr[i] = Out->mutable_data<T>() + i * M * N;
      IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data());
    }
    VLOG(3) << "MatMul's case 14";
    blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
                     trans_y ? CblasTrans : CblasNoTrans,
                     M,
                     N,
                     K,
                     static_cast<T>(1),
                     x_ptr.data(),
                     y_ptr.data(),
                     static_cast<T>(flag),
                     out_ptr.data(),
                     out_batch_size);
  }

  // Add还有很多代码,这里省略
}

2. 本PR实现方式(一行写完,并且支持多设备,已在单测中验证)

说明: 这里使用的C++ API与Python端对应API在以下各方面均保持一致:

  • 接口路径(暂时在paddle后插入experimental,后续会移除)
  • 命名
  • 参数列表(除去name参数)
  • 参数类型支持
  • API功能

因此,用户可直接参考Python API文档使用C++ API,无额外理解成本,并且性能也是经过paddle内部同学优化的。

// The linear implemented here must be passed in bias
std::vector<paddle::Tensor> PtenLinearForward(const paddle::Tensor& x,
                                              const paddle::Tensor& weight,
                                              const paddle::Tensor& bias) {
  return {paddle::experimental::add(paddle::experimental::matmul(x, weight), bias)};
}

当然,目前以上两种实现方式都是支持的,并不是PR合入后原先的写法就不支持了。

TODO事项

  1. 自定义算子C++ API为对外暴露的正式接口,需要确保其兼容性,本PR在整合时考虑了这一点,但限于工作量较大,部分工作拆分开展,原自定义算子Tensor的reshape, copy_to, slice, cast方法均是之前专门为自定义算子编写的实现,本次整合后,这几个API将直接复用pten的C++ API和kernel实现功能,而目前pten相关的kernel还在迁移中,因此这几个API在后续PR(11.20之前)完成,本PR暂时禁用了这几个方法
  2. 本次新引入的API暂时先以paddle::experimental作为命名空间前缀,目前的API还在实验阶段,尚有不确定性,之前自定义算子开放API过于草率,导致之前设计考虑不够充分的一些数据结构和API直接暴露给用户,现在为了确保兼容性又不能废弃,例如之前的PlaceType,及Tensor的不完全构造函数、Tensor::stream()等,后续这些设计不太好的接口也会添加Deprecated Warning

XieYunshen
XieYunshen previously approved these changes Nov 13, 2021
Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
单测整理被识别为移除单测

Comment on lines 75 to 77
PD_THROW("Data type ",
static_cast<int>(data_type),
" is not supported by tensor.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和下面176行的报错提示写法有点不太一致,是两种写法都可以吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,这里也加了个引号

Comment on lines +156 to +162
* @brief Return the shape (dimensions) of Tensor.
* The compatible method of `Tensor::dims()`.
* This is a deprecated method and may be removed in the future!
*
* @return std::vector<int64_t>
*/
std::vector<int64_t> shape() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果将来删除shape函数,会不会出现C++的Tensor接口与Python不一致的问题?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要做的不是移除这个接口,而是做不兼容升级,后续在接口中给出具体的提示

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape is better than dims

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前需要保留shape接口也是为了兼容

bool is_cpu() const;

/**
* @brief Determine whether the tensor device is CPU
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的CPU应该是GPU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

*/
bool is_cpu() const { return paddle::platform::is_cpu_place(place()); }
bool is_cuda() const { return paddle::platform::is_gpu_place(place()); }
paddle::platform::Place inner_place() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果上面的place()接口将来移除的话,这里的inner_place()可以更新为place()吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以,这里也需要做不兼容升级

* This is a deprecated method and may be removed in the future!
*
* @tparam T
* @param target_place of target place, of which the tensor will copy to.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的语法感觉有点怪

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

* @return None
* @brief Transfer the current Tensor to the specified device and return.
*
* @param place of target place, of which the tensor will copy to.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx


/**
* @brief Determine whether Tensor is initialized.
* This is a deprecated method and may be removed in the future!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些将来准备删除的接口是不是可以在注释里给出替代当前接口的方法?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续会在接口调用时直接给出warning提示,注释对用户来讲没什么意义,后续会再补充一下,目前统一使用这个写法

Aurelius84
Aurelius84 previously approved these changes Nov 15, 2021
Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for overall


template <typename T>
T *Tensor::mutable_data() {
if (impl_->type_info().name() == "DenseTensor") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not Important,这里对是否为DenseTensor,后续是否可以单独一个函数?我看多处调用了,后续实现上有变动,接口层是不用改动的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx,暂时封装了一个inline函数做判断,后面再想想更好的方式

DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED,
DataLayout layout = DataLayout::UNDEFINED);
PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PD_DLL_DECL含义是?还有使用PD_DLL_DECL的原因是?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

windows符号需要手动导出,不然运行时会找不到符号,这点和unix不一样

@chenwhql chenwhql dismissed stale reviews from Aurelius84 and XieYunshen via 5d10182 November 15, 2021 03:41
int numel = x.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
auto cpu_input = x.copy_to<data_t>(paddle::PlaceType::kCPU);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不测下copy_to吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy_to暂时disable了,还有reshape,cast,slice,会在下个PR加回来,也会再补充对应的单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以看下PR描述中的TODO事项解释

PD_DLL_DECL int RegisterSymbolsFor##name() { return 0; }

#define PT_DECLARE_API(name) \
extern int RegisterSymbolsFor##name(); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也加个 PD_DLL_DECL 吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx


#define PT_DECLARE_API(name) \
extern int RegisterSymbolsFor##name(); \
UNUSED static int use_pten_api_##name = RegisterSymbolsFor##name()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UNUSED作用是?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对外暴露的普通函数和类,声明时都需要加上PD_DLL_DECL


namespace detail {

inline bool IsDenseTensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[TODO] 后续此处会改为 no-rtti 的形式
pten::DenseTensor::classof(derived_a.get());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@@ -20,4 +25,5 @@ endif()
if(WITH_XPU)
set(PTEN_DEPS ${PTEN_DEPS} manipulation_xpu)
endif()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove additional blank line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不重要,仅是和上面的做个跨行区分

platform::errors::InvalidArgument(
"TensorImpl with nullptr is not supported"));
}
explicit Tensor(const PlaceType& place);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a good way to compatible with original custom tensor, these two constructor here is not safe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个构造都不安全,但是现在需要兼容,后续会增加deprecated warning,需要一段时间,比如在2.4版本废弃掉

*/
paddle::experimental::DataType type() const { return impl_->data_type(); }
DataType type() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type is not a good name... how about var_type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这个type接口也是为了兼容自定义算子;2. var_type并不是一个好名字,因为目前的体系里没有Variable


} // namespace detail

paddle::platform::DeviceContext* GetDeviceContextByBackend(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const&

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

枚举本身应该属于POD类型,不需要传const引用,我们之前设计自定义算子的属性的时候都要求int, float传const&,其实也不太有必要
image

*
* @return int64_t
*/
int64_t size() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i prefer to using size() instead of numel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, 我下个PR将numel移除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #37237

Comment on lines +156 to +162
* @brief Return the shape (dimensions) of Tensor.
* The compatible method of `Tensor::dims()`.
* This is a deprecated method and may be removed in the future!
*
* @return std::vector<int64_t>
*/
std::vector<int64_t> shape() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape is better than dims

Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT

Copy link
Collaborator

@raindrops2sea raindrops2sea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants