Skip to content

Commit

Permalink
Merge branch 'PaddlePaddle:develop' into cmake_02
Browse files Browse the repository at this point in the history
  • Loading branch information
zade23 authored Dec 22, 2023
2 parents 9894fd1 + 256dc4d commit 3cd67de
Show file tree
Hide file tree
Showing 76 changed files with 909 additions and 426 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,7 @@ cc_library(
layer)

cc_library(type_info SRCS type_info.cc)
target_link_libraries(type_info pir op_dialect)
add_dependencies(type_info framework_proto auto_parallel_proto xxhash)
if(WITH_MKLDNN)
add_dependencies(type_info mkldnn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,26 @@ class InstructionBase {
next_instrs_in_same_thread_.push_back(id);
}

const EventInter& EventToRecord() const { return *event_to_record_; }
bool IsForceRecordEvent() const { return force_record_event_; }
void SetForceRecordEvent(bool force_record) {
force_record_event_ = force_record;
}

const std::vector<std::string>& EventsToWaitInfo() const {
return events_to_wait_info_;
}
void SetEventsToWaitInfo(const std::vector<std::string>& info) {
events_to_wait_info_ = info;
}

const std::string& EventToRecordInfo() const { return event_to_record_info_; }
void SetEventToRecordInfo(const std::string& info) {
event_to_record_info_ = info;
}

const std::shared_ptr<EventInter>& EventToRecord() const {
return event_to_record_;
}
void AddEventToRecord(std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
event_to_record_ = std::make_shared<EventInter>(id_, event, waiter_type);
Expand All @@ -95,6 +114,10 @@ class InstructionBase {
events_to_wait_.emplace_back(instr_id, event, waiter_type);
}

void AddEventToWait(const EventInter* event_inter) {
events_to_wait_.push_back(*event_inter);
}

void RecordEvent(const Place& place) const;
void WaitEvent(const Place& place) const;

Expand Down Expand Up @@ -170,6 +193,12 @@ class InstructionBase {

std::vector<size_t> next_instrs_in_same_thread_;

bool force_record_event_{false};

std::vector<std::string> events_to_wait_info_;

std::string event_to_record_info_{"default"};

std::shared_ptr<EventInter> event_to_record_;

std::vector<EventInter> events_to_wait_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,26 @@ LegacyKernelInstruction::LegacyKernelInstruction(
SetSchedulingPriority(1);
}
}
if (op_attributes.count("force_record_event") != 0) {
SetForceRecordEvent(op_attributes.at("force_record_event")
.dyn_cast<pir::BoolAttribute>()
.data());
}
if (op_attributes.count("event_to_record") != 0) {
SetEventToRecordInfo(op_attributes.at("event_to_record")
.dyn_cast<pir::StrAttribute>()
.AsString());
}
if (op_attributes.count("events_to_wait") != 0) {
std::vector<std::string> events_to_wait;
auto array_attr = op_attributes.at("events_to_wait")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
for (auto& attr : array_attr) {
events_to_wait.push_back(attr.dyn_cast<pir::StrAttribute>().AsString());
}
SetEventsToWaitInfo(events_to_wait);
}
VLOG(6) << "finish process dist attributes";

SetKernelType(AnalyseOpFuncType(op, place));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ PhiKernelInstruction::PhiKernelInstruction(
SetSchedulingPriority(1);
}
}
if (op_attributes.count("force_record_event") != 0) {
SetForceRecordEvent(op_attributes.at("force_record_event")
.dyn_cast<pir::BoolAttribute>()
.data());
}
if (op_attributes.count("event_to_record") != 0) {
SetEventToRecordInfo(op_attributes.at("event_to_record")
.dyn_cast<pir::StrAttribute>()
.AsString());
}
if (op_attributes.count("events_to_wait") != 0) {
std::vector<std::string> events_to_wait;
auto array_attr = op_attributes.at("events_to_wait")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
for (auto& attr : array_attr) {
events_to_wait.push_back(attr.dyn_cast<pir::StrAttribute>().AsString());
}
SetEventsToWaitInfo(events_to_wait);
}
VLOG(6) << "finish process dist attributes";

SetKernelType(AnalyseOpFuncType(op, place));
Expand Down
65 changes: 65 additions & 0 deletions paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,15 @@ void PirStreamAnalyzer::ConstructEvents(
platform::GenerateDeviceEventFlag());
recorder_instr->AddEventToRecord(device_event,
platform::kCUDA /*unused*/);
// It means the event will be waited for other interpreter that the
// event name of a operator is not 'default'.
if (recorder_instr->IsForceRecordEvent() == true &&
(*program_force_events_to_wait_)
.count(recorder_instr->EventToRecordInfo()) == 0) {
(*program_force_events_to_wait_)[recorder_instr
->EventToRecordInfo()] =
recorder_instr->EventToRecord();
}
instr2event.emplace(recorder_instr_id, device_event);
}

Expand All @@ -729,6 +738,62 @@ void PirStreamAnalyzer::ConstructEvents(
}
}
}
// NOTE(lizhiyu): The mannual event only support the program_interpreter to
// annalyze the streams across the sub_programs. construct mannual events to
// record
for (auto& instr : instructions) {
// create extra event to record
if (instr->IsForceRecordEvent() && instr->EventToRecord() == nullptr) {
auto place = instr->DeviceContext().GetPlace();
if (platform::is_gpu_place(place)) {
PADDLE_ENFORCE_NE(
instr->EventToRecordInfo(),
"default",
phi::errors::InvalidArgument(
"If the attribute 'force_record_event_' of one "
"operator is 'true', the 'event_to_record_' of this "
"operator can not be 'default'. But the "
"'event_name' of the operator %s is 'default'.",
instr->Name()));
PADDLE_ENFORCE_EQ(
(*program_force_events_to_wait_).find(instr->EventToRecordInfo()),
(*program_force_events_to_wait_).end(),
phi::errors::InvalidArgument(
"The program_force_events_to_wait_ had the event "
"that belongs to the operator : %s before the operator create "
"the event, "
"This is is werid.",
instr->Name()));
std::shared_ptr<DeviceEvent> device_event =
std::make_shared<DeviceEvent>(place,
platform::GenerateDeviceEventFlag());
instr->AddEventToRecord(device_event, platform::kCUDA /*unused*/);
(*program_force_events_to_wait_)[instr->EventToRecordInfo()] =
instr->EventToRecord();
VLOG(6) << "Create mannual event: " << instr->EventToRecordInfo()
<< " for the operator: " << instr->Name();
}
}
// add extra mannual events
if (!(instr->EventsToWaitInfo().empty())) {
for (auto event_name : instr->EventsToWaitInfo()) {
PADDLE_ENFORCE_NE(
(*program_force_events_to_wait_).find(event_name),
(*program_force_events_to_wait_).end(),
phi::errors::InvalidArgument(
"The program_force_events_to_wait_ don't have the event %s "
"for the operator: %s to wait. The event should had been "
"created by the operator "
"whose event_to_record_ is %s.",
event_name.c_str(),
instr->Name(),
event_name.c_str()));

instr->AddEventToWait(
(*program_force_events_to_wait_)[event_name].get());
}
}
}
}

void PirStreamAnalyzer::AnalyseAllRunType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ class PirStreamAnalyzer {

void ShareEventInfoFrom(const PirStreamAnalyzer& src);

void SetForceEventsToWaitInfo(
std::unordered_map<std::string, std::shared_ptr<EventInter>>*
program_force_events_to_wait) {
program_force_events_to_wait_ = program_force_events_to_wait;
}

std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
GetEventInfo() const;
Expand All @@ -174,6 +180,8 @@ class PirStreamAnalyzer {
std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
event_info_;
std::unordered_map<std::string, std::shared_ptr<EventInter>>*
program_force_events_to_wait_; // not owned
};

} // namespace interpreter
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,7 @@ void PirInterpreter::PreAnalysis() {
BuildInstructionDependences();
VLOG(4) << "Done BuildInstructionDependences";

ir_stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_);
ir_stream_analyzer_.ConstructEvents(vec_instruction_base_);
VLOG(4) << "Done ConstructEvents";

Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/new_executor/pir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ class PirInterpreter : public InterpreterBaseImpl {
// Only for debug
Variable* DebugVar(const std::string& name) const override;

std::unordered_map<std::string, std::shared_ptr<EventInter>>*
GetForceEventsToWaitInfo() {
return force_evnets_to_wait_;
}

void SetForceEventsToWaitInfo(
std::unordered_map<std::string, std::shared_ptr<EventInter>>*
force_evnets_to_wait) {
force_evnets_to_wait_ = force_evnets_to_wait;
}

private:
// build graph
void UpdateSyncOpNum();
Expand Down Expand Up @@ -153,6 +164,9 @@ class PirInterpreter : public InterpreterBaseImpl {

ExecutionConfig execution_config_;

std::unordered_map<std::string, std::shared_ptr<EventInter>>*
force_evnets_to_wait_;

VariableScope var_scope_;
Scope* scope_{nullptr};
Scope* local_scope_{nullptr}; // not owned
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/new_executor/standalone_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/pir_interpreter.h"
#include "paddle/fluid/framework/new_executor/program_interpreter.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
Expand Down Expand Up @@ -118,6 +119,11 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
shared_program->block(),
micro_batch_scopes_[micro_batch_id],
execution_config));
// Note(lizhiyu): Add mannual event info
auto pir_inter = const_cast<PirInterpreter*>(
static_cast<const PirInterpreter*>(interpretercores_.back()->Impl()));
pir_inter->SetForceEventsToWaitInfo(
&(vec_force_events_to_wait_[micro_batch_id]));
} else {
interpretercores_.emplace_back(
std::make_shared<InterpreterCore>(place_,
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,18 @@ static void TranslateOpDistAttribute(const OpDesc& op_desc,
"scheduling_priority", dist_attr->scheduling_priority());
(*attr_map)["scheduling_priority"] = new_attr;
}

pir::Attribute force_record_event_attr = attribute_translator(
"force_record_event", dist_attr->force_record_event());
(*attr_map)["force_record_event"] = force_record_event_attr;

pir::Attribute event_to_record_attr =
attribute_translator("event_to_record", dist_attr->event_to_record());
(*attr_map)["event_to_record"] = event_to_record_attr;

pir::Attribute events_to_wait_attr =
attribute_translator("events_to_wait", dist_attr->events_to_wait());
(*attr_map)["events_to_wait"] = events_to_wait_attr;
}
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
#pragma once

#include "paddle/pir/core/dialect.h"
#include "paddle/utils/test_macros.h"

namespace paddle {
namespace dialect {

class OperatorDialect : public pir::Dialect {
class TEST_API OperatorDialect : public pir::Dialect {
public:
explicit OperatorDialect(pir::IrContext* context);

Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/pir/dialect/operator/ir/op_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/core/type.h"
#include "paddle/utils/test_macros.h"

namespace paddle {
namespace dialect {

using DenseTensorType = pir::DenseTensorType;

class SelectedRowsType : public pir::Type::TypeBase<SelectedRowsType,
pir::Type,
SelectedRowsTypeStorage,
pir::ShapedTypeInterface> {
class TEST_API SelectedRowsType
: public pir::Type::TypeBase<SelectedRowsType,
pir::Type,
SelectedRowsTypeStorage,
pir::ShapedTypeInterface> {
public:
using Base::Base;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/pir/core/dialect_interface.h"
#include "paddle/pir/core/parameter.h"
#include "paddle/utils/test_macros.h"

namespace paddle {
namespace dialect {
class ParameterConvertInterface
class TEST_API ParameterConvertInterface
: public pir::DialectInterface::Base<ParameterConvertInterface> {
public:
explicit ParameterConvertInterface(pir::Dialect* dialect) : Base(dialect) {}
Expand Down
36 changes: 36 additions & 0 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ pir::Type ConvertOpTypeToKernelType(pir::IrContext* ctx,
"Not support op type %s in ConvertOpTypeToKernelType.", op_type));
}

std::vector<int64_t> GetValueShape(const pir::Value& value) {
if (value.type().isa<DenseTensorType>()) {
return phi::vectorize(value.type().dyn_cast<DenseTensorType>().dims());
} else if (value.type().isa<SelectedRowsType>()) {
return phi::vectorize(value.type().dyn_cast<SelectedRowsType>().dims());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get shape for dense "
"tensor."));
}
}

std::unordered_map<std::string, phi::DataType> Str2PhiDataType = {
{"DataType::FLOAT16", phi::DataType::FLOAT16},
{"DataType::BFLOAT16", phi::DataType::BFLOAT16},
Expand Down Expand Up @@ -145,6 +157,30 @@ static bool NeedFallBackFromGPUDNN2GPU(pir::Operation* op,
.data() == true)) {
return true;
}
} else if ((op->isa<AffineGridOp>() || op->isa<AffineGridGradOp>()) &&
kernel_key.backend() == phi::Backend::GPUDNN) {
bool use_cudnn = true;
int version = -1;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
version = platform::DnnVersion();
#endif
if (version >= 6000 && op->attributes()
.at("align_corners")
.dyn_cast<pir::BoolAttribute>()
.data() == true) {
use_cudnn = true;
} else {
use_cudnn = false;
}

auto shape = GetValueShape(op->operand_source(0));
if (shape[1] == 3) {
use_cudnn = false;
}
#if defined(PADDLE_WITH_HIP)
use_cudnn = false;
#endif
return !use_cudnn;
}
return false;
}
Expand Down
Loading

0 comments on commit 3cd67de

Please sign in to comment.