-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2St][PIR] Hold backward program in GradNode #63694
Changes from 11 commits
3223492
c1398d1
2a7f67d
733ac03
5c90eac
64a56b0
b48eaa0
5534c70
8aba5ef
23bbd58
31be3b9
e005597
6277bc2
4972090
4bcbcc6
f5c6e0b
414fbda
1e68e62
cb6cdab
a9946ee
d83476b
0bf8a6a
462cdf7
c7e9825
de20f12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ limitations under the License. */ | |
#include "paddle/phi/common/scalar.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/pir/include/core/block.h" | ||
#include "paddle/pir/include/core/program.h" | ||
#include "paddle/pir/include/core/value.h" | ||
#include "paddle/utils/blank.h" | ||
#include "paddle/utils/small_vector.h" | ||
|
@@ -40,6 +41,7 @@ class InferShapeContext; | |
class InferVarTypeContext; | ||
class VarDesc; | ||
class BlockDesc; | ||
class ProgramDesc; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个有必要嘛?下面只添加了 |
||
class Variable; | ||
class InferNoNeedBufferVarsFN; | ||
|
||
|
@@ -67,7 +69,8 @@ using Attribute = paddle::variant<paddle::blank, | |
paddle::experimental::Scalar, | ||
std::vector<paddle::experimental::Scalar>, | ||
::pir::Block*, | ||
std::vector<::pir::Value>>; | ||
std::vector<::pir::Value>, | ||
std::shared_ptr<::pir::Program>>; | ||
using AttributeMap = std::unordered_map<std::string, Attribute>; | ||
|
||
using OpCreator = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
#include "paddle/phi/common/complex.h" | ||
#include "paddle/pir/include/core/block.h" | ||
#include "paddle/pir/include/core/op_result.h" | ||
#include "paddle/pir/include/core/region.h" | ||
#include "paddle/pir/include/core/value.h" | ||
|
||
namespace paddle { | ||
|
@@ -858,6 +859,17 @@ void CastPyArg2AttrIRBlock(PyObject* obj, | |
attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]); | ||
} | ||
|
||
void CastPyArg2AttrIRProgram(PyObject* obj, | ||
paddle::framework::AttributeMap& attrs, // NOLINT | ||
const std::string& key, | ||
const std::string& op_type, | ||
ssize_t arg_pos) { | ||
VLOG(1) << "After Process pir::Program*"; | ||
const std::shared_ptr<::pir::Program> program = | ||
::py::handle(obj).cast<std::shared_ptr<::pir::Program>>(); | ||
attrs[key] = program; | ||
} | ||
|
||
void CastPyArg2AttrValues(PyObject* obj, | ||
paddle::framework::AttributeMap& attrs, // NOLINT | ||
const std::string& key, | ||
|
@@ -998,6 +1010,7 @@ void ConstructAttrMapForRunProgram( | |
attr_end)); | ||
|
||
PyObject* obj = nullptr; | ||
attrs["testkey"] = std::string("testvalue"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 忘清了? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) { | ||
VLOG(1) << "Start Process " << arg_pos; | ||
Py_ssize_t key_len = 0; | ||
|
@@ -1020,11 +1033,11 @@ void ConstructAttrMapForRunProgram( | |
|
||
if (std::set<std::string>({"cuda_graph_capture_mode"}).count(key)) { | ||
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos); | ||
} else if (std::set<std::string>({"global_block", | ||
"forward_global_block", | ||
"backward_global_block"}) | ||
.count(key)) { | ||
} else if (std::set<std::string>({"global_block"}).count(key)) { | ||
CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos); | ||
} else if (std::set<std::string>({"forward_program", "backward_program"}) | ||
.count(key)) { | ||
CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos); | ||
} else if (std::set<std::string>({"is_test", "use_interpretorcore"}) | ||
.count(key)) { | ||
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个忘清了?