Skip to content

Commit 7f37699

Browse files
committed
Merge branch 'develop' into i41
2 parents fa6a3bd + 4ad8416 commit 7f37699

File tree

26 files changed

+619
-58
lines changed

26 files changed

+619
-58
lines changed

paddle/fluid/distributed/collective/deep_ep/include/event_pool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace deep_ep::detail {
2222

2323
class EventPool {
2424
public:
25-
EventPool() = default;
25+
EventPool();
2626
EventPool(const EventPool&) = delete;
2727
EventPool(EventPool&&) = delete;
2828
~EventPool();

paddle/fluid/distributed/collective/deep_ep/src/event_pool.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ EventPool &EventPool::Instance() {
2222
return pool;
2323
}
2424

25+
EventPool::EventPool() {
26+
for (size_t i = 0; i < 1000; ++i) {
27+
cudaEvent_t new_event;
28+
CUDA_CHECK(cudaEventCreate(&new_event));
29+
30+
cudaEventRecord(new_event, 0);
31+
incomplished_events_.push(new_event);
32+
}
33+
}
34+
2535
EventPool::~EventPool() {
2636
const auto &DestroyEvent = [](cudaEvent_t event) {
2737
cudaError_t e = cudaEventDestroy(event);

paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,6 @@ def GeneratePythonCFunction(self, no_input_out_tensor=False):
503503
get_input_out_str = ""
504504
if (
505505
not no_input_out_tensor
506-
and not forward_inplace_map
507506
and len(self.forward_outputs_position_map) == 1
508507
and next(iter(self.forward_outputs_position_map.values()))[0]
509508
== "Tensor"
@@ -573,6 +572,7 @@ def GeneratePythonCFunction(self, no_input_out_tensor=False):
573572
namespace,
574573
GetForwardFunctionName(inplaced_forward_api_name),
575574
)
575+
dygraph_function_call_str = ",".join(dygraph_function_call_list)
576576

577577
inplace_noamp_dygraph_function_str = (
578578
NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2342,7 +2342,9 @@ PDNode *patterns::QuantConv::operator()(const std::string &conv_type) {
23422342
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op(conv_type);
23432343
conv_op->assert_more([&](Node *node) {
23442344
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
2345-
"bfloat16";
2345+
"bfloat16" ||
2346+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
2347+
"bfloat16";
23462348
});
23472349

23482350
quant_op->LinksFrom({quant_in}).LinksTo({conv_in});
@@ -3172,7 +3174,8 @@ PDNode *patterns::QuantizePlacement::operator()(
31723174
auto *op =
31733175
pattern->NewNode(op_repr())->assert_is_ops(quantize_enabled_op_types);
31743176
op->assert_more([&](Node *node) {
3175-
return node->Op()->GetAttrIfExists<bool>("use_mkldnn");
3177+
return node->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
3178+
node->Op()->GetAttrIfExists<bool>("use_onednn");
31763179
});
31773180
return op;
31783181
}
@@ -3218,6 +3221,7 @@ PDNode *patterns::Bfloat16Placement::operator()(
32183221
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
32193222
op->assert_more([&](Node *node) {
32203223
return node->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
3224+
node->Op()->GetAttrIfExists<bool>("use_onednn") ||
32213225
node->Op()->Type() == "reshape2";
32223226
});
32233227
op->LinksFrom({op_in});
@@ -3227,25 +3231,35 @@ PDNode *patterns::Bfloat16Placement::operator()(
32273231
PDNode *patterns::OrphanedBfloat16::operator()() {
32283232
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
32293233
prev_op->assert_more([&](Node *node) {
3230-
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type");
3231-
bool data_type_is_fp32 = node->Op()->GetAttrIfExists<std::string>(
3232-
"mkldnn_data_type") == "float32";
3234+
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type") &&
3235+
!node->Op()->HasAttr("onednn_data_type");
3236+
bool data_type_is_fp32 =
3237+
node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3238+
"float32" ||
3239+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3240+
"float32";
32333241
return data_type_is_missing || data_type_is_fp32;
32343242
});
32353243
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();
32363244

32373245
auto *op = pattern->NewNode(op_repr())->assert_is_op();
32383246
op->assert_more([&](Node *node) {
32393247
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3240-
"bfloat16";
3248+
"bfloat16" ||
3249+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3250+
"bfloat16";
32413251
});
32423252
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
32433253

32443254
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
32453255
next_op->assert_more([&](Node *node) {
3246-
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type");
3247-
bool data_type_is_fp32 = node->Op()->GetAttrIfExists<std::string>(
3248-
"mkldnn_data_type") == "float32";
3256+
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type") &&
3257+
!node->Op()->HasAttr("onednn_data_type");
3258+
bool data_type_is_fp32 =
3259+
node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3260+
"float32" ||
3261+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3262+
"float32";
32493263
return data_type_is_missing || data_type_is_fp32;
32503264
});
32513265

@@ -3258,14 +3272,17 @@ PDNode *patterns::OrphanedBfloat16::operator()() {
32583272
PDNode *patterns::UnsupportedBfloat16::operator()() {
32593273
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
32603274
prev_op->assert_more([&](Node *node) {
3261-
return node->Op()->HasAttr("mkldnn_data_type") == false;
3275+
return node->Op()->HasAttr("mkldnn_data_type") == false &&
3276+
node->Op()->HasAttr("onednn_data_type") == false;
32623277
});
32633278
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();
32643279

32653280
auto *op = pattern->NewNode(op_repr())->assert_is_op();
32663281
op->assert_more([&](Node *node) {
32673282
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3268-
"bfloat16";
3283+
"bfloat16" ||
3284+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3285+
"bfloat16";
32693286
});
32703287
prev_op->LinksTo({prev_out});
32713288
op->LinksFrom({prev_out});
@@ -3276,7 +3293,9 @@ PDNode *patterns::Bloat16Ops::operator()() {
32763293
auto op = pattern->NewNode(op_repr())->assert_is_op();
32773294
op->assert_more([&](Node *node) {
32783295
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
3279-
"bfloat16";
3296+
"bfloat16" ||
3297+
node->Op()->GetAttrIfExists<std::string>("onednn_data_type") ==
3298+
"bfloat16";
32803299
});
32813300
return op;
32823301
}
@@ -3298,8 +3317,8 @@ PDNode *patterns::ONEDNNInPlace::operator()() {
32983317
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
32993318
auto next_output = pattern->NewNode(next_op_out_repr())->AsOutput();
33003319

3301-
// Check if op is MKL-DNN enabled
3302-
possible_inplace_op->assert_op_attr("use_mkldnn", true);
3320+
// Check if op is ONE-DNN enabled
3321+
possible_inplace_op->assert_op_attr_or("use_mkldnn", "use_onednn", true);
33033322

33043323
// linked structure
33053324
possible_inplace_op->LinksTo({output});

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,20 @@ struct PDNode {
168168
return this;
169169
}
170170

171+
template <typename T>
172+
PDNode* assert_op_attr_or(const std::string& attr_name1,
173+
const std::string& attr_name2,
174+
const T& attr) {
175+
asserts_.emplace_back([=](Node* x) {
176+
return x && x->IsOp() &&
177+
((x->Op()->HasAttr(attr_name1) &&
178+
PADDLE_GET_CONST(T, x->Op()->GetAttr(attr_name1)) == attr) ||
179+
(x->Op()->HasAttr(attr_name2) &&
180+
PADDLE_GET_CONST(T, x->Op()->GetAttr(attr_name2)) == attr));
181+
});
182+
return this;
183+
}
184+
171185
private:
172186
PDNode(PDPattern* pattern,
173187
const std::string& name = "",

paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,54 @@ void SlogDeterminantGradKernel(const Context& dev_ctx,
8282
inverse_A.Resize(x.dims());
8383
dev_ctx.template Alloc<T>(&inverse_A);
8484

85-
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
86-
mat_inv(dev_ctx, x, &inverse_A);
85+
const auto& mat_dims = x.dims();
86+
const int rank = mat_dims.size();
87+
int n = mat_dims[rank - 1];
88+
int64_t total_batch_size = rank > 2 ? x.numel() / (n * n) : 1;
89+
90+
// Divide the batch into chunks because of cublasMatInv limitation
91+
if (total_batch_size <= 65536) {
92+
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
93+
mat_inv(dev_ctx, x, &inverse_A);
94+
} else {
95+
constexpr int64_t max_batch_size = 65536;
96+
int64_t processed = 0;
97+
98+
VLOG(3) << "Large batch size detected (" << total_batch_size
99+
<< "), processing in chunks of " << max_batch_size;
100+
101+
while (processed < total_batch_size) {
102+
int64_t current_batch =
103+
std::min(max_batch_size, total_batch_size - processed);
104+
105+
// Extract current batch data
106+
DenseTensor x_batch;
107+
x_batch.ShareDataWith(x);
108+
x_batch.Resize({total_batch_size, n, n});
109+
x_batch = x_batch.Slice(processed, processed + current_batch);
110+
x_batch.Resize({current_batch, n, n});
111+
112+
DenseTensor inverse_batch;
113+
inverse_batch.Resize({current_batch, n, n});
114+
dev_ctx.template Alloc<T>(&inverse_batch);
115+
116+
// Compute the inverse matrix for the current batch
117+
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
118+
mat_inv(dev_ctx, x_batch, &inverse_batch);
119+
120+
// Copy the result to the output tensor
121+
DenseTensor output_slice;
122+
output_slice.ShareDataWith(inverse_A);
123+
output_slice.Resize({total_batch_size, n, n});
124+
output_slice = output_slice.Slice(processed, processed + current_batch);
125+
output_slice.Resize({current_batch, n, n});
126+
127+
phi::Copy(
128+
dev_ctx, inverse_batch, dev_ctx.GetPlace(), false, &output_slice);
129+
130+
processed += current_batch;
131+
}
132+
}
87133

88134
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
89135

paddle/phi/ops/yaml/op_compat.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,7 @@
17651765
attrs :
17661766
{scale_data : Scale_data, shift_data : Shift_data, scale_weights : Scale_weights}
17671767
extra :
1768-
attrs : [bool use_mkldnn = true, bool use_onednn = false, str mkldnn_data_type = "float32"]
1768+
attrs : [bool use_mkldnn = false, bool use_onednn = false, str mkldnn_data_type = "float32"]
17691769

17701770
- op : fusion_repeated_fc_relu
17711771
inputs :

python/paddle/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@
129129
tensor as tensor,
130130
utils as utils,
131131
)
132+
from .amp import (
133+
get_autocast_cpu_dtype,
134+
get_autocast_dtype,
135+
get_autocast_gpu_dtype,
136+
is_autocast_enabled,
137+
)
132138
from .autograd import (
133139
enable_grad,
134140
grad,
@@ -1233,6 +1239,10 @@
12331239
'nan',
12341240
'pi',
12351241
'e',
1242+
'is_autocast_enabled',
1243+
'get_autocast_dtype',
1244+
'get_autocast_cpu_dtype',
1245+
'get_autocast_gpu_dtype',
12361246
]
12371247

12381248
import os

python/paddle/amp/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
amp_guard,
3434
auto_cast,
3535
decorate,
36+
get_autocast_dtype,
37+
is_autocast_enabled,
3638
)
3739
from .grad_scaler import ( # noqa: F401
3840
AmpScaler,
@@ -46,8 +48,15 @@
4648
'decorate',
4749
'is_float16_supported',
4850
'is_bfloat16_supported',
51+
'is_autocast_enabled',
52+
'get_autocast_dtype',
53+
'get_autocast_cpu_dtype',
54+
'get_autocast_gpu_dtype',
4955
]
5056

57+
get_autocast_cpu_dtype = get_autocast_dtype
58+
get_autocast_gpu_dtype = get_autocast_dtype
59+
5160

5261
def is_float16_supported(device: str | None = None) -> bool:
5362
"""

python/paddle/amp/auto_cast.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from typing_extensions import TypeAlias, TypeGuard
4949

5050
from paddle import Tensor
51+
from paddle._typing import PlaceLike
5152
from paddle._typing.dtype_like import _DTypeLiteral
5253
from paddle.nn import Layer
5354
from paddle.nn.layer.layers import _StateDict
@@ -1322,3 +1323,73 @@ def decorate(
13221323
master_grad,
13231324
excluded_layers,
13241325
)
1326+
1327+
1328+
def is_autocast_enabled(device_type: PlaceLike | None = None) -> bool:
1329+
"""
1330+
Check whether auto-mixed-precision is enabled in the current context.
1331+
1332+
Args:
1333+
device_type (PlaceLike, optional): The device type to check. This argument is ignored for all devices sharing the same AMP state in paddlepaddle.
1334+
1335+
Returns:
1336+
bool: True if auto-mixed-precision is enabled, False otherwise.
1337+
1338+
Examples:
1339+
.. code-block:: python
1340+
1341+
>>> # doctest: +REQUIRES(env:GPU)
1342+
>>> # Demo1: Check if auto-mixed-precision is enabled by default
1343+
>>> import paddle
1344+
>>> paddle.device.set_device('gpu')
1345+
>>> print(paddle.is_autocast_enabled())
1346+
False
1347+
1348+
>>> # Demo2: Enable auto-mixed-precision and check again
1349+
>>> with paddle.amp.auto_cast():
1350+
... print(paddle.is_autocast_enabled())
1351+
True
1352+
"""
1353+
if in_pir_mode():
1354+
amp_attrs = core._get_amp_attrs()
1355+
return amp_attrs._amp_level != AMP_LEVEL.O0
1356+
else:
1357+
tracer = _dygraph_tracer()
1358+
if tracer:
1359+
return tracer._amp_level != core.AmpLevel.O0
1360+
return False
1361+
1362+
1363+
def get_autocast_dtype(device_type: PlaceLike | None = None) -> _DTypeLiteral:
1364+
"""
1365+
Get the auto-mixed-precision dtype in the current context if autocast is enabled else default AMP dtype(float16).
1366+
1367+
Args:
1368+
device_type (PlaceLike, optional): The device type to check. This argument is ignored for all devices sharing the same AMP state in paddlepaddle.
1369+
1370+
Returns:
1371+
_DTypeLiteral: The current AMP dtype.
1372+
1373+
Examples:
1374+
.. code-block:: python
1375+
1376+
>>> # doctest: +REQUIRES(env:GPU)
1377+
>>> # Demo1: Get default auto-mixed-precision dtype
1378+
>>> import paddle
1379+
>>> paddle.device.set_device('gpu')
1380+
>>> print(paddle.get_autocast_dtype())
1381+
float16
1382+
1383+
>>> # Demo2: Enable auto-mixed-precision and get the dtype
1384+
>>> with paddle.amp.auto_cast():
1385+
... print(paddle.get_autocast_dtype())
1386+
float16
1387+
"""
1388+
if not is_autocast_enabled():
1389+
return "float16"
1390+
if in_pir_mode():
1391+
amp_attrs = core._get_amp_attrs()
1392+
return amp_attrs._amp_dtype
1393+
else:
1394+
tracer = _dygraph_tracer()
1395+
return tracer._amp_dtype

0 commit comments

Comments
 (0)