@@ -820,6 +820,43 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
820820 indices_int64.push_back (indice);
821821 }
822822
823+ // AMP Logic
824+ if (egr::Controller::Instance ().GetAMPLevel () !=
825+ paddle::imperative::AmpLevel::O0) {
826+ auto op_name = phi::TransToFluidOpName (" index_elementwise_get" );
827+ paddle::small_vector<std::vector<paddle::Tensor>,
828+ egr::kSlotSmallVectorSize >
829+ amp_tensors_vector = {{self_tensor}};
830+
831+ auto amp_dst_dtype =
832+ paddle::imperative::GetAmpDestDtype (op_name, amp_tensors_vector);
833+
834+ auto new_self_tensor = paddle::imperative::AmpAutoCast (
835+ " self_tensor" , self_tensor, amp_dst_dtype, op_name);
836+ auto new_tensor = paddle::imperative::AmpAutoCast (
837+ " tensor" , tensor, amp_dst_dtype, op_name);
838+
839+ {
840+ paddle::imperative::AutoCastGuard guard (
841+ egr::Controller::Instance ().GetCurrentAmpAttrs (),
842+ paddle::imperative::AmpLevel::O0);
843+
844+ AdvancedIndex ad = AdvancedIndex (new_tensor, indices_int64);
845+ const bool is_combined = false ;
846+ const bool accumulate = false ;
847+
848+ return index_elementwise_get_ad_func (new_self_tensor,
849+ ad.indices ,
850+ ad.src_sizes ,
851+ ad.src_strides ,
852+ ad.indexed_sizes ,
853+ ad.indexed_strides ,
854+ slice_offset,
855+ accumulate,
856+ is_combined);
857+ }
858+ }
859+
823860 AdvancedIndex ad = AdvancedIndex (tensor, indices_int64);
824861 const bool is_combined = false ;
825862 const bool accumulate = false ;
@@ -1287,6 +1324,45 @@ static void ApplyGetitem(const int index_size,
12871324 transed_tensor,
12881325 &transed_index_int64);
12891326
1327+ // AMP Logic
1328+ if (egr::Controller::Instance ().GetAMPLevel () !=
1329+ paddle::imperative::AmpLevel::O0) {
1330+ auto op_name = phi::TransToFluidOpName (" index_elementwise_get" );
1331+ paddle::small_vector<std::vector<paddle::Tensor>,
1332+ egr::kSlotSmallVectorSize >
1333+ amp_tensors_vector = {{*self_tensor}};
1334+
1335+ auto amp_dst_dtype =
1336+ paddle::imperative::GetAmpDestDtype (op_name, amp_tensors_vector);
1337+
1338+ auto new_self_tensor = paddle::imperative::AmpAutoCast (
1339+ " self_tensor" , *self_tensor, amp_dst_dtype, op_name);
1340+ auto new_transed_tensor = paddle::imperative::AmpAutoCast (
1341+ " transed_tensor" , *transed_tensor, amp_dst_dtype, op_name);
1342+
1343+ {
1344+ paddle::imperative::AutoCastGuard guard (
1345+ egr::Controller::Instance ().GetCurrentAmpAttrs (),
1346+ paddle::imperative::AmpLevel::O0);
1347+
1348+ AdvancedIndex ad =
1349+ AdvancedIndex (new_transed_tensor, transed_index_int64);
1350+
1351+ const bool is_combined = (index_size == 1 ) ? false : true ;
1352+ const bool accumulate = true ;
1353+ *out = index_elementwise_get_ad_func (new_self_tensor,
1354+ ad.indices ,
1355+ ad.src_sizes ,
1356+ ad.src_strides ,
1357+ ad.indexed_sizes ,
1358+ ad.indexed_strides ,
1359+ slice_offset,
1360+ accumulate,
1361+ is_combined);
1362+ }
1363+ return ;
1364+ }
1365+
12901366 AdvancedIndex ad = AdvancedIndex (*transed_tensor, transed_index_int64);
12911367 // is_combined:
12921368 // Distinguishes between regular indexing (single index) and combined
0 commit comments