Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Move selected_rows PR #4] SelectedRows inherits from TensorBase. #39162

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f533aa0
Added selected_rows and rw_lock to pten
veyron95 Jan 20, 2022
832e903
Renamed the unit test target to fix CI
veyron95 Jan 21, 2022
318ef67
Removed Class SelectedRows in Fluid, changed include/cmake relationsh…
veyron95 Jan 21, 2022
40f635d
Remove rw_lock.h,rw_lock_test.cc in fluid
veyron95 Jan 22, 2022
aa5d93f
Use pten::RWLock and pten::AutoRDLock, fix CI
veyron95 Jan 22, 2022
8c51888
Use pten::SelectedRows
veyron95 Jan 22, 2022
15328f3
Use pten::SelectedRows
veyron95 Jan 22, 2022
7416f99
Fix to pass NPU CI
veyron95 Jan 22, 2022
794f7ef
Merge commit 'refs/pull/39128/head' of https://github.com/PaddlePaddl…
veyron95 Jan 24, 2022
4511b17
Merge branch 'develop' into fluid_move_selected_rows_to_pten_3
veyron95 Jan 24, 2022
47e3ccb
Selected_Rows inherits from TensorBase
veyron95 Jan 24, 2022
7afa032
Fix conflict
veyron95 Jan 24, 2022
3942e0f
Use pten::SelectedRows, to pass NPU CI
veyron95 Jan 24, 2022
75de13d
To fix NPU CI
veyron95 Jan 24, 2022
c650f51
Merge commit 'refs/pull/39128/head' of https://github.com/PaddlePaddl…
veyron95 Jan 24, 2022
d241507
To fix NPU CI again
veyron95 Jan 24, 2022
ef71e4a
Merge commit 'refs/pull/39128/head' of https://github.com/PaddlePaddl…
veyron95 Jan 24, 2022
51b0f24
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
veyron95 Jan 25, 2022
ef79f84
Use paddle/pten/core/enforce and polish code
veyron95 Jan 25, 2022
d815960
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
veyron95 Jan 25, 2022
91096d9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
veyron95 Jan 25, 2022
9eb0a82
Merge branch 'develop' into fluid_move_selected_rows_to_pten_4
veyron95 Jan 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/pten/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base )

cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector pten_enforce ddim)

cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
Expand Down
56 changes: 45 additions & 11 deletions paddle/pten/core/selected_rows.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ limitations under the License. */
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/utils/rw_lock.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"

namespace pten {
class SelectedRows {
class SelectedRows : public TensorBase,
public TypeInfoTraits<TensorBase, SelectedRows> {
/*
* @brief We can use the SelectedRows structure to reproduce a sparse table.
* A sparse table is a key-value structure that the key is an `int64_t`,
Expand All @@ -51,21 +52,19 @@ class SelectedRows {
public:
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) {
value_.reset(new pten::DenseTensor());
value_.reset(new DenseTensor());
rwlock_.reset(new RWLock);
}

SelectedRows() {
height_ = 0;
value_.reset(new pten::DenseTensor());
value_.reset(new DenseTensor());
rwlock_.reset(new RWLock);
}

const pten::Place& place() const { return value_->place(); }
const DenseTensor& value() const { return *value_; }

const pten::DenseTensor& value() const { return *value_; }

pten::DenseTensor* mutable_value() { return value_.get(); }
DenseTensor* mutable_value() { return value_.get(); }

int64_t height() const { return height_; }

Expand Down Expand Up @@ -109,8 +108,8 @@ class SelectedRows {
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
void Get(const pten::DenseTensor& ids,
pten::DenseTensor* value,
void Get(const DenseTensor& ids,
DenseTensor* value,
bool auto_grown = false,
bool is_test = false);

Expand Down Expand Up @@ -149,14 +148,49 @@ class SelectedRows {
return pten::framework::make_ddim(dims);
}

/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() { return "SelectedRows"; }

/// \brief Returns the number of elements contained in tensor.
/// \return The number of elements contained in tensor.
int64_t numel() const override { return value_->numel(); };

/// \brief Returns the dims of the tensor.
/// \return The dims of the tensor.
const DDim& dims() const noexcept override {
return value_->dims();
// return paddle::framework::make_ddim(dims);
}

/// \brief Returns the data type of the tensor.
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return value_->dtype(); }

/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return value_->layout(); }

/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
const Place& place() const override { return value_->place(); };

/// \brief Test whether the metadata is valid.
/// \return Whether the metadata is valid.
bool valid() const noexcept override { return value_->valid(); }

/// \brief Test whether the storage is allocated.
/// return Whether the storage is allocated.
bool initialized() const override { return value_->initialized(); }

private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
paddle::framework::Vector<int64_t> rows_;
std::unordered_map<int64_t, int64_t>
id_to_index_; // should not be used when rows_ has duplicate member
std::unique_ptr<pten::DenseTensor> value_{nullptr};
std::unique_ptr<DenseTensor> value_{nullptr};
int64_t height_; // height indicates the underline tensor's height
std::unique_ptr<RWLock> rwlock_{nullptr};
};
Expand Down