Skip to content

Commit

Permalink
[Move selected_rows PR #4] SelectedRows inherits from TensorBase. (Pa…
Browse files Browse the repository at this point in the history
…ddlePaddle#39162)

* Added selected_rows and rw_lock to pten

* Renamed the unit test target to fix CI

* Removed Class SelectedRows in Fluid, changed include/cmake relationship, use pten::SelectedRows in Fluid

* Remove rw_lock.h,rw_lock_test.cc in fluid

* Use pten::RWLock and pten::AutoRDLock, fix CI

* Use pten::SelectedRows

* Use pten::SelectedRows

* Fix to pass NPU CI

* Selected_Rows inherits from TensorBase

* Use pten::SelectedRows, to pass NPU CI

* To fix NPU CI

* To fix NPU CI again

* Use paddle/pten/core/enforce and polish code
  • Loading branch information
veyron95 authored Jan 26, 2022
1 parent d9acc87 commit 3e80253
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 12 deletions.
2 changes: 1 addition & 1 deletion paddle/pten/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base )

cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector pten_enforce ddim)

cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
Expand Down
56 changes: 45 additions & 11 deletions paddle/pten/core/selected_rows.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ limitations under the License. */
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/utils/rw_lock.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"

namespace pten {
class SelectedRows {
class SelectedRows : public TensorBase,
public TypeInfoTraits<TensorBase, SelectedRows> {
/*
* @brief We can use the SelectedRows structure to reproduce a sparse table.
* A sparse table is a key-value structure that the key is an `int64_t`,
Expand All @@ -51,21 +52,19 @@ class SelectedRows {
public:
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) {
value_.reset(new pten::DenseTensor());
value_.reset(new DenseTensor());
rwlock_.reset(new RWLock);
}

SelectedRows() {
height_ = 0;
value_.reset(new pten::DenseTensor());
value_.reset(new DenseTensor());
rwlock_.reset(new RWLock);
}

const pten::Place& place() const { return value_->place(); }
const DenseTensor& value() const { return *value_; }

const pten::DenseTensor& value() const { return *value_; }

pten::DenseTensor* mutable_value() { return value_.get(); }
DenseTensor* mutable_value() { return value_.get(); }

int64_t height() const { return height_; }

Expand Down Expand Up @@ -109,8 +108,8 @@ class SelectedRows {
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
void Get(const pten::DenseTensor& ids,
pten::DenseTensor* value,
void Get(const DenseTensor& ids,
DenseTensor* value,
bool auto_grown = false,
bool is_test = false);

Expand Down Expand Up @@ -149,14 +148,49 @@ class SelectedRows {
return pten::framework::make_ddim(dims);
}

/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() { return "SelectedRows"; }

/// \brief Returns the number of elements contained in tensor.
/// \return The number of elements contained in tensor.
int64_t numel() const override { return value_->numel(); };

/// \brief Returns the dims of the tensor.
/// \return The dims of the tensor.
const DDim& dims() const noexcept override {
return value_->dims();
// return paddle::framework::make_ddim(dims);
}

/// \brief Returns the data type of the tensor.
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return value_->dtype(); }

/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return value_->layout(); }

/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
const Place& place() const override { return value_->place(); };

/// \brief Test whether the metadata is valid.
/// \return Whether the metadata is valid.
bool valid() const noexcept override { return value_->valid(); }

/// \brief Test whether the storage is allocated.
/// return Whether the storage is allocated.
bool initialized() const override { return value_->initialized(); }

private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
paddle::framework::Vector<int64_t> rows_;
std::unordered_map<int64_t, int64_t>
id_to_index_; // should not be used when rows_ has duplicate member
std::unique_ptr<pten::DenseTensor> value_{nullptr};
std::unique_ptr<DenseTensor> value_{nullptr};
int64_t height_; // height indicates the underline tensor's height
std::unique_ptr<RWLock> rwlock_{nullptr};
};
Expand Down

0 comments on commit 3e80253

Please sign in to comment.