-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[hybrid] seed and dropout op support force-cpu #35820
Changes from 6 commits
382b18f
2c56ee1
c0a90b6
6d633eb
64aa034
3f88bb7
f9b31b8
3d1c0c2
91a0b92
d23b1c4
62c7b11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,18 @@ class DropoutOp : public framework::OperatorWithKernel { | |
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
|
||
framework::OpKernelType GetKernelTypeForVar( | ||
const std::string& var_name, const Tensor& tensor, | ||
const framework::OpKernelType& expected_kernel_type) const override { | ||
if (var_name == "Seed" && platform::is_cpu_place(tensor.place())) { | ||
VLOG(10) << "var_name:" << var_name << " need not to transform"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议VLOG加上在什么op, does not need to transform in dropout op There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
return expected_kernel_type; | ||
} | ||
|
||
return framework::OpKernelType(expected_kernel_type.data_type_, | ||
tensor.place(), tensor.layout()); | ||
} | ||
}; | ||
|
||
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,11 @@ class SeedOpMaker : public framework::OpProtoAndCheckerMaker { | |
void Make() override { | ||
AddOutput("Out", "The output of seed op."); | ||
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); | ||
AddAttr<bool>("force_cpu", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看下这个op预测会不会用,可能需要加上 AddCheckpoint 保证预测的兼容性 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已加上AddCheckpoint There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果预测不需要是不是还得加AsExtra(),新出的规范 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经确认,并已加上AsExtra() |
||
"(bool, default false) Force fill output variable to cpu " | ||
"memory. Otherwise, fill output variable to the running " | ||
"device") | ||
.SetDefault(false); | ||
AddComment(R"DOC( | ||
Seed Operator. | ||
)DOC"); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/math/math_function.h" | ||
#include "paddle/fluid/operators/seed_op.h" | ||
|
||
namespace paddle { | ||
|
@@ -20,22 +21,37 @@ namespace operators { | |
template <typename Place, typename T> | ||
class GPUSeedKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* out = context.Output<Tensor>("Out"); | ||
auto* out_data = out->mutable_data<T>(context.GetPlace()); | ||
void Compute(const framework::ExecutionContext &context) const override { | ||
auto *out = context.Output<Tensor>("Out"); | ||
int user_seed = context.Attr<int>("seed"); | ||
auto force_cpu = context.Attr<bool>("force_cpu"); | ||
std::random_device rnd; | ||
int seed; | ||
if (user_seed != 0) { | ||
seed = user_seed; | ||
} else { | ||
seed = rnd(); | ||
} | ||
auto target_gpu_place = | ||
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()); | ||
auto stream = context.cuda_device_context().stream(); | ||
memory::Copy(target_gpu_place, out_data, platform::CPUPlace(), &seed, | ||
sizeof(int), stream); | ||
|
||
bool cpu_place = force_cpu || context.GetPlace() == platform::CPUPlace(); | ||
if (cpu_place) { | ||
platform::DeviceContextPool &pool = | ||
platform::DeviceContextPool::Instance(); | ||
auto &dev_ctx = *pool.Get(context.GetPlace()); | ||
out->mutable_data<T>(platform::CPUPlace(), | ||
framework::proto::VarType::SIZE_T); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个SIZE_T是干啥的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删除 |
||
// out_data[0] = seed; | ||
math::SetConstant<platform::CPUDeviceContext, T> functor; | ||
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx), | ||
out, static_cast<T>(seed)); | ||
} else { | ||
out->mutable_data<T>(context.GetPlace()); | ||
auto target_gpu_place = | ||
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()); | ||
auto stream = context.cuda_device_context().stream(); | ||
memory::Copy(target_gpu_place, out, platform::CPUPlace(), &seed, | ||
sizeof(int), stream); | ||
} | ||
} | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -202,8 +202,11 @@ def modify_forward_desc_for_recompute(self): | |
type='seed', | ||
inputs={}, | ||
outputs={'Out': [added_var]}, | ||
attrs={'seed': seed, | ||
'op_device': op_device}) | ||
attrs={ | ||
'seed': seed, | ||
'op_device': op_device, | ||
'force_cpu': True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加一点点注释,为啥设置为True There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已加注释 |
||
}) | ||
self.ops.insert(op_idx, added_op) | ||
# modify dropout op desc so that it accept a seed var as input | ||
op.desc.set_input("Seed", [var_unique_name]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是不需要判断is_cpu_place了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改