From 0c2630be22610b4bc79305f73759ac5334ded540 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Thu, 19 Jan 2023 15:55:26 +0100 Subject: [PATCH 01/61] [InstCombine] Drop incorrect test (NFC) I made a typo here, this was supposed to be !align rather than !aligned. But then !align can only be applied to loads, not calls (where one would use the return attribute instead). And freeze can't be pushed through loads anyway, so there's no way to test this case (same as !nonnull). --- llvm/test/Transforms/InstCombine/freeze.ll | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll index 5e87dd7c7a320a..21a89d873b92bd 100644 --- a/llvm/test/Transforms/InstCombine/freeze.ll +++ b/llvm/test/Transforms/InstCombine/freeze.ll @@ -1115,19 +1115,6 @@ define i32 @freeze_ctpop(i32 %x) { ret i32 %fr } -declare ptr @llvm.ptrmask.p0.i64(ptr, i64) - -define ptr @freeze_ptrmask_align(ptr %p, i64 noundef %m) { -; CHECK-LABEL: @freeze_ptrmask_align( -; CHECK-NEXT: [[P_FR:%.*]] = freeze ptr [[P:%.*]] -; CHECK-NEXT: [[MASK:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P_FR]], i64 [[M:%.*]]), !aligned !1 -; CHECK-NEXT: ret ptr [[MASK]] -; - %mask = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m), !aligned !{i64 4} - %fr = freeze ptr %mask - ret ptr %fr -} - !0 = !{} !1 = !{i64 4} !2 = !{i32 0, i32 100} From adcc90aefeabee3bd659520c754bd86db6e1d389 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Thu, 19 Jan 2023 16:05:45 +0100 Subject: [PATCH 02/61] [InstCombine] Add non-logical variants for some ctpop and/or tests (NFC) And regenerate test checks to pick up new names. --- llvm/test/Transforms/InstCombine/ispow2.ll | 166 ++++++++++++--------- 1 file changed, 98 insertions(+), 68 deletions(-) diff --git a/llvm/test/Transforms/InstCombine/ispow2.ll b/llvm/test/Transforms/InstCombine/ispow2.ll index cbc247f3c38171..c9ffd96e393830 100644 --- a/llvm/test/Transforms/InstCombine/ispow2.ll +++ b/llvm/test/Transforms/InstCombine/ispow2.ll @@ -181,8 +181,8 @@ declare <2 x i8> @llvm.ctpop.v2i8(<2 x i8>) define i1 @is_pow2_ctpop(i32 %x) { ; CHECK-LABEL: @is_pow2_ctpop( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ult i32 %t0, 2 @@ -194,8 +194,8 @@ define i1 @is_pow2_ctpop(i32 %x) { define i1 @is_pow2_ctpop_logical(i32 %x) { ; CHECK-LABEL: @is_pow2_ctpop_logical( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ult i32 %t0, 2 @@ -214,8 +214,8 @@ define i1 @is_pow2_ctpop_extra_uses(i32 %x) { ; CHECK-NEXT: call void @use_i1(i1 [[CMP]]) ; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0 ; CHECK-NEXT: call void @use_i1(i1 [[NOTZERO]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ult i32 %t0, 2 @@ -233,8 +233,8 @@ define i1 @is_pow2_ctpop_extra_uses_logical(i32 %x) { ; CHECK-NEXT: call void @use_i1(i1 [[CMP]]) ; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0 ; CHECK-NEXT: call void @use_i1(i1 [[NOTZERO]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ult i32 %t0, 2 @@ -250,8 +250,8 @@ define i1 @is_pow2_ctpop_extra_uses_logical(i32 %x) { define <2 x i1> @is_pow2_ctpop_commute_vec(<2 x i8> %x) { ; CHECK-LABEL: @is_pow2_ctpop_commute_vec( ; CHECK-NEXT: [[T0:%.*]] = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq <2 x i8> [[T0]], -; CHECK-NEXT: ret <2 x i1> [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i8> [[T0]], +; CHECK-NEXT: ret <2 x i1> [[R]] ; %t0 = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %x) %cmp = icmp ult <2 x i8> %t0, @@ -393,8 +393,8 @@ define i1 @is_pow2_ctpop_wrong_pred2_logical(i32 %x) { define i1 @isnot_pow2_ctpop(i32 %x) { ; CHECK-LABEL: @isnot_pow2_ctpop( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp ne i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ugt i32 %t0, 1 @@ -406,8 +406,8 @@ define i1 @isnot_pow2_ctpop(i32 %x) { define i1 @isnot_pow2_ctpop_logical(i32 %x) { ; CHECK-LABEL: @isnot_pow2_ctpop_logical( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp ne i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ugt i32 %t0, 1 @@ -425,8 +425,8 @@ define i1 @isnot_pow2_ctpop_extra_uses(i32 %x) { ; CHECK-NEXT: call void @use_i1(i1 [[CMP]]) ; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 ; CHECK-NEXT: call void @use_i1(i1 [[ISZERO]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp ne i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ugt i32 %t0, 1 @@ -444,8 +444,8 @@ define i1 @isnot_pow2_ctpop_extra_uses_logical(i32 %x) { ; CHECK-NEXT: call void @use_i1(i1 [[CMP]]) ; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 ; CHECK-NEXT: call void @use_i1(i1 [[ISZERO]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp ne i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ugt i32 %t0, 1 @@ -461,8 +461,8 @@ define i1 @isnot_pow2_ctpop_extra_uses_logical(i32 %x) { define <2 x i1> @isnot_pow2_ctpop_commute_vec(<2 x i8> %x) { ; CHECK-LABEL: @isnot_pow2_ctpop_commute_vec( ; CHECK-NEXT: [[T0:%.*]] = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp ne <2 x i8> [[T0]], -; CHECK-NEXT: ret <2 x i1> [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ne <2 x i8> [[T0]], +; CHECK-NEXT: ret <2 x i1> [[R]] ; %t0 = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %x) %cmp = icmp ugt <2 x i8> %t0, @@ -570,8 +570,8 @@ define i1 @isnot_pow2_ctpop_wrong_pred2_logical(i32 %x) { define i1 @is_pow2_negate_op(i32 %x) { ; CHECK-LABEL: @is_pow2_negate_op( ; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[TMP1]], 1 -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %neg = sub i32 0, %x %and = and i32 %neg, %x @@ -584,8 +584,8 @@ define i1 @is_pow2_negate_op(i32 %x) { define i1 @is_pow2_negate_op_logical(i32 %x) { ; CHECK-LABEL: @is_pow2_negate_op_logical( ; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[TMP1]], 1 -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %neg = sub i32 0, %x %and = and i32 %neg, %x @@ -598,8 +598,8 @@ define i1 @is_pow2_negate_op_logical(i32 %x) { define <2 x i1> @is_pow2_negate_op_vec(<2 x i32> %x) { ; CHECK-LABEL: @is_pow2_negate_op_vec( ; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> [[X:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <2 x i32> [[TMP1]], -; CHECK-NEXT: ret <2 x i1> [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i32> [[TMP1]], +; CHECK-NEXT: ret <2 x i1> [[R]] ; %neg = sub <2 x i32> zeroinitializer, %x %and = and <2 x i32> %neg, %x @@ -612,8 +612,8 @@ define <2 x i1> @is_pow2_negate_op_vec(<2 x i32> %x) { define i1 @is_pow2_decrement_op(i8 %x) { ; CHECK-LABEL: @is_pow2_decrement_op( ; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG1]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], 1 -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %dec = add i8 %x, -1 %and = and i8 %dec, %x @@ -626,8 +626,8 @@ define i1 @is_pow2_decrement_op(i8 %x) { define i1 @is_pow2_decrement_op_logical(i8 %x) { ; CHECK-LABEL: @is_pow2_decrement_op_logical( ; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG1]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], 1 -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %dec = add i8 %x, -1 %and = and i8 %dec, %x @@ -640,8 +640,8 @@ define i1 @is_pow2_decrement_op_logical(i8 %x) { define <2 x i1> @is_pow2_decrement_op_vec(<2 x i8> %x) { ; CHECK-LABEL: @is_pow2_decrement_op_vec( ; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <2 x i8> [[TMP1]], -; CHECK-NEXT: ret <2 x i1> [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i8> [[TMP1]], +; CHECK-NEXT: ret <2 x i1> [[R]] ; %dec = add <2 x i8> %x, %and = and <2 x i8> %dec, %x @@ -654,8 +654,8 @@ define <2 x i1> @is_pow2_decrement_op_vec(<2 x i8> %x) { define i1 @isnot_pow2_negate_op(i32 %x) { ; CHECK-LABEL: @isnot_pow2_negate_op( ; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 1 -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %neg = sub i32 0, %x %and = and i32 %neg, %x @@ -668,8 +668,8 @@ define i1 @isnot_pow2_negate_op(i32 %x) { define i1 @isnot_pow2_negate_op_logical(i32 %x) { ; CHECK-LABEL: @isnot_pow2_negate_op_logical( ; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 1 -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %neg = sub i32 0, %x %and = and i32 %neg, %x @@ -682,8 +682,8 @@ define i1 @isnot_pow2_negate_op_logical(i32 %x) { define <2 x i1> @isnot_pow2_negate_op_vec(<2 x i32> %x) { ; CHECK-LABEL: @isnot_pow2_negate_op_vec( ; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> [[X:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <2 x i32> [[TMP1]], -; CHECK-NEXT: ret <2 x i1> [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp ne <2 x i32> [[TMP1]], +; CHECK-NEXT: ret <2 x i1> [[R]] ; %neg = sub <2 x i32> zeroinitializer, %x %and = and <2 x i32> %neg, %x @@ -696,8 +696,8 @@ define <2 x i1> @isnot_pow2_negate_op_vec(<2 x i32> %x) { define i1 @isnot_pow2_decrement_op(i8 %x) { ; CHECK-LABEL: @isnot_pow2_decrement_op( ; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG1]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i8 [[TMP1]], 1 -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %dec = add i8 %x, -1 %and = and i8 %dec, %x @@ -710,8 +710,8 @@ define i1 @isnot_pow2_decrement_op(i8 %x) { define i1 @isnot_pow2_decrement_op_logical(i8 %x) { ; CHECK-LABEL: @isnot_pow2_decrement_op_logical( ; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG1]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i8 [[TMP1]], 1 -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[TMP1]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %dec = add i8 %x, -1 %and = and i8 %dec, %x @@ -724,8 +724,8 @@ define i1 @isnot_pow2_decrement_op_logical(i8 %x) { define <2 x i1> @isnot_pow2_decrement_op_vec(<2 x i8> %x) { ; CHECK-LABEL: @isnot_pow2_decrement_op_vec( ; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]) -; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <2 x i8> [[TMP1]], -; CHECK-NEXT: ret <2 x i1> [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = icmp ne <2 x i8> [[TMP1]], +; CHECK-NEXT: ret <2 x i1> [[R]] ; %dec = add <2 x i8> %x, %and = and <2 x i8> %dec, %x @@ -740,8 +740,8 @@ define <2 x i1> @isnot_pow2_decrement_op_vec(<2 x i8> %x) { define i1 @is_pow2or0_ctpop(i32 %x) { ; CHECK-LABEL: @is_pow2or0_ctpop( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[T0]], 2 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ult i32 [[T0]], 2 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp eq i32 %t0, 1 @@ -753,8 +753,8 @@ define i1 @is_pow2or0_ctpop(i32 %x) { define i1 @is_pow2or0_ctpop_swap_cmp(i32 %x) { ; CHECK-LABEL: @is_pow2or0_ctpop_swap_cmp( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[T0]], 2 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ult i32 [[T0]], 2 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp eq i32 %t0, 1 @@ -766,8 +766,8 @@ define i1 @is_pow2or0_ctpop_swap_cmp(i32 %x) { define i1 @is_pow2or0_ctpop_logical(i32 %x) { ; CHECK-LABEL: @is_pow2or0_ctpop_logical( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[T0]], 2 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ult i32 [[T0]], 2 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp eq i32 %t0, 1 @@ -779,8 +779,8 @@ define i1 @is_pow2or0_ctpop_logical(i32 %x) { define <2 x i1> @is_pow2or0_ctpop_commute_vec(<2 x i8> %x) { ; CHECK-LABEL: @is_pow2or0_ctpop_commute_vec( ; CHECK-NEXT: [[T0:%.*]] = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp ult <2 x i8> [[T0]], -; CHECK-NEXT: ret <2 x i1> [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ult <2 x i8> [[T0]], +; CHECK-NEXT: ret <2 x i1> [[R]] ; %t0 = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %x) %cmp = icmp eq <2 x i8> %t0, @@ -799,8 +799,8 @@ define i1 @is_pow2or0_ctpop_extra_uses(i32 %x) { ; CHECK-NEXT: call void @use_i1(i1 [[CMP]]) ; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 ; CHECK-NEXT: call void @use_i1(i1 [[ISZERO]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[T0]], 2 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ult i32 [[T0]], 2 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) call void @use(i32 %t0) @@ -820,8 +820,8 @@ define i1 @is_pow2or0_ctpop_logical_extra_uses(i32 %x) { ; CHECK-NEXT: call void @use_i1(i1 [[CMP]]) ; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 ; CHECK-NEXT: call void @use_i1(i1 [[ISZERO]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[T0]], 2 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ult i32 [[T0]], 2 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) call void @use(i32 %t0) @@ -897,6 +897,21 @@ define i1 @is_pow2or0_ctpop_wrong_pred1(i32 %x) { ret i1 %r } +define i1 @is_pow2or0_ctpop_wrong_pred2(i32 %x) { +; CHECK-LABEL: @is_pow2or0_ctpop_wrong_pred2( +; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[T0]], 1 +; CHECK-NEXT: [[ISZERO:%.*]] = icmp ne i32 [[X]], 0 +; CHECK-NEXT: [[R:%.*]] = or i1 [[ISZERO]], [[CMP]] +; CHECK-NEXT: ret i1 [[R]] +; + %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) + %cmp = icmp ne i32 %t0, 1 + %iszero = icmp ne i32 %x, 0 + %r = or i1 %iszero, %cmp + ret i1 %r +} + define i1 @is_pow2or0_ctpop_wrong_pred2_logical(i32 %x) { ; CHECK-LABEL: @is_pow2or0_ctpop_wrong_pred2_logical( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] @@ -932,8 +947,8 @@ define <2 x i1> @is_pow2or0_ctpop_commute_vec_wrong_pred3(<2 x i8> %x) { define i1 @isnot_pow2nor0_ctpop(i32 %x) { ; CHECK-LABEL: @isnot_pow2nor0_ctpop( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ugt i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ne i32 %t0, 1 @@ -945,8 +960,8 @@ define i1 @isnot_pow2nor0_ctpop(i32 %x) { define i1 @isnot_pow2nor0_ctpop_swap_cmp(i32 %x) { ; CHECK-LABEL: @isnot_pow2nor0_ctpop_swap_cmp( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ugt i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ne i32 %t0, 1 @@ -958,8 +973,8 @@ define i1 @isnot_pow2nor0_ctpop_swap_cmp(i32 %x) { define i1 @isnot_pow2nor0_ctpop_logical(i32 %x) { ; CHECK-LABEL: @isnot_pow2nor0_ctpop_logical( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] -; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ugt i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ne i32 %t0, 1 @@ -971,8 +986,8 @@ define i1 @isnot_pow2nor0_ctpop_logical(i32 %x) { define <2 x i1> @isnot_pow2nor0_ctpop_commute_vec(<2 x i8> %x) { ; CHECK-LABEL: @isnot_pow2nor0_ctpop_commute_vec( ; CHECK-NEXT: [[T0:%.*]] = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt <2 x i8> [[T0]], -; CHECK-NEXT: ret <2 x i1> [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ugt <2 x i8> [[T0]], +; CHECK-NEXT: ret <2 x i1> [[R]] ; %t0 = tail call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %x) %cmp = icmp ne <2 x i8> %t0, @@ -991,8 +1006,8 @@ define i1 @isnot_pow2nor0_ctpop_extra_uses(i32 %x) { ; CHECK-NEXT: call void @use_i1(i1 [[CMP]]) ; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0 ; CHECK-NEXT: call void @use_i1(i1 [[NOTZERO]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ugt i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) call void @use(i32 %t0) @@ -1012,8 +1027,8 @@ define i1 @isnot_pow2nor0_ctpop_logical_extra_uses(i32 %x) { ; CHECK-NEXT: call void @use_i1(i1 [[CMP]]) ; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0 ; CHECK-NEXT: call void @use_i1(i1 [[NOTZERO]]) -; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[T0]], 1 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[R:%.*]] = icmp ugt i32 [[T0]], 1 +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) call void @use(i32 %t0) @@ -1089,6 +1104,21 @@ define i1 @isnot_pow2nor0_ctpop_wrong_pred1(i32 %x) { ret i1 %r } +define i1 @isnot_pow2nor0_ctpop_wrong_pred2(i32 %x) { +; CHECK-LABEL: @isnot_pow2nor0_ctpop_wrong_pred2( +; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[T0]], 1 +; CHECK-NEXT: [[NOTZERO:%.*]] = icmp eq i32 [[X]], 0 +; CHECK-NEXT: [[R:%.*]] = and i1 [[NOTZERO]], [[CMP]] +; CHECK-NEXT: ret i1 [[R]] +; + %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) + %cmp = icmp eq i32 %t0, 1 + %notzero = icmp eq i32 %x, 0 + %r = and i1 %notzero, %cmp + ret i1 %r +} + define i1 @isnot_pow2nor0_ctpop_wrong_pred2_logical(i32 %x) { ; CHECK-LABEL: @isnot_pow2nor0_ctpop_wrong_pred2_logical( ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] From 03c45f14bf639c7d2346b956cd1ec61a669236e0 Mon Sep 17 00:00:00 2001 From: Ben Mudd Date: Thu, 19 Jan 2023 15:05:04 +0000 Subject: [PATCH 03/61] [Dexter] Add on_line parameter to DexExpectStepOrder command DexExpectStepOrder uses the line to expect a debugger step from the actual line of the command in the Dexter source file. Now Dexter scripts have mainly moved to thier own script files instead of the actual source, there should be a option to override this behaviour to choose your own debugger step location. Reviewed By: Orlando Differential Revision: https://reviews.llvm.org/D142099 --- .../debuginfo-tests/dexter/Commands.md | 7 +++++-- .../dex/command/commands/DexExpectStepOrder.py | 11 ++++++++++- .../DebuggerControllers/ControllerHelpers.py | 4 ++-- .../dexter/dex/heuristic/Heuristic.py | 2 +- .../commands/perfect/expect_step_order.cpp | 14 +++++++++++--- 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/cross-project-tests/debuginfo-tests/dexter/Commands.md b/cross-project-tests/debuginfo-tests/dexter/Commands.md index d5bd999f5adf02..a98261a5000949 100644 --- a/cross-project-tests/debuginfo-tests/dexter/Commands.md +++ b/cross-project-tests/debuginfo-tests/dexter/Commands.md @@ -101,13 +101,16 @@ frame.
--- ## DexExpectStepOrder - DexExpectStepOrder(*order) + DexExpectStepOrder(*order [,**on_line]) Arg list: order (int): One or more indices. + Keyword args: + on_line (int): Expect this line to be stepped on in the order given. + ### Description -Expect the line every `DexExpectStepOrder` is found on to be stepped on in +Expect the line every `DexExpectStepOrder` is found on, or given from `on_line`, to be stepped on in `order`. Each instance must have a set of unique ascending indices. ### Heuristic diff --git a/cross-project-tests/debuginfo-tests/dexter/dex/command/commands/DexExpectStepOrder.py b/cross-project-tests/debuginfo-tests/dexter/dex/command/commands/DexExpectStepOrder.py index 700dc542043187..d5cfc3c82f4150 100644 --- a/cross-project-tests/debuginfo-tests/dexter/dex/command/commands/DexExpectStepOrder.py +++ b/cross-project-tests/debuginfo-tests/dexter/dex/command/commands/DexExpectStepOrder.py @@ -18,10 +18,16 @@ class DexExpectStepOrder(CommandBase): See Commands.md for more info. """ - def __init__(self, *args): + def __init__(self, *args, **kwargs): if not args: raise TypeError('Need at least one order number') + if 'on_line' in kwargs: + try: + on_line = kwargs.pop('on_line') + self.on_line = int(on_line) + except ValueError: + raise ValueError('on_line value \'{0}\' cannot be parsed to an integer'.format(on_line)) self.sequence = [int(x) for x in args] super(DexExpectStepOrder, self).__init__() @@ -29,6 +35,9 @@ def __init__(self, *args): def get_name(): return __class__.__name__ + def get_line(self): + return self.on_line if hasattr(self, 'on_line') else self.lineno + def eval(self, step_info): return {'DexExpectStepOrder': ValueIR(expression=str(step_info.current_location.lineno), value=str(step_info.step_index), type_name=None, diff --git a/cross-project-tests/debuginfo-tests/dexter/dex/debugger/DebuggerControllers/ControllerHelpers.py b/cross-project-tests/debuginfo-tests/dexter/dex/debugger/DebuggerControllers/ControllerHelpers.py index 014f15da2af07d..8044f3982f40e6 100644 --- a/cross-project-tests/debuginfo-tests/dexter/dex/debugger/DebuggerControllers/ControllerHelpers.py +++ b/cross-project-tests/debuginfo-tests/dexter/dex/debugger/DebuggerControllers/ControllerHelpers.py @@ -19,8 +19,8 @@ def in_source_file(source_files, step_info): for f in source_files) def have_hit_line(watch, loc): - if hasattr(watch, '_on_line'): - return watch._on_line == loc.lineno + if hasattr(watch, 'on_line'): + return watch.on_line == loc.lineno elif hasattr(watch, '_from_line'): return watch._from_line <= loc.lineno and watch._to_line >= loc.lineno elif watch.lineno == loc.lineno: diff --git a/cross-project-tests/debuginfo-tests/dexter/dex/heuristic/Heuristic.py b/cross-project-tests/debuginfo-tests/dexter/dex/heuristic/Heuristic.py index 1582e7b3b706d4..52ba7e1e897c6d 100644 --- a/cross-project-tests/debuginfo-tests/dexter/dex/heuristic/Heuristic.py +++ b/cross-project-tests/debuginfo-tests/dexter/dex/heuristic/Heuristic.py @@ -253,7 +253,7 @@ def __init__(self, context, steps): cmds = steps.commands['DexExpectStepOrder'] # Form a list of which line/cmd we _should_ have seen - cmd_num_lst = [(x, c.lineno) for c in cmds + cmd_num_lst = [(x, c.get_line()) for c in cmds for x in c.sequence] # Order them by the sequence number cmd_num_lst.sort(key=lambda t: t[0]) diff --git a/cross-project-tests/debuginfo-tests/dexter/feature_tests/commands/perfect/expect_step_order.cpp b/cross-project-tests/debuginfo-tests/dexter/feature_tests/commands/perfect/expect_step_order.cpp index cbcb5063c10f1c..c6e992df7ad9c7 100644 --- a/cross-project-tests/debuginfo-tests/dexter/feature_tests/commands/perfect/expect_step_order.cpp +++ b/cross-project-tests/debuginfo-tests/dexter/feature_tests/commands/perfect/expect_step_order.cpp @@ -9,8 +9,16 @@ int main() { - volatile int x = 1; // DexExpectStepOrder(1) - volatile int y = 1; // DexExpectStepOrder(2) - volatile int z = 1; // DexExpectStepOrder(3) + volatile int a = 1; // DexExpectStepOrder(1) + volatile int b = 1; // DexExpectStepOrder(2) + volatile int c = 1; // DexExpectStepOrder(3) + + volatile int x = 1; + volatile int y = 1; + volatile int z = 1; return 0; } + +// DexExpectStepOrder(4, on_line=16); +// DexExpectStepOrder(5, on_line=17); +// DexExpectStepOrder(6, on_line=18); From e0b3b6cec7cf1c03bde8b65dcd2f9233839ad0a9 Mon Sep 17 00:00:00 2001 From: Guilherme Valarini Date: Thu, 19 Jan 2023 12:11:20 -0300 Subject: [PATCH 04/61] [OpenMP][Fix] Track all threads that may delete an entry The entries inside a "target data end" is processed in three steps: 1. Query internal data maps for the entries and dispatch any necessary device-side operations (i.e., data retrieval); 2. Synchronize the such operations; 3. Update the host-side pointers and remove any entry which reference counter reached zero. Such steps may be executed by multiple threads which may even operate on the same entries. The current implementation (D121058) tries to synchronize these threads by tracking the "owner" for the deletion of each entry using their thread ID. Unfortunately it may failed to do so because of the following reasons: 1. The owner is always assigned at the first step only if the reference count is 0 when the map is queried. This does not work when such owner thread is faster than a previous one that is also processing the same entry on another "target data end", leading to user-after-free problems. 2. The entry is only added for post-processing (step 3) if its reference count was 0 at query time (step 1). This does not allow for threads to exchange responsibility for the deletion, leading again to user-after-free problems. 3. An entry may appear multiple times in the arguments array of a "target data end", which may lead to deleting the entry prematurely, leading, again, to user-after-free problems. This patch addresses these problems by tracking all the threads that are using an entry at "target data end" region through a counter, ensuring only the last one deletes it when needed. It also ensures that all entries that are successfully found inside the data maps in step 1 are also processed in step 3, regardless if their reference count was zeroed or not at query time. This ensures the deletion ownership may be passed to any thread that is using such entry. Reviewed By: ye-luo Differential Revision: https://reviews.llvm.org/D132676 --- openmp/libomptarget/include/device.h | 73 +++++--- openmp/libomptarget/src/device.cpp | 70 ++++---- openmp/libomptarget/src/omptarget.cpp | 168 +++++++++--------- .../test/mapping/map_back_race.cpp | 13 -- 4 files changed, 162 insertions(+), 162 deletions(-) diff --git a/openmp/libomptarget/include/device.h b/openmp/libomptarget/include/device.h index c7513e522ff667..a223475515bff7 100644 --- a/openmp/libomptarget/include/device.h +++ b/openmp/libomptarget/include/device.h @@ -61,8 +61,7 @@ struct HostDataToTargetTy { struct StatesTy { StatesTy(uint64_t DRC, uint64_t HRC) : DynRefCount(DRC), HoldRefCount(HRC), - MayContainAttachedPointers(false), DeleteThreadId(std::thread::id()) { - } + MayContainAttachedPointers(false) {} /// The dynamic reference count is the standard reference count as of OpenMP /// 4.5. The hold reference count is an OpenMP extension for the sake of /// OpenACC support. @@ -101,13 +100,10 @@ struct HostDataToTargetTy { /// should be written as void *Event[2]. void *Event = nullptr; - /// The id of the thread responsible for deleting this entry. This thread - /// set the reference count to zero *last*. Other threads might reuse the - /// entry while it is marked for deletion but not yet deleted (e.g., the - /// data is still being moved back). If another thread reuses the entry we - /// will have a non-zero reference count *or* the thread will have changed - /// this id, effectively taking over deletion responsibility. - std::thread::id DeleteThreadId; + /// Number of threads currently holding a reference to the entry at a + /// targetDataEnd. This is used to ensure that only the last thread that + /// references this entry will actually delete it. + int32_t DataEndThreadCount = 0; }; // When HostDataToTargetTy is used by std::set, std::set::iterator is const // use unique_ptr to make States mutable. @@ -148,13 +144,17 @@ struct HostDataToTargetTy { /// Returns OFFLOAD_FAIL if something went wrong, OFFLOAD_SUCCESS otherwise. int addEventIfNecessary(DeviceTy &Device, AsyncInfoTy &AsyncInfo) const; - /// Indicate that the current thread expected to delete this entry. - void setDeleteThreadId() const { - States->DeleteThreadId = std::this_thread::get_id(); + /// Functions that manages the number of threads referencing the entry in a + /// targetDataEnd. + void incDataEndThreadCount() { ++States->DataEndThreadCount; } + + [[nodiscard]] int32_t decDataEndThreadCount() { + return --States->DataEndThreadCount; } - /// Return the thread id of the thread expected to delete this entry. - std::thread::id getDeleteThreadId() const { return States->DeleteThreadId; } + [[nodiscard]] int32_t getDataEndThreadCount() const { + return States->DataEndThreadCount; + } /// Set the event bound to this data map. void setEvent(void *Event) const { States->Event = Event; } @@ -377,20 +377,37 @@ struct DeviceTy { void *getTgtPtrBegin(HDTTMapAccessorTy &HDTTMap, void *HstPtrBegin, int64_t Size); - TargetPointerResultTy getTgtPtrBegin(void *HstPtrBegin, int64_t Size, - bool &IsLast, bool UpdateRefCount, - bool UseHoldRefCount, bool &IsHostPtr, - bool MustContain = false, - bool ForceDelete = false); - - /// Deallocate \p LR and remove the entry. Assume the total reference count is - /// zero and the calling thread is the deleting thread for \p LR. \p HDTTMap - /// ensure the caller holds exclusive access and can modify the map. Return \c - /// OFFLOAD_SUCCESS if the map entry existed, and return \c OFFLOAD_FAIL if - /// not. It is the caller's responsibility to skip calling this function if - /// the map entry is not expected to exist because \p HstPtrBegin uses shared - /// memory. - int deallocTgtPtr(HDTTMapAccessorTy &HDTTMap, LookupResult LR, int64_t Size); + /// Return the target pointer begin (where the data will be moved). + /// Used by targetDataBegin, targetDataEnd, targetDataUpdate and target. + /// - \p UpdateRefCount and \p UseHoldRefCount controls which and if the entry + /// reference counters will be decremented. + /// - \p MustContain enforces that the query must not extend beyond an already + /// mapped entry to be valid. + /// - \p ForceDelete deletes the entry regardless of its reference counting + /// (unless it is infinite). + /// - \p FromDataEnd tracks the number of threads referencing the entry at + /// targetDataEnd for delayed deletion purpose. + [[nodiscard]] TargetPointerResultTy + getTgtPtrBegin(void *HstPtrBegin, int64_t Size, bool &IsLast, + bool UpdateRefCount, bool UseHoldRefCount, bool &IsHostPtr, + bool MustContain = false, bool ForceDelete = false, + bool FromDataEnd = false); + + /// Remove the \p Entry from the data map. Expect the entry's total reference + /// count to be zero and the caller thread to be the last one using it. \p + /// HDTTMap ensure the caller holds exclusive access and can modify the map. + /// Return \c OFFLOAD_SUCCESS if the map entry existed, and return \c + /// OFFLOAD_FAIL if not. It is the caller's responsibility to skip calling + /// this function if the map entry is not expected to exist because \p + /// HstPtrBegin uses shared memory. + [[nodiscard]] int eraseMapEntry(HDTTMapAccessorTy &HDTTMap, + HostDataToTargetTy *Entry, int64_t Size); + + /// Deallocate the \p Entry from the device memory and delete it. Return \c + /// OFFLOAD_SUCCESS if the deallocation operations executed successfully, and + /// return \c OFFLOAD_FAIL otherwise. + [[nodiscard]] int deallocTgtPtrAndEntry(HostDataToTargetTy *Entry, + int64_t Size); int associatePtr(void *HstPtrBegin, void *TgtPtrBegin, int64_t Size); int disassociatePtr(void *HstPtrBegin); diff --git a/openmp/libomptarget/src/device.cpp b/openmp/libomptarget/src/device.cpp index 6c09ec1bd5c341..65c8a325128586 100644 --- a/openmp/libomptarget/src/device.cpp +++ b/openmp/libomptarget/src/device.cpp @@ -359,13 +359,11 @@ TargetPointerResultTy DeviceTy::getTargetPointer( return {{IsNew, IsHostPtr, IsPresent}, Entry, TargetPointer}; } -// Used by targetDataBegin, targetDataEnd, targetDataUpdate and target. -// Return the target pointer begin (where the data will be moved). -// Decrement the reference counter if called from targetDataEnd. TargetPointerResultTy DeviceTy::getTgtPtrBegin(void *HstPtrBegin, int64_t Size, bool &IsLast, bool UpdateRefCount, bool UseHoldRefCount, - bool &IsHostPtr, bool MustContain, bool ForceDelete) { + bool &IsHostPtr, bool MustContain, bool ForceDelete, + bool FromDataEnd) { HDTTMapAccessorTy HDTTMap = HostDataToTargetMap.getExclusiveAccessor(); void *TargetPointer = NULL; @@ -386,15 +384,18 @@ DeviceTy::getTgtPtrBegin(void *HstPtrBegin, int64_t Size, bool &IsLast, "expected correct IsLast prediction for reset"); } + // Increment the number of threads that is using the entry on a + // targetDataEnd, tracking the number of possible "deleters". A thread may + // come to own the entry deletion even if it was not the last one querying + // for it. Thus, we must track every query on targetDataEnds to ensure only + // the last thread that holds a reference to an entry actually deletes it. + if (FromDataEnd) + HT.incDataEndThreadCount(); + const char *RefCountAction; if (!UpdateRefCount) { RefCountAction = " (update suppressed)"; } else if (IsLast) { - // Mark the entry as to be deleted by this thread. Another thread might - // reuse the entry and take "ownership" for the deletion while this thread - // is waiting for data transfers. That is fine and the current thread will - // simply skip the deletion step then. - HT.setDeleteThreadId(); HT.decRefCount(UseHoldRefCount); assert(HT.getTotalRefCount() == 0 && "Expected zero reference count when deletion is scheduled"); @@ -450,41 +451,42 @@ void *DeviceTy::getTgtPtrBegin(HDTTMapAccessorTy &HDTTMap, void *HstPtrBegin, return NULL; } -int DeviceTy::deallocTgtPtr(HDTTMapAccessorTy &HDTTMap, LookupResult LR, - int64_t Size) { - // Check if the pointer is contained in any sub-nodes. - if (!(LR.Flags.IsContained || LR.Flags.ExtendsBefore || - LR.Flags.ExtendsAfter)) { - REPORT("Section to delete (hst addr " DPxMOD ") does not exist in the" - " allocated memory\n", - DPxPTR(LR.Entry->HstPtrBegin)); - return OFFLOAD_FAIL; - } - - auto &HT = *LR.Entry; - // Verify this thread is still in charge of deleting the entry. - assert(HT.getTotalRefCount() == 0 && - HT.getDeleteThreadId() == std::this_thread::get_id() && +int DeviceTy::eraseMapEntry(HDTTMapAccessorTy &HDTTMap, + HostDataToTargetTy *Entry, int64_t Size) { + assert(Entry && "Trying to delete a null entry from the HDTT map."); + assert(Entry->getTotalRefCount() == 0 && Entry->getDataEndThreadCount() == 0 && "Trying to delete entry that is in use or owned by another thread."); - DP("Deleting tgt data " DPxMOD " of size %" PRId64 "\n", - DPxPTR(HT.TgtPtrBegin), Size); - deleteData((void *)HT.TgtPtrBegin); INFO(OMP_INFOTYPE_MAPPING_CHANGED, DeviceID, "Removing map entry with HstPtrBegin=" DPxMOD ", TgtPtrBegin=" DPxMOD ", Size=%" PRId64 ", Name=%s\n", - DPxPTR(HT.HstPtrBegin), DPxPTR(HT.TgtPtrBegin), Size, - (HT.HstPtrName) ? getNameFromMapping(HT.HstPtrName).c_str() : "unknown"); - void *Event = LR.Entry->getEvent(); - HDTTMap->erase(LR.Entry); - delete LR.Entry; + DPxPTR(Entry->HstPtrBegin), DPxPTR(Entry->TgtPtrBegin), Size, + (Entry->HstPtrName) ? getNameFromMapping(Entry->HstPtrName).c_str() + : "unknown"); + + if (HDTTMap->erase(Entry) == 0) { + REPORT("Trying to remove a non-existent map entry\n"); + return OFFLOAD_FAIL; + } - int Ret = OFFLOAD_SUCCESS; + return OFFLOAD_SUCCESS; +} + +int DeviceTy::deallocTgtPtrAndEntry(HostDataToTargetTy *Entry, int64_t Size) { + assert(Entry && "Trying to deallocate a null entry."); + + DP("Deleting tgt data " DPxMOD " of size %" PRId64 "\n", + DPxPTR(Entry->TgtPtrBegin), Size); + + void *Event = Entry->getEvent(); if (Event && destroyEvent(Event) != OFFLOAD_SUCCESS) { REPORT("Failed to destroy event " DPxMOD "\n", DPxPTR(Event)); - Ret = OFFLOAD_FAIL; + return OFFLOAD_FAIL; } + int Ret = deleteData((void *)Entry->TgtPtrBegin); + delete Entry; + return Ret; } diff --git a/openmp/libomptarget/src/omptarget.cpp b/openmp/libomptarget/src/omptarget.cpp index 27eca027d96c53..580b0c88e4375e 100644 --- a/openmp/libomptarget/src/omptarget.cpp +++ b/openmp/libomptarget/src/omptarget.cpp @@ -751,17 +751,16 @@ struct PostProcessingInfo { /// The mapping type (bitfield). int64_t ArgType; + /// Index of the argument in the data mapping scheme. + int32_t ArgIndex; + /// The target pointer information. TargetPointerResultTy TPR; - /// Are we expecting to delete this entry or not. Even if set, we might not - /// delete the entry if another thread reused the entry in the meantime. - bool DelEntry; - - PostProcessingInfo(void *HstPtr, int64_t Size, int64_t ArgType, bool DelEntry, - TargetPointerResultTy TPR) - : HstPtrBegin(HstPtr), DataSize(Size), ArgType(ArgType), TPR(TPR), - DelEntry(DelEntry) {} + PostProcessingInfo(void *HstPtr, int64_t Size, int64_t ArgType, + int32_t ArgIndex, TargetPointerResultTy TPR) + : HstPtrBegin(HstPtr), DataSize(Size), ArgType(ArgType), + ArgIndex(ArgIndex), TPR(TPR) {} }; /// Apply \p CB to the shadow map pointer entries in the range \p Begin, to @@ -811,42 +810,55 @@ static void applyToShadowMapEntries(DeviceTy &Device, CBTy CB, void *Begin, /// data end. This includes the update of pointers at the host and removal of /// device buffer when needed. It returns OFFLOAD_FAIL or OFFLOAD_SUCCESS /// according to the successfulness of the operations. -static int +[[nodiscard]] static int postProcessingTargetDataEnd(DeviceTy *Device, SmallVector EntriesInfo, - void *FromMapperBase) { + bool FromMapper) { int Ret = OFFLOAD_SUCCESS; + void *FromMapperBase = nullptr; + + for (auto &[HstPtrBegin, DataSize, ArgType, ArgIndex, TPR] : EntriesInfo) { + bool DelEntry = !TPR.isHostPointer(); + + // If the last element from the mapper (for end transfer args comes in + // reverse order), do not remove the partial entry, the parent struct still + // exists. + if ((ArgType & OMP_TGT_MAPTYPE_MEMBER_OF) && + !(ArgType & OMP_TGT_MAPTYPE_PTR_AND_OBJ)) { + DelEntry = false; // protect parent struct from being deallocated + } + + if (DelEntry && FromMapper && ArgIndex == 0) { + DelEntry = false; + FromMapperBase = HstPtrBegin; + } - for (PostProcessingInfo &Info : EntriesInfo) { // If we marked the entry to be deleted we need to verify no other // thread reused it by now. If deletion is still supposed to happen by // this thread LR will be set and exclusive access to the HDTT map // will avoid another thread reusing the entry now. Note that we do - // not request (exclusive) access to the HDTT map if Info.DelEntry is + // not request (exclusive) access to the HDTT map if DelEntry is // not set. - LookupResult LR; DeviceTy::HDTTMapAccessorTy HDTTMap = - Device->HostDataToTargetMap.getExclusiveAccessor(!Info.DelEntry); - - if (Info.DelEntry) { - LR = Device->lookupMapping(HDTTMap, Info.HstPtrBegin, Info.DataSize); - if (LR.Entry->getTotalRefCount() != 0 || - LR.Entry->getDeleteThreadId() != std::this_thread::get_id()) { - // The thread is not in charge of deletion anymore. Give up access - // to the HDTT map and unset the deletion flag. - HDTTMap.destroy(); - Info.DelEntry = false; - } + Device->HostDataToTargetMap.getExclusiveAccessor(!DelEntry); + + const bool IsNotLastUser = TPR.Entry->decDataEndThreadCount() != 0; + if (DelEntry && (TPR.Entry->getTotalRefCount() != 0 || IsNotLastUser)) { + // The thread is not in charge of deletion anymore. Give up access + // to the HDTT map and unset the deletion flag. + HDTTMap.destroy(); + DelEntry = false; } // If we copied back to the host a struct/array containing pointers, // we need to restore the original host pointer values from their // shadow copies. If the struct is going to be deallocated, remove any // remaining shadow pointer entries for this struct. + const bool HasFrom = ArgType & OMP_TGT_MAPTYPE_FROM; auto CB = [&](ShadowPtrListTy::iterator &Itr) { // If we copied the struct to the host, we need to restore the // pointer. - if (Info.ArgType & OMP_TGT_MAPTYPE_FROM) { + if (HasFrom) { void **ShadowHstPtrAddr = (void **)Itr->first; *ShadowHstPtrAddr = Itr->second.HstPtrVal; DP("Restoring original host pointer value " DPxMOD " for host " @@ -854,7 +866,7 @@ postProcessingTargetDataEnd(DeviceTy *Device, DPxPTR(Itr->second.HstPtrVal), DPxPTR(ShadowHstPtrAddr)); } // If the struct is to be deallocated, remove the shadow entry. - if (Info.DelEntry) { + if (DelEntry) { DP("Removing shadow pointer " DPxMOD "\n", DPxPTR((void **)Itr->first)); auto OldItr = Itr; Itr++; @@ -864,19 +876,20 @@ postProcessingTargetDataEnd(DeviceTy *Device, } return OFFLOAD_SUCCESS; }; - applyToShadowMapEntries(*Device, CB, Info.HstPtrBegin, Info.DataSize, - Info.TPR); + applyToShadowMapEntries(*Device, CB, HstPtrBegin, DataSize, TPR); + + if (!DelEntry || (FromMapperBase && FromMapperBase == HstPtrBegin)) + continue; // If we are deleting the entry the DataMapMtx is locked and we own // the entry. - if (Info.DelEntry) { - if (!FromMapperBase || FromMapperBase != Info.HstPtrBegin) - Ret = Device->deallocTgtPtr(HDTTMap, LR, Info.DataSize); - - if (Ret != OFFLOAD_SUCCESS) { - REPORT("Deallocating data from device failed.\n"); - break; - } + Ret = Device->eraseMapEntry(HDTTMap, TPR.Entry, DataSize); + // Entry is already remove from the map, we can unlock it now. + HDTTMap.destroy(); + Ret |= Device->deallocTgtPtrAndEntry(TPR.Entry, DataSize); + if (Ret != OFFLOAD_SUCCESS) { + REPORT("Deallocating data from device failed.\n"); + break; } } @@ -950,7 +963,7 @@ int targetDataEnd(ident_t *Loc, DeviceTy &Device, int32_t ArgNum, // If PTR_AND_OBJ, HstPtrBegin is address of pointee TargetPointerResultTy TPR = Device.getTgtPtrBegin( HstPtrBegin, DataSize, IsLast, UpdateRef, HasHoldModifier, IsHostPtr, - !IsImplicit, ForceDelete); + !IsImplicit, ForceDelete, /*FromDataEnd=*/true); void *TgtPtrBegin = TPR.TargetPointer; if (!TPR.isPresent() && !TPR.isHostPointer() && (DataSize || HasPresentModifier)) { @@ -991,61 +1004,42 @@ int targetDataEnd(ident_t *Loc, DeviceTy &Device, int32_t ArgNum, if (!TPR.isPresent()) continue; - bool DelEntry = IsLast; - - // If the last element from the mapper (for end transfer args comes in - // reverse order), do not remove the partial entry, the parent struct still - // exists. - if ((ArgTypes[I] & OMP_TGT_MAPTYPE_MEMBER_OF) && - !(ArgTypes[I] & OMP_TGT_MAPTYPE_PTR_AND_OBJ)) { - DelEntry = false; // protect parent struct from being deallocated - } - - if ((ArgTypes[I] & OMP_TGT_MAPTYPE_FROM) || DelEntry) { - // Move data back to the host - if (ArgTypes[I] & OMP_TGT_MAPTYPE_FROM) { - bool Always = ArgTypes[I] & OMP_TGT_MAPTYPE_ALWAYS; - if ((Always || IsLast) && !IsHostPtr) { - DP("Moving %" PRId64 " bytes (tgt:" DPxMOD ") -> (hst:" DPxMOD ")\n", - DataSize, DPxPTR(TgtPtrBegin), DPxPTR(HstPtrBegin)); - - std::lock_guard LG(*TPR.Entry); - // Wait for any previous transfer if an event is present. - if (void *Event = TPR.Entry->getEvent()) { - if (Device.waitEvent(Event, AsyncInfo) != OFFLOAD_SUCCESS) { - REPORT("Failed to wait for event " DPxMOD ".\n", DPxPTR(Event)); - return OFFLOAD_FAIL; - } - } - - Ret = Device.retrieveData(HstPtrBegin, TgtPtrBegin, DataSize, - AsyncInfo); - if (Ret != OFFLOAD_SUCCESS) { - REPORT("Copying data from device failed.\n"); - return OFFLOAD_FAIL; - } - - // As we are expecting to delete the entry the d2h copy might race - // with another one that also tries to delete the entry. This happens - // as the entry can be reused and the reuse might happen after the - // copy-back was issued but before it completed. Since the reuse might - // also copy-back a value we would race. - if (IsLast) { - if (TPR.Entry->addEventIfNecessary(Device, AsyncInfo) != - OFFLOAD_SUCCESS) - return OFFLOAD_FAIL; - } + // Move data back to the host + const bool HasAlways = ArgTypes[I] & OMP_TGT_MAPTYPE_ALWAYS; + const bool HasFrom = ArgTypes[I] & OMP_TGT_MAPTYPE_FROM; + if (HasFrom && (HasAlways || IsLast) && !IsHostPtr) { + DP("Moving %" PRId64 " bytes (tgt:" DPxMOD ") -> (hst:" DPxMOD ")\n", + DataSize, DPxPTR(TgtPtrBegin), DPxPTR(HstPtrBegin)); + + std::lock_guard LG(*TPR.Entry); + // Wait for any previous transfer if an event is present. + if (void *Event = TPR.Entry->getEvent()) { + if (Device.waitEvent(Event, AsyncInfo) != OFFLOAD_SUCCESS) { + REPORT("Failed to wait for event " DPxMOD ".\n", DPxPTR(Event)); + return OFFLOAD_FAIL; } } - if (DelEntry && FromMapper && I == 0) { - DelEntry = false; - FromMapperBase = HstPtrBegin; + + Ret = Device.retrieveData(HstPtrBegin, TgtPtrBegin, DataSize, AsyncInfo); + if (Ret != OFFLOAD_SUCCESS) { + REPORT("Copying data from device failed.\n"); + return OFFLOAD_FAIL; } - // Add pointer to the buffer for post-synchronize processing. - PostProcessingPtrs.emplace_back(HstPtrBegin, DataSize, ArgTypes[I], - DelEntry && !IsHostPtr, TPR); + // As we are expecting to delete the entry the d2h copy might race + // with another one that also tries to delete the entry. This happens + // as the entry can be reused and the reuse might happen after the + // copy-back was issued but before it completed. Since the reuse might + // also copy-back a value we would race. + if (IsLast) { + if (TPR.Entry->addEventIfNecessary(Device, AsyncInfo) != + OFFLOAD_SUCCESS) + return OFFLOAD_FAIL; + } } + + // Add pointer to the buffer for post-synchronize processing. + PostProcessingPtrs.emplace_back(HstPtrBegin, DataSize, ArgTypes[I], I, TPR); } // Add post-processing functions diff --git a/openmp/libomptarget/test/mapping/map_back_race.cpp b/openmp/libomptarget/test/mapping/map_back_race.cpp index dd8ed10f21367c..8a988d3be3b4f9 100644 --- a/openmp/libomptarget/test/mapping/map_back_race.cpp +++ b/openmp/libomptarget/test/mapping/map_back_race.cpp @@ -2,19 +2,6 @@ // Taken from https://github.com/llvm/llvm-project/issues/54216 -// UNSUPPORTED: aarch64-unknown-linux-gnu -// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO -// UNSUPPORTED: amdgcn-amd-amdhsa -// UNSUPPORTED: amdgcn-amd-amdhsa-LTO -// UNSUPPORTED: powerpc64le-ibm-linux-gnu -// UNSUPPORTED: powerpc64le-ibm-linux-gnu-LTO -// UNSUPPORTED: powerpc64-ibm-linux-gnu -// UNSUPPORTED: powerpc64-ibm-linux-gnu-LTO -// UNSUPPORTED: x86_64-pc-linux-gnu -// UNSUPPORTED: x86_64-pc-linux-gnu-LTO -// UNSUPPORTED: nvptx64-nvidia-cuda -// UNSUPPORTED: nvptx64-nvidia-cuda-LTO - #include #include #include From ee700dec052a0336798fb2570faec31719b53f8d Mon Sep 17 00:00:00 2001 From: David Green Date: Thu, 19 Jan 2023 15:37:50 +0000 Subject: [PATCH 05/61] [AArch64] Armv9-A implies FP16 As Armv9-a implies SVE2 it implies SVE (added in D141411) and so it should also imply FP16, which this patch adds. This helps get the target features correct when using `target("arch=armv9-a")` attributes. There is also an adjustment to AssertSameExtensionFlags in this patch to make it print cpu names, useful when the TargetParser unit tests are run through lit to distinguish which cpu is failing. Differential Revision: https://reviews.llvm.org/D142087 --- clang/test/CodeGen/aarch64-targetattr-arch.c | 7 ++ clang/test/CodeGen/aarch64-targetattr.c | 2 +- .../llvm/TargetParser/AArch64TargetParser.h | 4 +- .../TargetParser/TargetParserTest.cpp | 70 +++++++++++-------- 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/clang/test/CodeGen/aarch64-targetattr-arch.c b/clang/test/CodeGen/aarch64-targetattr-arch.c index d70d68c6e6a215..86ddeac0b9e626 100644 --- a/clang/test/CodeGen/aarch64-targetattr-arch.c +++ b/clang/test/CodeGen/aarch64-targetattr-arch.c @@ -8,6 +8,7 @@ #endif #include +#include #include __attribute__((target("arch=armv8.1-a"))) @@ -22,6 +23,12 @@ svint8_t test_svadd_attr(svbool_t pg, svint8_t op1, svint8_t op2) return svadd_s8_z(pg, op1, op2); } +__attribute__((target("arch=armv9-a"))) +float16_t test_fp16_on_v9(float16_t x, float16_t y) +{ + return vabdh_f16(x, y); +} + void test_errors() { #ifdef HAS8 diff --git a/clang/test/CodeGen/aarch64-targetattr.c b/clang/test/CodeGen/aarch64-targetattr.c index 8730ecfbd1343e..6ce3c72c44b72a 100644 --- a/clang/test/CodeGen/aarch64-targetattr.c +++ b/clang/test/CodeGen/aarch64-targetattr.c @@ -93,7 +93,7 @@ void nosimd() {} // CHECK: attributes #1 = { {{.*}} "target-features"="+crc,+fp-armv8,+fullfp16,+lse,+neon,+ras,+rdm,+sve,+v8.1a,+v8.2a,+v8a" } // CHECK: attributes #2 = { {{.*}} "target-features"="+crc,+fp-armv8,+fullfp16,+lse,+neon,+ras,+rdm,+sve,+sve2,+v8.1a,+v8.2a,+v8a" } // CHECK: attributes #3 = { {{.*}} "target-features"="+aes,+bf16,+crc,+dotprod,+fp-armv8,+fullfp16,+i8mm,+lse,+neon,+ras,+rcpc,+rdm,+sha2,+sha3,+sm4,+sve,+sve2,+v8.1a,+v8.2a,+v8.3a,+v8.4a,+v8.5a,+v8.6a,+v8a" } -// CHECK: attributes #4 = { {{.*}} "target-cpu"="cortex-a710" "target-features"="+bf16,+crc,+dotprod,+flagm,+fp-armv8,+fp16fml,+i8mm,+lse,+mte,+neon,+pauth,+ras,+rcpc,+rdm,+sb,+sve,+sve2,+sve2-bitperm" } +// CHECK: attributes #4 = { {{.*}} "target-cpu"="cortex-a710" "target-features"="+bf16,+crc,+dotprod,+flagm,+fp-armv8,+fp16fml,+fullfp16,+i8mm,+lse,+mte,+neon,+pauth,+ras,+rcpc,+rdm,+sb,+sve,+sve2,+sve2-bitperm" } // CHECK: attributes #5 = { {{.*}} "tune-cpu"="cortex-a710" } // CHECK: attributes #6 = { {{.*}} "target-cpu"="generic" } // CHECK: attributes #7 = { {{.*}} "tune-cpu"="generic" } diff --git a/llvm/include/llvm/TargetParser/AArch64TargetParser.h b/llvm/include/llvm/TargetParser/AArch64TargetParser.h index 3c416e4576e561..ff6a629c8faa03 100644 --- a/llvm/include/llvm/TargetParser/AArch64TargetParser.h +++ b/llvm/include/llvm/TargetParser/AArch64TargetParser.h @@ -314,7 +314,7 @@ inline constexpr ArchInfo ARMV8_6A = { VersionTuple{8, 6}, AProfile, "armv8.6-a inline constexpr ArchInfo ARMV8_7A = { VersionTuple{8, 7}, AProfile, "armv8.7-a", "+v8.7a", (ARMV8_6A.DefaultExts)}; inline constexpr ArchInfo ARMV8_8A = { VersionTuple{8, 8}, AProfile, "armv8.8-a", "+v8.8a", (ARMV8_7A.DefaultExts | AArch64::AEK_MOPS | AArch64::AEK_HBC)}; inline constexpr ArchInfo ARMV8_9A = { VersionTuple{8, 9}, AProfile, "armv8.9-a", "+v8.9a", (ARMV8_8A.DefaultExts)}; -inline constexpr ArchInfo ARMV9A = { VersionTuple{9, 0}, AProfile, "armv9-a", "+v9a", (BaseNoCrypto | AArch64::AEK_SVE | AArch64::AEK_SVE2)}; +inline constexpr ArchInfo ARMV9A = { VersionTuple{9, 0}, AProfile, "armv9-a", "+v9a", (BaseNoCrypto | AArch64::AEK_FP16 | AArch64::AEK_SVE | AArch64::AEK_SVE2)}; inline constexpr ArchInfo ARMV9_1A = { VersionTuple{9, 1}, AProfile, "armv9.1-a", "+v9.1a", (ARMV9A.DefaultExts | AArch64::AEK_BF16 | AArch64::AEK_I8MM)}; inline constexpr ArchInfo ARMV9_2A = { VersionTuple{9, 2}, AProfile, "armv9.2-a", "+v9.2a", (ARMV9_1A.DefaultExts)}; inline constexpr ArchInfo ARMV9_3A = { VersionTuple{9, 3}, AProfile, "armv9.3-a", "+v9.3a", (ARMV9_2A.DefaultExts | AArch64::AEK_MOPS | AArch64::AEK_HBC)}; @@ -509,4 +509,4 @@ uint64_t getCpuSupportsMask(ArrayRef FeatureStrs); } // namespace AArch64 } // namespace llvm -#endif \ No newline at end of file +#endif diff --git a/llvm/unittests/TargetParser/TargetParserTest.cpp b/llvm/unittests/TargetParser/TargetParserTest.cpp index 7bf20ef87436b3..39de9e25c8ccad 100644 --- a/llvm/unittests/TargetParser/TargetParserTest.cpp +++ b/llvm/unittests/TargetParser/TargetParserTest.cpp @@ -68,19 +68,26 @@ std::string FormatExtensionFlags(uint64_t Flags) { return llvm::join(Features, ", "); } -template -testing::AssertionResult -AssertSameExtensionFlags(const char *m_expr, const char *n_expr, - uint64_t ExpectedFlags, uint64_t GotFlags) { - if (ExpectedFlags == GotFlags) - return testing::AssertionSuccess(); - - return testing::AssertionFailure() << llvm::formatv( - "Expected extension flags: {0} ({1:x})\n" - " Got extension flags: {2} ({3:x})\n", - FormatExtensionFlags(ExpectedFlags), ExpectedFlags, - FormatExtensionFlags(GotFlags), GotFlags); -} +template struct AssertSameExtensionFlags { + AssertSameExtensionFlags(StringRef CPUName) : CPUName(CPUName) {} + + testing::AssertionResult operator()(const char *m_expr, const char *n_expr, + uint64_t ExpectedFlags, + uint64_t GotFlags) { + if (ExpectedFlags == GotFlags) + return testing::AssertionSuccess(); + + return testing::AssertionFailure() << llvm::formatv( + "CPU: {4}\n" + "Expected extension flags: {0} ({1:x})\n" + " Got extension flags: {2} ({3:x})\n", + FormatExtensionFlags(ExpectedFlags), ExpectedFlags, + FormatExtensionFlags(GotFlags), GotFlags, CPUName); + } + +private: + StringRef CPUName; +}; struct ARMCPUTestParams { ARMCPUTestParams(StringRef CPUName, StringRef ExpectedArch, @@ -116,8 +123,9 @@ TEST_P(ARMCPUTestFixture, ARMCPUTests) { EXPECT_EQ(params.ExpectedFPU, ARM::getFPUName(FPUKind)); uint64_t default_extensions = ARM::getDefaultExtensions(params.CPUName, AK); - EXPECT_PRED_FORMAT2(AssertSameExtensionFlags, - params.ExpectedFlags, default_extensions); + EXPECT_PRED_FORMAT2( + AssertSameExtensionFlags(params.CPUName), + params.ExpectedFlags, default_extensions); EXPECT_EQ(params.CPUAttr, ARM::getCPUAttr(AK)); } @@ -958,8 +966,9 @@ TEST_P(AArch64CPUTestFixture, testAArch64CPU) { uint64_t default_extensions = AArch64::getDefaultExtensions(params.CPUName, AI); - EXPECT_PRED_FORMAT2(AssertSameExtensionFlags, - params.ExpectedFlags, default_extensions); + EXPECT_PRED_FORMAT2( + AssertSameExtensionFlags(params.CPUName), + params.ExpectedFlags, default_extensions); } INSTANTIATE_TEST_SUITE_P( @@ -997,7 +1006,8 @@ INSTANTIATE_TEST_SUITE_P( AArch64::AEK_SVE | AArch64::AEK_SVE2 | AArch64::AEK_SVE2BITPERM | AArch64::AEK_PAUTH | AArch64::AEK_MTE | AArch64::AEK_SSBS | - AArch64::AEK_FP16FML | AArch64::AEK_SB, + AArch64::AEK_FP16 | AArch64::AEK_FP16FML | + AArch64::AEK_SB, "9-A"), ARMCPUTestParams("cortex-a57", "armv8-a", "crypto-neon-fp-armv8", AArch64::AEK_CRC | AArch64::AEK_CRYPTO | @@ -1076,17 +1086,16 @@ INSTANTIATE_TEST_SUITE_P( AArch64::AEK_PROFILE | AArch64::AEK_FLAGM | AArch64::AEK_PAUTH | AArch64::AEK_FP16FML, "8.2-A"), - ARMCPUTestParams("cortex-a710", "armv9-a", "neon-fp-armv8", - AArch64::AEK_CRC | AArch64::AEK_FP | - AArch64::AEK_SIMD | AArch64::AEK_RAS | - AArch64::AEK_LSE | AArch64::AEK_RDM | - AArch64::AEK_RCPC | AArch64::AEK_DOTPROD | - AArch64::AEK_MTE | AArch64::AEK_FP16FML | - AArch64::AEK_SVE | AArch64::AEK_SVE2 | - AArch64::AEK_SVE2BITPERM | AArch64::AEK_PAUTH | - AArch64::AEK_FLAGM | AArch64::AEK_SB | - AArch64::AEK_I8MM | AArch64::AEK_BF16, - "9-A"), + ARMCPUTestParams( + "cortex-a710", "armv9-a", "neon-fp-armv8", + AArch64::AEK_CRC | AArch64::AEK_FP | AArch64::AEK_SIMD | + AArch64::AEK_RAS | AArch64::AEK_LSE | AArch64::AEK_RDM | + AArch64::AEK_RCPC | AArch64::AEK_DOTPROD | AArch64::AEK_MTE | + AArch64::AEK_FP16 | AArch64::AEK_FP16FML | AArch64::AEK_SVE | + AArch64::AEK_SVE2 | AArch64::AEK_SVE2BITPERM | + AArch64::AEK_PAUTH | AArch64::AEK_FLAGM | AArch64::AEK_SB | + AArch64::AEK_I8MM | AArch64::AEK_BF16, + "9-A"), ARMCPUTestParams("cortex-a715", "armv9-a", "neon-fp-armv8", AArch64::AEK_CRC | AArch64::AEK_FP | AArch64::AEK_BF16 | AArch64::AEK_SIMD | AArch64::AEK_RAS | @@ -1155,7 +1164,8 @@ INSTANTIATE_TEST_SUITE_P( AArch64::AEK_I8MM | AArch64::AEK_BF16 | AArch64::AEK_SVE | AArch64::AEK_SVE2 | AArch64::AEK_SVE2BITPERM | AArch64::AEK_SSBS | - AArch64::AEK_SB | AArch64::AEK_FP16FML, + AArch64::AEK_SB | AArch64::AEK_FP16 | + AArch64::AEK_FP16FML, "9-A"), ARMCPUTestParams("cortex-x3", "armv9-a", "neon-fp-armv8", AArch64::AEK_CRC | AArch64::AEK_FP | AArch64::AEK_BF16 | From 9c0f9572e21e8ef9cdbab4ead88e02cc53e7ad1f Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Thu, 19 Jan 2023 16:50:02 +0100 Subject: [PATCH 06/61] [InstCombine] Add test for load type conversion with !noundef (NFC) The !noundef metadata is currently dropped. --- .../InstCombine/loadstore-metadata.ll | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/llvm/test/Transforms/InstCombine/loadstore-metadata.ll b/llvm/test/Transforms/InstCombine/loadstore-metadata.ll index d0c62b1eee02b4..5205b7d3f539c8 100644 --- a/llvm/test/Transforms/InstCombine/loadstore-metadata.ll +++ b/llvm/test/Transforms/InstCombine/loadstore-metadata.ll @@ -76,8 +76,8 @@ define ptr @test_load_cast_combine_align(ptr %ptr) { ; metadata. ; CHECK-LABEL: @test_load_cast_combine_align( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[L1:%.*]] = load ptr, ptr [[PTR:%.*]], align 8, !align !8 -; CHECK-NEXT: ret ptr [[L1]] +; CHECK-NEXT: [[L:%.*]] = load ptr, ptr [[PTR:%.*]], align 8, !align !8 +; CHECK-NEXT: ret ptr [[L]] ; entry: %l = load ptr, ptr %ptr, !align !9 @@ -89,8 +89,8 @@ define ptr @test_load_cast_combine_deref(ptr %ptr) { ; metadata. ; CHECK-LABEL: @test_load_cast_combine_deref( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[L1:%.*]] = load ptr, ptr [[PTR:%.*]], align 8, !dereferenceable !8 -; CHECK-NEXT: ret ptr [[L1]] +; CHECK-NEXT: [[L:%.*]] = load ptr, ptr [[PTR:%.*]], align 8, !dereferenceable !8 +; CHECK-NEXT: ret ptr [[L]] ; entry: %l = load ptr, ptr %ptr, !dereferenceable !9 @@ -102,8 +102,8 @@ define ptr @test_load_cast_combine_deref_or_null(ptr %ptr) { ; dereferenceable_or_null metadata. ; CHECK-LABEL: @test_load_cast_combine_deref_or_null( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[L1:%.*]] = load ptr, ptr [[PTR:%.*]], align 8, !dereferenceable_or_null !8 -; CHECK-NEXT: ret ptr [[L1]] +; CHECK-NEXT: [[L:%.*]] = load ptr, ptr [[PTR:%.*]], align 8, !dereferenceable_or_null !8 +; CHECK-NEXT: ret ptr [[L]] ; entry: %l = load ptr, ptr %ptr, !dereferenceable_or_null !9 @@ -122,7 +122,7 @@ define void @test_load_cast_combine_loop(ptr %src, ptr %dst, i32 %n) { ; CHECK-NEXT: [[SRC_GEP:%.*]] = getelementptr inbounds float, ptr [[SRC:%.*]], i64 [[TMP0]] ; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[I]] to i64 ; CHECK-NEXT: [[DST_GEP:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP1]] -; CHECK-NEXT: [[L1:%.*]] = load i32, ptr [[SRC_GEP]], align 4, !llvm.access.group !9 +; CHECK-NEXT: [[L1:%.*]] = load i32, ptr [[SRC_GEP]], align 4, !llvm.access.group [[ACC_GRP9:![0-9]+]] ; CHECK-NEXT: store i32 [[L1]], ptr [[DST_GEP]], align 4 ; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[I_NEXT]], [[N:%.*]] @@ -163,6 +163,16 @@ entry: ret void } +define i32 @test_load_cast_combine_noundef(ptr %ptr) { +; CHECK-LABEL: @test_load_cast_combine_noundef( +; CHECK-NEXT: [[L1:%.*]] = load i32, ptr [[PTR:%.*]], align 4 +; CHECK-NEXT: ret i32 [[L1]] +; + %l = load float, ptr %ptr, !noundef !{} + %c = bitcast float %l to i32 + ret i32 %c +} + !0 = !{!1, !1, i64 0} !1 = !{!"scalar type", !2} !2 = !{!"root"} From f13d6001324e9c9653d8568c1d86e182b217e272 Mon Sep 17 00:00:00 2001 From: Kelvin Li Date: Thu, 19 Jan 2023 10:52:10 -0500 Subject: [PATCH 07/61] [flang] Add semantic check for multiple part-ref with non-zero rank This patch is to diagnose the case when a type bound procedure is passed as an actual procedure argument. call sub0(t%t3%t2%t%info1) Fix: https://github.com/llvm/llvm-project/issues/55826 Committed on behalf of DanielCChen Differential Revision: https://reviews.llvm.org/D141506 --- flang/lib/Semantics/expression.cpp | 3 +++ flang/test/Semantics/expr-errors04.f90 | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp index a6fc90693aae3b..a4c18436007153 100644 --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -207,6 +207,9 @@ MaybeExpr ExpressionAnalyzer::Designate(DataRef &&ref) { last.name()); } if (auto *component{std::get_if(&ref.u)}) { + if (!CheckDataRef(ref)) { + return std::nullopt; + } return Expr{ProcedureDesignator{std::move(*component)}}; } else if (!std::holds_alternative(ref.u)) { DIE("unexpected alternative in DataRef"); diff --git a/flang/test/Semantics/expr-errors04.f90 b/flang/test/Semantics/expr-errors04.f90 index b8e0f221215b9d..be794c7c78f6d1 100644 --- a/flang/test/Semantics/expr-errors04.f90 +++ b/flang/test/Semantics/expr-errors04.f90 @@ -92,6 +92,8 @@ subroutine real_generic() !ERROR: Reference to whole rank-2 component 't1' of rank-1 array of derived type is not allowed call sub0(t%t3%t2%t1%info1(i)) !ERROR: Reference to whole rank-2 component 't1' of rank-1 array of derived type is not allowed + call sub0(t%t3%t2%t1%info1) + !ERROR: Reference to whole rank-2 component 't1' of rank-1 array of derived type is not allowed call t%t3%t2%t1%info2 !ERROR: Reference to whole rank-2 component 't1' of rank-1 array of derived type is not allowed call t%t3%t2%t1%g1 @@ -99,6 +101,8 @@ subroutine real_generic() !ERROR: Reference to rank-2 object 't1' has 1 subscripts call sub0(t%t3%t2%t1(1)%info1(i)) !ERROR: Reference to rank-2 object 't1' has 1 subscripts + call sub0(t%t3%t2%t1(1)%info1) + !ERROR: Reference to rank-2 object 't1' has 1 subscripts call t%t3%t2%t1(1)%info2 !ERROR: Reference to rank-2 object 't1' has 1 subscripts call t%t3%t2%t1(1)%g1 @@ -106,6 +110,8 @@ subroutine real_generic() !ERROR: Reference to rank-2 object 't1' has 1 subscripts call sub0(t%t3%t2%t1(1:)%info1(i)) !ERROR: Reference to rank-2 object 't1' has 1 subscripts + call sub0(t%t3%t2%t1(1:)%info1) + !ERROR: Reference to rank-2 object 't1' has 1 subscripts call t%t3%t2%t1(1:)%info2 !ERROR: Reference to rank-2 object 't1' has 1 subscripts call t%t3%t2%t1(1:)%g1 From b3b049a824a09749e106ee477718a6145fe4c23f Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Thu, 19 Jan 2023 16:55:23 +0100 Subject: [PATCH 08/61] [Local] Preserve noundef metadata in copyMetadataForLoad() If we're only changing the type of the load, preserve the noundef metadata. --- llvm/lib/Transforms/Utils/Local.cpp | 1 + llvm/test/Transforms/InstCombine/loadstore-metadata.ll | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 00cbee9a25e708..a57035d165ed52 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -2752,6 +2752,7 @@ void llvm::copyMetadataForLoad(LoadInst &Dest, const LoadInst &Source) { case LLVMContext::MD_nontemporal: case LLVMContext::MD_mem_parallel_loop_access: case LLVMContext::MD_access_group: + case LLVMContext::MD_noundef: // All of these directly apply. Dest.setMetadata(ID, N); break; diff --git a/llvm/test/Transforms/InstCombine/loadstore-metadata.ll b/llvm/test/Transforms/InstCombine/loadstore-metadata.ll index 5205b7d3f539c8..1ee3c3a25e2243 100644 --- a/llvm/test/Transforms/InstCombine/loadstore-metadata.ll +++ b/llvm/test/Transforms/InstCombine/loadstore-metadata.ll @@ -165,7 +165,7 @@ entry: define i32 @test_load_cast_combine_noundef(ptr %ptr) { ; CHECK-LABEL: @test_load_cast_combine_noundef( -; CHECK-NEXT: [[L1:%.*]] = load i32, ptr [[PTR:%.*]], align 4 +; CHECK-NEXT: [[L1:%.*]] = load i32, ptr [[PTR:%.*]], align 4, !noundef !6 ; CHECK-NEXT: ret i32 [[L1]] ; %l = load float, ptr %ptr, !noundef !{} From 3508f69174e663d276b58c55e176688f828591ac Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Thu, 19 Jan 2023 16:56:56 +0100 Subject: [PATCH 09/61] [flang][hlfir] Lower whole allocatable or pointer component ref Compare to other component ref lowering, the hlfir.designate result type computation is different, and the allocatable/pointer/contiguous must be set on the hlfir.designate so that the component attributes are kept in the IR. Differential Revision: https://reviews.llvm.org/D142111 --- flang/include/flang/Lower/ConvertVariable.h | 6 + flang/lib/Lower/ConvertExprToHLFIR.cpp | 41 ++++-- flang/lib/Lower/ConvertVariable.cpp | 9 +- .../HLFIR/Transforms/ConvertToFIR.cpp | 14 ++- .../allocatable-and-pointer-components.f90 | 119 ++++++++++++++++++ 5 files changed, 172 insertions(+), 17 deletions(-) create mode 100644 flang/test/Lower/HLFIR/allocatable-and-pointer-components.f90 diff --git a/flang/include/flang/Lower/ConvertVariable.h b/flang/include/flang/Lower/ConvertVariable.h index 0e2932fa1ab850..a77dcedea9020e 100644 --- a/flang/include/flang/Lower/ConvertVariable.h +++ b/flang/include/flang/Lower/ConvertVariable.h @@ -25,6 +25,7 @@ namespace fir { class ExtendedValue; class FirOpBuilder; class GlobalOp; +class FortranVariableFlagsAttr; } // namespace fir namespace Fortran ::lower { @@ -110,5 +111,10 @@ void createRuntimeTypeInfoGlobal(Fortran::lower::AbstractConverter &converter, mlir::Location loc, const Fortran::semantics::Symbol &typeInfoSym); +/// Translate the Fortran attributes of \p sym into the FIR variable attribute +/// representation. +fir::FortranVariableFlagsAttr +translateSymbolAttributes(mlir::MLIRContext *mlirContext, + const Fortran::semantics::Symbol &sym); } // namespace Fortran::lower #endif // FORTRAN_LOWER_CONVERT_VARIABLE_H diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp index 566840c62ba1f0..09c4c7a20e6fd2 100644 --- a/flang/lib/Lower/ConvertExprToHLFIR.cpp +++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp @@ -17,6 +17,7 @@ #include "flang/Lower/ConvertCall.h" #include "flang/Lower/ConvertConstant.h" #include "flang/Lower/ConvertType.h" +#include "flang/Lower/ConvertVariable.h" #include "flang/Lower/IntrinsicCall.h" #include "flang/Lower/StatementContext.h" #include "flang/Lower/SymbolMap.h" @@ -136,12 +137,17 @@ class HlfirDesignatorBuilder { const T &designatorNode) { mlir::Type designatorType = computeDesignatorType(resultValueType, partInfo, designatorNode); + return genDesignate(designatorType, partInfo, /*attributes=*/{}); + } + fir::FortranVariableOpInterface + genDesignate(mlir::Type designatorType, PartInfo &partInfo, + fir::FortranVariableFlagsAttr attributes) { std::optional complexPart; auto designate = getBuilder().create( getLoc(), designatorType, partInfo.base.getBase(), partInfo.componentName, partInfo.componentShape, partInfo.subscripts, partInfo.substring, complexPart, partInfo.resultShape, - partInfo.typeParams); + partInfo.typeParams, attributes); return mlir::cast( designate.getOperation()); } @@ -157,6 +163,17 @@ class HlfirDesignatorBuilder { fir::FortranVariableOpInterface gen(const Fortran::evaluate::Component &component) { PartInfo partInfo; + if (Fortran::semantics::IsAllocatableOrPointer(component.GetLastSymbol())) { + // Generate whole allocatable or pointer component reference. The + // hlfir.designate result will be a pointer/allocatable. + auto [_, componentType] = visitComponentImpl( + component, partInfo, /*dereferencePointerAndAllocComponents=*/false); + mlir::Type designatorType = fir::ReferenceType::get(componentType); + fir::FortranVariableFlagsAttr attributes = + Fortran::lower::translateSymbolAttributes(getBuilder().getContext(), + component.GetLastSymbol()); + return genDesignate(designatorType, partInfo, attributes); + } mlir::Type resultType = visit(component, partInfo); return genDesignate(resultType, partInfo, component); } @@ -280,7 +297,8 @@ class HlfirDesignatorBuilder { PartInfo &partInfo) { mlir::Type baseType; if (const auto *component = arrayRef.base().UnwrapComponent()) - baseType = visitComponentImpl(*component, partInfo).second; + baseType = hlfir::getFortranElementOrSequenceType( + visitComponentImpl(*component, partInfo).second); else baseType = visit(arrayRef.base().GetLastSymbol(), partInfo); @@ -428,6 +446,8 @@ class HlfirDesignatorBuilder { // array. The code below determines the shape of the component reference if // any. auto [baseType, componentType] = visitComponentImpl(component, partInfo); + mlir::Type componentBaseType = + hlfir::getFortranElementOrSequenceType(componentType); if (partInfo.base.isArray()) { // For array%scalar_comp, the result shape is // the one of the base. Compute it here. Note that the lower bounds of the @@ -436,13 +456,13 @@ class HlfirDesignatorBuilder { partInfo.resultShape = hlfir::genShape(loc, getBuilder(), partInfo.base); assert(!partInfo.componentShape && "Fortran designators can only have one ranked part"); - return changeElementType(baseType, componentType); + return changeElementType(baseType, componentBaseType); } // scalar%array_comp or scalar%scalar. In any case the shape of this // part-ref is coming from the component. partInfo.resultShape = partInfo.componentShape; partInfo.componentShape = {}; - return componentType; + return componentBaseType; } // Returns the pair, computes partInfo.base, @@ -451,7 +471,8 @@ class HlfirDesignatorBuilder { // processing a following ArrayRef, if any, and in "visit" otherwise. std::pair visitComponentImpl(const Fortran::evaluate::Component &component, - PartInfo &partInfo) { + PartInfo &partInfo, + bool dereferencePointerAndAllocComponents = true) { fir::FirOpBuilder &builder = getBuilder(); // Break the Designator visit here: if the base is an array-ref, a // coarray-ref, or another component, this creates another hlfir.designate @@ -473,10 +494,11 @@ class HlfirDesignatorBuilder { if (recordType.isDependentType()) TODO(getLoc(), "Designate derived type with length parameters in HLFIR"); mlir::Type fieldType = recordType.getType(partInfo.componentName); - fieldType = hlfir::getFortranElementOrSequenceType(fieldType); - partInfo.componentShape = genComponentShape(componentSym, fieldType); + mlir::Type fieldBaseType = + hlfir::getFortranElementOrSequenceType(fieldType); + partInfo.componentShape = genComponentShape(componentSym, fieldBaseType); - mlir::Type fieldEleType = hlfir::getFortranElementType(fieldType); + mlir::Type fieldEleType = hlfir::getFortranElementType(fieldBaseType); if (fir::isRecordWithTypeParameters(fieldEleType)) TODO(loc, "lower a component that is a parameterized derived type to HLFIR"); @@ -496,7 +518,8 @@ class HlfirDesignatorBuilder { // For pointers and allocatables, if there is a substring, complex part or // array ref, the designator should be broken here and the pointer or // allocatable dereferenced. - if (Fortran::semantics::IsAllocatableOrPointer(componentSym)) + if (Fortran::semantics::IsAllocatableOrPointer(componentSym) && + dereferencePointerAndAllocComponents) TODO(loc, "lowering ref to allocatable or pointer component to HLFIR"); return {baseType, fieldType}; diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index fff076bb054985..83b56567d5f242 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -1311,9 +1311,8 @@ recoverShapeVector(llvm::ArrayRef shapeVec, mlir::Value initVal) { return result; } -static fir::FortranVariableFlagsAttr -translateSymbolAttributes(mlir::MLIRContext *mlirContext, - const Fortran::semantics::Symbol &sym) { +fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes( + mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) { fir::FortranVariableFlagsEnum flags = fir::FortranVariableFlagsEnum::None; const auto &attrs = sym.attrs(); if (attrs.test(Fortran::semantics::Attr::ALLOCATABLE)) @@ -1372,7 +1371,7 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter, lenParams.emplace_back(len); auto name = Fortran::lower::mangle::mangleName(sym); fir::FortranVariableFlagsAttr attributes = - translateSymbolAttributes(builder.getContext(), sym); + Fortran::lower::translateSymbolAttributes(builder.getContext(), sym); auto newBase = builder.create( loc, base, name, shapeOrShift, lenParams, attributes); symMap.addVariableDefinition(sym, newBase, force); @@ -1411,7 +1410,7 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder = converter.getFirOpBuilder(); const mlir::Location loc = genLocation(converter, sym); fir::FortranVariableFlagsAttr attributes = - translateSymbolAttributes(builder.getContext(), sym); + Fortran::lower::translateSymbolAttributes(builder.getContext(), sym); auto name = Fortran::lower::mangle::mangleName(sym); hlfir::EntityWithAttributes declare = hlfir::genDeclare(loc, builder, exv, name, attributes); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index c426cddbd135e4..5487defb69dc57 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -214,11 +214,19 @@ class DesignateOpConversion // - scalar%array(indices) [substring| complex part] mlir::Type componentType = baseEleTy.cast().getType( designate.getComponent().value()); - if (componentType.isa()) - TODO(loc, - "addressing parametrized derived type automatic components"); mlir::Type coorTy = fir::ReferenceType::get(componentType); base = builder.create(loc, coorTy, base, fieldIndex); + if (componentType.isa()) { + auto variableInterface = mlir::cast( + designate.getOperation()); + if (variableInterface.isAllocatable() || + variableInterface.isPointer()) { + rewriter.replaceOp(designate, base); + return mlir::success(); + } + TODO(loc, + "addressing parametrized derived type automatic components"); + } baseEleTy = hlfir::getFortranElementType(componentType); shape = designate.getComponentShape(); } else { diff --git a/flang/test/Lower/HLFIR/allocatable-and-pointer-components.f90 b/flang/test/Lower/HLFIR/allocatable-and-pointer-components.f90 new file mode 100644 index 00000000000000..c29b62aef251df --- /dev/null +++ b/flang/test/Lower/HLFIR/allocatable-and-pointer-components.f90 @@ -0,0 +1,119 @@ +! Test lowering of whole allocatable and pointer components to HLFIR +! RUN: bbc -emit-fir -hlfir -o - %s -I nw | FileCheck %s + +module def_test_types + type t1 + real, pointer :: p(:) + end type + type t2 + real, allocatable :: a(:) + end type + type t3 + real, pointer, contiguous :: p_contiguous(:) + end type + type t4 + character(:), pointer :: char_p(:) + end type + type t5 + character(10), allocatable :: char_a(:) + end type + interface + subroutine takes_pointer(y) + real, pointer :: y(:) + end subroutine + subroutine takes_contiguous_pointer(y) + real, pointer, contiguous :: y(:) + end subroutine + subroutine takes_allocatable(y) + real, allocatable :: y(:) + end subroutine + subroutine takes_char_pointer(y) + character(:), pointer :: y(:) + end subroutine + subroutine takes_char_alloc_cst_len(y) + character(10), allocatable :: y(:) + end subroutine + subroutine takes_array(y) + real :: y(*) + end subroutine + subroutine takes_char_array(y) + character(*) :: y(*) + end subroutine + + end interface +end module + +subroutine passing_pointer(x) + use def_test_types + implicit none + type(t1) :: x + call takes_pointer(x%p) +end subroutine +! CHECK-LABEL: func.func @_QPpassing_pointer( +! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0:[a-z0-9]*]] {{.*}}Ex +! CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"p"} {fortran_attrs = #fir.var_attrs} : (!fir.ref>>}>>) -> !fir.ref>>> +! CHECK: fir.call @_QPtakes_pointer(%[[VAL_2]]) {{.*}}: (!fir.ref>>>) -> () + +subroutine passing_allocatable(x) + use def_test_types + implicit none + type(t2) :: x + call takes_allocatable(x%a) + call takes_array(x%a) +end subroutine +! CHECK-LABEL: func.func @_QPpassing_allocatable( +! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0:[a-z0-9]*]] {{.*}}Ex +! CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"a"} {fortran_attrs = #fir.var_attrs} : (!fir.ref>>}>>) -> !fir.ref>>> +! CHECK: fir.call @_QPtakes_allocatable(%[[VAL_2]]) {{.*}}: (!fir.ref>>>) -> () +! CHECK: %[[VAL_3:.*]] = hlfir.designate %[[VAL_1]]#0{"a"} {fortran_attrs = #fir.var_attrs} : (!fir.ref>>}>>) -> !fir.ref>>> +! CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref>>> +! CHECK: %[[VAL_5:.*]] = fir.box_addr %[[VAL_4]] : (!fir.box>>) -> !fir.heap> +! CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (!fir.heap>) -> !fir.ref> +! CHECK: fir.call @_QPtakes_array(%[[VAL_6]]) {{.*}}: (!fir.ref>) -> () + +subroutine passing_contiguous_pointer(x) + use def_test_types + type(t3) :: x + call takes_contiguous_pointer(x%p_contiguous) + call takes_array(x%p_contiguous) +end subroutine +! CHECK-LABEL: func.func @_QPpassing_contiguous_pointer( +! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0:[a-z0-9]*]] {{.*}}Ex +! CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"p_contiguous"} {fortran_attrs = #fir.var_attrs} : (!fir.ref>>}>>) -> !fir.ref>>> +! CHECK: fir.call @_QPtakes_contiguous_pointer(%[[VAL_2]]) {{.*}}: (!fir.ref>>>) -> () +! CHECK: %[[VAL_3:.*]] = hlfir.designate %[[VAL_1]]#0{"p_contiguous"} {fortran_attrs = #fir.var_attrs} : (!fir.ref>>}>>) -> !fir.ref>>> +! CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref>>> +! CHECK: %[[VAL_5:.*]] = fir.box_addr %[[VAL_4]] : (!fir.box>>) -> !fir.ptr> +! CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (!fir.ptr>) -> !fir.ref> +! CHECK: fir.call @_QPtakes_array(%[[VAL_6]]) {{.*}}: (!fir.ref>) -> () + +subroutine passing_char_pointer(x) + use def_test_types + implicit none + type(t4) :: x + call takes_char_pointer(x%char_p) +end subroutine +! CHECK-LABEL: func.func @_QPpassing_char_pointer( +! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0:[a-z0-9]*]] {{.*}}Ex +! CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"char_p"} {fortran_attrs = #fir.var_attrs} : (!fir.ref>>>}>>) -> !fir.ref>>>> +! CHECK: fir.call @_QPtakes_char_pointer(%[[VAL_2]]) {{.*}}: (!fir.ref>>>>) -> () + +subroutine passing_char_alloc_cst_len(x) + use def_test_types + implicit none + type(t5) :: x + call takes_char_alloc_cst_len(x%char_a) + call takes_char_array(x%char_a) +end subroutine +! CHECK-LABEL: func.func @_QPpassing_char_alloc_cst_len( +! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0:[a-z0-9]*]] {{.*}}Ex +! CHECK: %[[VAL_2:.*]] = arith.constant 10 : index +! CHECK: %[[VAL_3:.*]] = hlfir.designate %[[VAL_1]]#0{"char_a"} typeparams %[[VAL_2]] {fortran_attrs = #fir.var_attrs} : (!fir.ref>>>}>>, index) -> !fir.ref>>>> +! CHECK: fir.call @_QPtakes_char_alloc_cst_len(%[[VAL_3]]) {{.*}}: (!fir.ref>>>>) -> () +! CHECK: %[[VAL_4:.*]] = arith.constant 10 : index +! CHECK: %[[VAL_5:.*]] = hlfir.designate %[[VAL_1]]#0{"char_a"} typeparams %[[VAL_4]] {fortran_attrs = #fir.var_attrs} : (!fir.ref>>>}>>, index) -> !fir.ref>>>> +! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref>>>> +! CHECK: %[[VAL_7:.*]] = fir.box_addr %[[VAL_6]] : (!fir.box>>>) -> !fir.heap>> +! CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_7]] : (!fir.heap>>) -> !fir.ref> +! CHECK: %[[VAL_9:.*]] = fir.emboxchar %[[VAL_8]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.boxchar<1> +! CHECK: fir.call @_QPtakes_char_array(%[[VAL_9]]) {{.*}}: (!fir.boxchar<1>) -> () From daa316bcaf717e1dacdfee443f2c325a783d2c70 Mon Sep 17 00:00:00 2001 From: Yitzhak Mandelbaum Date: Fri, 13 Jan 2023 18:33:52 +0000 Subject: [PATCH 10/61] [clang][dataflow] Fix bug in joining bool values. Currently, the code assumes that all boolean-typed values are an instance of `BoolValue` (or its subclasses). Yet, lvalues violate this assumption. This patch drops the assumption and strengthens the check to confirm the shape of both values being joined. The patch also notes as FIXMES a number of problems discovered fixing this bug. Differential Revision: https://reviews.llvm.org/D141709 --- .../FlowSensitive/DataflowEnvironment.cpp | 22 +++++++------ clang/lib/Analysis/FlowSensitive/Transfer.cpp | 2 ++ .../TypeErasedDataflowAnalysisTest.cpp | 31 +++++++++++++++++++ 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/clang/lib/Analysis/FlowSensitive/DataflowEnvironment.cpp b/clang/lib/Analysis/FlowSensitive/DataflowEnvironment.cpp index 064d0f92e71a4b..cc3992805cc78c 100644 --- a/clang/lib/Analysis/FlowSensitive/DataflowEnvironment.cpp +++ b/clang/lib/Analysis/FlowSensitive/DataflowEnvironment.cpp @@ -93,18 +93,20 @@ static Value *mergeDistinctValues(QualType Type, Value &Val1, Environment::ValueModel &Model) { // Join distinct boolean values preserving information about the constraints // in the respective path conditions. - if (Type->isBooleanType()) { - // FIXME: The type check above is a workaround and should be unnecessary. - // However, right now we can end up with BoolValue's in integer-typed - // variables due to our incorrect handling of boolean-to-integer casts (we - // just propagate the BoolValue to the result of the cast). For example: + if (isa(&Val1) && isa(&Val2)) { + // FIXME: Checking both values should be unnecessary, since they should have + // a consistent shape. However, right now we can end up with BoolValue's in + // integer-typed variables due to our incorrect handling of + // boolean-to-integer casts (we just propagate the BoolValue to the result + // of the cast). So, a join can encounter an integer in one branch but a + // bool in the other. + // For example: + // ``` // std::optional o; - // - // // int x; - // if (o.has_value()) { + // if (o.has_value()) // x = o.value(); - // } + // ``` auto *Expr1 = cast(&Val1); auto *Expr2 = cast(&Val2); auto &MergedVal = MergedEnv.makeAtomicBoolValue(); @@ -118,6 +120,8 @@ static Value *mergeDistinctValues(QualType Type, Value &Val1, // FIXME: Consider destroying `MergedValue` immediately if `ValueModel::merge` // returns false to avoid storing unneeded values in `DACtx`. + // FIXME: Creating the value based on the type alone creates misshapen values + // for lvalues, since the type does not reflect the need for `ReferenceValue`. if (Value *MergedVal = MergedEnv.createValue(Type)) if (Model.merge(Type, Val1, Env1, Val2, Env2, *MergedVal, MergedEnv)) return MergedVal; diff --git a/clang/lib/Analysis/FlowSensitive/Transfer.cpp b/clang/lib/Analysis/FlowSensitive/Transfer.cpp index 259b82d6479814..9fa17c86d2d9de 100644 --- a/clang/lib/Analysis/FlowSensitive/Transfer.cpp +++ b/clang/lib/Analysis/FlowSensitive/Transfer.cpp @@ -110,6 +110,8 @@ static BoolValue &unpackValue(BoolValue &V, Environment &Env) { // Unpacks the value (if any) associated with `E` and updates `E` to the new // value, if any unpacking occured. static Value *maybeUnpackLValueExpr(const Expr &E, Environment &Env) { + // FIXME: this is too flexible: it _allows_ a reference, while it should + // _require_ one, since lvalues should always be wrapped in `ReferenceValue`. auto *Loc = Env.getStorageLocation(E, SkipPast::Reference); if (Loc == nullptr) return nullptr; diff --git a/clang/unittests/Analysis/FlowSensitive/TypeErasedDataflowAnalysisTest.cpp b/clang/unittests/Analysis/FlowSensitive/TypeErasedDataflowAnalysisTest.cpp index 06b68afe095861..0b0c77558df926 100644 --- a/clang/unittests/Analysis/FlowSensitive/TypeErasedDataflowAnalysisTest.cpp +++ b/clang/unittests/Analysis/FlowSensitive/TypeErasedDataflowAnalysisTest.cpp @@ -129,6 +129,37 @@ TEST(DataflowAnalysisTest, NonConvergingAnalysis) { "maximum number of iterations reached"); } +// Regression test for joins of bool-typed lvalue expressions. The first loop +// results in two passes through the code that follows. Each pass results in a +// different `ReferenceValue` for the pointee of `v`. Then, the second loop +// causes a join at the loop head where the two environments map expresssion +// `*v` to different `ReferenceValue`s. +// +// An earlier version crashed for this condition (for boolean-typed lvalues), so +// this test only verifies that the analysis runs successfully, without +// examining any details of the results. +TEST(DataflowAnalysisTest, JoinBoolLValues) { + std::string Code = R"( + void target() { + for (int x = 1; x; x = 0) + (void)x; + bool *v; + if (*v) + for (int x = 1; x; x = 0) + (void)x; + } + )"; + ASSERT_THAT_ERROR( + runAnalysis(Code, + [](ASTContext &C) { + auto EnableBuiltIns = DataflowAnalysisOptions{ + DataflowAnalysisContext::Options{}}; + return NoopAnalysis(C, EnableBuiltIns); + }) + .takeError(), + llvm::Succeeded()); +} + struct FunctionCallLattice { using FunctionSet = llvm::SmallSet; FunctionSet CalledFunctions; From 48df06f1d00c6accb396438c04133fb7fdd99d2c Mon Sep 17 00:00:00 2001 From: Matt Devereau Date: Mon, 16 Jan 2023 14:21:18 +0000 Subject: [PATCH 11/61] [AArch64] Allow poison elements of fixed-vectors to be duplicated as a widened element Expanding upon https://reviews.llvm.org/D138203, allow null indices in InsertElts to be matched with any value and be duplicated if the fixed vector the scalar values are inserted into is poison, and the scalable vector the subvector being inserted into is poison. Differential Revision: https://reviews.llvm.org/D141846 --- .../AArch64/AArch64TargetTransformInfo.cpp | 25 ++++++-- .../AArch64/sve-intrinsic-dupqlane.ll | 62 +++++++++++++++++-- 2 files changed, 75 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 471b05ba379b27..916eefc3441121 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1436,7 +1436,7 @@ static std::optional instCombineSVESDIV(InstCombiner &IC, return std::nullopt; } -bool SimplifyValuePattern(SmallVector &Vec) { +bool SimplifyValuePattern(SmallVector &Vec, bool AllowPoison) { size_t VecSize = Vec.size(); if (VecSize == 1) return true; @@ -1446,13 +1446,20 @@ bool SimplifyValuePattern(SmallVector &Vec) { for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize; RHS != Vec.end(); LHS++, RHS++) { - if (*LHS != nullptr && *RHS != nullptr && *LHS == *RHS) - continue; - return false; + if (*LHS != nullptr && *RHS != nullptr) { + if (*LHS == *RHS) + continue; + else + return false; + } + if (!AllowPoison) + return false; + if (*LHS == nullptr && *RHS != nullptr) + *LHS = *RHS; } Vec.resize(HalfVecSize); - SimplifyValuePattern(Vec); + SimplifyValuePattern(Vec, AllowPoison); return true; } @@ -1476,7 +1483,9 @@ static std::optional instCombineSVEDupqLane(InstCombiner &IC, CurrentInsertElt = InsertElt->getOperand(0); } - if (!SimplifyValuePattern(Elts)) + bool AllowPoison = + isa(CurrentInsertElt) && isa(Default); + if (!SimplifyValuePattern(Elts, AllowPoison)) return std::nullopt; // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b) @@ -1484,9 +1493,13 @@ static std::optional instCombineSVEDupqLane(InstCombiner &IC, Builder.SetInsertPoint(&II); Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType()); for (size_t I = 0; I < Elts.size(); I++) { + if (Elts[I] == nullptr) + continue; InsertEltChain = Builder.CreateInsertElement(InsertEltChain, Elts[I], Builder.getInt64(I)); } + if (InsertEltChain == nullptr) + return std::nullopt; // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64 // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-dupqlane.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-dupqlane.ll index 9b375836c0fdf0..d059670451f24a 100644 --- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-dupqlane.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-dupqlane.ll @@ -96,12 +96,11 @@ define dso_local @dupq_f16_abcnull_pattern(half %a, half %b, ; CHECK-NEXT: [[TMP1:%.*]] = insertelement <8 x half> poison, half [[A:%.*]], i64 0 ; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1 ; CHECK-NEXT: [[TMP3:%.*]] = insertelement <8 x half> [[TMP2]], half [[C:%.*]], i64 2 -; CHECK-NEXT: [[TMP4:%.*]] = insertelement <8 x half> [[TMP3]], half [[A]], i64 4 -; CHECK-NEXT: [[TMP5:%.*]] = insertelement <8 x half> [[TMP4]], half [[B]], i64 5 -; CHECK-NEXT: [[TMP6:%.*]] = insertelement <8 x half> [[TMP5]], half [[C]], i64 6 -; CHECK-NEXT: [[TMP7:%.*]] = tail call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> [[TMP6]], i64 0) -; CHECK-NEXT: [[TMP8:%.*]] = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( [[TMP7]], i64 0) -; CHECK-NEXT: ret [[TMP8]] +; CHECK-NEXT: [[TMP4:%.*]] = call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> [[TMP3]], i64 0) +; CHECK-NEXT: [[TMP5:%.*]] = bitcast [[TMP4]] to +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector [[TMP5]], poison, zeroinitializer +; CHECK-NEXT: [[TMP7:%.*]] = bitcast [[TMP6]] to +; CHECK-NEXT: ret [[TMP7]] ; %1 = insertelement <8 x half> poison, half %a, i64 0 %2 = insertelement <8 x half> %1, half %b, i64 1 @@ -114,6 +113,57 @@ define dso_local @dupq_f16_abcnull_pattern(half %a, half %b, ret %8 } +define dso_local @dupq_f16_abnull_pattern(half %a, half %b) { +; CHECK-LABEL: @dupq_f16_abnull_pattern( +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <8 x half> poison, half [[A:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> [[TMP2]], i64 0) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast [[TMP3]] to +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector [[TMP4]], poison, zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = bitcast [[TMP5]] to +; CHECK-NEXT: ret [[TMP6]] +; + %1 = insertelement <8 x half> poison, half %a, i64 0 + %2 = insertelement <8 x half> %1, half %b, i64 1 + %3 = tail call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> %2, i64 0) + %4 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %3, i64 0) + ret %4 +} + +define dso_local @neg_dupq_f16_non_poison_fixed(half %a, half %b, <8 x half> %v) { +; CHECK-LABEL: @neg_dupq_f16_non_poison_fixed( +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <8 x half> [[V:%.*]], half [[A:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> [[TMP2]], i64 0) +; CHECK-NEXT: [[TMP4:%.*]] = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( [[TMP3]], i64 0) +; CHECK-NEXT: ret [[TMP4]] +; + %1 = insertelement <8 x half> %v, half %a, i64 0 + %2 = insertelement <8 x half> %1, half %b, i64 1 + %3 = insertelement <8 x half> %2, half %a, i64 0 + %4 = insertelement <8 x half> %3, half %b, i64 1 + %5 = tail call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> %4, i64 0) + %6 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %5, i64 0) + ret %6 +} + +define dso_local @neg_dupq_f16_into_non_poison_scalable(half %a, half %b, %v) { +; CHECK-LABEL: @neg_dupq_f16_into_non_poison_scalable( +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <8 x half> poison, half [[A:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.vector.insert.nxv8f16.v8f16( [[V:%.*]], <8 x half> [[TMP2]], i64 0) +; CHECK-NEXT: [[TMP4:%.*]] = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( [[TMP3]], i64 0) +; CHECK-NEXT: ret [[TMP4]] +; + %1 = insertelement <8 x half> poison, half %a, i64 0 + %2 = insertelement <8 x half> %1, half %b, i64 1 + %3 = insertelement <8 x half> %2, half %a, i64 0 + %4 = insertelement <8 x half> %3, half %b, i64 1 + %5 = tail call @llvm.vector.insert.nxv8f16.v8f16( %v, <8 x half> %4, i64 0) + %6 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %5, i64 0) + ret %6 +} + ; Insert %c to override the last element in the insertelement chain, which will fail to combine define dso_local @neg_dupq_f16_abcd_pattern_double_insert(half %a, half %b, half %c, half %d) { From e84d69f52d9a9fab9162128d8fe8ebec99ea60da Mon Sep 17 00:00:00 2001 From: Haojian Wu Date: Thu, 19 Jan 2023 17:03:08 +0100 Subject: [PATCH 12/61] [clangd] Fix clangd-fuzzer build. --- clang-tools-extra/clangd/fuzzer/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/clang-tools-extra/clangd/fuzzer/CMakeLists.txt b/clang-tools-extra/clangd/fuzzer/CMakeLists.txt index eaf4478585290c..72feb52e1f02ab 100644 --- a/clang-tools-extra/clangd/fuzzer/CMakeLists.txt +++ b/clang-tools-extra/clangd/fuzzer/CMakeLists.txt @@ -14,6 +14,7 @@ clang_target_link_libraries(clangd-fuzzer clangBasic clangFormat clangFrontend + clangIncludeCleaner clangSema clangTooling clangToolingCore From 72dc033fa64257429da787ef6fd82ecb42f72d20 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Thu, 19 Jan 2023 17:10:09 +0100 Subject: [PATCH 13/61] [SROA] Check TBAA metadata in tests (NFC) By switching to --check-globals. Also make sure that the !tbaa.struct metadata mapping is preserved. --- llvm/test/Transforms/SROA/tbaa-struct.ll | 13 +++++++++++-- llvm/test/Transforms/SROA/tbaa-struct2.ll | 21 +++++++++++++++++---- llvm/test/Transforms/SROA/tbaa-subload.ll | 9 ++++++++- llvm/utils/UpdateTestChecks/common.py | 1 + 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/llvm/test/Transforms/SROA/tbaa-struct.ll b/llvm/test/Transforms/SROA/tbaa-struct.ll index 6d15a3a4adc991..29892cb84d8efe 100644 --- a/llvm/test/Transforms/SROA/tbaa-struct.ll +++ b/llvm/test/Transforms/SROA/tbaa-struct.ll @@ -1,4 +1,4 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals ; RUN: opt -S -passes='sroa' %s | FileCheck %s --check-prefixes=CHECK,CHECK-PRESERVE-CFG ; RUN: opt -S -passes='sroa' %s | FileCheck %s --check-prefixes=CHECK,CHECK-MODIFY-CFG @@ -11,7 +11,7 @@ declare <2 x float> @foo(ptr %0) define void @bar(ptr %y2) { ; CHECK-LABEL: @bar( ; CHECK-NEXT: [[X14:%.*]] = call <2 x float> @foo(ptr [[Y2:%.*]]) -; CHECK-NEXT: store <2 x float> [[X14]], ptr [[Y2]], align 4, !tbaa.struct !0 +; CHECK-NEXT: store <2 x float> [[X14]], ptr [[Y2]], align 4, !tbaa.struct [[TBAA_STRUCT0:![0-9]+]] ; CHECK-NEXT: ret void ; %x7 = alloca %vector @@ -27,6 +27,15 @@ define void @bar(ptr %y2) { !8 = !{!"float", !4, i64 0} !10 = !{i64 0, i64 4, !11, i64 4, i64 4, !11} !11 = !{!8, !8, i64 0} +;. +; CHECK: attributes #[[ATTR0:[0-9]+]] = { nocallback nofree nounwind willreturn memory(argmem: readwrite) } +;. +; CHECK: [[TBAA_STRUCT0]] = !{i64 0, i64 4, !1, i64 4, i64 4, !1} +; CHECK: [[META1:![0-9]+]] = !{!2, !2, i64 0} +; CHECK: [[META2:![0-9]+]] = !{!"float", !3, i64 0} +; CHECK: [[META3:![0-9]+]] = !{!"omnipotent char", !4, i64 0} +; CHECK: [[META4:![0-9]+]] = !{!"Simple C++ TBAA"} +;. ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: ; CHECK-MODIFY-CFG: {{.*}} ; CHECK-PRESERVE-CFG: {{.*}} diff --git a/llvm/test/Transforms/SROA/tbaa-struct2.ll b/llvm/test/Transforms/SROA/tbaa-struct2.ll index f2e4cf86fd2739..1fd37e82d67775 100644 --- a/llvm/test/Transforms/SROA/tbaa-struct2.ll +++ b/llvm/test/Transforms/SROA/tbaa-struct2.ll @@ -1,4 +1,4 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals ; RUN: opt -S -passes='sroa' %s | FileCheck %s --check-prefixes=CHECK,CHECK-PRESERVE-CFG ; RUN: opt -S -passes='sroa' %s | FileCheck %s --check-prefixes=CHECK,CHECK-MODIFY-CFG @@ -11,11 +11,11 @@ declare double @subcall(double %g, i32 %m) define double @bar(ptr %wishart) { ; CHECK-LABEL: @bar( ; CHECK-NEXT: [[TMP_SROA_3:%.*]] = alloca [4 x i8], align 4 -; CHECK-NEXT: [[TMP_SROA_0_0_COPYLOAD:%.*]] = load double, ptr [[WISHART:%.*]], align 8, !tbaa.struct !0 +; CHECK-NEXT: [[TMP_SROA_0_0_COPYLOAD:%.*]] = load double, ptr [[WISHART:%.*]], align 8, !tbaa.struct [[TBAA_STRUCT0:![0-9]+]] ; CHECK-NEXT: [[TMP_SROA_2_0_WISHART_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[WISHART]], i64 8 -; CHECK-NEXT: [[TMP_SROA_2_0_COPYLOAD:%.*]] = load i32, ptr [[TMP_SROA_2_0_WISHART_SROA_IDX]], align 8, !tbaa.struct !7 +; CHECK-NEXT: [[TMP_SROA_2_0_COPYLOAD:%.*]] = load i32, ptr [[TMP_SROA_2_0_WISHART_SROA_IDX]], align 8, !tbaa.struct [[TBAA_STRUCT7:![0-9]+]] ; CHECK-NEXT: [[TMP_SROA_3_0_WISHART_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[WISHART]], i64 12 -; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[TMP_SROA_3]], ptr align 4 [[TMP_SROA_3_0_WISHART_SROA_IDX]], i64 4, i1 false), !tbaa.struct !8 +; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[TMP_SROA_3]], ptr align 4 [[TMP_SROA_3_0_WISHART_SROA_IDX]], i64 4, i1 false), !tbaa.struct [[TBAA_STRUCT8:![0-9]+]] ; CHECK-NEXT: [[CALL:%.*]] = call double @subcall(double [[TMP_SROA_0_0_COPYLOAD]], i32 [[TMP_SROA_2_0_COPYLOAD]]) ; CHECK-NEXT: ret double [[CALL]] ; @@ -35,6 +35,19 @@ define double @bar(ptr %wishart) { !6 = !{!"Simple C++ TBAA"} !7 = !{!8, !8, i64 0} !8 = !{!"int", !5, i64 0} +;. +; CHECK: attributes #[[ATTR0:[0-9]+]] = { nocallback nofree nounwind willreturn memory(argmem: readwrite) } +;. +; CHECK: [[TBAA_STRUCT0]] = !{i64 0, i64 8, !1, i64 8, i64 4, !5} +; CHECK: [[META1:![0-9]+]] = !{!2, !2, i64 0} +; CHECK: [[META2:![0-9]+]] = !{!"double", !3, i64 0} +; CHECK: [[META3:![0-9]+]] = !{!"omnipotent char", !4, i64 0} +; CHECK: [[META4:![0-9]+]] = !{!"Simple C++ TBAA"} +; CHECK: [[META5:![0-9]+]] = !{!6, !6, i64 0} +; CHECK: [[META6:![0-9]+]] = !{!"int", !3, i64 0} +; CHECK: [[TBAA_STRUCT7]] = !{i64 0, i64 4, !5} +; CHECK: [[TBAA_STRUCT8]] = !{} +;. ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: ; CHECK-MODIFY-CFG: {{.*}} ; CHECK-PRESERVE-CFG: {{.*}} diff --git a/llvm/test/Transforms/SROA/tbaa-subload.ll b/llvm/test/Transforms/SROA/tbaa-subload.ll index 3851203f072ae0..b07874da7ab03c 100644 --- a/llvm/test/Transforms/SROA/tbaa-subload.ll +++ b/llvm/test/Transforms/SROA/tbaa-subload.ll @@ -1,4 +1,4 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals ; RUN: opt -S -passes='sroa' %s | FileCheck %s --check-prefixes=CHECK,CHECK-PRESERVE-CFG ; RUN: opt -S -passes='sroa' %s | FileCheck %s --check-prefixes=CHECK,CHECK-MODIFY-CFG @@ -35,6 +35,13 @@ declare void @use(ptr %this) !7 = !{!8, !3, i64 8} !8 = !{!"_ZTSZN2ax2baEMS_FvvE2an2arE3$_0", !9, i64 0, !3, i64 8} !9 = !{!"_ZTS2ar"} +;. +; CHECK: [[TBAA0]] = !{!1, !3, i64 8} +; CHECK: [[META1:![0-9]+]] = !{!"_ZTSZN2ax2baEMS_FvvE2an2arE3$_0", !2, i64 0, !3, i64 8} +; CHECK: [[META2:![0-9]+]] = !{!"_ZTS2ar"} +; CHECK: [[META3:![0-9]+]] = !{!"omnipotent char", !4, i64 0} +; CHECK: [[META4:![0-9]+]] = !{!"Simple C++ TBAA"} +;. ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: ; CHECK-MODIFY-CFG: {{.*}} ; CHECK-PRESERVE-CFG: {{.*}} diff --git a/llvm/utils/UpdateTestChecks/common.py b/llvm/utils/UpdateTestChecks/common.py index aec8a97115c598..0566d225e78495 100644 --- a/llvm/utils/UpdateTestChecks/common.py +++ b/llvm/utils/UpdateTestChecks/common.py @@ -744,6 +744,7 @@ def get_value_use(self, var, match, var_prefix=None): NamelessValue(r'DIASSIGNID' , '!' , r'!DIAssignID ' , None , None , r'![0-9]+' , None , False) , NamelessValue(r'PROF' , '!' , r'!prof ' , None , None , r'![0-9]+' , None , False) , NamelessValue(r'TBAA' , '!' , r'!tbaa ' , None , None , r'![0-9]+' , None , False) , + NamelessValue(r'TBAA_STRUCT', '!' , r'!tbaa.struct ' , None , None , r'![0-9]+' , None , False) , NamelessValue(r'RNG' , '!' , r'!range ' , None , None , r'![0-9]+' , None , False) , NamelessValue(r'LOOP' , '!' , r'!llvm.loop ' , None , None , r'![0-9]+' , None , False) , NamelessValue(r'META' , '!' , r'metadata ' , None , None , r'![0-9]+' , None , False) , From 22e8525dfdd7d78d05ad3873cf94476021e4a08f Mon Sep 17 00:00:00 2001 From: Mark de Wever Date: Thu, 5 May 2022 18:57:32 +0200 Subject: [PATCH 14/61] [libc++][format] Implements range_formatter Implements parts of - P2286R8 Formatting Ranges - P2585R0 Improving default container formatting Depends on D140651 Reviewed By: ldionne, #libc Differential Revision: https://reviews.llvm.org/D140653 --- libcxx/docs/Status/FormatPaper.csv | 2 +- libcxx/include/CMakeLists.txt | 1 + libcxx/include/__format/buffer.h | 69 + libcxx/include/__format/format_context.h | 78 +- libcxx/include/__format/formatter_output.h | 11 + .../include/__format/parser_std_format_spec.h | 1 + .../__format/range_default_formatter.h | 25 +- libcxx/include/__format/range_formatter.h | 245 ++++ libcxx/include/__memory/allocate_at_least.h | 1 + libcxx/include/format | 6 + libcxx/include/module.modulemap.in | 1 + libcxx/test/libcxx/private_headers.verify.cpp | 1 + .../test/libcxx/transitive_includes/cxx03.csv | 1 + .../test/libcxx/transitive_includes/cxx11.csv | 1 + .../test/libcxx/transitive_includes/cxx14.csv | 1 + .../test/libcxx/transitive_includes/cxx17.csv | 1 + .../test/libcxx/transitive_includes/cxx20.csv | 2 + .../test/libcxx/transitive_includes/cxx2b.csv | 2 + .../concept.formattable.compile.pass.cpp | 14 +- .../format.functions.format.pass.cpp | 58 + .../format.functions.tests.h | 1287 +++++++++++++++++ .../format.functions.vformat.pass.cpp | 71 + .../format.range.formatter/format.pass.cpp | 75 + .../format.range.formatter/parse.pass.cpp | 75 + .../set_brackets.pass.cpp | 73 + .../set_separator.pass.cpp | 72 + .../underlying.pass.cpp | 58 + .../format.functions.format.pass.cpp | 3 + .../format.tuple/format.functions.tests.h | 2 +- libcxx/test/support/format.functions.common.h | 54 + 30 files changed, 2280 insertions(+), 11 deletions(-) create mode 100644 libcxx/include/__format/range_formatter.h create mode 100644 libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.format.pass.cpp create mode 100644 libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.tests.h create mode 100644 libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.vformat.pass.cpp create mode 100644 libcxx/test/std/utilities/format/format.range/format.range.formatter/format.pass.cpp create mode 100644 libcxx/test/std/utilities/format/format.range/format.range.formatter/parse.pass.cpp create mode 100644 libcxx/test/std/utilities/format/format.range/format.range.formatter/set_brackets.pass.cpp create mode 100644 libcxx/test/std/utilities/format/format.range/format.range.formatter/set_separator.pass.cpp create mode 100644 libcxx/test/std/utilities/format/format.range/format.range.formatter/underlying.pass.cpp diff --git a/libcxx/docs/Status/FormatPaper.csv b/libcxx/docs/Status/FormatPaper.csv index 4a5f1c4b1e2dd4..9fb7cd09dc4995 100644 --- a/libcxx/docs/Status/FormatPaper.csv +++ b/libcxx/docs/Status/FormatPaper.csv @@ -30,7 +30,7 @@ Section,Description,Dependencies,Assignee,Status,First released version `P2286R8 `__,"Formatting ranges" `[format.syn] `_,"Concept ``formattable``",,Mark de Wever,|Complete|, Clang 16 `[format.string.std] `_,"std-format-spec ``type`` debug",,Mark de Wever,|Complete|,Clang 16 -`[format.range] `_,"Formatting for ranges: sequences",,Mark de Wever,|In Progress|, +`[format.range] `_,"Formatting for ranges: sequences",,Mark de Wever,|Complete|,Clang 16 `[format.range] `_,"Formatting for ranges: associative",,Mark de Wever,, `[format.range] `_,"Formatting for ranges: container adaptors",,Mark de Wever,, `[format.range] `_,"Formatting for ranges: ``pair`` and ``tuple``",,Mark de Wever,|Complete|,Clang 16 diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt index ea52f205d4980d..81d488ffc4b20e 100644 --- a/libcxx/include/CMakeLists.txt +++ b/libcxx/include/CMakeLists.txt @@ -327,6 +327,7 @@ set(files __format/formatter_tuple.h __format/parser_std_format_spec.h __format/range_default_formatter.h + __format/range_formatter.h __format/unicode.h __functional/binary_function.h __functional/binary_negate.h diff --git a/libcxx/include/__format/buffer.h b/libcxx/include/__format/buffer.h index 60c1f8093c61d4..ddfe76728e9b81 100644 --- a/libcxx/include/__format/buffer.h +++ b/libcxx/include/__format/buffer.h @@ -31,6 +31,7 @@ #include #include #include +#include #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) # pragma GCC system_header @@ -493,6 +494,74 @@ struct _LIBCPP_TEMPLATE_VIS __format_to_n_buffer final return {_VSTD::move(this->__writer_).__out_it(), this->__size_}; } }; + +// A dynamically growing buffer intended to be used for retargeting a context. +// +// P2286 Formatting ranges adds range formatting support. It allows the user to +// specify the minimum width for the entire formatted range. The width of the +// range is not known until the range is formatted. Formatting is done to an +// output_iterator so there's no guarantee it would be possible to add the fill +// to the front of the output. Instead the range is formatted to a temporary +// buffer and that buffer is formatted as a string. +// +// There is an issue with that approach, the format context used in +// std::formatter::format contains the output iterator used as part of its +// type. So using this output iterator means there needs to be a new format +// context and the format arguments need to be retargeted to the new context. +// This retargeting is done by a basic_format_context specialized for the +// __iterator of this container. +template <__fmt_char_type _CharT> +class _LIBCPP_TEMPLATE_VIS __retarget_buffer { +public: + using value_type = _CharT; + + struct __iterator { + using difference_type = ptrdiff_t; + + _LIBCPP_HIDE_FROM_ABI constexpr explicit __iterator(__retarget_buffer& __buffer) + : __buffer_(std::addressof(__buffer)) {} + _LIBCPP_HIDE_FROM_ABI constexpr __iterator& operator=(const _CharT& __c) { + __buffer_->push_back(__c); + return *this; + } + _LIBCPP_HIDE_FROM_ABI constexpr __iterator& operator=(_CharT&& __c) { + __buffer_->push_back(__c); + return *this; + } + + _LIBCPP_HIDE_FROM_ABI constexpr __iterator& operator*() { return *this; } + _LIBCPP_HIDE_FROM_ABI constexpr __iterator& operator++() { return *this; } + _LIBCPP_HIDE_FROM_ABI constexpr __iterator operator++(int) { return *this; } + __retarget_buffer* __buffer_; + }; + + _LIBCPP_HIDE_FROM_ABI explicit __retarget_buffer(size_t __size_hint) { __buffer_.reserve(__size_hint); } + + _LIBCPP_HIDE_FROM_ABI __iterator __make_output_iterator() { return __iterator{*this}; } + + _LIBCPP_HIDE_FROM_ABI void push_back(_CharT __c) { __buffer_.push_back(__c); } + + template <__fmt_char_type _InCharT> + _LIBCPP_HIDE_FROM_ABI void __copy(basic_string_view<_InCharT> __str) { + __buffer_.insert(__buffer_.end(), __str.begin(), __str.end()); + } + + template <__fmt_char_type _InCharT, class _UnaryOperation> + _LIBCPP_HIDE_FROM_ABI void __transform(const _InCharT* __first, const _InCharT* __last, _UnaryOperation __operation) { + _LIBCPP_ASSERT(__first <= __last, "not a valid range"); + std::transform(__first, __last, std::back_inserter(__buffer_), std::move(__operation)); + } + + _LIBCPP_HIDE_FROM_ABI void __fill(size_t __n, _CharT __value) { __buffer_.insert(__buffer_.end(), __n, __value); } + + _LIBCPP_HIDE_FROM_ABI basic_string_view<_CharT> __view() { return {__buffer_.data(), __buffer_.size()}; } + +private: + // Use vector instead of string to avoid adding zeros after every append + // operation. The buffer is exposed as a string_view and not as a c-string. + vector<_CharT> __buffer_; +}; + } // namespace __format #endif //_LIBCPP_STD_VER > 17 diff --git a/libcxx/include/__format/format_context.h b/libcxx/include/__format/format_context.h index 882a6049bb1eb5..85e00eb222cf4c 100644 --- a/libcxx/include/__format/format_context.h +++ b/libcxx/include/__format/format_context.h @@ -11,13 +11,19 @@ #define _LIBCPP___FORMAT_FORMAT_CONTEXT_H #include <__availability> +#include <__concepts/same_as.h> #include <__config> #include <__format/buffer.h> +#include <__format/format_arg.h> +#include <__format/format_arg_store.h> #include <__format/format_args.h> +#include <__format/format_error.h> #include <__format/format_fwd.h> #include <__iterator/back_insert_iterator.h> #include <__iterator/concepts.h> +#include <__memory/addressof.h> #include <__utility/move.h> +#include <__variant/monostate.h> #include #ifndef _LIBCPP_HAS_NO_LOCALIZATION @@ -138,8 +144,78 @@ class : __out_it_(_VSTD::move(__out_it)), __args_(__args) {} #endif }; -_LIBCPP_CTAD_SUPPORTED_FOR_TYPE(basic_format_context); +// A specialization for __retarget_buffer +// +// See __retarget_buffer for the motivation for this specialization. +// +// This context holds a reference to the instance of the basic_format_context +// that is retargeted. It converts a formatting argument when it is requested +// during formatting. It is expected that the usage of the arguments is rare so +// the lookups are not expected to be used often. An alternative would be to +// convert all elements during construction. +// +// The elements of the retargets context are only used when an underlying +// formatter uses a locale specific formatting or an formatting argument is +// part for the format spec. For example +// format("{:256:{}}", input, 8); +// Here the width of an element in input is determined dynamically. +// Note when the top-level element has no width the retargeting is not needed. +template +class _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT + basic_format_context::__iterator, _CharT> { +public: + using iterator = typename __format::__retarget_buffer<_CharT>::__iterator; + using char_type = _CharT; + template + using formatter_type = formatter<_Tp, _CharT>; + + template + _LIBCPP_HIDE_FROM_ABI explicit basic_format_context(iterator __out_it, _Context& __ctx) + : __out_it_(std::move(__out_it)), +# ifndef _LIBCPP_HAS_NO_LOCALIZATION + __loc_([](void* __c) { return static_cast<_Context*>(__c)->locale(); }), +# endif + __ctx_(std::addressof(__ctx)), + __arg_([](void* __c, size_t __id) { + return std::visit_format_arg( + [&](auto __arg) -> basic_format_arg { + if constexpr (same_as) + return {}; + else if constexpr (same_as::handle>) + // At the moment it's not possible for formatting to use a re-targeted handle. + // TODO FMT add this when support is needed. + std::__throw_format_error("Re-targeting handle not supported"); + else + return basic_format_arg{ + __format::__determine_arg_t(), + __basic_format_arg_value(__arg)}; + }, + static_cast<_Context*>(__c)->arg(__id)); + }) { + } + + _LIBCPP_HIDE_FROM_ABI basic_format_arg arg(size_t __id) const noexcept { + return __arg_(__ctx_, __id); + } +# ifndef _LIBCPP_HAS_NO_LOCALIZATION + _LIBCPP_HIDE_FROM_ABI _VSTD::locale locale() { return __loc_(__ctx_); } +# endif + _LIBCPP_HIDE_FROM_ABI iterator out() { return std::move(__out_it_); } + _LIBCPP_HIDE_FROM_ABI void advance_to(iterator __it) { __out_it_ = std::move(__it); } + +private: + iterator __out_it_; + +# ifndef _LIBCPP_HAS_NO_LOCALIZATION + std::locale (*__loc_)(void* __ctx); +# endif + + void* __ctx_; + basic_format_arg (*__arg_)(void* __ctx, size_t __id); +}; + +_LIBCPP_CTAD_SUPPORTED_FOR_TYPE(basic_format_context); #endif //_LIBCPP_STD_VER > 17 _LIBCPP_END_NAMESPACE_STD diff --git a/libcxx/include/__format/formatter_output.h b/libcxx/include/__format/formatter_output.h index 70eae1564cbccc..467692559ce918 100644 --- a/libcxx/include/__format/formatter_output.h +++ b/libcxx/include/__format/formatter_output.h @@ -102,6 +102,10 @@ _LIBCPP_HIDE_FROM_ABI auto __copy(basic_string_view<_CharT> __str, output_iterat if constexpr (_VSTD::same_as>>) { __out_it.__get_container()->__copy(__str); return __out_it; + } else if constexpr (_VSTD::same_as::__iterator>) { + __out_it.__buffer_->__copy(__str); + return __out_it; } else { return std::ranges::copy(__str, _VSTD::move(__out_it)).out; } @@ -132,6 +136,10 @@ __transform(const _CharT* __first, if constexpr (_VSTD::same_as>>) { __out_it.__get_container()->__transform(__first, __last, _VSTD::move(__operation)); return __out_it; + } else if constexpr (_VSTD::same_as::__iterator>) { + __out_it.__buffer_->__transform(__first, __last, _VSTD::move(__operation)); + return __out_it; } else { return std::ranges::transform(__first, __last, _VSTD::move(__out_it), __operation).out; } @@ -145,6 +153,9 @@ _LIBCPP_HIDE_FROM_ABI _OutIt __fill(_OutIt __out_it, size_t __n, _CharT __value) if constexpr (_VSTD::same_as>>) { __out_it.__get_container()->__fill(__n, __value); return __out_it; + } else if constexpr (_VSTD::same_as::__iterator>) { + __out_it.__buffer_->__fill(__n, __value); + return __out_it; } else { return std::ranges::fill_n(_VSTD::move(__out_it), __n, __value); } diff --git a/libcxx/include/__format/parser_std_format_spec.h b/libcxx/include/__format/parser_std_format_spec.h index 36f6505e7debf8..c03cec97963644 100644 --- a/libcxx/include/__format/parser_std_format_spec.h +++ b/libcxx/include/__format/parser_std_format_spec.h @@ -139,6 +139,7 @@ inline constexpr __fields __fields_pointer{.__type_ = true}; # if _LIBCPP_STD_VER > 20 inline constexpr __fields __fields_tuple{.__type_ = false, .__allow_colon_in_fill_ = true}; +inline constexpr __fields __fields_range{.__type_ = false, .__allow_colon_in_fill_ = true}; # endif enum class _LIBCPP_ENUM_VIS __alignment : uint8_t { diff --git a/libcxx/include/__format/range_default_formatter.h b/libcxx/include/__format/range_default_formatter.h index 56558f3ffb6c45..ee1dc52e37526e 100644 --- a/libcxx/include/__format/range_default_formatter.h +++ b/libcxx/include/__format/range_default_formatter.h @@ -19,9 +19,11 @@ #include <__config> #include <__format/concepts.h> #include <__format/formatter.h> +#include <__format/range_formatter.h> #include <__ranges/concepts.h> #include <__type_traits/remove_cvref.h> #include <__utility/pair.h> +#include #include _LIBCPP_BEGIN_NAMESPACE_STD @@ -104,7 +106,28 @@ struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT __range_default_formatte template struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT __range_default_formatter { - __range_default_formatter() = delete; // TODO FMT Implement +private: + using __maybe_const_r = __fmt_maybe_const<_Rp, _CharT>; + range_formatter>, _CharT> __underlying_; + +public: + _LIBCPP_HIDE_FROM_ABI constexpr void set_separator(basic_string_view<_CharT> __separator) { + __underlying_.set_separator(__separator); + } + _LIBCPP_HIDE_FROM_ABI constexpr void + set_brackets(basic_string_view<_CharT> __opening_bracket, basic_string_view<_CharT> __closing_bracket) { + __underlying_.set_brackets(__opening_bracket, __closing_bracket); + } + + template + _LIBCPP_HIDE_FROM_ABI constexpr typename _ParseContext::iterator parse(_ParseContext& __ctx) { + return __underlying_.parse(__ctx); + } + + template + _LIBCPP_HIDE_FROM_ABI typename FormatContext::iterator format(__maybe_const_r& __range, FormatContext& __ctx) const { + return __underlying_.format(__range, __ctx); + } }; template diff --git a/libcxx/include/__format/range_formatter.h b/libcxx/include/__format/range_formatter.h new file mode 100644 index 00000000000000..f2a4342e2ec665 --- /dev/null +++ b/libcxx/include/__format/range_formatter.h @@ -0,0 +1,245 @@ +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef _LIBCPP___FORMAT_RANGE_FORMATTER_H +#define _LIBCPP___FORMAT_RANGE_FORMATTER_H + +#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) +# pragma GCC system_header +#endif + +#include <__algorithm/ranges_copy.h> +#include <__availability> +#include <__chrono/statically_widen.h> +#include <__concepts/same_as.h> +#include <__config> +#include <__format/buffer.h> +#include <__format/concepts.h> +#include <__format/format_args.h> +#include <__format/format_context.h> +#include <__format/format_error.h> +#include <__format/formatter.h> +#include <__format/formatter_output.h> +#include <__format/parser_std_format_spec.h> +#include <__iterator/back_insert_iterator.h> +#include <__ranges/concepts.h> +#include <__type_traits/remove_cvref.h> +#include + +_LIBCPP_BEGIN_NAMESPACE_STD + +#if _LIBCPP_STD_VER > 20 + +template + requires same_as, _Tp> && formattable<_Tp, _CharT> +struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT range_formatter { + _LIBCPP_HIDE_FROM_ABI constexpr void set_separator(basic_string_view<_CharT> __separator) { + __separator_ = __separator; + } + _LIBCPP_HIDE_FROM_ABI constexpr void + set_brackets(basic_string_view<_CharT> __opening_bracket, basic_string_view<_CharT> __closing_bracket) { + __opening_bracket_ = __opening_bracket; + __closing_bracket_ = __closing_bracket; + } + + _LIBCPP_HIDE_FROM_ABI constexpr formatter<_Tp, _CharT>& underlying() { return __underlying_; } + _LIBCPP_HIDE_FROM_ABI constexpr const formatter<_Tp, _CharT>& underlying() const { return __underlying_; } + + template + _LIBCPP_HIDE_FROM_ABI constexpr typename _ParseContext::iterator parse(_ParseContext& __parse_ctx) { + const _CharT* __begin = __parser_.__parse(__parse_ctx, __format_spec::__fields_range); + const _CharT* __end = __parse_ctx.end(); + if (__begin == __end) + return __begin; + + // The n field overrides a possible m type, therefore delay applying the + // effect of n until the type has been procesed. + bool __clear_brackets = (*__begin == _CharT('n')); + if (__clear_brackets) { + ++__begin; + if (__begin == __end) { + // Since there is no more data, clear the brackets before returning. + set_brackets({}, {}); + return __begin; + } + } + + __parse_type(__begin, __end); + if (__clear_brackets) + set_brackets({}, {}); + if (__begin == __end) + return __begin; + + bool __has_range_underlying_spec = *__begin == _CharT(':'); + if (__parser_.__type_ != __format_spec::__type::__default) { + // [format.range.formatter]/6 + // If the range-type is s or ?s, then there shall be no n option and no + // range-underlying-spec. + if (__clear_brackets) { + if (__parser_.__type_ == __format_spec::__type::__string) + std::__throw_format_error("The n option and type s can't be used together"); + std::__throw_format_error("The n option and type ?s can't be used together"); + } + if (__has_range_underlying_spec) { + if (__parser_.__type_ == __format_spec::__type::__string) + std::__throw_format_error("Type s and an underlying format specification can't be used together"); + std::__throw_format_error("Type ?s and an underlying format specification can't be used together"); + } + } else if (!__has_range_underlying_spec) + std::__set_debug_format(__underlying_); + + if (__has_range_underlying_spec) { + // range-underlying-spec: + // : format-spec + ++__begin; + if (__begin == __end) + return __begin; + + __parse_ctx.advance_to(__begin); + __begin = __underlying_.parse(__parse_ctx); + } + + if (__begin != __end && *__begin != _CharT('}')) + std::__throw_format_error("The format-spec should consume the input or end with a '}'"); + + return __begin; + } + + template + requires formattable, _CharT> && + same_as>, _Tp> + _LIBCPP_HIDE_FROM_ABI typename _FormatContext::iterator format(_Rp&& __range, _FormatContext& __ctx) const { + __format_spec::__parsed_specifications<_CharT> __specs = __parser_.__get_parsed_std_specifications(__ctx); + + if (!__specs.__has_width()) + return __format_range(__range, __ctx, __specs); + + // The size of the buffer needed is: + // - open bracket characters + // - close bracket character + // - n elements where every element may have a different size + // - (n -1) separators + // The size of the element is hard to predict, knowing the type helps but + // it depends on the format-spec. As an initial estimate we guess 6 + // characters. + // Typically both brackets are 1 character and the separator is 2 + // characters. Which means there will be + // (n - 1) * 2 + 1 + 1 = n * 2 character + // So estimate 8 times the range size as buffer. + __format::__retarget_buffer<_CharT> __buffer{8 * ranges::size(__range)}; + basic_format_context::__iterator, _CharT> __c{ + __buffer.__make_output_iterator(), __ctx}; + + __format_range(__range, __c, __specs); + + return __formatter::__write_string_no_precision(__buffer.__view(), __ctx.out(), __specs); + } + + template + typename _FormatContext::iterator _LIBCPP_HIDE_FROM_ABI + __format_range(_Rp&& __range, _FormatContext& __ctx, __format_spec::__parsed_specifications<_CharT> __specs) const { + if constexpr (same_as<_Tp, _CharT>) { + switch (__specs.__std_.__type_) { + case __format_spec::__type::__string: + case __format_spec::__type::__debug: + return __format_as_string(__range, __ctx, __specs.__std_.__type_ == __format_spec::__type::__debug); + default: + return __format_as_sequence(__range, __ctx); + } + } else + return __format_as_sequence(__range, __ctx); + } + + template + _LIBCPP_HIDE_FROM_ABI typename _FormatContext::iterator + __format_as_string(_Rp&& __range, _FormatContext& __ctx, bool __debug_format) const { + // When the range is contiguous use a basic_string_view instead to avoid a + // copy of the underlying data. The basic_string_view formatter + // specialization is the "basic" string formatter in libc++. + if constexpr (ranges::contiguous_range<_Rp>) { + std::formatter, _CharT> __formatter; + if (__debug_format) + __formatter.set_debug_format(); + return __formatter.format(basic_string_view<_CharT>{__range.data(), __range.size()}, __ctx); + } else { + std::formatter, _CharT> __formatter; + if (__debug_format) + __formatter.set_debug_format(); + // P2106's from_range has not been implemented yet. Instead use a simple + // copy operation. + // TODO FMT use basic_string's "from_range" constructor. + // return std::formatter, _CharT>{}.format(basic_string<_CharT>{from_range, __range}, __ctx); + basic_string<_CharT> __str; + ranges::copy(__range, back_insert_iterator{__str}); + return __formatter.format(__str, __ctx); + } + } + + template + _LIBCPP_HIDE_FROM_ABI typename _FormatContext::iterator + __format_as_sequence(_Rp&& __range, _FormatContext& __ctx) const { + __ctx.advance_to(ranges::copy(__opening_bracket_, __ctx.out()).out); + bool __use_separator = false; + for (auto&& __e : __range) { + if (__use_separator) + __ctx.advance_to(ranges::copy(__separator_, __ctx.out()).out); + else + __use_separator = true; + + __ctx.advance_to(__underlying_.format(__e, __ctx)); + } + + return ranges::copy(__closing_bracket_, __ctx.out()).out; + } + + __format_spec::__parser<_CharT> __parser_{.__alignment_ = __format_spec::__alignment::__left}; + +private: + _LIBCPP_HIDE_FROM_ABI constexpr void __parse_type(const _CharT*& __begin, const _CharT* __end) { + switch (*__begin) { + case _CharT('m'): + if constexpr (__fmt_pair_like<_Tp>) { + set_brackets(_LIBCPP_STATICALLY_WIDEN(_CharT, "{"), _LIBCPP_STATICALLY_WIDEN(_CharT, "}")); + set_separator(_LIBCPP_STATICALLY_WIDEN(_CharT, ", ")); + ++__begin; + } else + std::__throw_format_error("The range-format-spec type m requires two elements for a pair or tuple"); + break; + + case _CharT('s'): + if constexpr (same_as<_Tp, _CharT>) { + __parser_.__type_ = __format_spec::__type::__string; + ++__begin; + } else + std::__throw_format_error("The range-format-spec type s requires formatting a character type"); + break; + + case _CharT('?'): + ++__begin; + if (__begin == __end || *__begin != _CharT('s')) + std::__throw_format_error("The format-spec should consume the input or end with a '}'"); + if constexpr (same_as<_Tp, _CharT>) { + __parser_.__type_ = __format_spec::__type::__debug; + ++__begin; + } else + std::__throw_format_error("The range-format-spec type ?s requires formatting a character type"); + } + } + + formatter<_Tp, _CharT> __underlying_; + basic_string_view<_CharT> __separator_ = _LIBCPP_STATICALLY_WIDEN(_CharT, ", "); + basic_string_view<_CharT> __opening_bracket_ = _LIBCPP_STATICALLY_WIDEN(_CharT, "["); + basic_string_view<_CharT> __closing_bracket_ = _LIBCPP_STATICALLY_WIDEN(_CharT, "]"); +}; + +#endif //_LIBCPP_STD_VER > 20 + +_LIBCPP_END_NAMESPACE_STD + +#endif // _LIBCPP___FORMAT_RANGE_FORMATTER_H diff --git a/libcxx/include/__memory/allocate_at_least.h b/libcxx/include/__memory/allocate_at_least.h index 7ce588a25d1be3..ef205f855cb2ae 100644 --- a/libcxx/include/__memory/allocate_at_least.h +++ b/libcxx/include/__memory/allocate_at_least.h @@ -25,6 +25,7 @@ struct allocation_result { _Pointer ptr; size_t count; }; +_LIBCPP_CTAD_SUPPORTED_FOR_TYPE(allocation_result); template [[nodiscard]] _LIBCPP_HIDE_FROM_ABI constexpr diff --git a/libcxx/include/format b/libcxx/include/format index 900f27c9e03e5c..aaac95b39c4c9b 100644 --- a/libcxx/include/format +++ b/libcxx/include/format @@ -130,6 +130,11 @@ namespace std { requires same_as> constexpr range_format format_kind = see below; // since C++23 + // [format.range.formatter], class template range_formatter + template + requires same_as, T> && formattable + class range_formatter; // since C++23 + // [format.range.fmtdef], class template range-default-formatter template struct range-default-formatter; // exposition only, since C++23 @@ -194,6 +199,7 @@ namespace std { #include <__format/formatter_tuple.h> #include <__format/parser_std_format_spec.h> #include <__format/range_default_formatter.h> +#include <__format/range_formatter.h> #include <__format/unicode.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) diff --git a/libcxx/include/module.modulemap.in b/libcxx/include/module.modulemap.in index 5d4cf53aa334ea..81f503e29c6a0c 100644 --- a/libcxx/include/module.modulemap.in +++ b/libcxx/include/module.modulemap.in @@ -871,6 +871,7 @@ module std [system] { module formatter_tuple { private header "__format/formatter_tuple.h" } module parser_std_format_spec { private header "__format/parser_std_format_spec.h" } module range_default_formatter { private header "__format/range_default_formatter.h" } + module range_formatter { private header "__format/range_formatter.h" } module unicode { private header "__format/unicode.h" } } } diff --git a/libcxx/test/libcxx/private_headers.verify.cpp b/libcxx/test/libcxx/private_headers.verify.cpp index f88ee37062d462..27b1c90ff5bfa5 100644 --- a/libcxx/test/libcxx/private_headers.verify.cpp +++ b/libcxx/test/libcxx/private_headers.verify.cpp @@ -359,6 +359,7 @@ END-SCRIPT #include <__format/formatter_tuple.h> // expected-error@*:* {{use of private header from outside its module: '__format/formatter_tuple.h'}} #include <__format/parser_std_format_spec.h> // expected-error@*:* {{use of private header from outside its module: '__format/parser_std_format_spec.h'}} #include <__format/range_default_formatter.h> // expected-error@*:* {{use of private header from outside its module: '__format/range_default_formatter.h'}} +#include <__format/range_formatter.h> // expected-error@*:* {{use of private header from outside its module: '__format/range_formatter.h'}} #include <__format/unicode.h> // expected-error@*:* {{use of private header from outside its module: '__format/unicode.h'}} #include <__functional/binary_function.h> // expected-error@*:* {{use of private header from outside its module: '__functional/binary_function.h'}} #include <__functional/binary_negate.h> // expected-error@*:* {{use of private header from outside its module: '__functional/binary_negate.h'}} diff --git a/libcxx/test/libcxx/transitive_includes/cxx03.csv b/libcxx/test/libcxx/transitive_includes/cxx03.csv index 08f449b9e24253..c245d11b5a4e06 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx03.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx03.csv @@ -359,6 +359,7 @@ format string format string_view format tuple format type_traits +format vector format version forward_list algorithm forward_list atomic diff --git a/libcxx/test/libcxx/transitive_includes/cxx11.csv b/libcxx/test/libcxx/transitive_includes/cxx11.csv index 6ecca62114a46c..74bf87284072bc 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx11.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx11.csv @@ -359,6 +359,7 @@ format string format string_view format tuple format type_traits +format vector format version forward_list algorithm forward_list atomic diff --git a/libcxx/test/libcxx/transitive_includes/cxx14.csv b/libcxx/test/libcxx/transitive_includes/cxx14.csv index fb50933e5233aa..7ee16527026eba 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx14.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx14.csv @@ -361,6 +361,7 @@ format string format string_view format tuple format type_traits +format vector format version forward_list algorithm forward_list atomic diff --git a/libcxx/test/libcxx/transitive_includes/cxx17.csv b/libcxx/test/libcxx/transitive_includes/cxx17.csv index fb50933e5233aa..7ee16527026eba 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx17.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx17.csv @@ -361,6 +361,7 @@ format string format string_view format tuple format type_traits +format vector format version forward_list algorithm forward_list atomic diff --git a/libcxx/test/libcxx/transitive_includes/cxx20.csv b/libcxx/test/libcxx/transitive_includes/cxx20.csv index 7e7661d30958f0..f48e46896b7ea3 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx20.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx20.csv @@ -125,6 +125,7 @@ chrono string chrono string_view chrono tuple chrono type_traits +chrono vector chrono version cinttypes cstdint cmath type_traits @@ -369,6 +370,7 @@ format string format string_view format tuple format type_traits +format vector format version forward_list algorithm forward_list atomic diff --git a/libcxx/test/libcxx/transitive_includes/cxx2b.csv b/libcxx/test/libcxx/transitive_includes/cxx2b.csv index 21b32bf401a290..7401de433dbcac 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx2b.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx2b.csv @@ -95,6 +95,7 @@ chrono string chrono string_view chrono tuple chrono type_traits +chrono vector chrono version cinttypes cstdint cmath version @@ -277,6 +278,7 @@ format string format string_view format tuple format type_traits +format vector format version forward_list compare forward_list cstddef diff --git a/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp b/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp index 4e9e526bf032e0..e17a6d0da2d47f 100644 --- a/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp +++ b/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp @@ -200,11 +200,11 @@ void test_P1636() { // TODO validate whether the test is correct after the paper has been accepted. template void test_P2286() { - assert_is_not_formattable, CharT>(); - assert_is_not_formattable, CharT>(); - assert_is_not_formattable, CharT>(); - assert_is_not_formattable, CharT>(); - assert_is_not_formattable, CharT>(); + assert_is_formattable, CharT>(); + assert_is_formattable, CharT>(); + assert_is_formattable, CharT>(); + assert_is_formattable, CharT>(); + assert_is_formattable, CharT>(); assert_is_not_formattable, CharT>(); assert_is_not_formattable, CharT>(); @@ -220,9 +220,9 @@ void test_P2286() { assert_is_not_formattable, CharT>(); assert_is_not_formattable, CharT>(); - assert_is_not_formattable, CharT>(); + assert_is_formattable, CharT>(); - assert_is_not_formattable, CharT>(); + assert_is_formattable, CharT>(); assert_is_formattable, CharT>(); assert_is_formattable, CharT>(); diff --git a/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.format.pass.cpp b/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.format.pass.cpp new file mode 100644 index 00000000000000..78e04fe367f481 --- /dev/null +++ b/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.format.pass.cpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires same_as, T> && formattable +// class range_formatter; + +// template +// string format(format_string fmt, Args&&... args); +// template +// wstring format(wformat_string fmt, Args&&... args); + +#include +#include + +#include "format.functions.tests.h" +#include "test_format_string.h" +#include "test_macros.h" +#include "assert_macros.h" + +auto test = []( + std::basic_string_view expected, test_format_string fmt, Args&&... args) { + std::basic_string out = std::format(fmt, std::forward(args)...); + TEST_REQUIRE( + out == expected, + test_concat_message("\nFormat string ", fmt, "\nExpected output ", expected, "\nActual output ", out, '\n')); +}; + +auto test_exception = [](std::string_view, std::basic_string_view, Args&&...) { + // After P2216 most exceptions thrown by std::format become ill-formed. + // Therefore this tests does nothing. +}; + +int main(int, char**) { + format_tests(test, test_exception); + +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + format_tests(test, test_exception); +#endif + + return 0; +} diff --git a/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.tests.h b/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.tests.h new file mode 100644 index 00000000000000..b2627d596fe2f3 --- /dev/null +++ b/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.tests.h @@ -0,0 +1,1287 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_STD_UTILITIES_FORMAT_FORMAT_RANGE_FORMAT_RANGE_FORMATTER_FORMAT_FUNCTIONS_TESTS_H +#define TEST_STD_UTILITIES_FORMAT_FORMAT_RANGE_FORMAT_RANGE_FORMATTER_FORMAT_FUNCTIONS_TESTS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "format.functions.common.h" +#include "make_string.h" +#include "platform_support.h" // locale name macros +#include "test_iterators.h" +#include "test_macros.h" + +// +// Char +// + +template +void test_char_default(TestFunction check, ExceptionTest check_exception, auto&& input) { + // Note when no range-underlying-spec is present the char is escaped, + check(SV("['H', 'e', 'l', 'l', 'o']"), SV("{}"), input); + + // when one is present there is no escaping, + check(SV("[H, e, l, l, o]"), SV("{::}"), input); + // unless forced by the type specifier. + check(SV("['H', 'e', 'l', 'l', 'o']"), SV("{::?}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("['H', 'e', 'l', 'l', 'o'] "), SV("{:30}"), input); + check(SV("['H', 'e', 'l', 'l', 'o']*****"), SV("{:*<30}"), input); + check(SV("__['H', 'e', 'l', 'l', 'o']___"), SV("{:_^30}"), input); + check(SV("#####['H', 'e', 'l', 'l', 'o']"), SV("{:#>30}"), input); + + check(SV("['H', 'e', 'l', 'l', 'o'] "), SV("{:{}}"), input, 30); + check(SV("['H', 'e', 'l', 'l', 'o']*****"), SV("{:*<{}}"), input, 30); + check(SV("__['H', 'e', 'l', 'l', 'o']___"), SV("{:_^{}}"), input, 30); + check(SV("#####['H', 'e', 'l', 'l', 'o']"), SV("{:#>{}}"), input, 30); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__'H', 'e', 'l', 'l', 'o'___"), SV("{:_^28n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[H , e , l , l , o ]"), SV("{::4}"), input); + check(SV("[H***, e***, l***, l***, o***]"), SV("{::*<4}"), input); + check(SV("[_H__, _e__, _l__, _l__, _o__]"), SV("{::_^4}"), input); + check(SV("[:::H, :::e, :::l, :::l, :::o]"), SV("{:::>4}"), input); + + check(SV("[H , e , l , l , o ]"), SV("{::{}}"), input, 4); + check(SV("[H***, e***, l***, l***, o***]"), SV("{::*<{}}"), input, 4); + check(SV("[_H__, _e__, _l__, _l__, _o__]"), SV("{::_^{}}"), input, 4); + check(SV("[:::H, :::e, :::l, :::l, :::o]"), SV("{:::>{}}"), input, 4); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("A sign field isn't allowed in this format-spec", SV("{::-}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{::+}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{:: }"), input); + + check(SV("[72, 101, 108, 108, 111]"), SV("{::-d}"), input); + check(SV("[+72, +101, +108, +108, +111]"), SV("{::+d}"), input); + check(SV("[ 72, 101, 108, 108, 111]"), SV("{:: d}"), input); + + // *** alternate form *** + check_exception("An alternate form field isn't allowed in this format-spec", SV("{::#}"), input); + + check(SV("[0x48, 0x65, 0x6c, 0x6c, 0x6f]"), SV("{::#x}"), input); + + // *** zero-padding *** + check_exception("A zero-padding field isn't allowed in this format-spec", SV("{::05}"), input); + + check(SV("[00110, 00145, 00154, 00154, 00157]"), SV("{::05o}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check(SV("[H, e, l, l, o]"), SV("{::L}"), input); + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("bBcdoxX?")) + check_exception("The format-spec type has a type not supported for a char argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[:H, :e, :l, :l, :o]^^^"), SV("{:^^25::>2}"), input); + check(SV("^^[:H, :e, :l, :l, :o]^^^"), SV("{:^^{}::>2}"), input, 25); + check(SV("^^[:H, :e, :l, :l, :o]^^^"), SV("{:^^{}::>{}}"), input, 25, 2); + + check_exception("Argument index out of bounds", SV("{:^^{}::>2}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 25); +} + +template +void test_char_string(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("Hello"), SV("{:s}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("Hello "), SV("{:8s}"), input); + check(SV("Hello***"), SV("{:*<8s}"), input); + check(SV("_Hello__"), SV("{:_^8s}"), input); + check(SV("###Hello"), SV("{:#>8s}"), input); + + check(SV("Hello "), SV("{:{}s}"), input, 8); + check(SV("Hello***"), SV("{:*<{}s}"), input, 8); + check(SV("_Hello__"), SV("{:_^{}s}"), input, 8); + check(SV("###Hello"), SV("{:#>{}s}"), input, 8); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:} fmt : fmt_invalid_nested_types("bBcdoxX?")) + check_exception("The format-spec type has a type not supported for a char argument", fmt, input); + + // ***** Both have a format-spec + check_exception("Type s and an underlying format specification can't be used together", SV("{:5s:5}"), input); +} + +template +void test_char_escaped_string(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV(R"("\"Hello'")"), SV("{:?s}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV(R"("\"Hello'" )"), SV("{:13?s}"), input); + check(SV(R"("\"Hello'"***)"), SV("{:*<13?s}"), input); + check(SV(R"(_"\"Hello'"__)"), SV("{:_^13?s}"), input); + check(SV(R"(###"\"Hello'")"), SV("{:#>13?s}"), input); + + check(SV(R"("\"Hello'" )"), SV("{:{}?s}"), input, 13); + check(SV(R"("\"Hello'"***)"), SV("{:*<{}?s}"), input, 13); + check(SV(R"(_"\"Hello'"__)"), SV("{:_^{}?s}"), input, 13); + check(SV(R"(###"\"Hello'")"), SV("{:#>{}?s}"), input, 13); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:} +void test_char(TestFunction check, ExceptionTest check_exception) { + test_char_default( + check, check_exception, std::array{CharT('H'), CharT('e'), CharT('l'), CharT('l'), CharT('o')}); + + // This tests two different implementations in libc++. A basic_string_view + // formatter if the range is contiguous, a basic_string otherwise. + test_char_escaped_string( + check, + check_exception, + std::array{CharT('"'), CharT('H'), CharT('e'), CharT('l'), CharT('l'), CharT('o'), CharT('\'')}); + test_char_escaped_string( + check, + check_exception, + std::list{CharT('"'), CharT('H'), CharT('e'), CharT('l'), CharT('l'), CharT('o'), CharT('\'')}); + + // This tests two different implementations in libc++. A basic_string_view + // formatter if the range is contiguous, a basic_string otherwise. + test_char_string( + check, check_exception, std::array{CharT('H'), CharT('e'), CharT('l'), CharT('l'), CharT('o')}); + test_char_string( + check, check_exception, std::list{CharT('H'), CharT('e'), CharT('l'), CharT('l'), CharT('o')}); +} + +// +// char -> wchar_t +// + +#ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS +template +void test_char_to_wchar(TestFunction check, ExceptionTest check_exception) { + test_char_default(check, check_exception, std::array{'H', 'e', 'l', 'l', 'o'}); + + // The types s and ?s may only be used when using range_formatter + // where the types T and charT are the same. This means this can't be used for + // range_formatter even when formatter has a + // debug-enabled specialization. + + using CharT = wchar_t; + check_exception("The range-format-spec type s requires formatting a character type", + SV("{:s}"), + std::array{'H', 'e', 'l', 'l', 'o'}); + check_exception("The range-format-spec type ?s requires formatting a character type", + SV("{:?s}"), + std::array{'H', 'e', 'l', 'l', 'o'}); +} +#endif + +// +// Bool +// + +template +void test_bool(TestFunction check, ExceptionTest check_exception) { + std::array input{true, true, false}; + + check(SV("[true, true, false]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[true, true, false] "), SV("{:24}"), input); + check(SV("[true, true, false]*****"), SV("{:*<24}"), input); + check(SV("__[true, true, false]___"), SV("{:_^24}"), input); + check(SV("#####[true, true, false]"), SV("{:#>24}"), input); + + check(SV("[true, true, false] "), SV("{:{}}"), input, 24); + check(SV("[true, true, false]*****"), SV("{:*<{}}"), input, 24); + check(SV("__[true, true, false]___"), SV("{:_^{}}"), input, 24); + check(SV("#####[true, true, false]"), SV("{:#>{}}"), input, 24); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__true, true, false___"), SV("{:_^22n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[true , true , false ]"), SV("{::7}"), input); + check(SV("[true***, true***, false**]"), SV("{::*<7}"), input); + check(SV("[_true__, _true__, _false_]"), SV("{::_^7}"), input); + check(SV("[:::true, :::true, ::false]"), SV("{:::>7}"), input); + + check(SV("[true , true , false ]"), SV("{::{}}"), input, 7); + check(SV("[true***, true***, false**]"), SV("{::*<{}}"), input, 7); + check(SV("[_true__, _true__, _false_]"), SV("{::_^{}}"), input, 7); + check(SV("[:::true, :::true, ::false]"), SV("{:::>{}}"), input, 7); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("A sign field isn't allowed in this format-spec", SV("{::-}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{::+}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{:: }"), input); + + check(SV("[1, 1, 0]"), SV("{::-d}"), input); + check(SV("[+1, +1, +0]"), SV("{::+d}"), input); + check(SV("[ 1, 1, 0]"), SV("{:: d}"), input); + + // *** alternate form *** + check_exception("An alternate form field isn't allowed in this format-spec", SV("{::#}"), input); + + check(SV("[0x1, 0x1, 0x0]"), SV("{::#x}"), input); + + // *** zero-padding *** + check_exception("A zero-padding field isn't allowed in this format-spec", SV("{::05}"), input); + + check(SV("[00001, 00001, 00000]"), SV("{::05o}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check(SV("[true, true, false]"), SV("{::L}"), input); + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("bBdosxX")) + check_exception("The format-spec type has a type not supported for a bool argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[:::true, :::true, ::false]^^^"), SV("{:^^32::>7}"), input); + check(SV("^^[:::true, :::true, ::false]^^^"), SV("{:^^{}::>7}"), input, 32); + check(SV("^^[:::true, :::true, ::false]^^^"), SV("{:^^{}::>{}}"), input, 32, 7); + + check_exception("Argument index out of bounds", SV("{:^^{}::>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 32); +} + +// +// Integral +// + +template +void test_int(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("[1, 2, 42, -42]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[1, 2, 42, -42] "), SV("{:20}"), input); + check(SV("[1, 2, 42, -42]*****"), SV("{:*<20}"), input); + check(SV("__[1, 2, 42, -42]___"), SV("{:_^20}"), input); + check(SV("#####[1, 2, 42, -42]"), SV("{:#>20}"), input); + + check(SV("[1, 2, 42, -42] "), SV("{:{}}"), input, 20); + check(SV("[1, 2, 42, -42]*****"), SV("{:*<{}}"), input, 20); + check(SV("__[1, 2, 42, -42]___"), SV("{:_^{}}"), input, 20); + check(SV("#####[1, 2, 42, -42]"), SV("{:#>{}}"), input, 20); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__1, 2, 42, -42___"), SV("{:_^18n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[ 1, 2, 42, -42]"), SV("{::5}"), input); + check(SV("[1****, 2****, 42***, -42**]"), SV("{::*<5}"), input); + check(SV("[__1__, __2__, _42__, _-42_]"), SV("{::_^5}"), input); + check(SV("[::::1, ::::2, :::42, ::-42]"), SV("{:::>5}"), input); + + check(SV("[ 1, 2, 42, -42]"), SV("{::{}}"), input, 5); + check(SV("[1****, 2****, 42***, -42**]"), SV("{::*<{}}"), input, 5); + check(SV("[__1__, __2__, _42__, _-42_]"), SV("{::_^{}}"), input, 5); + check(SV("[::::1, ::::2, :::42, ::-42]"), SV("{:::>{}}"), input, 5); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check(SV("[1, 2, 42, -42]"), SV("{::-}"), input); + check(SV("[+1, +2, +42, -42]"), SV("{::+}"), input); + check(SV("[ 1, 2, 42, -42]"), SV("{:: }"), input); + + // *** alternate form *** + check(SV("[0x1, 0x2, 0x2a, -0x2a]"), SV("{::#x}"), input); + + // *** zero-padding *** + check(SV("[00001, 00002, 00042, -0042]"), SV("{::05}"), input); + check(SV("[00001, 00002, 0002a, -002a]"), SV("{::05x}"), input); + check(SV("[0x001, 0x002, 0x02a, -0x2a]"), SV("{::#05x}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check(SV("[1, 2, 42, -42]"), SV("{::L}"), input); // does nothing in this test, but is accepted. + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("bBcdoxX")) + check_exception("The format-spec type has a type not supported for an integer argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[::::1, ::::2, :::42, ::-42]^^^"), SV("{:^^33::>5}"), input); + check(SV("^^[::::1, ::::2, :::42, ::-42]^^^"), SV("{:^^{}::>5}"), input, 33); + check(SV("^^[::::1, ::::2, :::42, ::-42]^^^"), SV("{:^^{}::>{}}"), input, 33, 5); + + check_exception("Argument index out of bounds", SV("{:^^{}::>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 33); +} + +template +void test_int(TestFunction check, ExceptionTest check_exception) { + test_int(check, check_exception, std::array{1, 2, 42, -42}); + test_int(check, check_exception, std::list{1, 2, 42, -42}); + test_int(check, check_exception, std::vector{1, 2, 42, -42}); + std::array input{1, 2, 42, -42}; + test_int(check, check_exception, std::span{input}); +} + +// +// Floating point +// + +template +void test_floating_point(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[-42.5, 0, 1.25, 42.5] "), SV("{:27}"), input); + check(SV("[-42.5, 0, 1.25, 42.5]*****"), SV("{:*<27}"), input); + check(SV("__[-42.5, 0, 1.25, 42.5]___"), SV("{:_^27}"), input); + check(SV("#####[-42.5, 0, 1.25, 42.5]"), SV("{:#>27}"), input); + + check(SV("[-42.5, 0, 1.25, 42.5] "), SV("{:{}}"), input, 27); + check(SV("[-42.5, 0, 1.25, 42.5]*****"), SV("{:*<{}}"), input, 27); + check(SV("__[-42.5, 0, 1.25, 42.5]___"), SV("{:_^{}}"), input, 27); + check(SV("#####[-42.5, 0, 1.25, 42.5]"), SV("{:#>{}}"), input, 27); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__-42.5, 0, 1.25, 42.5___"), SV("{:_^25n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::5}"), input); + check(SV("[-42.5, 0****, 1.25*, 42.5*]"), SV("{::*<5}"), input); + check(SV("[-42.5, __0__, 1.25_, 42.5_]"), SV("{::_^5}"), input); + check(SV("[-42.5, ::::0, :1.25, :42.5]"), SV("{:::>5}"), input); + + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::{}}"), input, 5); + check(SV("[-42.5, 0****, 1.25*, 42.5*]"), SV("{::*<{}}"), input, 5); + check(SV("[-42.5, __0__, 1.25_, 42.5_]"), SV("{::_^{}}"), input, 5); + check(SV("[-42.5, ::::0, :1.25, :42.5]"), SV("{:::>{}}"), input, 5); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::-}"), input); + check(SV("[-42.5, +0, +1.25, +42.5]"), SV("{::+}"), input); + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{:: }"), input); + + // *** alternate form *** + check(SV("[-42.5, 0., 1.25, 42.5]"), SV("{::#}"), input); + + // *** zero-padding *** + check(SV("[-42.5, 00000, 01.25, 042.5]"), SV("{::05}"), input); + check(SV("[-42.5, 0000., 01.25, 042.5]"), SV("{::#05}"), input); + + // *** precision *** + check(SV("[-42, 0, 1.2, 42]"), SV("{::.2}"), input); + check(SV("[-42.500, 0.000, 1.250, 42.500]"), SV("{::.3f}"), input); + + check(SV("[-42, 0, 1.2, 42]"), SV("{::.{}}"), input, 2); + check(SV("[-42.500, 0.000, 1.250, 42.500]"), SV("{::.{}f}"), input, 3); + + check_exception("The format-spec precision field doesn't contain a value or arg-id", SV("{::.}"), input); + + // *** locale-specific form *** + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::L}"), input); // does not require locales present +#ifndef TEST_HAS_NO_LOCALIZATION + std::locale::global(std::locale(LOCALE_fr_FR_UTF_8)); + check(SV("[-42,5, 0, 1,25, 42,5]"), SV("{::L}"), input); + + std::locale::global(std::locale(LOCALE_en_US_UTF_8)); + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::L}"), input); + + std::locale::global(std::locale::classic()); +#endif // TEST_HAS_NO_LOCALIZATION + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("aAeEfFgG")) + check_exception("The format-spec type has a type not supported for a floating-point argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[-42.5, ::::0, :1.25, :42.5]^^^"), SV("{:^^33::>5}"), input); + check(SV("^^[-42.5, ::::0, :1.25, :42.5]^^^"), SV("{:^^{}::>5}"), input, 33); + check(SV("^^[-42.5, ::::0, :1.25, :42.5]^^^"), SV("{:^^{}::>{}}"), input, 33, 5); + + check(SV("^^[::-42, ::::0, ::1.2, :::42]^^^"), SV("{:^^33::>5.2}"), input); + check(SV("^^[::-42, ::::0, ::1.2, :::42]^^^"), SV("{:^^{}::>5.2}"), input, 33); + check(SV("^^[::-42, ::::0, ::1.2, :::42]^^^"), SV("{:^^{}::>{}.2}"), input, 33, 5); + check(SV("^^[::-42, ::::0, ::1.2, :::42]^^^"), SV("{:^^{}::>{}.{}}"), input, 33, 5, 2); + + check_exception("Argument index out of bounds", SV("{:^^{}::>5.2}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}.2}"), input, 33); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}.{}}"), input, 33, 5); +} + +template +void test_floating_point(TestFunction check, ExceptionTest check_exception) { + test_floating_point(check, check_exception, std::array{-42.5f, 0.0f, 1.25f, 42.5f}); + test_floating_point(check, check_exception, std::vector{-42.5, 0.0, 1.25, 42.5}); + + std::array input{-42.5l, 0.0l, 1.25l, 42.5l}; + test_floating_point(check, check_exception, std::span{input}); +} + +// +// Pointer +// + +template +void test_pointer(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("[0x0]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[0x0] "), SV("{:10}"), input); + check(SV("[0x0]*****"), SV("{:*<10}"), input); + check(SV("__[0x0]___"), SV("{:_^10}"), input); + check(SV("#####[0x0]"), SV("{:#>10}"), input); + + check(SV("[0x0] "), SV("{:{}}"), input, 10); + check(SV("[0x0]*****"), SV("{:*<{}}"), input, 10); + check(SV("__[0x0]___"), SV("{:_^{}}"), input, 10); + check(SV("#####[0x0]"), SV("{:#>{}}"), input, 10); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("_0x0_"), SV("{:_^5n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[ 0x0]"), SV("{::5}"), input); + check(SV("[0x0**]"), SV("{::*<5}"), input); + check(SV("[_0x0_]"), SV("{::_^5}"), input); + check(SV("[::0x0]"), SV("{:::>5}"), input); + + check(SV("[ 0x0]"), SV("{::{}}"), input, 5); + check(SV("[0x0**]"), SV("{::*<{}}"), input, 5); + check(SV("[_0x0_]"), SV("{::_^{}}"), input, 5); + check(SV("[::0x0]"), SV("{:::>{}}"), input, 5); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::-}"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{::05}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::L}"), input); + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("p")) + check_exception("The format-spec type has a type not supported for a pointer argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[::0x0]^^^"), SV("{:^^12::>5}"), input); + check(SV("^^[::0x0]^^^"), SV("{:^^{}::>5}"), input, 12); + check(SV("^^[::0x0]^^^"), SV("{:^^{}::>{}}"), input, 12, 5); + + check(SV("^^[::0x0]^^^"), SV("{:^^12::>5}"), input); + check(SV("^^[::0x0]^^^"), SV("{:^^{}::>5}"), input, 12); + check(SV("^^[::0x0]^^^"), SV("{:^^{}::>{}}"), input, 12, 5); + + check_exception("Argument index out of bounds", SV("{:^^{}::>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 12); +} + +template +void test_pointer(TestFunction check, ExceptionTest check_exception) { + test_pointer(check, check_exception, std::array{nullptr}); + test_pointer(check, check_exception, std::array{static_cast(0)}); + test_pointer(check, check_exception, std::array{static_cast(0)}); +} + +// +// String +// + +template +void test_string(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV(R"(["Hello", "world"])"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV(R"(["Hello", "world"] )"), SV("{:23}"), input); + check(SV(R"(["Hello", "world"]*****)"), SV("{:*<23}"), input); + check(SV(R"(__["Hello", "world"]___)"), SV("{:_^23}"), input); + check(SV(R"(#####["Hello", "world"])"), SV("{:#>23}"), input); + + check(SV(R"(["Hello", "world"] )"), SV("{:{}}"), input, 23); + check(SV(R"(["Hello", "world"]*****)"), SV("{:*<{}}"), input, 23); + check(SV(R"(__["Hello", "world"]___)"), SV("{:_^{}}"), input, 23); + check(SV(R"(#####["Hello", "world"])"), SV("{:#>{}}"), input, 23); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV(R"(_"Hello", "world"_)"), SV("{:_^18n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV(R"([Hello , world ])"), SV("{::8}"), input); + check(SV(R"([Hello***, world***])"), SV("{::*<8}"), input); + check(SV(R"([_Hello__, _world__])"), SV("{::_^8}"), input); + check(SV(R"([:::Hello, :::world])"), SV("{:::>8}"), input); + + check(SV(R"([Hello , world ])"), SV("{::{}}"), input, 8); + check(SV(R"([Hello***, world***])"), SV("{::*<{}}"), input, 8); + check(SV(R"([_Hello__, _world__])"), SV("{::_^{}}"), input, 8); + check(SV(R"([:::Hello, :::world])"), SV("{:::>{}}"), input, 8); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::-}"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{::05}"), input); + + // *** precision *** + check(SV(R"([Hel, wor])"), SV("{::.3}"), input); + + check(SV(R"([Hel, wor])"), SV("{::.{}}"), input, 3); + + check_exception("The format-spec precision field doesn't contain a value or arg-id", SV("{::.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::L}"), input); + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("s?")) + check_exception("The format-spec type has a type not supported for a string argument", fmt, input); + + // ***** Both have a format-spec + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^25::>8}"), input); + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^{}::>8}"), input, 25); + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^{}::>{}}"), input, 25, 8); + + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^25::>8}"), input); + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^{}::>8}"), input, 25); + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^{}::>{}}"), input, 25, 8); + + check_exception("Argument index out of bounds", SV("{:^^{}::>8}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 25); +} + +template +void test_string(TestFunction check, ExceptionTest check_exception) { + test_string(check, check_exception, std::array{CSTR("Hello"), CSTR("world")}); + test_string(check, check_exception, std::array{STR("Hello"), STR("world")}); + test_string(check, check_exception, std::array{SV("Hello"), SV("world")}); +} + +// +// Handle +// + +template +void test_status(TestFunction check, ExceptionTest check_exception) { + std::array input{status::foo, status::bar, status::foobar}; + + check(SV("[0xaaaa, 0x5555, 0xaa55]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[0xaaaa, 0x5555, 0xaa55] "), SV("{:29}"), input); + check(SV("[0xaaaa, 0x5555, 0xaa55]*****"), SV("{:*<29}"), input); + check(SV("__[0xaaaa, 0x5555, 0xaa55]___"), SV("{:_^29}"), input); + check(SV("#####[0xaaaa, 0x5555, 0xaa55]"), SV("{:#>29}"), input); + + check(SV("[0xaaaa, 0x5555, 0xaa55] "), SV("{:{}}"), input, 29); + check(SV("[0xaaaa, 0x5555, 0xaa55]*****"), SV("{:*<{}}"), input, 29); + check(SV("__[0xaaaa, 0x5555, 0xaa55]___"), SV("{:_^{}}"), input, 29); + check(SV("#####[0xaaaa, 0x5555, 0xaa55]"), SV("{:#>{}}"), input, 29); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__0xaaaa, 0x5555, 0xaa55___"), SV("{:_^27n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check_exception("The format-spec type has a type not supported for a status argument", SV("{::*<7}"), input); + for (std::basic_string_view fmt : fmt_invalid_nested_types("sxX")) + check_exception("The format-spec type has a type not supported for a status argument", fmt, input); + + check(SV("[0xaaaa, 0x5555, 0xaa55]"), SV("{::x}"), input); + check(SV("[0XAAAA, 0X5555, 0XAA55]"), SV("{::X}"), input); + check(SV("[foo, bar, foobar]"), SV("{::s}"), input); + + // ***** Both have a format-spec + check(SV("^^[0XAAAA, 0X5555, 0XAA55]^^^"), SV("{:^^29:X}"), input); + check(SV("^^[0XAAAA, 0X5555, 0XAA55]^^^"), SV("{:^^{}:X}"), input, 29); + + check_exception("Argument index out of bounds", SV("{:^^{}:X}"), input); +} + +// +// Pair +// + +template +void test_pair_tuple(TestFunction check, ExceptionTest check_exception, auto&& input) { + // [format.range.formatter]/3 + // For range_formatter, the format-spec in a + // range-underlying-spec, if any, is interpreted by formatter. + // + // template + // constexpr typename ParseContext::iterator + // parse(ParseContext& ctx); + // [format.tuple]/7 + // ... if e.set_debug_format() is a valid expression, calls + // e.set_debug_format(). + // So when there is no range-underlying-spec, there is no need to call parse + // thus the char element is not escaped. + // TODO FMT P2733 addresses this issue. + check(SV("[(1, a), (42, *)]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[(1, a), (42, *)] "), SV("{:22}"), input); + check(SV("[(1, a), (42, *)]*****"), SV("{:*<22}"), input); + check(SV("__[(1, a), (42, *)]___"), SV("{:_^22}"), input); + check(SV("#####[(1, a), (42, *)]"), SV("{:#>22}"), input); + + check(SV("[(1, a), (42, *)] "), SV("{:{}}"), input, 22); + check(SV("[(1, a), (42, *)]*****"), SV("{:*<{}}"), input, 22); + check(SV("__[(1, a), (42, *)]___"), SV("{:_^{}}"), input, 22); + check(SV("#####[(1, a), (42, *)]"), SV("{:#>{}}"), input, 22); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__(1, a), (42, *)___"), SV("{:_^20n}"), input); + check(SV("__(1, a), (42, *)___"), SV("{:_^20nm}"), input); // m should have no effect + + // *** type *** + check(SV("__{(1, a), (42, *)}___"), SV("{:_^22m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[(1, 'a') , (42, '*') ]"), SV("{::11}"), input); + check(SV("[(1, 'a')***, (42, '*')**]"), SV("{::*<11}"), input); + check(SV("[_(1, 'a')__, _(42, '*')_]"), SV("{::_^11}"), input); + check(SV("[###(1, 'a'), ##(42, '*')]"), SV("{::#>11}"), input); + + check(SV("[(1, 'a') , (42, '*') ]"), SV("{::{}}"), input, 11); + check(SV("[(1, 'a')***, (42, '*')**]"), SV("{::*<{}}"), input, 11); + check(SV("[_(1, 'a')__, _(42, '*')_]"), SV("{::_^{}}"), input, 11); + check(SV("[###(1, 'a'), ##(42, '*')]"), SV("{::#>{}}"), input, 11); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:::<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{::+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{::05}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::L}"), input); + + // *** type *** + check(SV("[1: 'a', 42: '*']"), SV("{::m}"), input); + check(SV("[1, 'a', 42, '*']"), SV("{::n}"), input); + for (std::basic_string_view fmt : fmt_invalid_nested_types("")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[###(1, 'a'), ##(42, '*')]^^^"), SV("{:^^31:#>11}"), input); + check(SV("^^[###(1, 'a'), ##(42, '*')]^^^"), SV("{:^^31:#>11}"), input); + check(SV("^^[###(1, 'a'), ##(42, '*')]^^^"), SV("{:^^{}:#>11}"), input, 31); + check(SV("^^[###(1, 'a'), ##(42, '*')]^^^"), SV("{:^^{}:#>{}}"), input, 31, 11); + + check_exception("Argument index out of bounds", SV("{:^^{}:#>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}:#>{}}"), input, 31); + + check(SV("1: 'a', 42: '*'"), SV("{:n:m}"), input); + check(SV("1, 'a', 42, '*'"), SV("{:n:n}"), input); + check(SV("{1: 'a', 42: '*'}"), SV("{:m:m}"), input); + check(SV("{1, 'a', 42, '*'}"), SV("{:m:n}"), input); +} + +template +void test_pair_tuple(TestFunction check, ExceptionTest check_exception) { + test_pair_tuple( + check, check_exception, std::array{std::make_pair(1, CharT('a')), std::make_pair(42, CharT('*'))}); + test_pair_tuple( + check, check_exception, std::array{std::make_tuple(1, CharT('a')), std::make_tuple(42, CharT('*'))}); +} + +// +// Tuple 1 +// + +template +void test_tuple_int(TestFunction check, ExceptionTest check_exception) { + std::array input{std::make_tuple(42), std::make_tuple(99)}; + + check(SV("[(42), (99)]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[(42), (99)] "), SV("{:17}"), input); + check(SV("[(42), (99)]*****"), SV("{:*<17}"), input); + check(SV("__[(42), (99)]___"), SV("{:_^17}"), input); + check(SV("#####[(42), (99)]"), SV("{:#>17}"), input); + + check(SV("[(42), (99)] "), SV("{:{}}"), input, 17); + check(SV("[(42), (99)]*****"), SV("{:*<{}}"), input, 17); + check(SV("__[(42), (99)]___"), SV("{:_^{}}"), input, 17); + check(SV("#####[(42), (99)]"), SV("{:#>{}}"), input, 17); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__(42), (99)___"), SV("{:_^15n}"), input); + + // *** type *** + check(SV("__{(42), (99)}___"), SV("{:_^17m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[(42) , (99) ]"), SV("{::7}"), input); + check(SV("[(42)***, (99)***]"), SV("{::*<7}"), input); + check(SV("[_(42)__, _(99)__]"), SV("{::_^7}"), input); + check(SV("[###(42), ###(99)]"), SV("{::#>7}"), input); + + check(SV("[(42) , (99) ]"), SV("{::{}}"), input, 7); + check(SV("[(42)***, (99)***]"), SV("{::*<{}}"), input, 7); + check(SV("[_(42)__, _(99)__]"), SV("{::_^{}}"), input, 7); + check(SV("[###(42), ###(99)]"), SV("{::#>{}}"), input, 7); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:::<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{::+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{::05}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::L}"), input); + + // *** type *** + check(SV("[42, 99]"), SV("{::n}"), input); + for (std::basic_string_view fmt : fmt_invalid_nested_types("")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[###(42), ###(99)]^^^"), SV("{:^^23:#>7}"), input); + check(SV("^^[###(42), ###(99)]^^^"), SV("{:^^23:#>7}"), input); + check(SV("^^[###(42), ###(99)]^^^"), SV("{:^^{}:#>7}"), input, 23); + check(SV("^^[###(42), ###(99)]^^^"), SV("{:^^{}:#>{}}"), input, 23, 7); + + check_exception("Argument index out of bounds", SV("{:^^{}:#>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}:#>{}}"), input, 23); +} + +// +// Tuple 3 +// + +template +void test_tuple_int_int_int(TestFunction check, ExceptionTest check_exception) { + std::array input{std::make_tuple(42, 99, 0), std::make_tuple(1, 10, 100)}; + + check(SV("[(42, 99, 0), (1, 10, 100)]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[(42, 99, 0), (1, 10, 100)] "), SV("{:32}"), input); + check(SV("[(42, 99, 0), (1, 10, 100)]*****"), SV("{:*<32}"), input); + check(SV("__[(42, 99, 0), (1, 10, 100)]___"), SV("{:_^32}"), input); + check(SV("#####[(42, 99, 0), (1, 10, 100)]"), SV("{:#>32}"), input); + + check(SV("[(42, 99, 0), (1, 10, 100)] "), SV("{:{}}"), input, 32); + check(SV("[(42, 99, 0), (1, 10, 100)]*****"), SV("{:*<{}}"), input, 32); + check(SV("__[(42, 99, 0), (1, 10, 100)]___"), SV("{:_^{}}"), input, 32); + check(SV("#####[(42, 99, 0), (1, 10, 100)]"), SV("{:#>{}}"), input, 32); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__(42, 99, 0), (1, 10, 100)___"), SV("{:_^30n}"), input); + + // *** type *** + check(SV("__{(42, 99, 0), (1, 10, 100)}___"), SV("{:_^32m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[(42, 99, 0) , (1, 10, 100) ]"), SV("{::14}"), input); + check(SV("[(42, 99, 0)***, (1, 10, 100)**]"), SV("{::*<14}"), input); + check(SV("[_(42, 99, 0)__, _(1, 10, 100)_]"), SV("{::_^14}"), input); + check(SV("[###(42, 99, 0), ##(1, 10, 100)]"), SV("{::#>14}"), input); + + check(SV("[(42, 99, 0) , (1, 10, 100) ]"), SV("{::{}}"), input, 14); + check(SV("[(42, 99, 0)***, (1, 10, 100)**]"), SV("{::*<{}}"), input, 14); + check(SV("[_(42, 99, 0)__, _(1, 10, 100)_]"), SV("{::_^{}}"), input, 14); + check(SV("[###(42, 99, 0), ##(1, 10, 100)]"), SV("{::#>{}}"), input, 14); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:::<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{::+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{::05}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::L}"), input); + + // *** type *** + check(SV("[42, 99, 0, 1, 10, 100]"), SV("{::n}"), input); + for (std::basic_string_view fmt : fmt_invalid_nested_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[###(42, 99, 0), ##(1, 10, 100)]^^^"), SV("{:^^37:#>14}"), input); + check(SV("^^[###(42, 99, 0), ##(1, 10, 100)]^^^"), SV("{:^^37:#>14}"), input); + check(SV("^^[###(42, 99, 0), ##(1, 10, 100)]^^^"), SV("{:^^{}:#>14}"), input, 37); + check(SV("^^[###(42, 99, 0), ##(1, 10, 100)]^^^"), SV("{:^^{}:#>{}}"), input, 37, 14); + + check_exception("Argument index out of bounds", SV("{:^^{}:#>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}:#>{}}"), input, 37); +} + +// +// Ranges +// + +template +void test_with_ranges(TestFunction check, ExceptionTest check_exception, auto&& iter) { + std::ranges::subrange range{std::move(iter), std::default_sentinel}; + test_int(check, check_exception, std::move(range)); +} + +template +void test_with_ranges(TestFunction check, ExceptionTest check_exception) { + std::array input{1, 2, 42, -42}; + test_with_ranges( + check, check_exception, std::counted_iterator{cpp20_input_iterator(input.data()), input.size()}); + test_with_ranges( + check, check_exception, std::counted_iterator{forward_iterator(input.data()), input.size()}); + test_with_ranges( + check, check_exception, std::counted_iterator{bidirectional_iterator(input.data()), input.size()}); + test_with_ranges( + check, check_exception, std::counted_iterator{random_access_iterator(input.data()), input.size()}); + test_with_ranges( + check, check_exception, std::counted_iterator{contiguous_iterator(input.data()), input.size()}); +} + +// +// Driver +// + +template +void format_tests(TestFunction check, ExceptionTest check_exception) { + test_char(check, check_exception); +#ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS + if (std::same_as) // avoid testing twice + test_char_to_wchar(check, check_exception); +#endif + test_bool(check, check_exception); + test_int(check, check_exception); + test_floating_point(check, check_exception); + test_pointer(check, check_exception); + test_string(check, check_exception); + + test_status(check, check_exception); // Has its own handler with its own parser + + test_pair_tuple(check, check_exception); + test_tuple_int(check, check_exception); + test_tuple_int_int_int(check, check_exception); + + test_with_ranges(check, check_exception); +} + +#endif // TEST_STD_UTILITIES_FORMAT_FORMAT_RANGE_FORMAT_RANGE_FORMATTER_FORMAT_FUNCTIONS_TESTS_H diff --git a/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.vformat.pass.cpp b/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.vformat.pass.cpp new file mode 100644 index 00000000000000..c4738936c81aee --- /dev/null +++ b/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.functions.vformat.pass.cpp @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires same_as, T> && formattable +// class range_formatter; + +// string vformat(string_view fmt, format_args args); +// wstring vformat(wstring_view fmt, wformat_args args); + +#include +#include + +#include "format.functions.tests.h" +#include "test_macros.h" +#include "assert_macros.h" + +auto test = []( + std::basic_string_view expected, std::basic_string_view fmt, Args&&... args) { + std::basic_string out = std::vformat(fmt, std::make_format_args>(args...)); + TEST_REQUIRE( + out == expected, + test_concat_message("\nFormat string ", fmt, "\nExpected output ", expected, "\nActual output ", out, '\n')); +}; + +auto test_exception = + []( + [[maybe_unused]] std::string_view what, + [[maybe_unused]] std::basic_string_view fmt, + [[maybe_unused]] Args&&... args) { +#ifndef TEST_HAS_NO_EXCEPTIONS + try { + TEST_IGNORE_NODISCARD std::vformat(fmt, std::make_format_args>(args...)); + TEST_FAIL(test_concat_message("\nFormat string ", fmt, "\nDidn't throw an exception.\n")); + } catch (const std::format_error& e) { + TEST_LIBCPP_REQUIRE( + e.what() == what, + test_concat_message( + "\nFormat string ", fmt, "\nExpected exception ", what, "\nActual exception ", e.what(), '\n')); + + return; + } + assert(false); +#endif + }; + +int main(int, char**) { + format_tests(test, test_exception); + +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + format_tests(test, test_exception); +#endif + + return 0; +} diff --git a/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.pass.cpp b/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.pass.cpp new file mode 100644 index 00000000000000..acfe8cc9ac4134 --- /dev/null +++ b/libcxx/test/std/utilities/format/format.range/format.range.formatter/format.pass.cpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires same_as, T> && formattable +// class range_formatter + +// template +// typename FormatContext::iterator +// format(const T& ref, FormatContext& ctx) const; + +// Note this tests the basics of this function. It's tested in more detail in +// the format functions test. + +#include +#include +#include +#include + +#include "test_format_context.h" +#include "test_macros.h" +#include "make_string.h" + +#define SV(S) MAKE_STRING_VIEW(CharT, S) + +template +void test_format(StringViewT expected, std::vector arg) { + using CharT = typename StringViewT::value_type; + using String = std::basic_string; + using OutIt = std::back_insert_iterator; + using FormatCtxT = std::basic_format_context; + + const std::range_formatter formatter; + + String result; + OutIt out = std::back_inserter(result); + FormatCtxT format_ctx = test_format_context_create(out, std::make_format_args(arg)); + formatter.format(arg, format_ctx); + assert(result == expected); +} + +template +void test_fmt() { + test_format(SV("[1]"), std::vector{1}); + test_format(SV("[0]"), std::vector{0}); +} + +void test() { + test_fmt(); +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + test_fmt(); +#endif +} + +int main(int, char**) { + test(); + + return 0; +} diff --git a/libcxx/test/std/utilities/format/format.range/format.range.formatter/parse.pass.cpp b/libcxx/test/std/utilities/format/format.range/format.range.formatter/parse.pass.cpp new file mode 100644 index 00000000000000..ce1c0c93130b6e --- /dev/null +++ b/libcxx/test/std/utilities/format/format.range/format.range.formatter/parse.pass.cpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires same_as, T> && formattable +// class range_formatter + +// template +// constexpr typename ParseContext::iterator +// parse(ParseContext& ctx); + +// Note this tests the basics of this function. It's tested in more detail in +// the format functions test. + +#include +#include +#include + +#include "test_format_context.h" +#include "test_macros.h" +#include "make_string.h" + +#define SV(S) MAKE_STRING_VIEW(CharT, S) + +template +constexpr void test_parse(StringViewT fmt) { + using CharT = typename StringViewT::value_type; + auto parse_ctx = std::basic_format_parse_context(fmt); + std::range_formatter formatter; + static_assert(std::semiregular); + + std::same_as auto it = formatter.parse(parse_ctx); + assert(it == fmt.end() - (!fmt.empty() && fmt.back() == '}')); +} + +template +constexpr void test_fmt() { + test_parse(SV("")); + test_parse(SV(":d")); + + test_parse(SV("}")); + test_parse(SV(":d}")); +} + +constexpr bool test() { + test_fmt(); +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + test_fmt(); +#endif + + return true; +} + +int main(int, char**) { + test(); + static_assert(test()); + + return 0; +} diff --git a/libcxx/test/std/utilities/format/format.range/format.range.formatter/set_brackets.pass.cpp b/libcxx/test/std/utilities/format/format.range/format.range.formatter/set_brackets.pass.cpp new file mode 100644 index 00000000000000..fe2d72cb249ba7 --- /dev/null +++ b/libcxx/test/std/utilities/format/format.range/format.range.formatter/set_brackets.pass.cpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires same_as, T> && formattable +// class range_formatter + +// constexpr void constexpr void set_brackets(basic_string_view opening, +// basic_string_view closing); + +// Note this tests the basics of this function. It's tested in more detail in +// the format functions test. + +#include +#include +#include +#include + +#include "make_string.h" +#include "test_format_context.h" + +#define SV(S) MAKE_STRING_VIEW(CharT, S) + +template +constexpr void test_setter() { + std::range_formatter formatter; + formatter.set_brackets(SV("open"), SV("close")); + + // Note there is no direct way to validate this function modified the object. + if (!std::is_constant_evaluated()) { + using String = std::basic_string; + using OutIt = std::back_insert_iterator; + using FormatCtxT = std::basic_format_context; + + String result; + OutIt out = std::back_inserter(result); + FormatCtxT format_ctx = test_format_context_create(out, std::make_format_args()); + formatter.format(std::vector{0, 42, 99}, format_ctx); + assert(result == SV("open0, 42, 99close")); + } +} + +constexpr bool test() { + test_setter(); +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + test_setter(); +#endif + + return true; +} + +int main(int, char**) { + test(); + static_assert(test()); + + return 0; +} diff --git a/libcxx/test/std/utilities/format/format.range/format.range.formatter/set_separator.pass.cpp b/libcxx/test/std/utilities/format/format.range/format.range.formatter/set_separator.pass.cpp new file mode 100644 index 00000000000000..413f5b4669a49b --- /dev/null +++ b/libcxx/test/std/utilities/format/format.range/format.range.formatter/set_separator.pass.cpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires same_as, T> && formattable +// class range_formatter + +// constexpr void set_separator(basic_string_view sep); + +// Note this tests the basics of this function. It's tested in more detail in +// the format functions test. + +#include +#include +#include +#include + +#include "make_string.h" +#include "test_format_context.h" + +#define SV(S) MAKE_STRING_VIEW(CharT, S) + +template +constexpr void test_setter() { + std::range_formatter formatter; + formatter.set_separator(SV("sep")); + + // Note there is no direct way to validate this function modified the object. + if (!std::is_constant_evaluated()) { + using String = std::basic_string; + using OutIt = std::back_insert_iterator; + using FormatCtxT = std::basic_format_context; + + String result; + OutIt out = std::back_inserter(result); + FormatCtxT format_ctx = test_format_context_create(out, std::make_format_args()); + formatter.format(std::vector{0, 42, 99}, format_ctx); + assert(result == SV("[0sep42sep99]")); + } +} + +constexpr bool test() { + test_setter(); +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + test_setter(); +#endif + + return true; +} + +int main(int, char**) { + test(); + static_assert(test()); + + return 0; +} diff --git a/libcxx/test/std/utilities/format/format.range/format.range.formatter/underlying.pass.cpp b/libcxx/test/std/utilities/format/format.range/format.range.formatter/underlying.pass.cpp new file mode 100644 index 00000000000000..7d819d4ec6c371 --- /dev/null +++ b/libcxx/test/std/utilities/format/format.range/format.range.formatter/underlying.pass.cpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires same_as, T> && formattable +// class range_formatter + +// constexpr formatter& underlying(); +// constexpr const formatter& underlying() const; + +#include +#include + +#include "test_macros.h" + +template +constexpr void test_underlying() { + { + std::range_formatter formatter; + [[maybe_unused]] std::same_as&> decltype(auto) underlying = formatter.underlying(); + } + { + const std::range_formatter formatter; + [[maybe_unused]] std::same_as&> decltype(auto) underlying = formatter.underlying(); + } +} + +constexpr bool test() { + test_underlying(); +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + test_underlying(); +#endif + + return true; +} + +int main(int, char**) { + test(); + static_assert(test()); + + return 0; +} diff --git a/libcxx/test/std/utilities/format/format.tuple/format.functions.format.pass.cpp b/libcxx/test/std/utilities/format/format.tuple/format.functions.format.pass.cpp index d6f9bd376849c5..6f64016cfd57b2 100644 --- a/libcxx/test/std/utilities/format/format.tuple/format.functions.format.pass.cpp +++ b/libcxx/test/std/utilities/format/format.tuple/format.functions.format.pass.cpp @@ -8,6 +8,9 @@ // UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 // UNSUPPORTED: libcpp-has-no-incomplete-format +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + // This test requires the dylib support introduced in D92214. // XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} // XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} diff --git a/libcxx/test/std/utilities/format/format.tuple/format.functions.tests.h b/libcxx/test/std/utilities/format/format.tuple/format.functions.tests.h index 2ad2d879d4e077..3f7a1fa0322013 100644 --- a/libcxx/test/std/utilities/format/format.tuple/format.functions.tests.h +++ b/libcxx/test/std/utilities/format/format.tuple/format.functions.tests.h @@ -279,7 +279,7 @@ template void test_nested(TestFunction check, ExceptionTest check_exception, Nested&& input) { // [format.formatter.spec]/2 // A debug-enabled specialization of formatter additionally provides a - // public, constexpr, non-static member function set_­debug_­format() + // public, constexpr, non-static member function set_debug_format() // which modifies the state of the formatter to be as if the type of the // std-format-spec parsed by the last call to parse were ?. // pair and tuple are not debug-enabled specializations to the diff --git a/libcxx/test/support/format.functions.common.h b/libcxx/test/support/format.functions.common.h index 14faf3acd0d5de..9796c87f56b864 100644 --- a/libcxx/test/support/format.functions.common.h +++ b/libcxx/test/support/format.functions.common.h @@ -13,6 +13,7 @@ #include #include #include +#include #include "make_string.h" @@ -129,4 +130,57 @@ struct std::formatter { } }; +// Creates format string for the invalid types. +// +// valid contains a list of types that are valid. +// - The type ?s is the only type requiring 2 characters, use S for that type. +// - Whether n is a type or not depends on the context, is is always used. +// +// The return value is a collection of basic_strings, instead of +// basic_string_views since the values are temporaries. +namespace detail { +template +std::basic_string get_colons() { + static std::basic_string result(N, CharT(':')); + return result; +} + +constexpr std::string_view get_format_types() { + return "aAbBcdeEfFgGopsxX" +#if TEST_STD_VER > 20 + "?" +#endif + ; +} + +template +std::vector> fmt_invalid_types(std::string_view valid) { + // std::ranges::to is not available in C++20. + std::vector> result; + std::ranges::copy( + get_format_types() | std::views::filter([&](char type) { return valid.find(type) == std::string_view::npos; }) | + std::views::transform([&](char type) { return std::format(SV("{{{}{}}}"), get_colons(), type); }), + std::back_inserter(result)); + return result; +} + +} // namespace detail + +// Creates format string for the invalid types. +// +// valid contains a list of types that are valid. +// +// The return value is a collection of basic_strings, instead of +// basic_string_views since the values are temporaries. +template +std::vector> fmt_invalid_types(std::string_view valid) { + return detail::fmt_invalid_types(valid); +} + +// Like fmt_invalid_types but when the format spec is for an underlying formatter. +template +std::vector> fmt_invalid_nested_types(std::string_view valid) { + return detail::fmt_invalid_types(valid); +} + #endif // TEST_SUPPORT_FORMAT_FUNCTIONS_COMMON_H From c85cecf36bf8b64b541becb0ec3ec11379d062e7 Mon Sep 17 00:00:00 2001 From: LLVM GN Syncbot Date: Thu, 19 Jan 2023 16:20:20 +0000 Subject: [PATCH 15/61] [gn build] Port 22e8525dfdd7 --- llvm/utils/gn/secondary/libcxx/include/BUILD.gn | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/utils/gn/secondary/libcxx/include/BUILD.gn b/llvm/utils/gn/secondary/libcxx/include/BUILD.gn index 73eabf546aba42..ecf9bff5646f40 100644 --- a/llvm/utils/gn/secondary/libcxx/include/BUILD.gn +++ b/llvm/utils/gn/secondary/libcxx/include/BUILD.gn @@ -402,6 +402,7 @@ if (current_toolchain == default_toolchain) { "__format/formatter_tuple.h", "__format/parser_std_format_spec.h", "__format/range_default_formatter.h", + "__format/range_formatter.h", "__format/unicode.h", "__functional/binary_function.h", "__functional/binary_negate.h", From a459a2485b54fbd9e1f8a48061e79cdcd12b12a5 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 19 Jan 2023 17:32:02 +0100 Subject: [PATCH 16/61] [flang] Fix SELECT TYPE lowering when CLASS DEFAULT is not the last type guard CLASS DEFAULT needs to be the last attribute when fir.select_type op is created. It needs to be at its actual position in the Fortran code when the TypeGuardStmt are processed. The current lowering was crashing when CLASS DEFAULT was not at the last position. This patch fixes the issue by tracking the actual position of the CLASS DEFAULT type guard and set it at the correct position after the fir.select_type op is created. Reviewed By: jeanPerier, PeteSteinfeld Differential Revision: https://reviews.llvm.org/D142091 --- flang/lib/Lower/Bridge.cpp | 19 ++++++++++++++++++ flang/test/Lower/select-type.f90 | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 71e06ed347cdc4..9e5898e9db0eef 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2133,6 +2133,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { llvm::SmallVector attrList; llvm::SmallVector blockList; unsigned typeGuardIdx = 0; + std::size_t defaultAttrPos = std::numeric_limits::max(); bool hasLocalScope = false; for (Fortran::lower::pft::Evaluation &eval : @@ -2162,6 +2163,9 @@ class FirConverter : public Fortran::lower::AbstractConverter { // CLASS DEFAULT if (std::holds_alternative(guard.u)) { defaultBlock = e->block; + // Keep track of the actual position of the CLASS DEFAULT type guard + // in the SELECT TYPE construct. + defaultAttrPos = attrList.size(); continue; } @@ -2197,6 +2201,21 @@ class FirConverter : public Fortran::lower::AbstractConverter { blockList.push_back(defaultBlock); builder->create(loc, fir::getBase(selector), attrList, blockList); + + // If the actual position of CLASS DEFAULT type guard is not the last + // one, it needs to be put back at its correct position for the rest of + // the processing. TypeGuardStmt are processed in the same order they + // appear in the Fortran code. + if (defaultAttrPos < attrList.size() - 1) { + auto attrIt = attrList.begin(); + attrIt = attrIt + defaultAttrPos; + auto blockIt = blockList.begin(); + blockIt = blockIt + defaultAttrPos; + attrList.insert(attrIt, mlir::UnitAttr::get(context)); + blockList.insert(blockIt, defaultBlock); + attrList.pop_back(); + blockList.pop_back(); + } } else if (auto *typeGuardStmt = eval.getIf()) { // Map the type guard local symbol for the selector to a more precise diff --git a/flang/test/Lower/select-type.f90 b/flang/test/Lower/select-type.f90 index 23f60cd7cc36f7..846758281be844 100644 --- a/flang/test/Lower/select-type.f90 +++ b/flang/test/Lower/select-type.f90 @@ -722,6 +722,40 @@ subroutine select_type12(a) ! CHECK: ^bb{{.*}}: // pred: ^bb0 ! CHECK: %[[EXACT_BOX:.*]] = fir.convert %[[SELECTOR]] : (!fir.class>>) -> !fir.box>> + + ! Test correct lowering when CLASS DEFAULT is not at the last position in the + ! SELECT TYPE construct. + subroutine select_type13(a) + class(p1), pointer :: a(:) + select type (a) + class default + print*, 'default' + class is (p1) + print*, 'class' + end select + + select type (a) + type is (p1) + print*, 'type' + class default + print*, 'default' + class is (p1) + print*, 'class' + end select + + end subroutine + +! CHECK-LABEL: func.func @_QMselect_type_lower_testPselect_type13 +! CHECK: fir.select_type %{{.*}} : !fir.class>> [#fir.class_is>, ^bb2, unit, ^bb1] +! CHECK: ^bb1: +! CHECK: ^bb2: +! CHECK: ^bb3: +! CHECK: fir.select_type %{{.*}} : !fir.class>> [#fir.type_is>, ^bb4, #fir.class_is>, ^bb6, unit, ^bb5] +! CHECK: ^bb4: +! CHECK: ^bb5: +! CHECK: ^bb6: +! CHECK: ^bb7: + end module program test_select_type From e315b872b291ad268bf0ae6e276a07de6f29992b Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 19 Jan 2023 17:33:08 +0100 Subject: [PATCH 17/61] [flang] Support polymorphic source in RESHAPE intrinsic Result must carry the polymorphic type information from the source. Reviewed By: jeanPerier, PeteSteinfeld Differential Revision: https://reviews.llvm.org/D142095 --- flang/lib/Lower/IntrinsicCall.cpp | 5 +++-- flang/test/Lower/polymorphic-temp.f90 | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp index daf3aa6c992bf0..7d3d4d113bd35e 100644 --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -4222,8 +4222,9 @@ IntrinsicLibrary::genReshape(mlir::Type resultType, // Create mutable fir.box to be passed to the runtime for the result. mlir::Type type = builder.getVarLenSeqTy(resultType, resultRank); - fir::MutableBoxValue resultMutableBox = - fir::factory::createTempMutableBox(builder, loc, type); + fir::MutableBoxValue resultMutableBox = fir::factory::createTempMutableBox( + builder, loc, type, {}, + fir::isPolymorphicType(source.getType()) ? source : mlir::Value{}); mlir::Value resultIrBox = fir::factory::getMutableIRBox(builder, loc, resultMutableBox); diff --git a/flang/test/Lower/polymorphic-temp.f90 b/flang/test/Lower/polymorphic-temp.f90 index 6cd715581a7d2d..02245f2d1bf8dd 100644 --- a/flang/test/Lower/polymorphic-temp.f90 +++ b/flang/test/Lower/polymorphic-temp.f90 @@ -47,4 +47,24 @@ subroutine test_temp_from_intrinsic_spread() ! CHECK: fir.call @_QMpoly_tmpPpass_unlimited_poly_1d(%[[REBOX]]) {{.*}} : (!fir.class>) -> () ! CHECK: fir.freemem %[[RES_ADDR]] : !fir.heap> + subroutine test_temp_from_intrinsic_reshape(i) + class(*), allocatable :: a(:,:) + class(*), intent(in) :: i(20,20) + allocate(a(10,10), source=reshape(i,[10,10])) + end subroutine + +! CHECK-LABEL: func.func @_QMpoly_tmpPtest_temp_from_intrinsic_reshape( +! CHECK-SAME: %[[I:.*]]: !fir.class> {fir.bindc_name = "i"}) { +! CHECK: %[[TMP_RES:.*]] = fir.alloca !fir.class>> +! CHECK: %[[A:.*]] = fir.alloca !fir.class>> {bindc_name = "a", uniq_name = "_QMpoly_tmpFtest_temp_from_intrinsic_reshapeEa"} +! CHECK: %[[EMBOX_WITH_SOURCE:.*]] = fir.embox %{{.*}}(%{{.*}}) source_box %[[I]] : (!fir.heap>, !fir.shape<2>, !fir.class>) -> !fir.class>> +! CHECK: fir.store %[[EMBOX_WITH_SOURCE]] to %[[TMP_RES]] : !fir.ref>>> +! CHECK: %[[RES_BOX_NONE:.*]] = fir.convert %[[TMP_RES]] : (!fir.ref>>>) -> !fir.ref> +! CHECK: %[[I_BOX_NONE:.*]] = fir.convert %[[I]] : (!fir.class>) -> !fir.box +! CHECK: %{{.*}} = fir.call @_FortranAReshape(%[[RES_BOX_NONE]], %[[I_BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (!fir.ref>, !fir.box, !fir.box, !fir.box, !fir.box, !fir.ref, i32) -> none +! CHECK: %[[LOAD_RES:.*]] = fir.load %[[TMP_RES]] : !fir.ref>>> +! CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A]] : (!fir.ref>>>) -> !fir.ref> +! CHECK: %[[RES_BOX_NONE:.*]] = fir.convert %[[LOAD_RES]] : (!fir.class>>) -> !fir.box +! CHECK: %{{.*}} = fir.call @_FortranAAllocatableApplyMold(%[[A_BOX_NONE]], %[[RES_BOX_NONE]]) {{.*}} : (!fir.ref>, !fir.box) -> none + end module From fd26eee486ed81a2477e28d069fb68072fd7b2a2 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 19 Jan 2023 17:34:01 +0100 Subject: [PATCH 18/61] [flang] Allow pointer association between derived-type pointer and unlimited polymorphic target Pointer association to unlimited polymorphic target is allowed for unlimited polymorphic pointer and non-extensible derived-type. This is checked by the semantic and this patch allows it in the fir.rebox operation. Reviewed By: jeanPerier, PeteSteinfeld Differential Revision: https://reviews.llvm.org/D142104 --- flang/lib/Optimizer/Dialect/FIROps.cpp | 1 + flang/test/Lower/polymorphic.f90 | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index b093f693b08f38..dcedff6aae1cb2 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2322,6 +2322,7 @@ mlir::LogicalResult fir::ReboxOp::verify() { // character type. const bool typeCanMismatch = inputEleTy.isa() || outEleTy.isa() || + (inputEleTy.isa() && outEleTy.isa()) || (getSlice() && inputEleTy.isa()) || areCompatibleCharacterTypes(inputEleTy, outEleTy); if (!typeCanMismatch) diff --git a/flang/test/Lower/polymorphic.f90 b/flang/test/Lower/polymorphic.f90 index 8caa6ce9603f67..da843a2ca59ac5 100644 --- a/flang/test/Lower/polymorphic.f90 +++ b/flang/test/Lower/polymorphic.f90 @@ -44,6 +44,11 @@ module polymorphic_test type(p1) :: inner end type + type non_extensible + sequence + integer :: d + end type + contains elemental subroutine assign_p1_int(lhs, rhs) @@ -749,6 +754,19 @@ subroutine test_polymorphic_intentout(a) ! CHECK: %[[BOX_NONE:.*]] = fir.convert %[[ARG0]] : (!fir.class>) -> !fir.box ! CHECK: %{{.*}} = fir.call @_FortranAInitialize(%[[BOX_NONE]], %{{.*}}, %{{.*}}) {{.*}} : (!fir.box, !fir.ref, i32) -> none + subroutine rebox_up_to_record_type(p) + class(*), allocatable, target :: p(:,:) + type(non_extensible), pointer :: t(:,:) + t => p + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_testPrebox_up_to_record_type( +! CHECK-SAME: %[[P:.*]]: !fir.ref>>> {fir.bindc_name = "p", fir.target}) { +! CHECK: %[[T:.*]] = fir.alloca !fir.box>>> {bindc_name = "t", uniq_name = "_QMpolymorphic_testFrebox_up_to_record_typeEt"} +! CHECK: %[[LOAD_P:.*]] = fir.load %[[P]] : !fir.ref>>> +! CHECK: %[[REBOX:.*]] = fir.rebox %[[LOAD_P]](%{{.*}}) : (!fir.class>>, !fir.shift<2>) -> !fir.box>>> +! CHECK: fir.store %[[REBOX]] to %[[T]] : !fir.ref>>>> + end module program test From a0f8bdbb18a82ec150515d24f6eefb3519d4459a Mon Sep 17 00:00:00 2001 From: Paul Robinson Date: Wed, 4 Jan 2023 11:35:59 -0800 Subject: [PATCH 19/61] [lit] Stop supporting triple substrings in UNSUPPORTED and XFAIL AFAICT all in-tree lit tests have been converted to use `target=...` and so there is no longer any need for triples being special. Some project config files still define their own features based on the triple, but those are normal feature words (although now are redundant with target= checks). Downstream tests that use triple substrings will need to convert. For example: UNSUPPORTED: -aix XFAIL: arm becomes UNSUPPORTED: target={{.*}}-aix{{.*}} XFAIL: target=arm{{.*}} You can do git log --grep "special handling for triples" to find many examples of updates to the upstream tests. https://discourse.llvm.org/t/rfc-lits-requires-and-triples/66041 Differential Revision: https://reviews.llvm.org/D141007 --- llvm/docs/ReleaseNotes.rst | 7 +++++++ llvm/utils/lit/lit/BooleanExpression.py | 24 ++++-------------------- llvm/utils/lit/lit/Test.py | 24 ++++++++++-------------- 3 files changed, 21 insertions(+), 34 deletions(-) diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst index 6dc2dad13ad93d..783e8f5617fe07 100644 --- a/llvm/docs/ReleaseNotes.rst +++ b/llvm/docs/ReleaseNotes.rst @@ -309,6 +309,13 @@ Changes to Sanitizers Other Changes ------------- +* lit no longer supports using substrings of the default target triple as + feature names in ``UNSUPPORTED:`` and ``XFAIL:`` directives. These have been + replaced by the ``target=`` feature, and tests can use regex + matching to achieve the same effect. For example, ``UNSUPPORTED: arm`` + would now be ``UNSUPPORTED: target=arm{{.*}}`` and ``XFAIL: windows`` + would now be ``XFAIL: target={{.*}}-windows{{.*}}``. + External Open Source Projects Using LLVM 15 =========================================== diff --git a/llvm/utils/lit/lit/BooleanExpression.py b/llvm/utils/lit/lit/BooleanExpression.py index ff5352778e99cc..ba8453d60e98a3 100644 --- a/llvm/utils/lit/lit/BooleanExpression.py +++ b/llvm/utils/lit/lit/BooleanExpression.py @@ -22,24 +22,22 @@ class BooleanExpression: # # Variables in `variables` are true. # Regexes that match any variable in `variables` are true. - # Substrings of `triple` are true. # 'true' is true. # All other identifiers are false. @staticmethod - def evaluate(string, variables, triple=""): + def evaluate(string, variables): try: - parser = BooleanExpression(string, set(variables), triple) + parser = BooleanExpression(string, set(variables)) return parser.parseAll() except ValueError as e: raise ValueError(str(e) + ('\nin expression: %r' % string)) ##### - def __init__(self, string, variables, triple=""): + def __init__(self, string, variables): self.tokens = BooleanExpression.tokenize(string) self.variables = variables self.variables.add('true') - self.triple = triple self.value = None self.token = None @@ -101,7 +99,7 @@ def parseMATCH(self): else: regex += re.escape(part) regex = re.compile(regex) - self.value = self.token in self.triple or any(regex.fullmatch(var) for var in self.variables) + self.value = any(regex.fullmatch(var) for var in self.variables) self.token = next(self.tokens) def parseNOT(self): @@ -174,20 +172,6 @@ def test_variables(self): self.assertFalse(BooleanExpression.evaluate('tru', variables)) self.assertFalse(BooleanExpression.evaluate('{{its-true.+}}', variables)) - def test_triple(self): - triple = 'arch-vendor-os' - self.assertTrue(BooleanExpression.evaluate('arch-', {}, triple)) - self.assertTrue(BooleanExpression.evaluate('ar', {}, triple)) - self.assertTrue(BooleanExpression.evaluate('ch-vend', {}, triple)) - self.assertTrue(BooleanExpression.evaluate('-vendor-', {}, triple)) - self.assertTrue(BooleanExpression.evaluate('-os', {}, triple)) - self.assertFalse(BooleanExpression.evaluate('arch-os', {}, triple)) - - # When matching against the triple, a regex is treated as an identifier and checked - # for a literal match. This preserves existing behavior before regexes were introduced. - self.assertFalse(BooleanExpression.evaluate('arch-{{vendor}}-os', {}, triple)) - self.assertTrue(BooleanExpression.evaluate('arch-{{vendor}}-os', {}, 'arch-{{vendor}}-os')) - def test_matching(self): expr1 = 'linux && (target={{aarch64-.+}} || target={{x86_64-.+}})' self.assertTrue(BooleanExpression.evaluate(expr1, {'linux', 'target=x86_64-unknown-linux-gnu'})) diff --git a/llvm/utils/lit/lit/Test.py b/llvm/utils/lit/lit/Test.py index dc1c66e896c54f..6c72359440b9cf 100644 --- a/llvm/utils/lit/lit/Test.py +++ b/llvm/utils/lit/lit/Test.py @@ -227,9 +227,9 @@ def __init__(self, suite, path_in_suite, config, file_path = None, gtest_json_fi self.gtest_json_file = gtest_json_file # A list of conditions under which this test is expected to fail. - # Each condition is a boolean expression of features and target - # triple parts. These can optionally be provided by test format - # handlers, and will be honored when the test result is supplied. + # Each condition is a boolean expression of features, or '*'. + # These can optionally be provided by test format handlers, + # and will be honored when the test result is supplied. self.xfails = [] # If true, ignore all items in self.xfails. @@ -238,12 +238,11 @@ def __init__(self, suite, path_in_suite, config, file_path = None, gtest_json_fi # A list of conditions that must be satisfied before running the test. # Each condition is a boolean expression of features. All of them # must be True for the test to run. - # FIXME should target triple parts count here too? self.requires = [] # A list of conditions that prevent execution of the test. - # Each condition is a boolean expression of features and target - # triple parts. All of them must be False for the test to run. + # Each condition is a boolean expression of features. All of them + # must be False for the test to run. self.unsupported = [] # An optional number of retries allowed before the test finally succeeds. @@ -317,18 +316,16 @@ def isExpectedToFail(self): return False features = self.config.available_features - triple = getattr(self.suite.config, 'target_triple', "") - # Check if any of the xfails match an available feature or the target. + # Check if any of the xfails match an available feature. for item in self.xfails: # If this is the wildcard, it always fails. if item == '*': return True - # If this is a True expression of features and target triple parts, - # it fails. + # If this is a True expression of features, it fails. try: - if BooleanExpression.evaluate(item, features, triple): + if BooleanExpression.evaluate(item, features): return True except ValueError as e: raise ValueError('Error in XFAIL list:\n%s' % str(e)) @@ -385,16 +382,15 @@ def getUnsupportedFeatures(self): getUnsupportedFeatures() -> list of strings Returns a list of features from UNSUPPORTED that are present - in the test configuration's features or target triple. + in the test configuration's features. Throws ValueError if an UNSUPPORTED line has a syntax error. """ features = self.config.available_features - triple = getattr(self.suite.config, 'target_triple', "") try: return [item for item in self.unsupported - if BooleanExpression.evaluate(item, features, triple)] + if BooleanExpression.evaluate(item, features)] except ValueError as e: raise ValueError('Error in UNSUPPORTED list:\n%s' % str(e)) From ac0938709c6d3d1bd7a0620ae1b79792b55a5da9 Mon Sep 17 00:00:00 2001 From: Jordan Rupprecht Date: Thu, 19 Jan 2023 08:39:18 -0800 Subject: [PATCH 20/61] [NFC][bazel] Add _tensor_ops_ext.py to SparseTensorOpsPyFiles This corresponds to the cmake change in 81ca5aa452400843235e058bc9c83fe71eccd593 --- utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel index 22f7dfbbbbb9b3..e4012808440318 100644 --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -764,6 +764,7 @@ gentbl_filegroup( filegroup( name = "SparseTensorOpsPyFiles", srcs = [ + "mlir/dialects/_tensor_ops_ext.py", "mlir/dialects/sparse_tensor.py", ":SparseTensorOpsPyGen", ], From afca08a567e3803aeb0c32385e62c84bec58ef46 Mon Sep 17 00:00:00 2001 From: Jordan Rupprecht Date: Thu, 19 Jan 2023 08:43:24 -0800 Subject: [PATCH 21/61] [NFC][bazel] Move _tensor_ops_ext.py to the correct filegroup --- utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel index e4012808440318..5c41fc5fc7ed6d 100644 --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -764,7 +764,6 @@ gentbl_filegroup( filegroup( name = "SparseTensorOpsPyFiles", srcs = [ - "mlir/dialects/_tensor_ops_ext.py", "mlir/dialects/sparse_tensor.py", ":SparseTensorOpsPyGen", ], @@ -806,6 +805,7 @@ gentbl_filegroup( filegroup( name = "TensorOpsPyFiles", srcs = [ + "mlir/dialects/_tensor_ops_ext.py", "mlir/dialects/tensor.py", ":TensorOpsPyGen", ], From a9c5a98f810d380919d25c3073c61725d73d3c41 Mon Sep 17 00:00:00 2001 From: Jonas Paulsson Date: Wed, 18 Jan 2023 13:29:37 -0600 Subject: [PATCH 22/61] [SystemZ] Improvement in tryRxSBG(). Only allow replacements of nodes that have a single user. This is better as simple instructions (e.g. XGRK) are one cycle faster, and it helps in cases where both inputs share a common node. Review: Ulrich Weigand --- .../Target/SystemZ/SystemZISelDAGToDAG.cpp | 11 +++++---- llvm/test/CodeGen/SystemZ/flt-rounds.ll | 23 ++++++++----------- llvm/test/CodeGen/SystemZ/rnsbg-01.ll | 5 ++-- .../SystemZ/store_nonbytesized_vecs.ll | 23 ++++++++++--------- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/llvm/lib/Target/SystemZ/SystemZISelDAGToDAG.cpp b/llvm/lib/Target/SystemZ/SystemZISelDAGToDAG.cpp index 59d4639c01e760..250edf64cb6c56 100644 --- a/llvm/lib/Target/SystemZ/SystemZISelDAGToDAG.cpp +++ b/llvm/lib/Target/SystemZ/SystemZISelDAGToDAG.cpp @@ -1071,10 +1071,13 @@ bool SystemZDAGToDAGISel::tryRxSBG(SDNode *N, unsigned Opcode) { }; unsigned Count[] = { 0, 0 }; for (unsigned I = 0; I < 2; ++I) - while (expandRxSBG(RxSBG[I])) - // The widening or narrowing is expected to be free. - // Counting widening or narrowing as a saved operation will result in - // preferring an R*SBG over a simple shift/logical instruction. + while (RxSBG[I].Input->hasOneUse() && expandRxSBG(RxSBG[I])) + // In cases of multiple users it seems better to keep the simple + // instruction as they are one cycle faster, and it also helps in cases + // where both inputs share a common node. + // The widening or narrowing is expected to be free. Counting widening + // or narrowing as a saved operation will result in preferring an R*SBG + // over a simple shift/logical instruction. if (RxSBG[I].Input.getOpcode() != ISD::ANY_EXTEND && RxSBG[I].Input.getOpcode() != ISD::TRUNCATE) Count[I] += 1; diff --git a/llvm/test/CodeGen/SystemZ/flt-rounds.ll b/llvm/test/CodeGen/SystemZ/flt-rounds.ll index 550001049b781d..37a7806345e111 100644 --- a/llvm/test/CodeGen/SystemZ/flt-rounds.ll +++ b/llvm/test/CodeGen/SystemZ/flt-rounds.ll @@ -8,11 +8,10 @@ define dso_local signext i32 @test_flt_rounds() nounwind { ; CHECK-LABEL: test_flt_rounds: ; CHECK: # %bb.0: ; CHECK-NEXT: efpc %r0 -; CHECK-NEXT: lr %r1, %r0 -; CHECK-NEXT: nilf %r1, 3 -; CHECK-NEXT: rxsbg %r1, %r0, 63, 63, 63 -; CHECK-NEXT: xilf %r1, 1 -; CHECK-NEXT: llgfr %r2, %r1 +; CHECK-NEXT: nilf %r0, 3 +; CHECK-NEXT: rxsbg %r0, %r0, 33, 63, 63 +; CHECK-NEXT: xilf %r0, 1 +; CHECK-NEXT: llgfr %r2, %r0 ; CHECK-NEXT: br %r14 %1 = call i32 @llvm.get.rounding() ret i32 %1 @@ -25,18 +24,16 @@ define dso_local signext i32 @test_order(i32 noundef signext %0) nounwind { ; CHECK: # %bb.0: ; CHECK-NEXT: stmg %r13, %r15, 104(%r15) ; CHECK-NEXT: aghi %r15, -160 -; CHECK-NEXT: efpc %r0 -; CHECK-NEXT: lr %r13, %r0 +; CHECK-NEXT: efpc %r13 ; CHECK-NEXT: nilf %r13, 3 -; CHECK-NEXT: rxsbg %r13, %r0, 63, 63, 63 +; CHECK-NEXT: rxsbg %r13, %r13, 33, 63, 63 ; CHECK-NEXT: xilf %r13, 1 ; CHECK-NEXT: brasl %r14, fesetround@PLT ; CHECK-NEXT: efpc %r0 -; CHECK-NEXT: lr %r1, %r0 -; CHECK-NEXT: nilf %r1, 3 -; CHECK-NEXT: rxsbg %r1, %r0, 63, 63, 63 -; CHECK-NEXT: xilf %r1, 1 -; CHECK-NEXT: crje %r13, %r1, .LBB1_2 +; CHECK-NEXT: nilf %r0, 3 +; CHECK-NEXT: rxsbg %r0, %r0, 33, 63, 63 +; CHECK-NEXT: xilf %r0, 1 +; CHECK-NEXT: crje %r13, %r0, .LBB1_2 ; CHECK-NEXT: # %bb.1: ; CHECK-NEXT: lhi %r0, 1 ; CHECK-NEXT: strl %r0, changed diff --git a/llvm/test/CodeGen/SystemZ/rnsbg-01.ll b/llvm/test/CodeGen/SystemZ/rnsbg-01.ll index 428a05c1111501..0595364ffc071b 100644 --- a/llvm/test/CodeGen/SystemZ/rnsbg-01.ll +++ b/llvm/test/CodeGen/SystemZ/rnsbg-01.ll @@ -219,7 +219,8 @@ define i64 @f19(i64 %a, i64 %b, ptr %dest) { ; Test a combination involving an ASHR in which the sign bits don't matter. define i32 @f20(i32 %a, i32 %b, ptr %dest) { ; CHECK-LABEL: f20: -; CHECK: rnsbg %r2, %r3, 48, 62, 48 +; CHECK-NOT: lr +; CHECK: rnsbg %r2, %r3, 48, 62, 1 ; CHECK: br %r14 %ashrb = ashr i32 %b, 17 store i32 %ashrb, ptr %dest @@ -232,7 +233,7 @@ define i32 @f20(i32 %a, i32 %b, ptr %dest) { ; ...and again with i64. define i64 @f21(i64 %a, i64 %b, ptr %dest) { ; CHECK-LABEL: f21: -; CHECK: rnsbg %r2, %r3, 48, 62, 16 +; CHECK: rnsbg %r2, %r0, 48, 62, 1 ; CHECK: br %r14 %ashrb = ashr i64 %b, 49 store i64 %ashrb, ptr %dest diff --git a/llvm/test/CodeGen/SystemZ/store_nonbytesized_vecs.ll b/llvm/test/CodeGen/SystemZ/store_nonbytesized_vecs.ll index 26878570934147..1507f2c3581b2a 100644 --- a/llvm/test/CodeGen/SystemZ/store_nonbytesized_vecs.ll +++ b/llvm/test/CodeGen/SystemZ/store_nonbytesized_vecs.ll @@ -77,36 +77,37 @@ define void @fun2(<8 x i32> %src, ptr %p) ; CHECK-NEXT: vlgvf %r5, %v24, 0 ; CHECK-NEXT: vlgvf %r3, %v24, 1 ; CHECK-NEXT: srlk %r0, %r1, 8 -; CHECK-NEXT: sllg %r5, %r5, 33 ; CHECK-NEXT: sth %r0, 28(%r2) -; CHECK-NEXT: rosbg %r5, %r3, 31, 55, 2 ; CHECK-NEXT: vlgvf %r0, %v24, 2 +; CHECK-NEXT: sllg %r5, %r5, 33 ; CHECK-NEXT: sllg %r4, %r3, 58 +; CHECK-NEXT: risbgn %r0, %r0, 6, 164, 27 +; CHECK-NEXT: rosbg %r5, %r3, 31, 55, 2 ; CHECK-NEXT: vlgvf %r3, %v26, 2 ; CHECK-NEXT: stc %r1, 30(%r2) -; CHECK-NEXT: rosbg %r4, %r0, 6, 36, 27 +; CHECK-NEXT: ogr %r4, %r0 ; CHECK-NEXT: risbgn %r1, %r1, 33, 167, 0 -; CHECK-NEXT: rosbg %r1, %r3, 2, 32, 31 -; CHECK-NEXT: srlg %r1, %r1, 24 ; CHECK-NEXT: rosbg %r5, %r4, 56, 63, 8 +; CHECK-NEXT: risbgn %r3, %r3, 2, 160, 31 +; CHECK-NEXT: ogr %r1, %r3 ; CHECK-NEXT: vlgvf %r4, %v24, 3 +; CHECK-NEXT: srlg %r1, %r1, 24 +; CHECK-NEXT: rosbg %r0, %r4, 37, 63, 60 ; CHECK-NEXT: st %r1, 24(%r2) ; CHECK-NEXT: vlgvf %r1, %v26, 0 -; CHECK-NEXT: risbgn %r0, %r0, 6, 164, 27 -; CHECK-NEXT: rosbg %r0, %r4, 37, 63, 60 ; CHECK-NEXT: stg %r5, 0(%r2) +; CHECK-NEXT: risbgn %r1, %r1, 4, 162, 29 ; CHECK-NEXT: sllg %r5, %r4, 60 +; CHECK-NEXT: ogr %r5, %r1 ; CHECK-NEXT: sllg %r0, %r0, 8 -; CHECK-NEXT: rosbg %r5, %r1, 4, 34, 29 -; CHECK-NEXT: risbgn %r1, %r1, 4, 162, 29 ; CHECK-NEXT: rosbg %r0, %r5, 56, 63, 8 ; CHECK-NEXT: stg %r0, 8(%r2) ; CHECK-NEXT: vlgvf %r0, %v26, 1 ; CHECK-NEXT: sllg %r4, %r0, 62 +; CHECK-NEXT: ogr %r3, %r4 ; CHECK-NEXT: rosbg %r1, %r0, 35, 63, 62 ; CHECK-NEXT: sllg %r0, %r1, 8 -; CHECK-NEXT: rosbg %r4, %r3, 2, 32, 31 -; CHECK-NEXT: rosbg %r0, %r4, 56, 63, 8 +; CHECK-NEXT: rosbg %r0, %r3, 56, 63, 8 ; CHECK-NEXT: stg %r0, 16(%r2) ; CHECK-NEXT: br %r14 { From 04d4f4b3d4e4fd608a3bc3fe387006435f04b61d Mon Sep 17 00:00:00 2001 From: Mark de Wever Date: Thu, 5 May 2022 18:57:32 +0200 Subject: [PATCH 23/61] [libc++][format] Adds container adaptor formatters. Implements parts of - P2286R8 Formatting Ranges Depends on D140653 Reviewed By: ldionne, #libc Differential Revision: https://reviews.llvm.org/D141290 --- libcxx/docs/Status/FormatPaper.csv | 2 +- libcxx/include/CMakeLists.txt | 1 + libcxx/include/__format/container_adaptor.h | 70 ++ .../__format/range_default_formatter.h | 2 - libcxx/include/format | 1 + libcxx/include/module.modulemap.in | 1 + libcxx/include/queue | 4 + libcxx/include/stack | 2 + libcxx/test/libcxx/private_headers.verify.cpp | 1 + .../test/libcxx/transitive_includes/cxx03.csv | 2 + .../test/libcxx/transitive_includes/cxx11.csv | 2 + .../test/libcxx/transitive_includes/cxx14.csv | 2 + .../test/libcxx/transitive_includes/cxx17.csv | 2 + .../test/libcxx/transitive_includes/cxx20.csv | 2 + .../test/libcxx/transitive_includes/cxx2b.csv | 2 + .../format.functions.format.pass.cpp | 60 ++ .../format.functions.tests.h | 947 ++++++++++++++++++ .../format.functions.vformat.pass.cpp | 73 ++ .../container.adaptors.format/format.pass.cpp | 78 ++ .../container.adaptors.format/parse.pass.cpp | 83 ++ .../concept.formattable.compile.pass.cpp | 9 +- 21 files changed, 1337 insertions(+), 9 deletions(-) create mode 100644 libcxx/include/__format/container_adaptor.h create mode 100644 libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.format.pass.cpp create mode 100644 libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.tests.h create mode 100644 libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.vformat.pass.cpp create mode 100644 libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.pass.cpp create mode 100644 libcxx/test/std/containers/container.adaptors/container.adaptors.format/parse.pass.cpp diff --git a/libcxx/docs/Status/FormatPaper.csv b/libcxx/docs/Status/FormatPaper.csv index 9fb7cd09dc4995..10463c44ef8fed 100644 --- a/libcxx/docs/Status/FormatPaper.csv +++ b/libcxx/docs/Status/FormatPaper.csv @@ -32,6 +32,6 @@ Section,Description,Dependencies,Assignee,Status,First released version `[format.string.std] `_,"std-format-spec ``type`` debug",,Mark de Wever,|Complete|,Clang 16 `[format.range] `_,"Formatting for ranges: sequences",,Mark de Wever,|Complete|,Clang 16 `[format.range] `_,"Formatting for ranges: associative",,Mark de Wever,, -`[format.range] `_,"Formatting for ranges: container adaptors",,Mark de Wever,, +`[format.range] `_,"Formatting for ranges: container adaptors",,Mark de Wever,|Complete|,Clang 16 `[format.range] `_,"Formatting for ranges: ``pair`` and ``tuple``",,Mark de Wever,|Complete|,Clang 16 `[format.range] `_,"Formatting for ranges: ``vector``",,Mark de Wever,, diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt index 81d488ffc4b20e..2d202444894f3d 100644 --- a/libcxx/include/CMakeLists.txt +++ b/libcxx/include/CMakeLists.txt @@ -302,6 +302,7 @@ set(files __filesystem/u8path.h __format/buffer.h __format/concepts.h + __format/container_adaptor.h __format/enable_insertable.h __format/escaped_output_table.h __format/extended_grapheme_cluster_table.h diff --git a/libcxx/include/__format/container_adaptor.h b/libcxx/include/__format/container_adaptor.h new file mode 100644 index 00000000000000..62b698186ea135 --- /dev/null +++ b/libcxx/include/__format/container_adaptor.h @@ -0,0 +1,70 @@ +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef _LIBCPP___FORMAT_CONTAINER_ADAPTOR_H +#define _LIBCPP___FORMAT_CONTAINER_ADAPTOR_H + +#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) +# pragma GCC system_header +#endif + +#include <__availability> +#include <__config> +#include <__format/concepts.h> +#include <__format/formatter.h> +#include <__format/range_default_formatter.h> +#include +#include + +_LIBCPP_BEGIN_NAMESPACE_STD + +#if _LIBCPP_STD_VER > 20 + +// [container.adaptors.format] only specifies the library should provide the +// formatter specializations, not which header should provide them. +// Since includes a lot of headers, add these headers here instead of +// adding more dependencies like, locale, optinal, string, tuple, etc. to the +// adaptor headers. To use the format functions users already include . + +template +struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT __formatter_container_adaptor { +private: + using __maybe_const_adaptor = __fmt_maybe_const<_Adaptor, _CharT>; + formatter __underlying_; + +public: + template + _LIBCPP_HIDE_FROM_ABI constexpr typename _ParseContext::iterator parse(_ParseContext& __ctx) { + return __underlying_.parse(__ctx); + } + + template + _LIBCPP_HIDE_FROM_ABI typename _FormatContext::iterator + format(__maybe_const_adaptor& __adaptor, _FormatContext& __ctx) const { + return __underlying_.format(__adaptor.__get_container(), __ctx); + } +}; + +template _Container> +struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT formatter, _CharT> + : public __formatter_container_adaptor, _CharT> {}; + +template +struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT formatter, _CharT> + : public __formatter_container_adaptor, _CharT> {}; + +template _Container> +struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT formatter, _CharT> + : public __formatter_container_adaptor, _CharT> {}; + +#endif //_LIBCPP_STD_VER > 20 + +_LIBCPP_END_NAMESPACE_STD + +#endif // _LIBCPP___FORMAT_CONTAINER_ADAPTOR_H diff --git a/libcxx/include/__format/range_default_formatter.h b/libcxx/include/__format/range_default_formatter.h index ee1dc52e37526e..652c021277664a 100644 --- a/libcxx/include/__format/range_default_formatter.h +++ b/libcxx/include/__format/range_default_formatter.h @@ -146,8 +146,6 @@ struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT __range_default_formatte __range_default_formatter() = delete; // TODO FMT Implement }; -// Dispatcher to select the specialization based on the type of the range. - template requires(format_kind<_Rp> != range_format::disabled && formattable, _CharT>) struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT formatter<_Rp, _CharT> diff --git a/libcxx/include/format b/libcxx/include/format index aaac95b39c4c9b..2c583c1065ba65 100644 --- a/libcxx/include/format +++ b/libcxx/include/format @@ -178,6 +178,7 @@ namespace std { #include <__config> #include <__format/buffer.h> #include <__format/concepts.h> +#include <__format/container_adaptor.h> #include <__format/enable_insertable.h> #include <__format/format_arg.h> #include <__format/format_arg_store.h> diff --git a/libcxx/include/module.modulemap.in b/libcxx/include/module.modulemap.in index 81f503e29c6a0c..aa4da4dadc42ff 100644 --- a/libcxx/include/module.modulemap.in +++ b/libcxx/include/module.modulemap.in @@ -839,6 +839,7 @@ module std [system] { module __format { module buffer { private header "__format/buffer.h" } module concepts { private header "__format/concepts.h" } + module container_adaptor { private header "__format/container_adaptor.h" } module enable_insertable { private header "__format/enable_insertable.h" } module escaped_output_table { private header "__format/escaped_output_table.h" } module extended_grapheme_cluster_table { private header "__format/extended_grapheme_cluster_table.h" } diff --git a/libcxx/include/queue b/libcxx/include/queue index c58da5ec6ee3c8..6c1b892efadc3d 100644 --- a/libcxx/include/queue +++ b/libcxx/include/queue @@ -382,6 +382,8 @@ public: swap(c, __q.c); } + _LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI const _Container& __get_container() const { return c; } + template friend _LIBCPP_INLINE_VISIBILITY @@ -633,6 +635,8 @@ public: void swap(priority_queue& __q) _NOEXCEPT_(__is_nothrow_swappable::value && __is_nothrow_swappable::value); + + _LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI const _Container& __get_container() const { return c; } }; #if _LIBCPP_STD_VER >= 17 diff --git a/libcxx/include/stack b/libcxx/include/stack index 2abbcd025c4aa9..d653d1bc7e49c2 100644 --- a/libcxx/include/stack +++ b/libcxx/include/stack @@ -255,6 +255,8 @@ public: swap(c, __s.c); } + _LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI const _Container& __get_container() const { return c; } + template friend bool diff --git a/libcxx/test/libcxx/private_headers.verify.cpp b/libcxx/test/libcxx/private_headers.verify.cpp index 27b1c90ff5bfa5..e61eeddd2b4661 100644 --- a/libcxx/test/libcxx/private_headers.verify.cpp +++ b/libcxx/test/libcxx/private_headers.verify.cpp @@ -334,6 +334,7 @@ END-SCRIPT #include <__filesystem/u8path.h> // expected-error@*:* {{use of private header from outside its module: '__filesystem/u8path.h'}} #include <__format/buffer.h> // expected-error@*:* {{use of private header from outside its module: '__format/buffer.h'}} #include <__format/concepts.h> // expected-error@*:* {{use of private header from outside its module: '__format/concepts.h'}} +#include <__format/container_adaptor.h> // expected-error@*:* {{use of private header from outside its module: '__format/container_adaptor.h'}} #include <__format/enable_insertable.h> // expected-error@*:* {{use of private header from outside its module: '__format/enable_insertable.h'}} #include <__format/escaped_output_table.h> // expected-error@*:* {{use of private header from outside its module: '__format/escaped_output_table.h'}} #include <__format/extended_grapheme_cluster_table.h> // expected-error@*:* {{use of private header from outside its module: '__format/extended_grapheme_cluster_table.h'}} diff --git a/libcxx/test/libcxx/transitive_includes/cxx03.csv b/libcxx/test/libcxx/transitive_includes/cxx03.csv index c245d11b5a4e06..58fe49cf2844e6 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx03.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx03.csv @@ -354,6 +354,8 @@ format initializer_list format limits format locale format optional +format queue +format stack format stdexcept format string format string_view diff --git a/libcxx/test/libcxx/transitive_includes/cxx11.csv b/libcxx/test/libcxx/transitive_includes/cxx11.csv index 74bf87284072bc..2f9ef7da30ca67 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx11.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx11.csv @@ -354,6 +354,8 @@ format initializer_list format limits format locale format optional +format queue +format stack format stdexcept format string format string_view diff --git a/libcxx/test/libcxx/transitive_includes/cxx14.csv b/libcxx/test/libcxx/transitive_includes/cxx14.csv index 7ee16527026eba..0d10e239d8f2d0 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx14.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx14.csv @@ -356,6 +356,8 @@ format initializer_list format limits format locale format optional +format queue +format stack format stdexcept format string format string_view diff --git a/libcxx/test/libcxx/transitive_includes/cxx17.csv b/libcxx/test/libcxx/transitive_includes/cxx17.csv index 7ee16527026eba..0d10e239d8f2d0 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx17.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx17.csv @@ -356,6 +356,8 @@ format initializer_list format limits format locale format optional +format queue +format stack format stdexcept format string format string_view diff --git a/libcxx/test/libcxx/transitive_includes/cxx20.csv b/libcxx/test/libcxx/transitive_includes/cxx20.csv index f48e46896b7ea3..8f47b6120f7a80 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx20.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx20.csv @@ -365,6 +365,8 @@ format initializer_list format limits format locale format optional +format queue +format stack format stdexcept format string format string_view diff --git a/libcxx/test/libcxx/transitive_includes/cxx2b.csv b/libcxx/test/libcxx/transitive_includes/cxx2b.csv index 7401de433dbcac..186d0b4f5d1213 100644 --- a/libcxx/test/libcxx/transitive_includes/cxx2b.csv +++ b/libcxx/test/libcxx/transitive_includes/cxx2b.csv @@ -273,6 +273,8 @@ format initializer_list format limits format locale format optional +format queue +format stack format stdexcept format string format string_view diff --git a/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.format.pass.cpp b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.format.pass.cpp new file mode 100644 index 00000000000000..7806576baae5e3 --- /dev/null +++ b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.format.pass.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// [container.adaptors.format] +// For each of queue, priority_queue, and stack, the library provides the +// following formatter specialization where adaptor-type is the name of the +// template: +// +// template Container, class... U> +// struct formatter, charT> + +// template +// string format(format_string fmt, Args&&... args); +// template +// wstring format(wformat_string fmt, Args&&... args); + +#include +#include + +#include "format.functions.tests.h" +#include "test_format_string.h" +#include "test_macros.h" +#include "assert_macros.h" + +auto test = []( + std::basic_string_view expected, test_format_string fmt, Args&&... args) { + std::basic_string out = std::format(fmt, std::forward(args)...); + TEST_REQUIRE( + out == expected, + test_concat_message("\nFormat string ", fmt, "\nExpected output ", expected, "\nActual output ", out, '\n')); +}; + +auto test_exception = [](std::string_view, std::basic_string_view, Args&&...) { + // After P2216 most exceptions thrown by std::format become ill-formed. + // Therefore this tests does nothing. +}; + +int main(int, char**) { + format_tests(test, test_exception); + +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + format_tests(test, test_exception); +#endif + + return 0; +} diff --git a/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.tests.h b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.tests.h new file mode 100644 index 00000000000000..cae566d9aecf57 --- /dev/null +++ b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.tests.h @@ -0,0 +1,947 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_STD_CONTAINERS_CONTAINER_ADAPTORS_CONTAINER_ADAPTORS_FORMAT_FORMAT_FUNCTIONS_TESTS_H +#define TEST_STD_CONTAINERS_CONTAINER_ADAPTORS_CONTAINER_ADAPTORS_FORMAT_FORMAT_FUNCTIONS_TESTS_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "format.functions.common.h" +#include "make_string.h" +#include "platform_support.h" // locale name macros +#include "test_macros.h" + +// +// Char +// + +template +void test_char_default(TestFunction check, ExceptionTest check_exception, auto&& input) { + // Note when no range-underlying-spec is present the char is escaped, + check(SV("['H', 'e', 'l', 'l', 'o']"), SV("{}"), input); + + // when one is present there is no escaping, + check(SV("[H, e, l, l, o]"), SV("{::}"), input); + // unless forced by the type specifier. + check(SV("['H', 'e', 'l', 'l', 'o']"), SV("{::?}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("['H', 'e', 'l', 'l', 'o'] "), SV("{:30}"), input); + check(SV("['H', 'e', 'l', 'l', 'o']*****"), SV("{:*<30}"), input); + check(SV("__['H', 'e', 'l', 'l', 'o']___"), SV("{:_^30}"), input); + check(SV("#####['H', 'e', 'l', 'l', 'o']"), SV("{:#>30}"), input); + + check(SV("['H', 'e', 'l', 'l', 'o'] "), SV("{:{}}"), input, 30); + check(SV("['H', 'e', 'l', 'l', 'o']*****"), SV("{:*<{}}"), input, 30); + check(SV("__['H', 'e', 'l', 'l', 'o']___"), SV("{:_^{}}"), input, 30); + check(SV("#####['H', 'e', 'l', 'l', 'o']"), SV("{:#>{}}"), input, 30); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__'H', 'e', 'l', 'l', 'o'___"), SV("{:_^28n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[H , e , l , l , o ]"), SV("{::4}"), input); + check(SV("[H***, e***, l***, l***, o***]"), SV("{::*<4}"), input); + check(SV("[_H__, _e__, _l__, _l__, _o__]"), SV("{::_^4}"), input); + check(SV("[:::H, :::e, :::l, :::l, :::o]"), SV("{:::>4}"), input); + + check(SV("[H , e , l , l , o ]"), SV("{::{}}"), input, 4); + check(SV("[H***, e***, l***, l***, o***]"), SV("{::*<{}}"), input, 4); + check(SV("[_H__, _e__, _l__, _l__, _o__]"), SV("{::_^{}}"), input, 4); + check(SV("[:::H, :::e, :::l, :::l, :::o]"), SV("{:::>{}}"), input, 4); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("A sign field isn't allowed in this format-spec", SV("{::-}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{::+}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{:: }"), input); + + check(SV("[72, 101, 108, 108, 111]"), SV("{::-d}"), input); + check(SV("[+72, +101, +108, +108, +111]"), SV("{::+d}"), input); + check(SV("[ 72, 101, 108, 108, 111]"), SV("{:: d}"), input); + + // *** alternate form *** + check_exception("An alternate form field isn't allowed in this format-spec", SV("{::#}"), input); + + check(SV("[0x48, 0x65, 0x6c, 0x6c, 0x6f]"), SV("{::#x}"), input); + + // *** zero-padding *** + check_exception("A zero-padding field isn't allowed in this format-spec", SV("{::05}"), input); + + check(SV("[00110, 00145, 00154, 00154, 00157]"), SV("{::05o}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check(SV("[H, e, l, l, o]"), SV("{::L}"), input); + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("bBcdoxX?")) + check_exception("The format-spec type has a type not supported for a char argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[:H, :e, :l, :l, :o]^^^"), SV("{:^^25::>2}"), input); + check(SV("^^[:H, :e, :l, :l, :o]^^^"), SV("{:^^{}::>2}"), input, 25); + check(SV("^^[:H, :e, :l, :l, :o]^^^"), SV("{:^^{}::>{}}"), input, 25, 2); + + check_exception("Argument index out of bounds", SV("{:^^{}::>2}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 25); +} + +template +void test_char_string(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("Hello"), SV("{:s}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("Hello "), SV("{:8s}"), input); + check(SV("Hello***"), SV("{:*<8s}"), input); + check(SV("_Hello__"), SV("{:_^8s}"), input); + check(SV("###Hello"), SV("{:#>8s}"), input); + + check(SV("Hello "), SV("{:{}s}"), input, 8); + check(SV("Hello***"), SV("{:*<{}s}"), input, 8); + check(SV("_Hello__"), SV("{:_^{}s}"), input, 8); + check(SV("###Hello"), SV("{:#>{}s}"), input, 8); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:} fmt : fmt_invalid_nested_types("bBcdoxX?")) + check_exception("The format-spec type has a type not supported for a char argument", fmt, input); + + // ***** Both have a format-spec + check_exception("Type s and an underlying format specification can't be used together", SV("{:5s:5}"), input); +} + +template +void test_char_escaped_string(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV(R"("Hello")"), SV("{:?s}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV(R"("Hello" )"), SV("{:10?s}"), input); + check(SV(R"("Hello"***)"), SV("{:*<10?s}"), input); + check(SV(R"(_"Hello"__)"), SV("{:_^10?s}"), input); + check(SV(R"(###"Hello")"), SV("{:#>10?s}"), input); + + check(SV(R"("Hello" )"), SV("{:{}?s}"), input, 10); + check(SV(R"("Hello"***)"), SV("{:*<{}?s}"), input, 10); + check(SV(R"(_"Hello"__)"), SV("{:_^{}?s}"), input, 10); + check(SV(R"(###"Hello")"), SV("{:#>{}?s}"), input, 10); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:} +void test_char(TestFunction check, ExceptionTest check_exception) { + // These values are in numeric order when using ASCII, which is used by the priority_queue. + std::array input{CharT('H'), CharT('e'), CharT('l'), CharT('l'), CharT('o')}; + test_char_default(check, check_exception, std::queue{input.begin(), input.end()}); + test_char_default(check, check_exception, std::priority_queue{input.begin(), input.end(), std::greater{}}); + test_char_default(check, check_exception, std::stack{input.begin(), input.end()}); + + test_char_string(check, check_exception, std::queue{input.begin(), input.end()}); + test_char_string(check, check_exception, std::priority_queue{input.begin(), input.end(), std::greater{}}); + test_char_string(check, check_exception, std::stack{input.begin(), input.end()}); + + test_char_escaped_string(check, check_exception, std::queue{input.begin(), input.end()}); + test_char_escaped_string( + check, check_exception, std::priority_queue{input.begin(), input.end(), std::greater{}}); + test_char_escaped_string(check, check_exception, std::stack{input.begin(), input.end()}); +} + +// +// char -> wchar_t +// + +#ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS +template +void test_char_to_wchar(TestFunction check, ExceptionTest check_exception) { + std::array input{'H', 'e', 'l', 'l', 'o'}; + test_char_default(check, check_exception, std::queue{input.begin(), input.end()}); + test_char_default(check, check_exception, std::priority_queue{input.begin(), input.end(), std::greater{}}); + test_char_default(check, check_exception, std::stack{input.begin(), input.end()}); + + // The types s and ?s may only be used when using range_formatter + // where the types T and charT are the same. This means this can't be used for + // range_formatter even when formatter has a + // debug-enabled specialization. + + using CharT = wchar_t; + check_exception("The range-format-spec type s requires formatting a character type", + SV("{:s}"), + std::queue{input.begin(), input.end()}); + check_exception("The range-format-spec type s requires formatting a character type", + SV("{:s}"), + std::priority_queue{input.begin(), input.end()}); + check_exception("The range-format-spec type s requires formatting a character type", + SV("{:s}"), + std::stack{input.begin(), input.end()}); + check_exception("The range-format-spec type ?s requires formatting a character type", + SV("{:?s}"), + std::queue{input.begin(), input.end()}); + check_exception("The range-format-spec type ?s requires formatting a character type", + SV("{:?s}"), + std::priority_queue{input.begin(), input.end()}); + check_exception("The range-format-spec type ?s requires formatting a character type", + SV("{:?s}"), + std::stack{input.begin(), input.end()}); +} +#endif // _LIBCPP_HAS_NO_WIDE_CHARACTERS + +// +// Bool +// + +template +void test_bool(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("[true, true, false]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[true, true, false] "), SV("{:24}"), input); + check(SV("[true, true, false]*****"), SV("{:*<24}"), input); + check(SV("__[true, true, false]___"), SV("{:_^24}"), input); + check(SV("#####[true, true, false]"), SV("{:#>24}"), input); + + check(SV("[true, true, false] "), SV("{:{}}"), input, 24); + check(SV("[true, true, false]*****"), SV("{:*<{}}"), input, 24); + check(SV("__[true, true, false]___"), SV("{:_^{}}"), input, 24); + check(SV("#####[true, true, false]"), SV("{:#>{}}"), input, 24); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__true, true, false___"), SV("{:_^22n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[true , true , false ]"), SV("{::7}"), input); + check(SV("[true***, true***, false**]"), SV("{::*<7}"), input); + check(SV("[_true__, _true__, _false_]"), SV("{::_^7}"), input); + check(SV("[:::true, :::true, ::false]"), SV("{:::>7}"), input); + + check(SV("[true , true , false ]"), SV("{::{}}"), input, 7); + check(SV("[true***, true***, false**]"), SV("{::*<{}}"), input, 7); + check(SV("[_true__, _true__, _false_]"), SV("{::_^{}}"), input, 7); + check(SV("[:::true, :::true, ::false]"), SV("{:::>{}}"), input, 7); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("A sign field isn't allowed in this format-spec", SV("{::-}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{::+}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{:: }"), input); + + check(SV("[1, 1, 0]"), SV("{::-d}"), input); + check(SV("[+1, +1, +0]"), SV("{::+d}"), input); + check(SV("[ 1, 1, 0]"), SV("{:: d}"), input); + + // *** alternate form *** + check_exception("An alternate form field isn't allowed in this format-spec", SV("{::#}"), input); + + check(SV("[0x1, 0x1, 0x0]"), SV("{::#x}"), input); + + // *** zero-padding *** + check_exception("A zero-padding field isn't allowed in this format-spec", SV("{::05}"), input); + + check(SV("[00001, 00001, 00000]"), SV("{::05o}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check(SV("[true, true, false]"), SV("{::L}"), input); + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("bBdosxX")) + check_exception("The format-spec type has a type not supported for a bool argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[:::true, :::true, ::false]^^^"), SV("{:^^32::>7}"), input); + check(SV("^^[:::true, :::true, ::false]^^^"), SV("{:^^{}::>7}"), input, 32); + check(SV("^^[:::true, :::true, ::false]^^^"), SV("{:^^{}::>{}}"), input, 32, 7); + + check_exception("Argument index out of bounds", SV("{:^^{}::>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 32); +} + +template +void test_bool(TestFunction check, ExceptionTest check_exception) { + std::array input{true, true, false}; + test_bool(check, check_exception, std::queue{input.begin(), input.end()}); + // TODO FMT Use std::vector after it has been implemented. + test_bool(check, check_exception, std::priority_queue>{input.begin(), input.end()}); + test_bool(check, check_exception, std::stack{input.begin(), input.end()}); +} + +// +// Integral +// + +template +void test_int(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("[-42, 1, 2, 42]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[-42, 1, 2, 42] "), SV("{:20}"), input); + check(SV("[-42, 1, 2, 42]*****"), SV("{:*<20}"), input); + check(SV("__[-42, 1, 2, 42]___"), SV("{:_^20}"), input); + check(SV("#####[-42, 1, 2, 42]"), SV("{:#>20}"), input); + + check(SV("[-42, 1, 2, 42] "), SV("{:{}}"), input, 20); + check(SV("[-42, 1, 2, 42]*****"), SV("{:*<{}}"), input, 20); + check(SV("__[-42, 1, 2, 42]___"), SV("{:_^{}}"), input, 20); + check(SV("#####[-42, 1, 2, 42]"), SV("{:#>{}}"), input, 20); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__-42, 1, 2, 42___"), SV("{:_^18n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[ -42, 1, 2, 42]"), SV("{::5}"), input); + check(SV("[-42**, 1****, 2****, 42***]"), SV("{::*<5}"), input); + check(SV("[_-42_, __1__, __2__, _42__]"), SV("{::_^5}"), input); + check(SV("[::-42, ::::1, ::::2, :::42]"), SV("{:::>5}"), input); + + check(SV("[ -42, 1, 2, 42]"), SV("{::{}}"), input, 5); + check(SV("[-42**, 1****, 2****, 42***]"), SV("{::*<{}}"), input, 5); + check(SV("[_-42_, __1__, __2__, _42__]"), SV("{::_^{}}"), input, 5); + check(SV("[::-42, ::::1, ::::2, :::42]"), SV("{:::>{}}"), input, 5); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check(SV("[-42, 1, 2, 42]"), SV("{::-}"), input); + check(SV("[-42, +1, +2, +42]"), SV("{::+}"), input); + check(SV("[-42, 1, 2, 42]"), SV("{:: }"), input); + + // *** alternate form *** + check(SV("[-0x2a, 0x1, 0x2, 0x2a]"), SV("{::#x}"), input); + + // *** zero-padding *** + check(SV("[-0042, 00001, 00002, 00042]"), SV("{::05}"), input); + check(SV("[-002a, 00001, 00002, 0002a]"), SV("{::05x}"), input); + check(SV("[-0x2a, 0x001, 0x002, 0x02a]"), SV("{::#05x}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check(SV("[-42, 1, 2, 42]"), SV("{::L}"), input); // does nothing in this test, but is accepted. + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("bBcdoxX")) + check_exception("The format-spec type has a type not supported for an integer argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[::-42, ::::1, ::::2, :::42]^^^"), SV("{:^^33::>5}"), input); + check(SV("^^[::-42, ::::1, ::::2, :::42]^^^"), SV("{:^^{}::>5}"), input, 33); + check(SV("^^[::-42, ::::1, ::::2, :::42]^^^"), SV("{:^^{}::>{}}"), input, 33, 5); + + check_exception("Argument index out of bounds", SV("{:^^{}::>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 33); +} + +template +void test_int(TestFunction check, ExceptionTest check_exception) { + std::array input{-42, 1, 2, 42}; + test_int(check, check_exception, std::queue{input.begin(), input.end()}); + test_int(check, check_exception, std::priority_queue{input.begin(), input.end(), std::greater{}}); + test_int(check, check_exception, std::stack{input.begin(), input.end()}); +} + +// +// Floating point +// + +template +void test_floating_point(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[-42.5, 0, 1.25, 42.5] "), SV("{:27}"), input); + check(SV("[-42.5, 0, 1.25, 42.5]*****"), SV("{:*<27}"), input); + check(SV("__[-42.5, 0, 1.25, 42.5]___"), SV("{:_^27}"), input); + check(SV("#####[-42.5, 0, 1.25, 42.5]"), SV("{:#>27}"), input); + + check(SV("[-42.5, 0, 1.25, 42.5] "), SV("{:{}}"), input, 27); + check(SV("[-42.5, 0, 1.25, 42.5]*****"), SV("{:*<{}}"), input, 27); + check(SV("__[-42.5, 0, 1.25, 42.5]___"), SV("{:_^{}}"), input, 27); + check(SV("#####[-42.5, 0, 1.25, 42.5]"), SV("{:#>{}}"), input, 27); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__-42.5, 0, 1.25, 42.5___"), SV("{:_^25n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::5}"), input); + check(SV("[-42.5, 0****, 1.25*, 42.5*]"), SV("{::*<5}"), input); + check(SV("[-42.5, __0__, 1.25_, 42.5_]"), SV("{::_^5}"), input); + check(SV("[-42.5, ::::0, :1.25, :42.5]"), SV("{:::>5}"), input); + + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::{}}"), input, 5); + check(SV("[-42.5, 0****, 1.25*, 42.5*]"), SV("{::*<{}}"), input, 5); + check(SV("[-42.5, __0__, 1.25_, 42.5_]"), SV("{::_^{}}"), input, 5); + check(SV("[-42.5, ::::0, :1.25, :42.5]"), SV("{:::>{}}"), input, 5); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::-}"), input); + check(SV("[-42.5, +0, +1.25, +42.5]"), SV("{::+}"), input); + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{:: }"), input); + + // *** alternate form *** + check(SV("[-42.5, 0., 1.25, 42.5]"), SV("{::#}"), input); + + // *** zero-padding *** + check(SV("[-42.5, 00000, 01.25, 042.5]"), SV("{::05}"), input); + check(SV("[-42.5, 0000., 01.25, 042.5]"), SV("{::#05}"), input); + + // *** precision *** + check(SV("[-42, 0, 1.2, 42]"), SV("{::.2}"), input); + check(SV("[-42.500, 0.000, 1.250, 42.500]"), SV("{::.3f}"), input); + + check(SV("[-42, 0, 1.2, 42]"), SV("{::.{}}"), input, 2); + check(SV("[-42.500, 0.000, 1.250, 42.500]"), SV("{::.{}f}"), input, 3); + + check_exception("The format-spec precision field doesn't contain a value or arg-id", SV("{::.}"), input); + + // *** locale-specific form *** + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::L}"), input); // does not require locales present +#ifndef TEST_HAS_NO_LOCALIZATION + std::locale::global(std::locale(LOCALE_fr_FR_UTF_8)); + check(SV("[-42,5, 0, 1,25, 42,5]"), SV("{::L}"), input); + + std::locale::global(std::locale(LOCALE_en_US_UTF_8)); + check(SV("[-42.5, 0, 1.25, 42.5]"), SV("{::L}"), input); + + std::locale::global(std::locale::classic()); +#endif // TEST_HAS_NO_LOCALIZATION + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("aAeEfFgG")) + check_exception("The format-spec type has a type not supported for a floating-point argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[-42.5, ::::0, :1.25, :42.5]^^^"), SV("{:^^33::>5}"), input); + check(SV("^^[-42.5, ::::0, :1.25, :42.5]^^^"), SV("{:^^{}::>5}"), input, 33); + check(SV("^^[-42.5, ::::0, :1.25, :42.5]^^^"), SV("{:^^{}::>{}}"), input, 33, 5); + + check(SV("^^[::-42, ::::0, ::1.2, :::42]^^^"), SV("{:^^33::>5.2}"), input); + check(SV("^^[::-42, ::::0, ::1.2, :::42]^^^"), SV("{:^^{}::>5.2}"), input, 33); + check(SV("^^[::-42, ::::0, ::1.2, :::42]^^^"), SV("{:^^{}::>{}.2}"), input, 33, 5); + check(SV("^^[::-42, ::::0, ::1.2, :::42]^^^"), SV("{:^^{}::>{}.{}}"), input, 33, 5, 2); + + check_exception("Argument index out of bounds", SV("{:^^{}::>5.2}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}.2}"), input, 33); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}.{}}"), input, 33, 5); +} + +template +void test_floating_point(TestFunction check, ExceptionTest check_exception) { + std::array input{-42.5l, 0.0l, 1.25l, 42.5l}; + test_floating_point(check, check_exception, std::queue{input.begin(), input.end()}); + test_floating_point(check, check_exception, std::priority_queue{input.begin(), input.end(), std::greater{}}); + test_floating_point(check, check_exception, std::stack{input.begin(), input.end()}); +} + +// +// Pointer +// + +template +void test_pointer(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("[0x0]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[0x0] "), SV("{:10}"), input); + check(SV("[0x0]*****"), SV("{:*<10}"), input); + check(SV("__[0x0]___"), SV("{:_^10}"), input); + check(SV("#####[0x0]"), SV("{:#>10}"), input); + + check(SV("[0x0] "), SV("{:{}}"), input, 10); + check(SV("[0x0]*****"), SV("{:*<{}}"), input, 10); + check(SV("__[0x0]___"), SV("{:_^{}}"), input, 10); + check(SV("#####[0x0]"), SV("{:#>{}}"), input, 10); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("_0x0_"), SV("{:_^5n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[ 0x0]"), SV("{::5}"), input); + check(SV("[0x0**]"), SV("{::*<5}"), input); + check(SV("[_0x0_]"), SV("{::_^5}"), input); + check(SV("[::0x0]"), SV("{:::>5}"), input); + + check(SV("[ 0x0]"), SV("{::{}}"), input, 5); + check(SV("[0x0**]"), SV("{::*<{}}"), input, 5); + check(SV("[_0x0_]"), SV("{::_^{}}"), input, 5); + check(SV("[::0x0]"), SV("{:::>{}}"), input, 5); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::-}"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{::05}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::L}"), input); + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("p")) + check_exception("The format-spec type has a type not supported for a pointer argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[::0x0]^^^"), SV("{:^^12::>5}"), input); + check(SV("^^[::0x0]^^^"), SV("{:^^{}::>5}"), input, 12); + check(SV("^^[::0x0]^^^"), SV("{:^^{}::>{}}"), input, 12, 5); + + check(SV("^^[::0x0]^^^"), SV("{:^^12::>5}"), input); + check(SV("^^[::0x0]^^^"), SV("{:^^{}::>5}"), input, 12); + check(SV("^^[::0x0]^^^"), SV("{:^^{}::>{}}"), input, 12, 5); + + check_exception("Argument index out of bounds", SV("{:^^{}::>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 12); +} + +template +void test_pointer(TestFunction check, ExceptionTest check_exception) { + std::array input{static_cast(0)}; + test_pointer(check, check_exception, std::queue{input.begin(), input.end()}); + test_pointer(check, check_exception, std::priority_queue{input.begin(), input.end()}); + test_pointer(check, check_exception, std::stack{input.begin(), input.end()}); +} + +// +// String +// + +template +void test_string(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV(R"(["Hello", "world"])"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV(R"(["Hello", "world"] )"), SV("{:23}"), input); + check(SV(R"(["Hello", "world"]*****)"), SV("{:*<23}"), input); + check(SV(R"(__["Hello", "world"]___)"), SV("{:_^23}"), input); + check(SV(R"(#####["Hello", "world"])"), SV("{:#>23}"), input); + + check(SV(R"(["Hello", "world"] )"), SV("{:{}}"), input, 23); + check(SV(R"(["Hello", "world"]*****)"), SV("{:*<{}}"), input, 23); + check(SV(R"(__["Hello", "world"]___)"), SV("{:_^{}}"), input, 23); + check(SV(R"(#####["Hello", "world"])"), SV("{:#>{}}"), input, 23); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV(R"(_"Hello", "world"_)"), SV("{:_^18n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV(R"([Hello , world ])"), SV("{::8}"), input); + check(SV(R"([Hello***, world***])"), SV("{::*<8}"), input); + check(SV(R"([_Hello__, _world__])"), SV("{::_^8}"), input); + check(SV(R"([:::Hello, :::world])"), SV("{:::>8}"), input); + + check(SV(R"([Hello , world ])"), SV("{::{}}"), input, 8); + check(SV(R"([Hello***, world***])"), SV("{::*<{}}"), input, 8); + check(SV(R"([_Hello__, _world__])"), SV("{::_^{}}"), input, 8); + check(SV(R"([:::Hello, :::world])"), SV("{:::>{}}"), input, 8); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::-}"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{::05}"), input); + + // *** precision *** + check(SV(R"([Hel, wor])"), SV("{::.3}"), input); + + check(SV(R"([Hel, wor])"), SV("{::.{}}"), input, 3); + + check_exception("The format-spec precision field doesn't contain a value or arg-id", SV("{::.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::L}"), input); + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("s?")) + check_exception("The format-spec type has a type not supported for a string argument", fmt, input); + + // ***** Both have a format-spec + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^25::>8}"), input); + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^{}::>8}"), input, 25); + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^{}::>{}}"), input, 25, 8); + + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^25::>8}"), input); + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^{}::>8}"), input, 25); + check(SV(R"(^^[:::Hello, :::world]^^^)"), SV("{:^^{}::>{}}"), input, 25, 8); + + check_exception("Argument index out of bounds", SV("{:^^{}::>8}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 25); +} + +template +void test_string(TestFunction check, ExceptionTest check_exception) { + std::array input{STR("Hello"), STR("world")}; + test_string(check, check_exception, std::queue{input.begin(), input.end()}); + test_string(check, check_exception, std::priority_queue{input.begin(), input.end(), std::greater{}}); + test_string(check, check_exception, std::stack{input.begin(), input.end()}); +} + +// +// Handle +// + +template +void test_status(TestFunction check, ExceptionTest check_exception, auto&& input) { + check(SV("[0xaaaa, 0x5555, 0xaa55]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[0xaaaa, 0x5555, 0xaa55] "), SV("{:29}"), input); + check(SV("[0xaaaa, 0x5555, 0xaa55]*****"), SV("{:*<29}"), input); + check(SV("__[0xaaaa, 0x5555, 0xaa55]___"), SV("{:_^29}"), input); + check(SV("#####[0xaaaa, 0x5555, 0xaa55]"), SV("{:#>29}"), input); + + check(SV("[0xaaaa, 0x5555, 0xaa55] "), SV("{:{}}"), input, 29); + check(SV("[0xaaaa, 0x5555, 0xaa55]*****"), SV("{:*<{}}"), input, 29); + check(SV("__[0xaaaa, 0x5555, 0xaa55]___"), SV("{:_^{}}"), input, 29); + check(SV("#####[0xaaaa, 0x5555, 0xaa55]"), SV("{:#>{}}"), input, 29); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__0xaaaa, 0x5555, 0xaa55___"), SV("{:_^27n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check_exception("The format-spec type has a type not supported for a status argument", SV("{::*<7}"), input); + for (std::basic_string_view fmt : fmt_invalid_nested_types("sxX")) + check_exception("The format-spec type has a type not supported for a status argument", fmt, input); + + check(SV("[0xaaaa, 0x5555, 0xaa55]"), SV("{::x}"), input); + check(SV("[0XAAAA, 0X5555, 0XAA55]"), SV("{::X}"), input); + check(SV("[foo, bar, foobar]"), SV("{::s}"), input); + + // ***** Both have a format-spec + check(SV("^^[0XAAAA, 0X5555, 0XAA55]^^^"), SV("{:^^29:X}"), input); + check(SV("^^[0XAAAA, 0X5555, 0XAA55]^^^"), SV("{:^^{}:X}"), input, 29); + + check_exception("Argument index out of bounds", SV("{:^^{}:X}"), input); +} + +template +void test_status(TestFunction check, ExceptionTest check_exception) { + std::array input{status::foo, status::bar, status::foobar}; + test_status(check, check_exception, std::queue{input.begin(), input.end()}); + test_status(check, check_exception, std::priority_queue{input.begin(), input.end(), std::less{}}); + test_status(check, check_exception, std::stack{input.begin(), input.end()}); +} + +// +// Driver +// + +template +void format_tests(TestFunction check, ExceptionTest check_exception) { + test_char(check, check_exception); +#ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS + if (std::same_as) // avoid testing twice + test_char_to_wchar(check, check_exception); +#endif + test_bool(check, check_exception); + test_int(check, check_exception); + test_floating_point(check, check_exception); + test_pointer(check, check_exception); + test_string(check, check_exception); + + test_status(check, check_exception); // Has its own handler with its own parser +} + +#endif // TEST_STD_CONTAINERS_CONTAINER_ADAPTORS_CONTAINER_ADAPTORS_FORMAT_FORMAT_FUNCTIONS_TESTS_H diff --git a/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.vformat.pass.cpp b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.vformat.pass.cpp new file mode 100644 index 00000000000000..27c52fa2b63683 --- /dev/null +++ b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.functions.vformat.pass.cpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// [container.adaptors.format] +// For each of queue, priority_queue, and stack, the library provides the +// following formatter specialization where adaptor-type is the name of the +// template: +// +// template Container, class... U> +// struct formatter, charT> + +// string vformat(string_view fmt, format_args args); +// wstring vformat(wstring_view fmt, wformat_args args); + +#include +#include + +#include "format.functions.tests.h" +#include "test_macros.h" +#include "assert_macros.h" + +auto test = []( + std::basic_string_view expected, std::basic_string_view fmt, Args&&... args) { + std::basic_string out = std::vformat(fmt, std::make_format_args>(args...)); + TEST_REQUIRE( + out == expected, + test_concat_message("\nFormat string ", fmt, "\nExpected output ", expected, "\nActual output ", out, '\n')); +}; + +auto test_exception = + []( + [[maybe_unused]] std::string_view what, + [[maybe_unused]] std::basic_string_view fmt, + [[maybe_unused]] Args&&... args) { +#ifndef TEST_HAS_NO_EXCEPTIONS + try { + TEST_IGNORE_NODISCARD std::vformat(fmt, std::make_format_args>(args...)); + TEST_FAIL(test_concat_message("\nFormat string ", fmt, "\nDidn't throw an exception.\n")); + } catch (const std::format_error& e) { + TEST_LIBCPP_REQUIRE( + e.what() == what, + test_concat_message( + "\nFormat string ", fmt, "\nExpected exception ", what, "\nActual exception ", e.what(), '\n')); + + return; + } + assert(false); +#endif + }; + +int main(int, char**) { + format_tests(test, test_exception); + +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + format_tests(test, test_exception); +#endif + + return 0; +} diff --git a/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.pass.cpp b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.pass.cpp new file mode 100644 index 00000000000000..901c63a0084869 --- /dev/null +++ b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/format.pass.cpp @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// [container.adaptors.format] +// For each of queue, priority_queue, and stack, the library provides the +// following formatter specialization where adaptor-type is the name of the +// template: +// +// template Container, class... U> +// struct formatter, charT> + +// template +// typename FormatContext::iterator +// format(maybe-const-adaptor& r, FormatContext& ctx) const; + +// Note this tests the basics of this function. It's tested in more detail in +// the format functions test. + +#include +#include +#include +#include +#include +#include + +#include "test_format_context.h" +#include "test_macros.h" +#include "make_string.h" + +#define SV(S) MAKE_STRING_VIEW(CharT, S) + +template +void test_format(StringViewT expected, Arg arg) { + using CharT = typename StringViewT::value_type; + using String = std::basic_string; + using OutIt = std::back_insert_iterator; + using FormatCtxT = std::basic_format_context; + + const std::formatter formatter; + + String result; + OutIt out = std::back_inserter(result); + FormatCtxT format_ctx = test_format_context_create(out, std::make_format_args(arg)); + formatter.format(arg, format_ctx); + assert(result == expected); +} + +template +void test_fmt() { + std::array input{1, 42, 99, 0}; + test_format(SV("[1, 42, 99, 0]"), std::queue{input.begin(), input.end()}); + test_format(SV("[99, 42, 1, 0]"), std::priority_queue{input.begin(), input.end()}); + test_format(SV("[1, 42, 99, 0]"), std::stack{input.begin(), input.end()}); +} + +void test() { + test_fmt(); +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + test_fmt(); +#endif +} + +int main(int, char**) { + test(); + + return 0; +} diff --git a/libcxx/test/std/containers/container.adaptors/container.adaptors.format/parse.pass.cpp b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/parse.pass.cpp new file mode 100644 index 00000000000000..11f08371b5ead4 --- /dev/null +++ b/libcxx/test/std/containers/container.adaptors/container.adaptors.format/parse.pass.cpp @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// [container.adaptors.format] +// For each of queue, priority_queue, and stack, the library provides the +// following formatter specialization where adaptor-type is the name of the +// template: +// +// template Container, class... U> +// struct formatter, charT> + +// template +// constexpr typename ParseContext::iterator +// parse(ParseContext& ctx); + +// Note this tests the basics of this function. It's tested in more detail in +// the format functions test. + +#include +#include +#include +#include +#include + +#include "test_format_context.h" +#include "test_macros.h" +#include "make_string.h" + +#define SV(S) MAKE_STRING_VIEW(CharT, S) + +template +constexpr void test_parse(StringViewT fmt) { + using CharT = typename StringViewT::value_type; + auto parse_ctx = std::basic_format_parse_context(fmt); + std::formatter formatter; + static_assert(std::semiregular); + + std::same_as auto it = formatter.parse(parse_ctx); + assert(it == fmt.end() - (!fmt.empty() && fmt.back() == '}')); +} + +template +constexpr void test_parse(StringViewT fmt) { + test_parse>(fmt); + test_parse>(fmt); + test_parse>(fmt); +} + +template +constexpr void test_fmt() { + test_parse(SV("")); + test_parse(SV(":d")); + + test_parse(SV("}")); + test_parse(SV(":d}")); +} + +constexpr bool test() { + test_fmt(); +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + test_fmt(); +#endif + + return true; +} + +int main(int, char**) { + test(); + static_assert(test()); + + return 0; +} diff --git a/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp b/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp index e17a6d0da2d47f..06a29502a97245 100644 --- a/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp +++ b/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp @@ -195,9 +195,6 @@ void test_P1636() { } // Tests for P2286 Formatting ranges -// -// The paper hasn't been voted in so currently all formatters are disabled. -// TODO validate whether the test is correct after the paper has been accepted. template void test_P2286() { assert_is_formattable, CharT>(); @@ -216,9 +213,9 @@ void test_P2286() { assert_is_not_formattable, CharT>(); assert_is_not_formattable, CharT>(); - assert_is_not_formattable, CharT>(); - assert_is_not_formattable, CharT>(); - assert_is_not_formattable, CharT>(); + assert_is_formattable, CharT>(); + assert_is_formattable, CharT>(); + assert_is_formattable, CharT>(); assert_is_formattable, CharT>(); From 3e6afa77b4ec8b5fb5e0b6349f08e4476892c0f8 Mon Sep 17 00:00:00 2001 From: Sjoerd Meijer Date: Wed, 11 Jan 2023 14:41:10 +0000 Subject: [PATCH 24/61] [CostModel][AArch64] Precommit tests for LD1 single-element to lane. NFC. --- .../CostModel/AArch64/insert-extract.ll | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/llvm/test/Analysis/CostModel/AArch64/insert-extract.ll b/llvm/test/Analysis/CostModel/AArch64/insert-extract.ll index 794fa61e216032..844b23dc52ba90 100644 --- a/llvm/test/Analysis/CostModel/AArch64/insert-extract.ll +++ b/llvm/test/Analysis/CostModel/AArch64/insert-extract.ll @@ -102,3 +102,73 @@ define void @vectorInstrCost() { %t160 = insertelement <2 x double> zeroinitializer, double 0.000000e+00, i64 1 ret void } + +;; LD1: Load one single-element structure to one lane of one register. + +define <8 x i8> @LD1_B(<8 x i8> %vec, ptr noundef %i) { +; KRYO-LABEL: 'LD1_B' +; KRYO-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %v1 = load i8, ptr %i, align 1 +; KRYO-NEXT: Cost Model: Found an estimated cost of 2 for instruction: %v2 = insertelement <8 x i8> %vec, i8 %v1, i32 1 +; KRYO-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret <8 x i8> %v2 +; +; NEO-LABEL: 'LD1_B' +; NEO-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %v1 = load i8, ptr %i, align 1 +; NEO-NEXT: Cost Model: Found an estimated cost of 3 for instruction: %v2 = insertelement <8 x i8> %vec, i8 %v1, i32 1 +; NEO-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret <8 x i8> %v2 +; +entry: + %v1 = load i8, ptr %i, align 1 + %v2 = insertelement <8 x i8> %vec, i8 %v1, i32 1 + ret <8x i8> %v2 +} + +define <4 x i16> @LD1_H(<4 x i16> %vec, ptr noundef %i) { +; KRYO-LABEL: 'LD1_H' +; KRYO-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %v1 = load i16, ptr %i, align 2 +; KRYO-NEXT: Cost Model: Found an estimated cost of 2 for instruction: %v2 = insertelement <4 x i16> %vec, i16 %v1, i32 2 +; KRYO-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret <4 x i16> %v2 +; +; NEO-LABEL: 'LD1_H' +; NEO-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %v1 = load i16, ptr %i, align 2 +; NEO-NEXT: Cost Model: Found an estimated cost of 3 for instruction: %v2 = insertelement <4 x i16> %vec, i16 %v1, i32 2 +; NEO-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret <4 x i16> %v2 +; +entry: + %v1 = load i16, ptr %i, align 2 + %v2 = insertelement <4 x i16> %vec, i16 %v1, i32 2 + ret <4 x i16> %v2 +} + +define <4 x i32> @LD1_W(<4 x i32> %vec, ptr noundef %i) { +; KRYO-LABEL: 'LD1_W' +; KRYO-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %v1 = load i32, ptr %i, align 4 +; KRYO-NEXT: Cost Model: Found an estimated cost of 2 for instruction: %v2 = insertelement <4 x i32> %vec, i32 %v1, i32 3 +; KRYO-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret <4 x i32> %v2 +; +; NEO-LABEL: 'LD1_W' +; NEO-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %v1 = load i32, ptr %i, align 4 +; NEO-NEXT: Cost Model: Found an estimated cost of 3 for instruction: %v2 = insertelement <4 x i32> %vec, i32 %v1, i32 3 +; NEO-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret <4 x i32> %v2 +; +entry: + %v1 = load i32, ptr %i, align 4 + %v2 = insertelement <4 x i32> %vec, i32 %v1, i32 3 + ret <4 x i32> %v2 +} + +define <2 x i64> @LD1_X(<2 x i64> %vec, ptr noundef %i) { +; KRYO-LABEL: 'LD1_X' +; KRYO-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %v1 = load i64, ptr %i, align 8 +; KRYO-NEXT: Cost Model: Found an estimated cost of 2 for instruction: %v2 = insertelement <2 x i64> %vec, i64 %v1, i32 0 +; KRYO-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret <2 x i64> %v2 +; +; NEO-LABEL: 'LD1_X' +; NEO-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %v1 = load i64, ptr %i, align 8 +; NEO-NEXT: Cost Model: Found an estimated cost of 3 for instruction: %v2 = insertelement <2 x i64> %vec, i64 %v1, i32 0 +; NEO-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret <2 x i64> %v2 +; +entry: + %v1 = load i64, ptr %i, align 8 + %v2 = insertelement <2 x i64> %vec, i64 %v1, i32 0 + ret <2 x i64> %v2 +} From 68f45796edbdc4331a6be0b6e9d58f01d8b1fba0 Mon Sep 17 00:00:00 2001 From: Zino Benaissa Date: Thu, 19 Jan 2023 16:44:23 +0000 Subject: [PATCH 25/61] [AARCH64][SVE] Do not optimize vector conversions shuffle_vector instructions are serialized targeting SVE fixed vectors, see https://reviews.llvm.org/D139111. This patch disables optimizeExtendOrTruncateConversion peepholes that generates shuffle_vector. Differential Revision: https://reviews.llvm.org/D141439 --- .../Target/AArch64/AArch64ISelLowering.cpp | 5 ++ .../CodeGen/AArch64/sve-fixed-vector-zext.ll | 59 +++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 324d1cf0d007cd..eaf467d7831a10 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14177,6 +14177,11 @@ static void createTblForTrunc(TruncInst *TI, bool IsLittleEndian) { bool AArch64TargetLowering::optimizeExtendOrTruncateConversion(Instruction *I, Loop *L) const { + // shuffle_vector instructions are serialized when targeting SVE, + // see LowerSPLAT_VECTOR. This peephole is not beneficial. + if (Subtarget->useSVEForFixedLengthVectors()) + return false; + // Try to optimize conversions using tbl. This requires materializing constant // index vectors, which can increase code size and add loads. Skip the // transform unless the conversion is in a loop block guaranteed to execute diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll b/llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll new file mode 100644 index 00000000000000..844afc27907c1f --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll @@ -0,0 +1,59 @@ + +; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-v1 -O3 -opaque-pointers -aarch64-sve-vector-bits-min=256 -verify-machineinstrs | FileCheck %s --check-prefixes=SVE256 +; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-v1 -O3 -opaque-pointers -aarch64-sve-vector-bits-min=128 -verify-machineinstrs | FileCheck %s --check-prefixes=NEON +; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-n1 -O3 -opaque-pointers -verify-machineinstrs | FileCheck %s --check-prefixes=NEON +; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-v2 -O3 -opaque-pointers -verify-machineinstrs | FileCheck %s --check-prefixes=NEON + +define internal i32 @test(ptr nocapture readonly %p1, i32 %i1, ptr nocapture readonly %p2, i32 %i2) { +; SVE256-LABEL: test: +; SVE256: ld1b { z0.h }, p0/z, +; SVE256: ld1b { z1.h }, p0/z, +; SVE256: sub z0.h, z0.h, z1.h +; SVE256-NEXT: sunpklo z1.s, z0.h +; SVE256-NEXT: ext z0.b, z0.b, z0.b, #16 +; SVE256-NEXT: sunpklo z0.s, z0.h +; SVE256-NEXT: add z0.s, z1.s, z0.s +; SVE256-NEXT: uaddv d0, p1, z0.s + +; NEON-LABEL: test: +; NEON: tbl +; NEON-NEXT: tbl +; NEON-NEXT: tbl +; NEON-NEXT: tbl +; NEON-NEXT: tbl +; NEON-NEXT: tbl +; NEON-NEXT: tbl +; NEON-NEXT: tbl +; NEON: addv + + +L.entry: + br label %L1 + +L1: ; preds = %L1, %L.entry + %a = phi i32 [ 16, %L.entry ], [ %14, %L1 ] + %b = phi i32 [ 0, %L.entry ], [ %13, %L1 ] + %i = phi i32 [ 0, %L.entry ], [ %12, %L1 ] + %0 = mul i32 %b, %i1 + %1 = sext i32 %0 to i64 + %2 = getelementptr i8, ptr %p1, i64 %1 + %3 = mul i32 %b, %i2 + %4 = sext i32 %3 to i64 + %5 = getelementptr i8, ptr %p2, i64 %4 + %6 = load <16 x i8>, ptr %2, align 1 + %7 = zext <16 x i8> %6 to <16 x i32> + %8 = load <16 x i8>, ptr %5, align 1 + %9 = zext <16 x i8> %8 to <16 x i32> + %10 = sub nsw <16 x i32> %7, %9 + %11 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %10) + %12 = add i32 %11, %i + %13 = add nuw nsw i32 %b, 1 + %14 = add nsw i32 %a, -1 + %.not = icmp eq i32 %14, 0 + br i1 %.not, label %L2, label %L1 + +L2: ; preds = %L1 + ret i32 %12 +} + +declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>) From 6ef990c5b6135ffe24f8697caf6630401669efeb Mon Sep 17 00:00:00 2001 From: LLVM GN Syncbot Date: Thu, 19 Jan 2023 16:51:17 +0000 Subject: [PATCH 26/61] [gn build] Port 04d4f4b3d4e4 --- llvm/utils/gn/secondary/libcxx/include/BUILD.gn | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/utils/gn/secondary/libcxx/include/BUILD.gn b/llvm/utils/gn/secondary/libcxx/include/BUILD.gn index ecf9bff5646f40..c56ae58192957e 100644 --- a/llvm/utils/gn/secondary/libcxx/include/BUILD.gn +++ b/llvm/utils/gn/secondary/libcxx/include/BUILD.gn @@ -377,6 +377,7 @@ if (current_toolchain == default_toolchain) { "__filesystem/u8path.h", "__format/buffer.h", "__format/concepts.h", + "__format/container_adaptor.h", "__format/enable_insertable.h", "__format/escaped_output_table.h", "__format/extended_grapheme_cluster_table.h", From a09b1dc1f2c534cc8b870b5b7c9bc7f7682a1ed5 Mon Sep 17 00:00:00 2001 From: Mark de Wever Date: Tue, 29 Nov 2022 07:57:30 +0100 Subject: [PATCH 27/61] [libc++][format] Adds formatter std::vector. Implements parts of - P2286R8 Formatting Ranges Depends on D140653 Reviewed By: ldionne, #libc Differential Revision: https://reviews.llvm.org/D141761 --- libcxx/docs/Status/FormatPaper.csv | 2 +- libcxx/include/__bit_reference | 2 + libcxx/include/vector | 30 +++++ .../format.functions.format.pass.cpp | 54 ++++++++ .../vector.bool.fmt/format.functions.tests.h | 116 ++++++++++++++++++ .../format.functions.vformat.pass.cpp | 71 +++++++++++ .../vector.bool.fmt/format.pass.cpp | 72 +++++++++++ .../vector.bool.fmt/parse.pass.cpp | 73 +++++++++++ .../concept.formattable.compile.pass.cpp | 11 ++ 9 files changed, 430 insertions(+), 1 deletion(-) create mode 100644 libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.format.pass.cpp create mode 100644 libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.tests.h create mode 100644 libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.vformat.pass.cpp create mode 100644 libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.pass.cpp create mode 100644 libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/parse.pass.cpp diff --git a/libcxx/docs/Status/FormatPaper.csv b/libcxx/docs/Status/FormatPaper.csv index 10463c44ef8fed..d80346e9d68c57 100644 --- a/libcxx/docs/Status/FormatPaper.csv +++ b/libcxx/docs/Status/FormatPaper.csv @@ -34,4 +34,4 @@ Section,Description,Dependencies,Assignee,Status,First released version `[format.range] `_,"Formatting for ranges: associative",,Mark de Wever,, `[format.range] `_,"Formatting for ranges: container adaptors",,Mark de Wever,|Complete|,Clang 16 `[format.range] `_,"Formatting for ranges: ``pair`` and ``tuple``",,Mark de Wever,|Complete|,Clang 16 -`[format.range] `_,"Formatting for ranges: ``vector``",,Mark de Wever,, +`[format.range] `_,"Formatting for ranges: ``vector``",,Mark de Wever,|Complete|,Clang 16 diff --git a/libcxx/include/__bit_reference b/libcxx/include/__bit_reference index b954c104937784..26657491393ddb 100644 --- a/libcxx/include/__bit_reference +++ b/libcxx/include/__bit_reference @@ -55,6 +55,8 @@ class __bit_reference friend class __bit_const_reference<_Cp>; friend class __bit_iterator<_Cp, false>; public: + using __container = typename _Cp::__self; + _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_SINCE_CXX20 __bit_reference(const __bit_reference&) = default; diff --git a/libcxx/include/vector b/libcxx/include/vector index e9d1b72171e7f8..4c939806dc12b1 100644 --- a/libcxx/include/vector +++ b/libcxx/include/vector @@ -267,6 +267,13 @@ template typename vector::size_type erase_if(vector& c, Predicate pred); // C++20 + +template + inline constexpr bool is-vector-bool-reference = see below; // exposition only, since C++23 + +template requires is-vector-bool-reference // Since C++23 + struct formatter; + } // std */ @@ -281,9 +288,11 @@ erase_if(vector& c, Predicate pred); // C++20 #include <__algorithm/unwrap_iter.h> #include <__assert> // all public C++ headers provide the assertion handler #include <__bit_reference> +#include <__concepts/same_as.h> #include <__config> #include <__debug> #include <__format/enable_insertable.h> +#include <__format/formatter.h> #include <__functional/hash.h> #include <__functional/unary_function.h> #include <__iterator/advance.h> @@ -3312,6 +3321,27 @@ inline constexpr bool __format::__enable_insertable> = true; #endif // _LIBCPP_STD_VER > 17 +#if _LIBCPP_STD_VER > 20 +template +// Since is-vector-bool-reference is only used once it's inlined here. + requires same_as> +struct _LIBCPP_TEMPLATE_VIS _LIBCPP_AVAILABILITY_FORMAT formatter<_Tp, CharT> { +private: + formatter __underlying_; + +public: + template + _LIBCPP_HIDE_FROM_ABI constexpr typename _ParseContext::iterator parse(_ParseContext& __ctx) { + return __underlying_.parse(__ctx); + } + + template + _LIBCPP_HIDE_FROM_ABI typename _FormatContext::iterator format(const _Tp& __ref, _FormatContext& __ctx) const { + return __underlying_.format(__ref, __ctx); + } +}; +#endif // _LIBCPP_STD_VER > 20 + _LIBCPP_END_NAMESPACE_STD #if _LIBCPP_STD_VER > 14 diff --git a/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.format.pass.cpp b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.format.pass.cpp new file mode 100644 index 00000000000000..61cbfea76de7f0 --- /dev/null +++ b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.format.pass.cpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// + +// template +// requires is-vector-bool-reference +// struct formatter + +// template +// string format(format_string fmt, Args&&... args); +// template +// wstring format(wformat_string fmt, Args&&... args); + +#include +#include + +#include "format.functions.tests.h" +#include "test_format_string.h" +#include "test_macros.h" +#include "assert_macros.h" + +auto test = []( + std::basic_string_view expected, test_format_string fmt, Args&&... args) { + std::basic_string out = std::format(fmt, std::forward(args)...); + TEST_REQUIRE( + out == expected, + test_concat_message("\nFormat string ", fmt, "\nExpected output ", expected, "\nActual output ", out, '\n')); +}; + +auto test_exception = [](std::string_view, std::basic_string_view, Args&&...) { + // After P2216 most exceptions thrown by std::format become ill-formed. + // Therefore this tests does nothing. +}; + +int main(int, char**) { + format_tests(test, test_exception); + +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + format_tests(test, test_exception); +#endif + + return 0; +} diff --git a/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.tests.h b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.tests.h new file mode 100644 index 00000000000000..32b68b85f85ab5 --- /dev/null +++ b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.tests.h @@ -0,0 +1,116 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_STD_CONTAINERS_SEQUENCES_VECTOR_BOOL_VECTOR_BOOL_FMT_FORMAT_FUNCTIONS_TESTS_H +#define TEST_STD_CONTAINERS_SEQUENCES_VECTOR_BOOL_VECTOR_BOOL_FMT_FORMAT_FUNCTIONS_TESTS_H + +#include + +#include "format.functions.common.h" +#include "test_macros.h" + +template +void format_tests(TestFunction check, ExceptionTest check_exception) { + std::vector input{true, true, false}; + + check(SV("[true, true, false]"), SV("{}"), input); + + // ***** underlying has no format-spec + + // *** align-fill & width *** + check(SV("[true, true, false] "), SV("{:24}"), input); + check(SV("[true, true, false]*****"), SV("{:*<24}"), input); + check(SV("__[true, true, false]___"), SV("{:_^24}"), input); + check(SV("#####[true, true, false]"), SV("{:#>24}"), input); + + check(SV("[true, true, false] "), SV("{:{}}"), input, 24); + check(SV("[true, true, false]*****"), SV("{:*<{}}"), input, 24); + check(SV("__[true, true, false]___"), SV("{:_^{}}"), input, 24); + check(SV("#####[true, true, false]"), SV("{:#>{}}"), input, 24); + + check_exception("The format-spec range-fill field contains an invalid character", SV("{:}<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{:{<}"), input); + check_exception("The format-spec range-fill field contains an invalid character", SV("{::<}"), input); + + // *** sign *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:-}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{:+}"), input); + check_exception("The format-spec should consume the input or end with a '}'", SV("{: }"), input); + + // *** alternate form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:#}"), input); + + // *** zero-padding *** + check_exception("A format-spec width field shouldn't have a leading zero", SV("{:0}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:.}"), input); + + // *** locale-specific form *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{:L}"), input); + + // *** n + check(SV("__true, true, false___"), SV("{:_^22n}"), input); + + // *** type *** + check_exception("The range-format-spec type m requires two elements for a pair or tuple", SV("{:m}"), input); + check_exception("The range-format-spec type s requires formatting a character type", SV("{:s}"), input); + check_exception("The range-format-spec type ?s requires formatting a character type", SV("{:?s}"), input); + + for (std::basic_string_view fmt : fmt_invalid_types("s")) + check_exception("The format-spec should consume the input or end with a '}'", fmt, input); + + // ***** Only underlying has a format-spec + check(SV("[true , true , false ]"), SV("{::7}"), input); + check(SV("[true***, true***, false**]"), SV("{::*<7}"), input); + check(SV("[_true__, _true__, _false_]"), SV("{::_^7}"), input); + check(SV("[:::true, :::true, ::false]"), SV("{:::>7}"), input); + + check(SV("[true , true , false ]"), SV("{::{}}"), input, 7); + check(SV("[true***, true***, false**]"), SV("{::*<{}}"), input, 7); + check(SV("[_true__, _true__, _false_]"), SV("{::_^{}}"), input, 7); + check(SV("[:::true, :::true, ::false]"), SV("{:::>{}}"), input, 7); + + check_exception("The format-spec fill field contains an invalid character", SV("{::}<}"), input); + check_exception("The format-spec fill field contains an invalid character", SV("{::{<}"), input); + + // *** sign *** + check_exception("A sign field isn't allowed in this format-spec", SV("{::-}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{::+}"), input); + check_exception("A sign field isn't allowed in this format-spec", SV("{:: }"), input); + + check(SV("[1, 1, 0]"), SV("{::-d}"), input); + check(SV("[+1, +1, +0]"), SV("{::+d}"), input); + check(SV("[ 1, 1, 0]"), SV("{:: d}"), input); + + // *** alternate form *** + check_exception("An alternate form field isn't allowed in this format-spec", SV("{::#}"), input); + + check(SV("[0x1, 0x1, 0x0]"), SV("{::#x}"), input); + + // *** zero-padding *** + check_exception("A zero-padding field isn't allowed in this format-spec", SV("{::05}"), input); + + check(SV("[00001, 00001, 00000]"), SV("{::05o}"), input); + + // *** precision *** + check_exception("The format-spec should consume the input or end with a '}'", SV("{::.}"), input); + + // *** type *** + for (std::basic_string_view fmt : fmt_invalid_nested_types("bBdosxX")) + check_exception("The format-spec type has a type not supported for a bool argument", fmt, input); + + // ***** Both have a format-spec + check(SV("^^[:::true, :::true, ::false]^^^"), SV("{:^^32::>7}"), input); + check(SV("^^[:::true, :::true, ::false]^^^"), SV("{:^^{}::>7}"), input, 32); + check(SV("^^[:::true, :::true, ::false]^^^"), SV("{:^^{}::>{}}"), input, 32, 7); + + check_exception("Argument index out of bounds", SV("{:^^{}::>5}"), input); + check_exception("Argument index out of bounds", SV("{:^^{}::>{}}"), input, 32); +} + +#endif // TEST_STD_CONTAINERS_SEQUENCES_VECTOR_BOOL_VECTOR_BOOL_FMT_FORMAT_FUNCTIONS_TESTS_H diff --git a/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.vformat.pass.cpp b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.vformat.pass.cpp new file mode 100644 index 00000000000000..9e6db80cf9890c --- /dev/null +++ b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.functions.vformat.pass.cpp @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// TODO FMT Fix this test using GCC, it currently times out. +// UNSUPPORTED: gcc-12 + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires is-vector-bool-reference +// struct formatter + +// string vformat(string_view fmt, format_args args); +// wstring vformat(wstring_view fmt, wformat_args args); + +#include +#include + +#include "format.functions.tests.h" +#include "test_macros.h" +#include "assert_macros.h" + +auto test = []( + std::basic_string_view expected, std::basic_string_view fmt, Args&&... args) { + std::basic_string out = std::vformat(fmt, std::make_format_args>(args...)); + TEST_REQUIRE( + out == expected, + test_concat_message("\nFormat string ", fmt, "\nExpected output ", expected, "\nActual output ", out, '\n')); +}; + +auto test_exception = + []( + [[maybe_unused]] std::string_view what, + [[maybe_unused]] std::basic_string_view fmt, + [[maybe_unused]] Args&&... args) { +#ifndef TEST_HAS_NO_EXCEPTIONS + try { + TEST_IGNORE_NODISCARD std::vformat(fmt, std::make_format_args>(args...)); + TEST_FAIL(test_concat_message("\nFormat string ", fmt, "\nDidn't throw an exception.\n")); + } catch (const std::format_error& e) { + TEST_LIBCPP_REQUIRE( + e.what() == what, + test_concat_message( + "\nFormat string ", fmt, "\nExpected exception ", what, "\nActual exception ", e.what(), '\n')); + + return; + } + assert(false); +#endif + }; + +int main(int, char**) { + format_tests(test, test_exception); + +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + format_tests(test, test_exception); +#endif + + return 0; +} diff --git a/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.pass.cpp b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.pass.cpp new file mode 100644 index 00000000000000..bb29081dfb898e --- /dev/null +++ b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/format.pass.cpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires is-vector-bool-reference +// struct formatter + +// template +// typename FormatContext::iterator +// format(const T& r, FormatContext& ctx) const; + +// Note this tests the basics of this function. It's tested in more detail in +// the format functions test. + +#include +#include +#include +#include + +#include "test_format_context.h" +#include "test_macros.h" +#include "make_string.h" + +#define SV(S) MAKE_STRING_VIEW(CharT, S) + +template +void test_format(StringViewT expected, std::vector::reference arg) { + using CharT = typename StringViewT::value_type; + using String = std::basic_string; + using OutIt = std::back_insert_iterator; + using FormatCtxT = std::basic_format_context; + + const std::formatter::reference, CharT> formatter; + + String result; + OutIt out = std::back_inserter(result); + FormatCtxT format_ctx = test_format_context_create(out, std::make_format_args(arg)); + formatter.format(arg, format_ctx); + assert(result == expected); +} + +template +void test_fmt() { + test_format(SV("true"), std::vector{true}[0]); + test_format(SV("false"), std::vector{false}[0]); +} + +void test() { + test_fmt(); +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + test_fmt(); +#endif +} + +int main(int, char**) { + test(); + + return 0; +} diff --git a/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/parse.pass.cpp b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/parse.pass.cpp new file mode 100644 index 00000000000000..c711e25d1cd065 --- /dev/null +++ b/libcxx/test/std/containers/sequences/vector.bool/vector.bool.fmt/parse.pass.cpp @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17, c++20 +// UNSUPPORTED: libcpp-has-no-incomplete-format + +// This test requires the dylib support introduced in D92214. +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx10.{{.+}} +// XFAIL: use_system_cxx_lib && target={{.+}}-apple-macosx11.{{.+}} + +// + +// template +// requires is-vector-bool-reference +// struct formatter + +// template +// constexpr typename ParseContext::iterator +// parse(ParseContext& ctx); + +// Note this tests the basics of this function. It's tested in more detail in +// the format functions test. + +#include +#include +#include +#include + +#include "test_format_context.h" +#include "test_macros.h" +#include "make_string.h" + +#define SV(S) MAKE_STRING_VIEW(CharT, S) + +template +constexpr void test_parse(StringViewT fmt) { + using CharT = typename StringViewT::value_type; + auto parse_ctx = std::basic_format_parse_context(fmt); + std::formatter>::reference, CharT> formatter; + static_assert(std::semiregular); + + std::same_as auto it = formatter.parse(parse_ctx); + assert(it == fmt.end() - (!fmt.empty() && fmt.back() == '}')); +} + +template +constexpr void test_fmt() { + test_parse(SV("")); + test_parse(SV("b")); + + test_parse(SV("}")); + test_parse(SV("b}")); +} + +constexpr bool test() { + test_fmt(); +#ifndef TEST_HAS_NO_WIDE_CHARACTERS + test_fmt(); +#endif + + return true; +} + +int main(int, char**) { + test(); + static_assert(test()); + + return 0; +} diff --git a/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp b/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp index 06a29502a97245..cef22792f3aa60 100644 --- a/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp +++ b/libcxx/test/std/utilities/format/format.formattable/concept.formattable.compile.pass.cpp @@ -38,6 +38,7 @@ #include #include "test_macros.h" +#include "min_allocator.h" #ifndef TEST_HAS_NO_FILESYSTEM_LIBRARY # include @@ -194,6 +195,12 @@ void test_P1636() { assert_is_not_formattable, CharT>(); } +template +void test_P2286_vector_bool() { + assert_is_formattable(); + assert_is_formattable(); +} + // Tests for P2286 Formatting ranges template void test_P2286() { @@ -223,6 +230,10 @@ void test_P2286() { assert_is_formattable, CharT>(); assert_is_formattable, CharT>(); + + test_P2286_vector_bool>(); + test_P2286_vector_bool>>(); + test_P2286_vector_bool>>(); } class c { From eef0210706bc16ffde17d5b574799d8aae705790 Mon Sep 17 00:00:00 2001 From: Kiran Chandramohan Date: Thu, 19 Jan 2023 16:49:26 +0000 Subject: [PATCH 28/61] [Flang][Debug] Modifications for getting pathname -> Use file pathname from the Flang frontend. It is the frontend that is in-charge of finding the files and is hence the canonical source for paths. -> Convert pathname to absolute pathname while creating the moduleOp. Co-authored-by: Peter Klausler Reviewed By: PeteSteinfeld, vzakhari, jeanPerier, awarzynski Differential Revision: https://reviews.llvm.org/D141674 --- flang/include/flang/Lower/Bridge.h | 8 ++--- flang/lib/Frontend/FrontendActions.cpp | 7 ++--- flang/lib/Lower/Bridge.cpp | 29 ++++++++++++++++--- .../Lower/module-debug-file-loc-linux.f90 | 14 +++++++++ flang/tools/bbc/bbc.cpp | 2 +- 5 files changed, 46 insertions(+), 14 deletions(-) create mode 100644 flang/test/Lower/module-debug-file-loc-linux.f90 diff --git a/flang/include/flang/Lower/Bridge.h b/flang/include/flang/Lower/Bridge.h index 38109a37396c42..dabbe72fc3764f 100644 --- a/flang/include/flang/Lower/Bridge.h +++ b/flang/include/flang/Lower/Bridge.h @@ -57,11 +57,10 @@ class LoweringBridge { const Fortran::parser::AllCookedSources &allCooked, llvm::StringRef triple, fir::KindMapping &kindMap, const Fortran::lower::LoweringOptions &loweringOptions, - const std::vector &envDefaults, - llvm::StringRef filePath) { + const std::vector &envDefaults) { return LoweringBridge(ctx, semanticsContext, defaultKinds, intrinsics, targetCharacteristics, allCooked, triple, kindMap, - loweringOptions, envDefaults, filePath); + loweringOptions, envDefaults); } //===--------------------------------------------------------------------===// @@ -130,8 +129,7 @@ class LoweringBridge { const Fortran::parser::AllCookedSources &cooked, llvm::StringRef triple, fir::KindMapping &kindMap, const Fortran::lower::LoweringOptions &loweringOptions, - const std::vector &envDefaults, - llvm::StringRef filePath); + const std::vector &envDefaults); LoweringBridge() = delete; LoweringBridge(const LoweringBridge &) = delete; diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp index 16badc769d90a1..d46d281f94da6b 100644 --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -171,8 +171,7 @@ bool CodeGenAction::beginSourceFileAction() { ci.getInvocation().getSemanticsContext().targetCharacteristics(), ci.getParsing().allCooked(), ci.getInvocation().getTargetOpts().triple, kindMap, ci.getInvocation().getLoweringOpts(), - ci.getInvocation().getFrontendOpts().envDefaults, - getCurrentFileOrBufferName()); + ci.getInvocation().getFrontendOpts().envDefaults); // Fetch module from lb, so we can set mlirModule = std::make_unique(lb.getModule()); @@ -698,8 +697,8 @@ void CodeGenAction::runOptimizationPipeline(llvm::raw_pwrite_stream &os) { llvm::PassInstrumentationCallbacks pic; llvm::PipelineTuningOptions pto; std::optional pgoOpt; - llvm::StandardInstrumentations si( - llvmModule->getContext(), opts.DebugPassManager); + llvm::StandardInstrumentations si(llvmModule->getContext(), + opts.DebugPassManager); si.registerCallbacks(pic, &fam); llvm::PassBuilder pb(tm.get(), pto, pgoOpt, &pic); diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 9e5898e9db0eef..783f77c10659d6 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -58,6 +58,8 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" #include #define DEBUG_TYPE "flang-lower-bridge" @@ -3785,8 +3787,7 @@ Fortran::lower::LoweringBridge::LoweringBridge( const Fortran::parser::AllCookedSources &cooked, llvm::StringRef triple, fir::KindMapping &kindMap, const Fortran::lower::LoweringOptions &loweringOptions, - const std::vector &envDefaults, - llvm::StringRef filePath) + const std::vector &envDefaults) : semanticsContext{semanticsContext}, defaultKinds{defaultKinds}, intrinsics{intrinsics}, targetCharacteristics{targetCharacteristics}, cooked{&cooked}, context{context}, kindMap{kindMap}, @@ -3814,10 +3815,30 @@ Fortran::lower::LoweringBridge::LoweringBridge( return mlir::success(); }); + auto getPathLocation = [&semanticsContext, &context]() -> mlir::Location { + std::optional path; + const auto &allSources{semanticsContext.allCookedSources().allSources()}; + if (auto initial{allSources.GetFirstFileProvenance()}; + initial && !initial->empty()) { + if (const auto *sourceFile{allSources.GetSourceFile(initial->start())}) { + path = sourceFile->path(); + } + } + + if (path.has_value()) { + llvm::SmallString<256> curPath(*path); + llvm::sys::fs::make_absolute(curPath); + llvm::sys::path::remove_dots(curPath); + return mlir::FileLineColLoc::get(&context, curPath.str(), /*line=*/0, + /*col=*/0); + } else { + return mlir::UnknownLoc::get(&context); + } + }; + // Create the module and attach the attributes. module = std::make_unique( - mlir::ModuleOp::create(mlir::FileLineColLoc::get( - &getMLIRContext(), filePath, /*line=*/0, /*col=*/0))); + mlir::ModuleOp::create(getPathLocation())); assert(module.get() && "module was not created"); fir::setTargetTriple(*module.get(), triple); fir::setKindMapping(*module.get(), kindMap); diff --git a/flang/test/Lower/module-debug-file-loc-linux.f90 b/flang/test/Lower/module-debug-file-loc-linux.f90 new file mode 100644 index 00000000000000..4f773017096d05 --- /dev/null +++ b/flang/test/Lower/module-debug-file-loc-linux.f90 @@ -0,0 +1,14 @@ +! Test that the module has the location information + +! RUN: %flang_fc1 -mmlir --mlir-print-debuginfo -emit-fir -o - %s | FileCheck %s + +! REQUIRES: linux + +subroutine sb1() +end subroutine + +! CHECK: module attributes +! CHECK: func.func @_QPsb1() { +! CHECK: } +! CHECK: } loc(#[[MODULE_LOC:.*]]) +! CHECK: #[[MODULE_LOC]] = loc("/{{.*}}flang/test/Lower/module-debug-file-loc-linux.f90":0:0) diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp index 289c5ad675a3fb..796e7faa288077 100644 --- a/flang/tools/bbc/bbc.cpp +++ b/flang/tools/bbc/bbc.cpp @@ -235,7 +235,7 @@ static mlir::LogicalResult convertFortranSourceToMLIR( auto burnside = Fortran::lower::LoweringBridge::create( ctx, semanticsContext, defKinds, semanticsContext.intrinsics(), semanticsContext.targetCharacteristics(), parsing.allCooked(), "", - kindMap, loweringOptions, {}, inputFilename); + kindMap, loweringOptions, {}); burnside.lower(parseTree, semanticsContext); mlir::ModuleOp mlirModule = burnside.getModule(); std::error_code ec; From 176c853d1da5c344194952659eeb74f063e13010 Mon Sep 17 00:00:00 2001 From: Michael Jones Date: Fri, 6 Jan 2023 15:00:51 -0800 Subject: [PATCH 29/61] [libc] add internal string class The scanf implementation needs a dynamically resizing string class. This patch adds a minimal version of that class along with tests to check the current functionality. Reviewed By: sivachandra Differential Revision: https://reviews.llvm.org/D141162 --- libc/src/__support/CMakeLists.txt | 6 ++ libc/src/__support/char_vector.h | 76 ++++++++++++++++++ libc/test/src/__support/CMakeLists.txt | 10 +++ .../src/__support/CPP/stringview_test.cpp | 3 +- libc/test/src/__support/char_vector_test.cpp | 78 +++++++++++++++++++ 5 files changed, 171 insertions(+), 2 deletions(-) create mode 100644 libc/src/__support/char_vector.h create mode 100644 libc/test/src/__support/char_vector_test.cpp diff --git a/libc/src/__support/CMakeLists.txt b/libc/src/__support/CMakeLists.txt index fc80c5389ddfcc..8b702af082fc2f 100644 --- a/libc/src/__support/CMakeLists.txt +++ b/libc/src/__support/CMakeLists.txt @@ -145,6 +145,12 @@ add_header_library( libc.src.__support.CPP.array ) +add_header_library( + char_vector + HDRS + char_vector.h +) + add_header_library( number_pair HDRS diff --git a/libc/src/__support/char_vector.h b/libc/src/__support/char_vector.h new file mode 100644 index 00000000000000..0a9a2d7e4cf148 --- /dev/null +++ b/libc/src/__support/char_vector.h @@ -0,0 +1,76 @@ +//===-- Standalone implementation of a char vector --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_CHARVECTOR_H +#define LLVM_LIBC_SRC_SUPPORT_CHARVECTOR_H + +#include +#include // For allocation. + +namespace __llvm_libc { + +// This is very simple alternate of the std::string class. There is no +// bounds check performed in any of the methods. The callers are expected to +// do the checks before invoking the methods. +// +// This class will be extended as needed in future. + +class CharVector { + static constexpr size_t INIT_BUFF_SIZE = 64; + char local_buffer[INIT_BUFF_SIZE]; + char *cur_str = local_buffer; + size_t cur_buff_size = INIT_BUFF_SIZE; + size_t index = 0; + +public: + CharVector() = default; + ~CharVector() { + if (cur_str != local_buffer) + free(cur_str); + } + + // append returns true on success and false on allocation failure. + bool append(char new_char) { + // Subtract 1 for index starting at 0 and another for the null terminator. + if (index >= cur_buff_size - 2) { + // If the new character would cause the string to be longer than the + // buffer's size, attempt to allocate a new buffer. + cur_buff_size = cur_buff_size * 2; + if (cur_str == local_buffer) { + char *new_str; + new_str = reinterpret_cast(malloc(cur_buff_size)); + if (new_str == NULL) { + return false; + } + // TODO: replace with inline memcpy + for (size_t i = 0; i < index; ++i) + new_str[i] = cur_str[i]; + cur_str = new_str; + } else { + cur_str = reinterpret_cast(realloc(cur_str, cur_buff_size)); + if (cur_str == NULL) { + return false; + } + } + } + cur_str[index] = new_char; + ++index; + return true; + } + + char *c_str() { + cur_str[index] = '\0'; + return cur_str; + } + + size_t length() { return index; } +}; + +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_SUPPORT_CHARVECTOR_H diff --git a/libc/test/src/__support/CMakeLists.txt b/libc/test/src/__support/CMakeLists.txt index b835812f6e1b86..fa35aa45712ff7 100644 --- a/libc/test/src/__support/CMakeLists.txt +++ b/libc/test/src/__support/CMakeLists.txt @@ -84,6 +84,16 @@ add_libc_unittest( libc.src.__support.fixedvector ) +add_libc_unittest( + char_vector_test + SUITE + libc_support_unittests + SRCS + char_vector_test.cpp + DEPENDS + libc.src.__support.char_vector +) + add_executable( libc_str_to_float_comparison_test str_to_float_comparison_test.cpp diff --git a/libc/test/src/__support/CPP/stringview_test.cpp b/libc/test/src/__support/CPP/stringview_test.cpp index bd6b614218bb44..240b590e69976f 100644 --- a/libc/test/src/__support/CPP/stringview_test.cpp +++ b/libc/test/src/__support/CPP/stringview_test.cpp @@ -1,5 +1,4 @@ -//===-- Unittests for string_view -//------------------------------------------===// +//===-- Unittests for string_view -----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/libc/test/src/__support/char_vector_test.cpp b/libc/test/src/__support/char_vector_test.cpp new file mode 100644 index 00000000000000..15e269bed18574 --- /dev/null +++ b/libc/test/src/__support/char_vector_test.cpp @@ -0,0 +1,78 @@ +//===-- Unittests for char_vector ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/char_vector.h" +#include "utils/UnitTest/Test.h" + +using __llvm_libc::CharVector; + +TEST(LlvmLibcCharVectorTest, InitializeCheck) { + CharVector v; + ASSERT_EQ(v.length(), size_t(0)); +} + +TEST(LlvmLibcCharVectorTest, AppendShort) { + CharVector v; + ASSERT_EQ(v.length(), size_t(0)); + + constexpr char test_str[] = "1234567890"; + for (size_t i = 0; test_str[i] != '\0'; ++i) { + v.append(test_str[i]); + } + ASSERT_STREQ(v.c_str(), test_str); +} + +TEST(LlvmLibcCharVectorTest, AppendMedium) { + CharVector v; + ASSERT_EQ(v.length(), size_t(0)); + + // 100 characters (each row is 50) + constexpr char test_str[] = + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy"; + for (size_t i = 0; test_str[i] != '\0'; ++i) { + ASSERT_EQ(v.length(), i); + v.append(test_str[i]); + } + ASSERT_STREQ(v.c_str(), test_str); + ASSERT_EQ(v.length(), size_t(100)); +} + +TEST(LlvmLibcCharVectorTest, AppendLong) { + CharVector v; + ASSERT_EQ(v.length(), size_t(0)); + + // 1000 characters + constexpr char test_str[] = + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy" + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy" + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy" + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy" + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy" + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy" + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy" + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy" + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy" + "12345678901234567890123456789012345678901234567890" + "ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxy"; + for (size_t i = 0; test_str[i] != '\0'; ++i) { + ASSERT_EQ(v.length(), i); + v.append(test_str[i]); + } + ASSERT_STREQ(v.c_str(), test_str); + ASSERT_EQ(v.length(), size_t(1000)); +} From 9a32e53941ccb6ccd55d5ecd305ecd16b90bfd58 Mon Sep 17 00:00:00 2001 From: Michael Jones Date: Wed, 21 Dec 2022 14:29:43 -0800 Subject: [PATCH 30/61] [libc] add scanf float converter This patch adds the %f/F/e/E/g/G/a/A conversions for scanf, as well as accompanying tests. This implementation matches the definition set forth in the standard, which may conflict with some other implementations. Reviewed By: sivachandra Differential Revision: https://reviews.llvm.org/D141091 --- libc/src/stdio/scanf_core/CMakeLists.txt | 6 + libc/src/stdio/scanf_core/converter.cpp | 33 +- libc/src/stdio/scanf_core/core_structs.h | 1 + libc/src/stdio/scanf_core/float_converter.cpp | 255 ++++++++++++ libc/src/stdio/scanf_core/float_converter.h | 25 ++ libc/src/stdio/scanf_core/int_converter.cpp | 2 +- libc/test/src/stdio/CMakeLists.txt | 2 + libc/test/src/stdio/sscanf_test.cpp | 363 ++++++++++++++++++ 8 files changed, 670 insertions(+), 17 deletions(-) create mode 100644 libc/src/stdio/scanf_core/float_converter.cpp create mode 100644 libc/src/stdio/scanf_core/float_converter.h diff --git a/libc/src/stdio/scanf_core/CMakeLists.txt b/libc/src/stdio/scanf_core/CMakeLists.txt index ab0530d6d38615..54fa40d5363b4e 100644 --- a/libc/src/stdio/scanf_core/CMakeLists.txt +++ b/libc/src/stdio/scanf_core/CMakeLists.txt @@ -79,9 +79,12 @@ add_object_library( converter.cpp string_converter.cpp int_converter.cpp + float_converter.cpp HDRS converter.h int_converter.h + string_converter.h + float_converter.h DEPENDS .reader .core_structs @@ -89,6 +92,9 @@ add_object_library( libc.src.__support.CPP.bitset libc.src.__support.CPP.string_view libc.src.__support.CPP.limits + libc.src.__support.char_vector + libc.include.errno + libc.src.errno.errno ) add_object_library( diff --git a/libc/src/stdio/scanf_core/converter.cpp b/libc/src/stdio/scanf_core/converter.cpp index fbb2c1bcd5ee58..c5a2932fb18f8f 100644 --- a/libc/src/stdio/scanf_core/converter.cpp +++ b/libc/src/stdio/scanf_core/converter.cpp @@ -12,6 +12,9 @@ #include "src/stdio/scanf_core/core_structs.h" #include "src/stdio/scanf_core/reader.h" +#ifndef LLVM_LIBC_SCANF_DISABLE_FLOAT +#include "src/stdio/scanf_core/float_converter.h" +#endif // LLVM_LIBC_SCANF_DISABLE_FLOAT #include "src/stdio/scanf_core/int_converter.h" #include "src/stdio/scanf_core/string_converter.h" @@ -43,24 +46,22 @@ int convert(Reader *reader, const FormatSection &to_conv) { if (ret_val != READ_OK) return ret_val; return convert_int(reader, to_conv); - // #ifndef LLVM_LIBC_SCANF_DISABLE_FLOAT - // case 'f': - // case 'F': - // case 'e': - // case 'E': - // case 'a': - // case 'A': - // case 'g': - // case 'G': - // ret_val = raw_match(reader, " "); - // if (ret_val != READ_OK) - // return ret_val; - // return convert_float(reader, to_conv); - // #endif // LLVM_LIBC_SCANF_DISABLE_FLOAT - // #ifndef LLVM_LIBC_SCANF_DISABLE_WRITE_INT +#ifndef LLVM_LIBC_SCANF_DISABLE_FLOAT + case 'f': + case 'F': + case 'e': + case 'E': + case 'a': + case 'A': + case 'g': + case 'G': + ret_val = raw_match(reader, " "); + if (ret_val != READ_OK) + return ret_val; + return convert_float(reader, to_conv); +#endif // LLVM_LIBC_SCANF_DISABLE_FLOAT // case 'n': // return convert_write_int(reader, to_conv); - // #endif // LLVM_LIBC_SCANF_DISABLE_WRITE_INT // case 'p': // ret_val = raw_match(reader, " "); // if (ret_val != READ_OK) diff --git a/libc/src/stdio/scanf_core/core_structs.h b/libc/src/stdio/scanf_core/core_structs.h index 7f331db3620236..4555595ab3bff6 100644 --- a/libc/src/stdio/scanf_core/core_structs.h +++ b/libc/src/stdio/scanf_core/core_structs.h @@ -84,6 +84,7 @@ enum ErrorCodes : int { FILE_READ_ERROR = -1, FILE_STATUS_ERROR = -2, MATCHING_FAILURE = -3, + ALLOCATION_FAILURE = -4, }; } // namespace scanf_core } // namespace __llvm_libc diff --git a/libc/src/stdio/scanf_core/float_converter.cpp b/libc/src/stdio/scanf_core/float_converter.cpp new file mode 100644 index 00000000000000..1a0ce42863f350 --- /dev/null +++ b/libc/src/stdio/scanf_core/float_converter.cpp @@ -0,0 +1,255 @@ +//===-- Int type specifier converters for scanf -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/stdio/scanf_core/float_converter.h" + +#include "src/__support/CPP/limits.h" +#include "src/__support/char_vector.h" +#include "src/__support/ctype_utils.h" +#include "src/__support/str_to_float.h" +#include "src/stdio/scanf_core/core_structs.h" +#include "src/stdio/scanf_core/reader.h" + +#include + +namespace __llvm_libc { +namespace scanf_core { + +constexpr char inline to_lower(char a) { return a | 32; } + +void write_with_length(char *str, const FormatSection &to_conv) { + if ((to_conv.flags & NO_WRITE) != 0) { + return; + } + + void *output_ptr = to_conv.output_ptr; + + LengthModifier lm = to_conv.length_modifier; + switch (lm) { + case (LengthModifier::l): { + auto value = internal::strtofloatingpoint(str, nullptr); + *reinterpret_cast(output_ptr) = value; + break; + } + case (LengthModifier::L): { + auto value = internal::strtofloatingpoint(str, nullptr); + *reinterpret_cast(output_ptr) = value; + break; + } + default: { + auto value = internal::strtofloatingpoint(str, nullptr); + *reinterpret_cast(output_ptr) = value; + break; + } + } +} + +// All of the floating point conversions are the same for scanf, every name will +// accept every style. +int convert_float(Reader *reader, const FormatSection &to_conv) { + // %a/A/e/E/f/F/g/G "Matches an optionally signed floating-point number, + // infinity, or NaN, whose format is the same as expected for the subject + // sequence of the strtod function. The corresponding argument shall be a + // pointer to floating." + + CharVector out_str = CharVector(); + bool is_number = false; + + size_t max_width = cpp::numeric_limits::max(); + if (to_conv.max_width > 0) { + max_width = to_conv.max_width; + } + + char cur_char = reader->getc(); + // Handle the sign. + if (cur_char == '+' || cur_char == '-') { + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + if (out_str.length() == max_width) { + return MATCHING_FAILURE; + } else { + cur_char = reader->getc(); + } + } + + static constexpr char DECIMAL_POINT = '.'; + static const char inf_string[] = "infinity"; + + // Handle inf + + if (to_lower(cur_char) == inf_string[0]) { + size_t inf_index = 0; + + for (; to_lower(cur_char) == inf_string[inf_index] && + inf_index < sizeof(inf_string) && out_str.length() < max_width; + ++inf_index) { + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + cur_char = reader->getc(); + } + + if (inf_index == 3 || inf_index == sizeof(inf_string) - 1) { + write_with_length(out_str.c_str(), to_conv); + return READ_OK; + } else { + return MATCHING_FAILURE; + } + } + + static const char nan_string[] = "nan"; + + // Handle nan + if (to_lower(cur_char) == nan_string[0]) { + size_t nan_index = 0; + + for (; to_lower(cur_char) == nan_string[nan_index] && + nan_index < sizeof(nan_string) && out_str.length() < max_width; + ++nan_index) { + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + cur_char = reader->getc(); + } + + if (nan_index == sizeof(nan_string) - 1) { + write_with_length(out_str.c_str(), to_conv); + return READ_OK; + } else { + return MATCHING_FAILURE; + } + } + + // Assume base of 10 by default but check if it is actually base 16. + int base = 10; + + // If the string starts with 0 it might be in hex. + if (cur_char == '0') { + is_number = true; + // Read the next character to check. + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + // If we've hit the end, then this is "0", which is valid. + if (out_str.length() == max_width) { + write_with_length(out_str.c_str(), to_conv); + return READ_OK; + } else { + cur_char = reader->getc(); + } + + // If that next character is an 'x' then this is a hexadecimal number. + if (to_lower(cur_char) == 'x') { + base = 16; + + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + // If we've hit the end here, we have "0x" which is a valid prefix to a + // floating point number, and will be evaluated to 0. + if (out_str.length() == max_width) { + write_with_length(out_str.c_str(), to_conv); + return READ_OK; + } else { + cur_char = reader->getc(); + } + } + } + + const char exponent_mark = ((base == 10) ? 'e' : 'p'); + bool after_decimal = false; + + // The format for the remaining characters at this point is DD.DDe+/-DD for + // base 10 and XX.XXp+/-DD for base 16 + + // This handles the digits before and after the decimal point, but not the + // exponent. + while (out_str.length() < max_width) { + if (internal::isalnum(cur_char) && + internal::b36_char_to_int(cur_char) < base) { + is_number = true; + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + cur_char = reader->getc(); + } else if (cur_char == DECIMAL_POINT && !after_decimal) { + after_decimal = true; + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + cur_char = reader->getc(); + } else { + break; + } + } + + // Handle the exponent, which has an exponent mark, an optional sign, and + // decimal digits. + if (to_lower(cur_char) == exponent_mark) { + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + if (out_str.length() == max_width) { + // This is laid out in the standard as being a matching error (100e is not + // a valid float) but may conflict with existing implementations. + return MATCHING_FAILURE; + } else { + cur_char = reader->getc(); + } + + if (cur_char == '+' || cur_char == '-') { + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + if (out_str.length() == max_width) { + return MATCHING_FAILURE; + } else { + cur_char = reader->getc(); + } + } + + // It is specified by the standard that "100er" is a matching failure since + // the longest prefix of a possibly valid floating-point number (which is + // "100e") is not a valid floating-point number. If there is an exponent + // mark then there must be a digit after it else the number is not valid. + // Some implementations will roll back two characters (to just "100") and + // accept that since the prefix is not valid, and some will interpret an + // exponent mark followed by no digits as an additional exponent of 0 + // (accepting "100e" and returning 100.0). Both of these behaviors are wrong + // by the standard, but they may be used in real code, see Hyrum's law. This + // code follows the standard, but may be incompatible due to code expecting + // these bugs. + if (!internal::isdigit(cur_char)) { + return MATCHING_FAILURE; + } + + while (internal::isdigit(cur_char) && out_str.length() < max_width) { + if (!out_str.append(cur_char)) { + return ALLOCATION_FAILURE; + } + cur_char = reader->getc(); + } + } + + // We always read one more character than will be used, so we have to put the + // last one back. + reader->ungetc(cur_char); + + // If we haven't actually found any digits, this is a matching failure (this + // catches cases like "+.") + if (!is_number) { + return MATCHING_FAILURE; + } + write_with_length(out_str.c_str(), to_conv); + + return READ_OK; +} + +} // namespace scanf_core +} // namespace __llvm_libc diff --git a/libc/src/stdio/scanf_core/float_converter.h b/libc/src/stdio/scanf_core/float_converter.h new file mode 100644 index 00000000000000..e8abe8d20bcea5 --- /dev/null +++ b/libc/src/stdio/scanf_core/float_converter.h @@ -0,0 +1,25 @@ +//===-- Float type specifier converter for scanf ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_STDIO_SCANF_CORE_FLOAT_CONVERTER_H +#define LLVM_LIBC_SRC_STDIO_SCANF_CORE_FLOAT_CONVERTER_H + +#include "src/stdio/scanf_core/core_structs.h" +#include "src/stdio/scanf_core/reader.h" + +#include + +namespace __llvm_libc { +namespace scanf_core { + +int convert_float(Reader *reader, const FormatSection &to_conv); + +} // namespace scanf_core +} // namespace __llvm_libc + +#endif // LLVM_LIBC_SRC_STDIO_SCANF_CORE_FLOAT_CONVERTER_H diff --git a/libc/src/stdio/scanf_core/int_converter.cpp b/libc/src/stdio/scanf_core/int_converter.cpp index 1087166742b6cf..be88a01f942d71 100644 --- a/libc/src/stdio/scanf_core/int_converter.cpp +++ b/libc/src/stdio/scanf_core/int_converter.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "src/stdio/scanf_core/string_converter.h" +#include "src/stdio/scanf_core/int_converter.h" #include "src/__support/CPP/limits.h" #include "src/__support/ctype_utils.h" diff --git a/libc/test/src/stdio/CMakeLists.txt b/libc/test/src/stdio/CMakeLists.txt index f74fa46a5e77d1..e94432d069f2e2 100644 --- a/libc/test/src/stdio/CMakeLists.txt +++ b/libc/test/src/stdio/CMakeLists.txt @@ -182,6 +182,8 @@ add_libc_unittest( sscanf_test.cpp DEPENDS libc.src.stdio.sscanf + LINK_LIBRARIES + LibcFPTestHelpers ) add_libc_unittest( diff --git a/libc/test/src/stdio/sscanf_test.cpp b/libc/test/src/stdio/sscanf_test.cpp index b3e146f1d64be8..fc67593f57f0b7 100644 --- a/libc/test/src/stdio/sscanf_test.cpp +++ b/libc/test/src/stdio/sscanf_test.cpp @@ -7,10 +7,14 @@ //===----------------------------------------------------------------------===// #include "src/__support/CPP/limits.h" +#include "src/__support/FPUtil/FPBits.h" +#include "src/__support/FPUtil/PlatformDefs.h" + #include "src/stdio/sscanf.h" #include // For EOF +#include "utils/UnitTest/FPMatcher.h" #include "utils/UnitTest/Test.h" TEST(LlvmLibcSScanfTest, SimpleStringConv) { @@ -209,6 +213,365 @@ TEST(LlvmLibcSScanfTest, IntConvNoWriteTests) { EXPECT_EQ(result, 0); } +TEST(LlvmLibcSScanfTest, FloatConvSimple) { + int ret_val; + float result = 0; + + float inf = __llvm_libc::fputil::FPBits::inf().get_val(); + float nan = __llvm_libc::fputil::FPBits::build_nan(1); + + ret_val = __llvm_libc::sscanf("123", "%f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 123.0); + + ret_val = __llvm_libc::sscanf("456.1", "%a", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 456.1); + + ret_val = __llvm_libc::sscanf("0x789.ap0", "%e", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0x789.ap0); + + ret_val = __llvm_libc::sscanf("0x.8", "%e", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0x0.8p0); + + ret_val = __llvm_libc::sscanf("0x8.", "%e", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0x8.0p0); + + ret_val = __llvm_libc::sscanf("+12.0e1", "%g", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 12.0e1); + + ret_val = __llvm_libc::sscanf("inf", "%F", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, inf); + + ret_val = __llvm_libc::sscanf("NaN", "%A", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, nan); + + ret_val = __llvm_libc::sscanf("-InFiNiTy", "%E", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, -inf); + + ret_val = __llvm_libc::sscanf("1e10", "%G", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 1e10); + + ret_val = __llvm_libc::sscanf(".1", "%G", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.1); + + ret_val = __llvm_libc::sscanf("1.", "%G", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 1.0); + + ret_val = __llvm_libc::sscanf("0", "%f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("Not a float", "%f", &result); + EXPECT_EQ(ret_val, 0); +} + +TEST(LlvmLibcSScanfTest, FloatConvLengthModifier) { + int ret_val; + double d_result = 0; + long double ld_result = 0; + + double d_inf = __llvm_libc::fputil::FPBits::inf().get_val(); + long double ld_nan = __llvm_libc::fputil::FPBits::build_nan(1); + + ret_val = __llvm_libc::sscanf("123", "%lf", &d_result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(d_result, 123.0); + + ret_val = __llvm_libc::sscanf("456.1", "%La", &ld_result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(ld_result, 456.1L); + + ret_val = __llvm_libc::sscanf("inf", "%le", &d_result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(d_result, d_inf); + + ret_val = __llvm_libc::sscanf("nan", "%Lg", &ld_result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(ld_result, ld_nan); + + ret_val = __llvm_libc::sscanf("1e-300", "%lF", &d_result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(d_result, 1e-300); + + ret_val = __llvm_libc::sscanf("1.0e600", "%LA", &ld_result); + EXPECT_EQ(ret_val, 1); +// 1e600 may be larger than the maximum long double (if long double is double). +// In that case both of these should be evaluated as inf. +#ifdef LONG_DOUBLE_IS_DOUBLE + EXPECT_FP_EQ(ld_result, d_inf); +#else + EXPECT_FP_EQ(ld_result, 1.0e600L); +#endif +} + +TEST(LlvmLibcSScanfTest, FloatConvLongNumber) { + int ret_val; + float result = 0; + double d_result = 0; + + // 32 characters + ret_val = + __llvm_libc::sscanf("123456789012345678901234567890.0", "%f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 123456789012345678901234567890.0f); + + // 64 characters + ret_val = __llvm_libc::sscanf( + "123456789012345678901234567890123456789012345678901234567890.000", "%la", + &d_result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ( + d_result, + 123456789012345678901234567890123456789012345678901234567890.000); + + // 128 characters + ret_val = __llvm_libc::sscanf( + "123456789012345678901234567890123456789012345678901234567890" + "123456789012345678901234567890123456789012345678901234567890.0000000", + "%le", &d_result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ( + d_result, + 123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890.0000000); + + // 256 characters + ret_val = __llvm_libc::sscanf("10000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000", + "%lf", &d_result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(d_result, 1e255); + + // 288 characters + ret_val = __llvm_libc::sscanf("10000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000" + "00000000000000000000000000000000", + "%lf", &d_result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(d_result, 1e287); +} + +TEST(LlvmLibcSScanfTest, FloatConvComplexParsing) { + int ret_val; + float result = 0; + + float inf = __llvm_libc::fputil::FPBits::inf().get_val(); + float nan = __llvm_libc::fputil::FPBits::build_nan(1); + + ret_val = __llvm_libc::sscanf("0x1.0e3", "%f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0x1.0e3p0); + + ret_val = __llvm_libc::sscanf("", "%a", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("+", "%a", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("-", "%a", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("+.", "%a", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("-.e+10", "%a", &result); + EXPECT_EQ(ret_val, 0); + + // This is a specific example from the standard. Its behavior diverges from + // other implementations that accept "100e" as being the same as "100e0" + ret_val = __llvm_libc::sscanf("100er", "%a", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("nah", "%a", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("indirection", "%a", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("infnan", "%a", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, inf); + + ret_val = __llvm_libc::sscanf("naninf", "%a", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, nan); + + ret_val = __llvm_libc::sscanf("infinityinfinity", "%a", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, inf); + + // For %f to accept a string as representing it has to be either "inf" or + // "infinity" when it stops. It only stops when it encounters a character that + // isn't the next one in the string, so it accepts "infi" as the the longest + // prefix of a possibly valid floating-point number, but determines that it is + // not valid and returns a matching failure. This is because it can only unget + // one character so when it finds that the character after the second 'i' is + // not the next character in "infinity" it can't rewind to the point where it + // had just "inf". + ret_val = __llvm_libc::sscanf("infi", "%a", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("infinite", "%a", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("-.1e1", "%f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, -.1e1); + + ret_val = __llvm_libc::sscanf("1.2.e1", "%f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 1.2); +} + +/* +TODO: + Max width tests +*/ + +TEST(LlvmLibcSScanfTest, FloatConvMaxWidth) { + int ret_val; + float result = 0; + + float inf = __llvm_libc::fputil::FPBits::inf().get_val(); + + ret_val = __llvm_libc::sscanf("123", "%3f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 123.0); + + ret_val = __llvm_libc::sscanf("123", "%5f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 123.0); + + ret_val = __llvm_libc::sscanf("456", "%1f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 4.0); + + ret_val = __llvm_libc::sscanf("-789", "%1f", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("-123", "%2f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, -1.0); + + ret_val = __llvm_libc::sscanf("inf", "%2f", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("nan", "%1f", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("-inf", "%3f", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("-nan", "%3f", &result); + EXPECT_EQ(ret_val, 0); + + // If the max length were not here this would fail as discussed above, but + // since the max length limits it to the 3 it succeeds. + ret_val = __llvm_libc::sscanf("infinite", "%3f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, inf); + + ret_val = __llvm_libc::sscanf("-infinite", "%4f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, -inf); + + ret_val = __llvm_libc::sscanf("01", "%1f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("0x1", "%2f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("100e", "%4f", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("100e+10", "%5f", &result); + EXPECT_EQ(ret_val, 0); + + ret_val = __llvm_libc::sscanf("100e10", "%5f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 100e1); +} + +TEST(LlvmLibcSScanfTest, FloatConvNoWrite) { + int ret_val; + float result = 0; + + ret_val = __llvm_libc::sscanf("123", "%*f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("456.1", "%*a", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("0x789.ap0", "%*e", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("+12.0e1", "%*g", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("inf", "%*F", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("NaN", "%*A", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("-InFiNiTy", "%*E", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("1e10", "%*G", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf(".1", "%*G", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("123", "%*3f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("123", "%*5f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("456", "%*1f", &result); + EXPECT_EQ(ret_val, 1); + EXPECT_FP_EQ(result, 0.0); + + ret_val = __llvm_libc::sscanf("Not a float", "%*f", &result); + EXPECT_EQ(ret_val, 0); +} + TEST(LlvmLibcSScanfTest, CombinedConv) { int ret_val; int result = 0; From bc94a9b2eceba7effa032a5ce48419da2e69c8e9 Mon Sep 17 00:00:00 2001 From: Aaron Ballman Date: Thu, 19 Jan 2023 12:25:14 -0500 Subject: [PATCH 31/61] Silence some sign comparison warnings; NFC The std::optional implementation in MSVC causes this code to produce a sign comparison warning. This ensures the types are the same sign. --- clang/lib/Basic/Targets/AArch64.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/clang/lib/Basic/Targets/AArch64.cpp b/clang/lib/Basic/Targets/AArch64.cpp index 3a6e8f151a2d73..015deeb84524ef 100644 --- a/clang/lib/Basic/Targets/AArch64.cpp +++ b/clang/lib/Basic/Targets/AArch64.cpp @@ -60,14 +60,14 @@ void AArch64TargetInfo::setArchFeatures() { HasLSE = true; HasRDM = true; } else if (ArchInfo->Version.getMajor() == 8) { - if (ArchInfo->Version.getMinor() >= 7) { + if (ArchInfo->Version.getMinor() >= 7u) { HasWFxT = true; } - if (ArchInfo->Version.getMinor() >= 6) { + if (ArchInfo->Version.getMinor() >= 6u) { HasBFloat16 = true; HasMatMul = true; } - if (ArchInfo->Version.getMinor() >= 5) { + if (ArchInfo->Version.getMinor() >= 5u) { HasAlternativeNZCV = true; HasFRInt3264 = true; HasSSBS = true; @@ -75,28 +75,28 @@ void AArch64TargetInfo::setArchFeatures() { HasPredRes = true; HasBTI = true; } - if (ArchInfo->Version.getMinor() >= 4) { + if (ArchInfo->Version.getMinor() >= 4u) { HasDotProd = true; HasDIT = true; HasFlagM = true; } - if (ArchInfo->Version.getMinor() >= 3) { + if (ArchInfo->Version.getMinor() >= 3u) { HasRCPC = true; FPU |= NeonMode; } - if (ArchInfo->Version.getMinor() >= 2) { + if (ArchInfo->Version.getMinor() >= 2u) { HasCCPP = true; } - if (ArchInfo->Version.getMinor() >= 1) { + if (ArchInfo->Version.getMinor() >= 1u) { HasCRC = true; HasLSE = true; HasRDM = true; } } else if (ArchInfo->Version.getMajor() == 9) { - if (ArchInfo->Version.getMinor() >= 2) { + if (ArchInfo->Version.getMinor() >= 2u) { HasWFxT = true; } - if (ArchInfo->Version.getMinor() >= 1) { + if (ArchInfo->Version.getMinor() >= 1u) { HasBFloat16 = true; HasMatMul = true; } From b14763652f8519e9ec6b10667a17e32bc3dce913 Mon Sep 17 00:00:00 2001 From: Slava Zakharin Date: Thu, 19 Jan 2023 08:46:52 -0800 Subject: [PATCH 32/61] [NFC][flang] Added debug option to bisect TBAA tag attachments. Reviewed By: jeanPerier, PeteSteinfeld Differential Revision: https://reviews.llvm.org/D142070 --- flang/lib/Optimizer/CodeGen/TBAABuilder.cpp | 23 +++++++++++++++++++++ flang/lib/Optimizer/CodeGen/TBAABuilder.h | 3 +++ 2 files changed, 26 insertions(+) diff --git a/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp b/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp index 1bee674f748fe6..2d206ed2dcbf53 100644 --- a/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp +++ b/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp @@ -13,6 +13,9 @@ #include "TBAABuilder.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "flang-tbaa-builder" using namespace mlir; using namespace mlir::LLVM; @@ -23,6 +26,15 @@ static llvm::cl::opt disableTBAA( "to override default Flang behavior"), llvm::cl::init(false)); +// tagAttachmentLimit is a debugging option that allows limiting +// the number of TBAA access tag attributes attached to operations. +// It is set to kTagAttachmentUnlimited by default denoting "no limit". +static constexpr unsigned kTagAttachmentUnlimited = + std::numeric_limits::max(); +static llvm::cl::opt + tagAttachmentLimit("tbaa-attach-tag-max", llvm::cl::desc(""), + llvm::cl::init(kTagAttachmentUnlimited)); + namespace fir { std::string TBAABuilder::getNewTBAANodeName(llvm::StringRef basename) { return (llvm::Twine(basename) + llvm::Twine('_') + @@ -50,6 +62,9 @@ TBAABuilder::TBAABuilder(mlir::ModuleOp module, bool applyTBAA) return; } + LLVM_DEBUG(llvm::dbgs() << "Creating TBAA MetadataOp for module '" + << module.getName().value_or("") << "'\n"); + // Create TBAA MetadataOp with the root and basic type descriptors. Location loc = module.getLoc(); MLIRContext *context = module.getContext(); @@ -130,6 +145,14 @@ void TBAABuilder::attachTBAATag(Operation *op, Type baseFIRType, if (!enableTBAA) return; + ++tagAttachmentCounter; + if (tagAttachmentLimit != kTagAttachmentUnlimited && + tagAttachmentCounter > tagAttachmentLimit) + return; + + LLVM_DEBUG(llvm::dbgs() << "Attaching TBAA tag #" << tagAttachmentCounter + << "\n"); + SymbolRefAttr tbaaTagSym; if (baseFIRType.isa()) tbaaTagSym = getBoxAccessTag(baseFIRType, accessFIRType, gep); diff --git a/flang/lib/Optimizer/CodeGen/TBAABuilder.h b/flang/lib/Optimizer/CodeGen/TBAABuilder.h index dff924bb2a2024..f0dfc0ea48b92c 100644 --- a/flang/lib/Optimizer/CodeGen/TBAABuilder.h +++ b/flang/lib/Optimizer/CodeGen/TBAABuilder.h @@ -251,6 +251,9 @@ class TBAABuilder { // Counter for unique naming of TBAA operations' symbols. unsigned tbaaNodeCounter = 0; + // Number of attached TBAA tags (used for debugging). + unsigned tagAttachmentCounter = 0; + // Mapping from a FIR type to the symbol defined by the corresponding // TBAATypeDescriptorOp. It must be populated during the type conversion. // Currently unused. From 2f8b920f95aa1e308193cf5803df7912025e8400 Mon Sep 17 00:00:00 2001 From: Paul Robinson Date: Thu, 19 Jan 2023 09:26:33 -0800 Subject: [PATCH 33/61] Revert "[lit] Stop supporting triple substrings in UNSUPPORTED and XFAIL" This reverts commit a0f8bdbb18a82ec150515d24f6eefb3519d4459a. Several bots are failing in shtest-format.py, likely because of this. --- llvm/docs/ReleaseNotes.rst | 7 ------- llvm/utils/lit/lit/BooleanExpression.py | 24 ++++++++++++++++++++---- llvm/utils/lit/lit/Test.py | 24 ++++++++++++++---------- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst index 783e8f5617fe07..6dc2dad13ad93d 100644 --- a/llvm/docs/ReleaseNotes.rst +++ b/llvm/docs/ReleaseNotes.rst @@ -309,13 +309,6 @@ Changes to Sanitizers Other Changes ------------- -* lit no longer supports using substrings of the default target triple as - feature names in ``UNSUPPORTED:`` and ``XFAIL:`` directives. These have been - replaced by the ``target=`` feature, and tests can use regex - matching to achieve the same effect. For example, ``UNSUPPORTED: arm`` - would now be ``UNSUPPORTED: target=arm{{.*}}`` and ``XFAIL: windows`` - would now be ``XFAIL: target={{.*}}-windows{{.*}}``. - External Open Source Projects Using LLVM 15 =========================================== diff --git a/llvm/utils/lit/lit/BooleanExpression.py b/llvm/utils/lit/lit/BooleanExpression.py index ba8453d60e98a3..ff5352778e99cc 100644 --- a/llvm/utils/lit/lit/BooleanExpression.py +++ b/llvm/utils/lit/lit/BooleanExpression.py @@ -22,22 +22,24 @@ class BooleanExpression: # # Variables in `variables` are true. # Regexes that match any variable in `variables` are true. + # Substrings of `triple` are true. # 'true' is true. # All other identifiers are false. @staticmethod - def evaluate(string, variables): + def evaluate(string, variables, triple=""): try: - parser = BooleanExpression(string, set(variables)) + parser = BooleanExpression(string, set(variables), triple) return parser.parseAll() except ValueError as e: raise ValueError(str(e) + ('\nin expression: %r' % string)) ##### - def __init__(self, string, variables): + def __init__(self, string, variables, triple=""): self.tokens = BooleanExpression.tokenize(string) self.variables = variables self.variables.add('true') + self.triple = triple self.value = None self.token = None @@ -99,7 +101,7 @@ def parseMATCH(self): else: regex += re.escape(part) regex = re.compile(regex) - self.value = any(regex.fullmatch(var) for var in self.variables) + self.value = self.token in self.triple or any(regex.fullmatch(var) for var in self.variables) self.token = next(self.tokens) def parseNOT(self): @@ -172,6 +174,20 @@ def test_variables(self): self.assertFalse(BooleanExpression.evaluate('tru', variables)) self.assertFalse(BooleanExpression.evaluate('{{its-true.+}}', variables)) + def test_triple(self): + triple = 'arch-vendor-os' + self.assertTrue(BooleanExpression.evaluate('arch-', {}, triple)) + self.assertTrue(BooleanExpression.evaluate('ar', {}, triple)) + self.assertTrue(BooleanExpression.evaluate('ch-vend', {}, triple)) + self.assertTrue(BooleanExpression.evaluate('-vendor-', {}, triple)) + self.assertTrue(BooleanExpression.evaluate('-os', {}, triple)) + self.assertFalse(BooleanExpression.evaluate('arch-os', {}, triple)) + + # When matching against the triple, a regex is treated as an identifier and checked + # for a literal match. This preserves existing behavior before regexes were introduced. + self.assertFalse(BooleanExpression.evaluate('arch-{{vendor}}-os', {}, triple)) + self.assertTrue(BooleanExpression.evaluate('arch-{{vendor}}-os', {}, 'arch-{{vendor}}-os')) + def test_matching(self): expr1 = 'linux && (target={{aarch64-.+}} || target={{x86_64-.+}})' self.assertTrue(BooleanExpression.evaluate(expr1, {'linux', 'target=x86_64-unknown-linux-gnu'})) diff --git a/llvm/utils/lit/lit/Test.py b/llvm/utils/lit/lit/Test.py index 6c72359440b9cf..dc1c66e896c54f 100644 --- a/llvm/utils/lit/lit/Test.py +++ b/llvm/utils/lit/lit/Test.py @@ -227,9 +227,9 @@ def __init__(self, suite, path_in_suite, config, file_path = None, gtest_json_fi self.gtest_json_file = gtest_json_file # A list of conditions under which this test is expected to fail. - # Each condition is a boolean expression of features, or '*'. - # These can optionally be provided by test format handlers, - # and will be honored when the test result is supplied. + # Each condition is a boolean expression of features and target + # triple parts. These can optionally be provided by test format + # handlers, and will be honored when the test result is supplied. self.xfails = [] # If true, ignore all items in self.xfails. @@ -238,11 +238,12 @@ def __init__(self, suite, path_in_suite, config, file_path = None, gtest_json_fi # A list of conditions that must be satisfied before running the test. # Each condition is a boolean expression of features. All of them # must be True for the test to run. + # FIXME should target triple parts count here too? self.requires = [] # A list of conditions that prevent execution of the test. - # Each condition is a boolean expression of features. All of them - # must be False for the test to run. + # Each condition is a boolean expression of features and target + # triple parts. All of them must be False for the test to run. self.unsupported = [] # An optional number of retries allowed before the test finally succeeds. @@ -316,16 +317,18 @@ def isExpectedToFail(self): return False features = self.config.available_features + triple = getattr(self.suite.config, 'target_triple', "") - # Check if any of the xfails match an available feature. + # Check if any of the xfails match an available feature or the target. for item in self.xfails: # If this is the wildcard, it always fails. if item == '*': return True - # If this is a True expression of features, it fails. + # If this is a True expression of features and target triple parts, + # it fails. try: - if BooleanExpression.evaluate(item, features): + if BooleanExpression.evaluate(item, features, triple): return True except ValueError as e: raise ValueError('Error in XFAIL list:\n%s' % str(e)) @@ -382,15 +385,16 @@ def getUnsupportedFeatures(self): getUnsupportedFeatures() -> list of strings Returns a list of features from UNSUPPORTED that are present - in the test configuration's features. + in the test configuration's features or target triple. Throws ValueError if an UNSUPPORTED line has a syntax error. """ features = self.config.available_features + triple = getattr(self.suite.config, 'target_triple', "") try: return [item for item in self.unsupported - if BooleanExpression.evaluate(item, features)] + if BooleanExpression.evaluate(item, features, triple)] except ValueError as e: raise ValueError('Error in UNSUPPORTED list:\n%s' % str(e)) From b3faa1a87ac37e3825a67368dfb8dcfef95f4c53 Mon Sep 17 00:00:00 2001 From: v1nh1shungry Date: Thu, 19 Jan 2023 12:35:09 -0500 Subject: [PATCH 34/61] Fix zero-initialization fix-it for variable template Current version there is a fix-it for template constexpr int x = 0; template <> constexpr int x; // fix-it here but it will cause template <> constexpr int x = 0; Differential Revision: https://reviews.llvm.org/D139705 --- clang/docs/ReleaseNotes.rst | 2 ++ clang/include/clang/AST/DeclTemplate.h | 16 +++++++++++++ clang/test/FixIt/fixit-const-var-init.cpp | 28 +++++++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 clang/test/FixIt/fixit-const-var-init.cpp diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index 21960ab69ceabc..8de179cb896174 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -350,6 +350,8 @@ Bug Fixes This fixes `Issue 59765 `_ - Reject in-class defaulting of previosly declared comparison operators. Fixes `Issue 51227 `_. +- Fix the bug of inserting the ``ZeroInitializationFixit`` before the template + argument list of ``VarTemplateSpecializationDecl``. Improvements to Clang's diagnostics ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h index ead51d168fffae..ae2542f4f23116 100644 --- a/clang/include/clang/AST/DeclTemplate.h +++ b/clang/include/clang/AST/DeclTemplate.h @@ -2926,6 +2926,14 @@ class VarTemplateSpecializationDecl : public VarDecl, return ExplicitInfo ? ExplicitInfo->TemplateKeywordLoc : SourceLocation(); } + SourceRange getSourceRange() const override LLVM_READONLY { + if (isExplicitSpecialization()) { + if (const ASTTemplateArgumentListInfo *Info = getTemplateArgsInfo()) + return SourceRange(getOuterLocStart(), Info->getRAngleLoc()); + } + return VarDecl::getSourceRange(); + } + void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, TemplateArgs->asArray(), getASTContext()); } @@ -3083,6 +3091,14 @@ class VarTemplatePartialSpecializationDecl return First->InstantiatedFromMember.setInt(true); } + SourceRange getSourceRange() const override LLVM_READONLY { + if (isExplicitSpecialization()) { + if (const ASTTemplateArgumentListInfo *Info = getTemplateArgsAsWritten()) + return SourceRange(getOuterLocStart(), Info->getRAngleLoc()); + } + return VarDecl::getSourceRange(); + } + void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, getTemplateArgs().asArray(), getTemplateParameters(), getASTContext()); diff --git a/clang/test/FixIt/fixit-const-var-init.cpp b/clang/test/FixIt/fixit-const-var-init.cpp new file mode 100644 index 00000000000000..832490831b88ab --- /dev/null +++ b/clang/test/FixIt/fixit-const-var-init.cpp @@ -0,0 +1,28 @@ +// RUN: not %clang_cc1 -fdiagnostics-parseable-fixits -x c++ -std=c++14 %s 2>&1 | FileCheck %s + +const int a; // expected-error {{default initialization of an object of const type}} +// CHECK: fix-it:"{{.*}}":{3:12-3:12}:" = 0" + +template const int b; // expected-error {{default initialization of an object of const type}} +// CHECK: fix-it:"{{.*}}":{6:36-6:36}:" = 0" + +template const int b; // expected-error {{default initialization of an object of const type}} +// CHECK: fix-it:"{{.*}}":{9:39-9:39}:" = 0" + +template <> const int b; // expected-error {{default initialization of an object of const type}} +// CHECK: fix-it:"{{.*}}":{12:36-12:36}:" = 0" + +constexpr float c; // expected-error {{must be initialized by a constant expression}} +// CHECK: fix-it:"{{.*}}":{15:18-15:18}:" = 0.0" + +template constexpr float d; // expected-error {{must be initialized by a constant expression}} +// CHECK: fix-it:"{{.*}}":{18:42-18:42}:" = 0.0" + +template constexpr float d; // expected-error {{must be initialized by a constant expression}} +// CHECK: fix-it:"{{.*}}":{21:45-21:45}:" = 0.0" + +template <> constexpr float d; // expected-error {{must be initialized by a constant expression}} +// CHECK: fix-it:"{{.*}}":{24:42-24:42}:" = 0.0" + +void (* const func)(int, int); // expected-error {{default initialization of an object of const type}} +// CHECK: fix-it:"{{.*}}":{27:30-27:30}:" = nullptr" From b05fa4faf64d6c6811bbecead13dc7c23ac43aca Mon Sep 17 00:00:00 2001 From: Xing Xue Date: Thu, 19 Jan 2023 12:46:22 -0500 Subject: [PATCH 35/61] [clang][driver][AIX] Add OpenMP runtime if -fopenmp specified Summary: This patch adds OpenMP runtime to the linker command line if -fopenmp is specifed for AIX. Reviewed by: daltenty Differential Revision: https://reviews.llvm.org/D141862 --- clang/lib/Driver/ToolChains/AIX.cpp | 19 +++++++++ clang/test/Driver/aix-ld.c | 65 +++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/clang/lib/Driver/ToolChains/AIX.cpp b/clang/lib/Driver/ToolChains/AIX.cpp index 5f893ba6e89182..abbd3ef6c68f3e 100644 --- a/clang/lib/Driver/ToolChains/AIX.cpp +++ b/clang/lib/Driver/ToolChains/AIX.cpp @@ -243,6 +243,25 @@ void aix::Linker::ConstructJob(Compilation &C, const JobAction &JA, if (!Args.hasArg(options::OPT_nostdlib, options::OPT_nodefaultlibs)) { AddRunTimeLibs(ToolChain, D, CmdArgs, Args); + // Add OpenMP runtime if -fopenmp is specified. + if (Args.hasFlag(options::OPT_fopenmp, options::OPT_fopenmp_EQ, + options::OPT_fno_openmp, false)) { + switch (ToolChain.getDriver().getOpenMPRuntime(Args)) { + case Driver::OMPRT_OMP: + CmdArgs.push_back("-lomp"); + break; + case Driver::OMPRT_IOMP5: + CmdArgs.push_back("-liomp5"); + break; + case Driver::OMPRT_GOMP: + CmdArgs.push_back("-lgomp"); + break; + case Driver::OMPRT_Unknown: + // Already diagnosed. + break; + } + } + // Support POSIX threads if "-pthreads" or "-pthread" is present. if (Args.hasArg(options::OPT_pthreads, options::OPT_pthread)) CmdArgs.push_back("-lpthreads"); diff --git a/clang/test/Driver/aix-ld.c b/clang/test/Driver/aix-ld.c index 42631f4c59eee8..ddf3ae7a2ecb4a 100644 --- a/clang/test/Driver/aix-ld.c +++ b/clang/test/Driver/aix-ld.c @@ -1016,3 +1016,68 @@ // CHECK-LD64-SHARED-EXPFULL: "[[RESOURCE_DIR]]{{/|\\\\}}lib{{/|\\\\}}aix{{/|\\\\}}libclang_rt.builtins-powerpc64.a" // CHECK-LD64-SHARED-EXPFULL: "-lm" // CHECK-LD64-SHARED-EXPFULL: "-lc" + +// Check powerpc-ibm-aix7.1.0.0. -fopenmp to use default OpenMP runtime libomp. +// RUN: %clang %s -### 2>&1 \ +// RUN: -resource-dir=%S/Inputs/resource_dir \ +// RUN: --target=powerpc-ibm-aix7.1.0.0 \ +// RUN: --sysroot %S/Inputs/aix_ppc_tree \ +// RUN: --unwindlib=libunwind \ +// RUN: -fopenmp \ +// RUN: | FileCheck --check-prefixes=CHECK-FOPENMP,CHECK-FOPENMP-OMP %s +// CHECK-FOPENMP-NOT: warning: +// CHECK-FOPENMP: "-cc1" "-triple" "powerpc-ibm-aix7.1.0.0" +// CHECK-FOPENMP: "-resource-dir" "[[RESOURCE_DIR:[^"]+]]" +// CHECK-FOPENMP: "-isysroot" "[[SYSROOT:[^"]+]]" +// CHECK-FOPENMP: "{{.*}}ld{{(.exe)?}}" +// CHECK-FOPENMP-NOT: "-bnso" +// CHECK-FOPENMP: "-b32" +// CHECK-FOPENMP: "-bpT:0x10000000" "-bpD:0x20000000" +// CHECK-FOPENMP: "[[SYSROOT]]/usr/lib{{/|\\\\}}crt0.o" +// CHECK-FOPENMP: "[[SYSROOT]]/usr/lib{{/|\\\\}}crti.o" +// CHECK-FOPENMP-NOT: "-lc++" +// CHECK-FOPENMP-NOT: "-lc++abi" +// CHECK-FOPENMP: "[[RESOURCE_DIR]]{{/|\\\\}}lib{{/|\\\\}}aix{{/|\\\\}}libclang_rt.builtins-powerpc.a" +// CHECK-FOPENMP-NOT: "--as-needed" +// CHECK-FOPENMP: "-lunwind" +// CHECK-FOPENMP-NOT: "--no-as-needed" +// CHECK-FOPENMP-NOT: "-lm" +// CHECK-FOPENMP-OMP: "-lomp" +// CHECK-FOPENMP-IOMP5: "-liomp5" +// CHECK-FOPENMP-GOMP: "-lgomp" +// CHECK-FOPENMP: "-lc" + +// Check powerpc-ibm-aix7.1.0.0. -fopenmp=libomp to specify libomp explicitly. +// RUN: %clang %s -### 2>&1 \ +// RUN: -resource-dir=%S/Inputs/resource_dir \ +// RUN: --target=powerpc-ibm-aix7.1.0.0 \ +// RUN: --sysroot %S/Inputs/aix_ppc_tree \ +// RUN: --unwindlib=libunwind \ +// RUN: -fopenmp=libomp \ +// RUN: | FileCheck --check-prefixes=CHECK-FOPENMP,CHECK-FOPENMP-OMP %s + +// Check powerpc-ibm-aix7.1.0.0. -fopenmp=libiomp5 to specify libgomp explicitly. +// RUN: %clang %s -### 2>&1 \ +// RUN: -resource-dir=%S/Inputs/resource_dir \ +// RUN: --target=powerpc-ibm-aix7.1.0.0 \ +// RUN: --sysroot %S/Inputs/aix_ppc_tree \ +// RUN: --unwindlib=libunwind \ +// RUN: -fopenmp=libiomp5 \ +// RUN: | FileCheck --check-prefixes=CHECK-FOPENMP,CHECK-FOPENMP-IOMP5 %s + +// Check powerpc-ibm-aix7.1.0.0. -fopenmp=libgomp to specify libgomp explicitly. +// RUN: %clang %s -### 2>&1 \ +// RUN: -resource-dir=%S/Inputs/resource_dir \ +// RUN: --target=powerpc-ibm-aix7.1.0.0 \ +// RUN: --sysroot %S/Inputs/aix_ppc_tree \ +// RUN: --unwindlib=libunwind \ +// RUN: -fopenmp=libgomp \ +// RUN: | FileCheck --check-prefixes=CHECK-FOPENMP,CHECK-FOPENMP-GOMP %s + +// Check powerpc-ibm-aix7.1.0.0, 32-bit. -fopenmp=libfoo results an error. +// RUN: %clang %s 2>&1 -### \ +// RUN: --target=powerpc-ibm-aix7.1.0.0 \ +// RUN: --sysroot %S/Inputs/aix_ppc_tree \ +// RUN: -fopenmp=libfoo \ +// RUN: | FileCheck --check-prefixes=CHECK-FOPENMP-FOO %s +// CHECK-FOPENMP-FOO: error: unsupported argument 'libfoo' to option '-fopenmp=' From 455305624884cf9237143e2ba0635fcc5ba5206a Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Mon, 9 Jan 2023 17:09:17 +0000 Subject: [PATCH 36/61] [mlir][Index] Implement InferIntRangeInterface Implement InferIntRangeInterface for all operations in the Index dialect. The inference implementation, unlike the one for Arith, accounts for the fact that Index can be either 64 or 32 bits long by evaluating both cases. Bounds are stored as if index were i64, but when inferring new bounds, we compute both f(...) and f(trunc(...)). We then compare trunc(f(...)) to f(trunc(...)). If they are equal in the relevant range components, we use the 64-bit range computation, otherwise we give the range ext(f(trunc(...))) union f(...). Note that this can cause surprising behavior as seen in the tests, where, for example, the order of min and max operations impacts the behavior of the inference. The inference could perhaps be made more precise in the future (ex. by tracking 32 and 64-bit results separately and having them influence each other somehow) butt, since my project targets an index=i32 platform and doesn't see index-valued values > uint32_max, I'm not too concerned about it. Depends on https://reviews.llvm.org/D141299 Depends on https://reviews.llvm.org/D141296 Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D140899 --- mlir/include/mlir/Dialect/Index/IR/IndexOps.h | 1 + .../include/mlir/Dialect/Index/IR/IndexOps.td | 4 +- .../Interfaces/Utils/InferIntRangeCommon.h | 126 ++++ mlir/lib/Dialect/Arith/IR/CMakeLists.txt | 1 + .../Arith/IR/InferIntRangeInterfaceImpls.cpp | 481 ++----------- mlir/lib/Dialect/Index/IR/CMakeLists.txt | 2 + .../Index/IR/InferIntRangeInterfaceImpls.cpp | 252 +++++++ mlir/lib/Interfaces/CMakeLists.txt | 2 + mlir/lib/Interfaces/Utils/CMakeLists.txt | 13 + .../Interfaces/Utils/InferIntRangeCommon.cpp | 663 ++++++++++++++++++ .../Dialect/Index/int-range-inference.mlir | 66 ++ 11 files changed, 1170 insertions(+), 441 deletions(-) create mode 100644 mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h create mode 100644 mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp create mode 100644 mlir/lib/Interfaces/Utils/CMakeLists.txt create mode 100644 mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp create mode 100644 mlir/test/Dialect/Index/int-range-inference.mlir diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h index 85a0549edd4dd5..d8debfb731323c 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -13,6 +13,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td index 76008a17364f93..8fbccc4ba94fca 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Index/IR/IndexDialect.td" include "mlir/Dialect/Index/IR/IndexEnums.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" @@ -23,7 +24,8 @@ include "mlir/IR/OpBase.td" /// Base class for Index dialect operations. class IndexOp traits = []> - : Op; + : Op] # traits>; //===----------------------------------------------------------------------===// // IndexBinaryOp diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h new file mode 100644 index 00000000000000..7ee059cf342ce1 --- /dev/null +++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h @@ -0,0 +1,126 @@ +//===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares implementations of range inference for operations that are +// common to both the `arith` and `index` dialects to facilitate reuse. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H +#define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +namespace intrange { +/// Function that performs inference on an array of `ConstantIntRanges`, +/// abstracted away here to permit writing the function that handles both +/// 64- and 32-bit index types. +using InferRangeFn = + function_ref)>; + +static constexpr unsigned indexMinWidth = 32; +static constexpr unsigned indexMaxWidth = 64; + +enum class CmpMode : uint32_t { Both, Signed, Unsigned }; + +/// Compute `inferFn` on `ranges`, whose size should be the index storage +/// bitwidth. Then, compute the function on `argRanges` again after truncating +/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is +/// equal to the 32-bit result, use it (to preserve compatibility with folders +/// and inference precision), and take the union of the results otherwise. +/// +/// The `mode` argument specifies if the unsigned, signed, or both results of +/// the inference computation should be used when comparing the results. +ConstantIntRanges inferIndexOp(InferRangeFn inferFn, + ArrayRef argRanges, + CmpMode mode); + +/// Independently zero-extend the unsigned values and sign-extend the signed +/// values in `range` to `destWidth` bits, returning the resulting range. +ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth); + +/// Use the unsigned values in `range` to zero-extend it to `destWidth`. +ConstantIntRanges extUIRange(const ConstantIntRanges &range, + unsigned destWidth); + +/// Use the signed values in `range` to sign-extend it to `destWidth`. +ConstantIntRanges extSIRange(const ConstantIntRanges &range, + unsigned destWidth); + +/// Truncate `range` to `destWidth` bits, taking care to handle cases such as +/// the truncation of [255, 256] to i8 not being a uniform range. +ConstantIntRanges truncRange(const ConstantIntRanges &range, + unsigned destWidth); + +ConstantIntRanges inferAdd(ArrayRef argRanges); + +ConstantIntRanges inferSub(ArrayRef argRanges); + +ConstantIntRanges inferMul(ArrayRef argRanges); + +ConstantIntRanges inferDivS(ArrayRef argRanges); + +ConstantIntRanges inferDivU(ArrayRef argRanges); + +ConstantIntRanges inferCeilDivS(ArrayRef argRanges); + +ConstantIntRanges inferCeilDivU(ArrayRef argRanges); + +ConstantIntRanges inferFloorDivS(ArrayRef argRanges); + +ConstantIntRanges inferRemS(ArrayRef argRanges); + +ConstantIntRanges inferRemU(ArrayRef argRanges); + +ConstantIntRanges inferMaxS(ArrayRef argRanges); + +ConstantIntRanges inferMaxU(ArrayRef argRanges); + +ConstantIntRanges inferMinS(ArrayRef argRanges); + +ConstantIntRanges inferMinU(ArrayRef argRanges); + +ConstantIntRanges inferAnd(ArrayRef argRanges); + +ConstantIntRanges inferOr(ArrayRef argRanges); + +ConstantIntRanges inferXor(ArrayRef argRanges); + +ConstantIntRanges inferShl(ArrayRef argRanges); + +ConstantIntRanges inferShrS(ArrayRef argRanges); + +ConstantIntRanges inferShrU(ArrayRef argRanges); + +/// Copy of the enum from `arith` and `index` to allow the common integer range +/// infrastructure to not depend on either dialect. +enum class CmpPredicate : uint64_t { + eq, + ne, + slt, + sle, + sgt, + sge, + ult, + ule, + ugt, + uge, +}; + +/// Returns a boolean value if `pred` is statically true or false for +/// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the +/// value of the predicate cannot be determined. +Optional evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs); + +} // namespace intrange +} // namespace mlir + +#endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt index 0de17bbfbd12a2..ffbe80105911ed 100644 --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArithDialect LINK_LIBS PUBLIC MLIRDialect + MLIRInferIntRangeCommon MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp index 10d6ef29756c6f..971477fa94cb9f 100644 --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/Support/Debug.h" #include @@ -16,48 +17,7 @@ using namespace mlir; using namespace mlir::arith; - -/// Function that evaluates the result of doing something on arithmetic -/// constants and returns std::nullopt on overflow. -using ConstArithFn = - function_ref(const APInt &, const APInt &)>; - -/// Return the maxmially wide signed or unsigned range for a given bitwidth. - -/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, -/// If either computation overflows, make the result unbounded. -static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, - const APInt &minRight, - const APInt &maxLeft, - const APInt &maxRight, bool isSigned) { - std::optional maybeMin = op(minLeft, minRight); - std::optional maybeMax = op(maxLeft, maxRight); - if (maybeMin && maybeMax) - return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); - return ConstantIntRanges::maxRange(minLeft.getBitWidth()); -} - -/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, -/// ignoring unbounded values. Returns the maximal range if `op` overflows. -static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, - ArrayRef rhs, bool isSigned) { - unsigned width = lhs[0].getBitWidth(); - APInt min = - isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); - APInt max = - isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); - for (const APInt &left : lhs) { - for (const APInt &right : rhs) { - std::optional maybeThisResult = op(left, right); - if (!maybeThisResult) - return ConstantIntRanges::maxRange(width); - APInt result = std::move(*maybeThisResult); - min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; - max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; - } - } - return ConstantIntRanges::range(min, max, isSigned); -} +using namespace mlir::intrange; //===----------------------------------------------------------------------===// // ConstantOp @@ -78,25 +38,7 @@ void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, void arith::AddIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn uadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.uadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn sadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = computeBoundsBy( - uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferAdd(argRanges)); } //===----------------------------------------------------------------------===// @@ -105,25 +47,7 @@ void arith::AddIOp::inferResultRanges(ArrayRef argRanges, void arith::SubIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn usub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.usub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn ssub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.ssub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstantIntRanges urange = computeBoundsBy( - usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferSub(argRanges)); } //===----------------------------------------------------------------------===// @@ -132,96 +56,25 @@ void arith::SubIOp::inferResultRanges(ArrayRef argRanges, void arith::MulIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn umul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.umul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn smul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.smul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = - minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, - /*isSigned=*/true); - - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferMul(argRanges)); } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// -/// Fix up division results (ex. for ceiling and floor), returning an APInt -/// if there has been no overflow -using DivisionFixupFn = function_ref( - const APInt &lhs, const APInt &rhs, const APInt &result)>; - -static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), - &rhsMax = rhs.umax(); - - if (!rhsMin.isZero()) { - auto udiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - return fixup(a, b, a.udiv(b)); - }; - return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/false); - } - // Otherwise, it's possible we might divide by 0. - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferDivUIRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; })); + setResultRange(getResult(), inferDivU(argRanges)); } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); - - if (canDivide) { - auto sdiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sdiv_ov(b, overflowed); - return overflowed ? std::optional() : fixup(a, b, result); - }; - return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/true); - } - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferDivSIRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; })); + setResultRange(getResult(), inferDivS(argRanges)); } //===----------------------------------------------------------------------===// @@ -230,20 +83,7 @@ void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, void arith::CeilDivUIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivUIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.urem(rhs).isZero()) { - bool overflowed = false; - APInt corrected = - result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix)); + setResultRange(getResult(), inferCeilDivU(argRanges)); } //===----------------------------------------------------------------------===// @@ -252,20 +92,7 @@ void arith::CeilDivUIOp::inferResultRanges( void arith::CeilDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix)); + setResultRange(getResult(), inferCeilDivS(argRanges)); } //===----------------------------------------------------------------------===// @@ -274,20 +101,7 @@ void arith::CeilDivSIOp::inferResultRanges( void arith::FloorDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn floorDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix)); + return setResultRange(getResult(), inferFloorDivS(argRanges)); } //===----------------------------------------------------------------------===// @@ -296,29 +110,7 @@ void arith::FloorDivSIOp::inferResultRanges( void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); - - unsigned width = rhsMin.getBitWidth(); - APInt umin = APInt::getZero(width); - APInt umax = APInt::getMaxValue(width); - - if (!rhsMin.isZero()) { - umax = rhsMax - 1; - // Special case: sweeping out a contiguous range in N/[modulus] - if (rhsMin == rhsMax) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); - if ((lhsMax - lhsMin).ult(rhsMax)) { - APInt minRem = lhsMin.urem(rhsMax); - APInt maxRem = lhsMax.urem(rhsMax); - if (minRem.ule(maxRem)) { - umin = minRem; - umax = maxRem; - } - } - } - } - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), inferRemU(argRanges)); } //===----------------------------------------------------------------------===// @@ -327,67 +119,16 @@ void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - - unsigned width = rhsMax.getBitWidth(); - APInt smin = APInt::getSignedMinValue(width); - APInt smax = APInt::getSignedMaxValue(width); - // No bounds if zero could be a divisor. - bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); - if (canBound) { - APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); - bool canNegativeDividend = lhsMin.isNegative(); - bool canPositiveDividend = lhsMax.isStrictlyPositive(); - APInt zero = APInt::getZero(maxDivisor.getBitWidth()); - APInt maxPositiveResult = maxDivisor - 1; - APInt minNegativeResult = -maxPositiveResult; - smin = canNegativeDividend ? minNegativeResult : zero; - smax = canPositiveDividend ? maxPositiveResult : zero; - // Special case: sweeping out a contiguous range in N/[modulus]. - if (rhsMin == rhsMax) { - if ((lhsMax - lhsMin).ult(maxDivisor)) { - APInt minRem = lhsMin.srem(maxDivisor); - APInt maxRem = lhsMax.srem(maxDivisor); - if (minRem.sle(maxRem)) { - smin = minRem; - smax = maxRem; - } - } - } - } - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + setResultRange(getResult(), inferRemS(argRanges)); } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// -/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, -/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits -/// that both bonuds have in common. This gives us a consertive approximation -/// for what values can be passed to bitwise operations. -static std::tuple -widenBitwiseBounds(const ConstantIntRanges &bound) { - APInt leftVal = bound.umin(), rightVal = bound.umax(); - unsigned bitwidth = leftVal.getBitWidth(); - unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); - leftVal.clearLowBits(differingBits); - rightVal.setLowBits(differingBits); - return std::make_tuple(std::move(leftVal), std::move(rightVal)); -} - void arith::AndIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto andi = [](const APInt &a, const APInt &b) -> std::optional { - return a & b; - }; - setResultRange(getResult(), - minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + setResultRange(getResult(), inferAnd(argRanges)); } //===----------------------------------------------------------------------===// @@ -396,14 +137,7 @@ void arith::AndIOp::inferResultRanges(ArrayRef argRanges, void arith::OrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto ori = [](const APInt &a, const APInt &b) -> std::optional { - return a | b; - }; - setResultRange(getResult(), - minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + setResultRange(getResult(), inferOr(argRanges)); } //===----------------------------------------------------------------------===// @@ -412,14 +146,7 @@ void arith::OrIOp::inferResultRanges(ArrayRef argRanges, void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto xori = [](const APInt &a, const APInt &b) -> std::optional { - return a ^ b; - }; - setResultRange(getResult(), - minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + setResultRange(getResult(), inferXor(argRanges)); } //===----------------------------------------------------------------------===// @@ -428,11 +155,7 @@ void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + setResultRange(getResult(), inferMaxS(argRanges)); } //===----------------------------------------------------------------------===// @@ -441,11 +164,7 @@ void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), inferMaxU(argRanges)); } //===----------------------------------------------------------------------===// @@ -454,11 +173,7 @@ void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + setResultRange(getResult(), inferMinS(argRanges)); } //===----------------------------------------------------------------------===// @@ -467,94 +182,40 @@ void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), inferMinU(argRanges)); } //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges extUIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt umin = range.umin().zext(destWidth); - APInt umax = range.umax().zext(destWidth); - return ConstantIntRanges::fromUnsigned(umin, umax); -} - void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), extUIRange(argRanges[0], destType)); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges extSIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt smin = range.smin().sext(destWidth); - APInt smax = range.smax().sext(destWidth); - return ConstantIntRanges::fromSigned(smin, smax); -} - void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), extSIRange(argRanges[0], destType)); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges truncIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], - // the range of the resulting value is not contiguous ind includes 0. - // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], - // but you can't truncate [255, 257] similarly. - bool hasUnsignedRollover = - range.umin().lshr(destWidth) != range.umax().lshr(destWidth); - APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) - : range.umin().trunc(destWidth); - APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) - : range.umax().trunc(destWidth); - - // Signed post-truncation rollover will not occur when either: - // - The high parts of the min and max, plus the sign bit, are the same - // - The high halves + sign bit of the min and max are either all 1s or all 0s - // and you won't create a [positive, negative] range by truncating. - // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 - // but not [255, 257]_i16 to a range of i8s. You can also truncate - // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. - // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) - // will truncate to 0x7e, which is greater than 0 - APInt sminHighPart = range.smin().ashr(destWidth - 1); - APInt smaxHighPart = range.smax().ashr(destWidth - 1); - bool hasSignedOverflow = - (sminHighPart != smaxHighPart) && - !(sminHighPart.isAllOnes() && - (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && - !(sminHighPart.isZero() && smaxHighPart.isZero()); - APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) - : range.smin().trunc(destWidth); - APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) - : range.smax().trunc(destWidth); - return {umin, umax, smin, smax}; -} - void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), truncIRange(argRanges[0], destType)); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + setResultRange(getResult(), truncRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// @@ -569,9 +230,9 @@ void arith::IndexCastOp::inferResultRanges( unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extSIRange(argRanges[0], destType)); + setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncIRange(argRanges[0], destType)); + setResultRange(getResult(), truncRange(argRanges[0], destWidth)); else setResultRange(getResult(), argRanges[0]); } @@ -588,9 +249,9 @@ void arith::IndexCastUIOp::inferResultRanges( unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extUIRange(argRanges[0], destType)); + setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncIRange(argRanges[0], destType)); + setResultRange(getResult(), truncRange(argRanges[0], destWidth)); else setResultRange(getResult(), argRanges[0]); } @@ -599,51 +260,19 @@ void arith::IndexCastUIOp::inferResultRanges( // CmpIOp //===----------------------------------------------------------------------===// -bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs) { - switch (pred) { - case arith::CmpIPredicate::sle: - case arith::CmpIPredicate::slt: - return (applyCmpPredicate(pred, lhs.smax(), rhs.smin())); - case arith::CmpIPredicate::ule: - case arith::CmpIPredicate::ult: - return applyCmpPredicate(pred, lhs.umax(), rhs.umin()); - case arith::CmpIPredicate::sge: - case arith::CmpIPredicate::sgt: - return applyCmpPredicate(pred, lhs.smin(), rhs.smax()); - case arith::CmpIPredicate::uge: - case arith::CmpIPredicate::ugt: - return applyCmpPredicate(pred, lhs.umin(), rhs.umax()); - case arith::CmpIPredicate::eq: { - std::optional lhsConst = lhs.getConstantValue(); - std::optional rhsConst = rhs.getConstantValue(); - return lhsConst && rhsConst && lhsConst == rhsConst; - } - case arith::CmpIPredicate::ne: { - // While equality requires that there is an interpration of the preceeding - // computations that produces equal constants, whether that be signed or - // unsigned, statically determining inequality requires that neither - // interpretation produce potentially overlapping ranges. - bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || - isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); - bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || - isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); - return sne && une; - } - } - return false; -} - void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - arith::CmpIPredicate pred = getPredicate(); + arith::CmpIPredicate arithPred = getPredicate(); + intrange::CmpPredicate pred = static_cast(arithPred); const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; APInt min = APInt::getZero(1); APInt max = APInt::getAllOnesValue(1); - if (isStaticallyTrue(pred, lhs, rhs)) + + Optional truthValue = intrange::evaluatePred(pred, lhs, rhs); + if (truthValue.has_value() && *truthValue) min = max; - else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) + else if (truthValue.has_value() && !(*truthValue)) max = min; setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); @@ -673,18 +302,7 @@ void arith::SelectOp::inferResultRanges(ArrayRef argRanges, void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn shl = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); - }; - ConstantIntRanges urange = - minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferShl(argRanges)); } //===----------------------------------------------------------------------===// @@ -693,15 +311,7 @@ void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn lshr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); - }; - setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()}, - {rhs.umin(), rhs.umax()}, - /*isSigned=*/false)); + setResultRange(getResult(), inferShrU(argRanges)); } //===----------------------------------------------------------------------===// @@ -710,14 +320,5 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn ashr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); - }; - - setResultRange(getResult(), - minMaxBy(ashr, {lhs.smin(), lhs.smax()}, - {rhs.umin(), rhs.umax()}, /*isSigned=*/true)); + setResultRange(getResult(), inferShrS(argRanges)); } diff --git a/mlir/lib/Dialect/Index/IR/CMakeLists.txt b/mlir/lib/Dialect/Index/IR/CMakeLists.txt index 53321f1ea3f25c..e820eececa483c 100644 --- a/mlir/lib/Dialect/Index/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Index/IR/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIndexDialect IndexAttrs.cpp IndexDialect.cpp IndexOps.cpp + InferIntRangeInterfaceImpls.cpp DEPENDS MLIRIndexOpsIncGen @@ -10,6 +11,7 @@ add_mlir_dialect_library(MLIRIndexDialect MLIRDialect MLIRIR MLIRCastInterfaces + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp new file mode 100644 index 00000000000000..6daa7640b017e4 --- /dev/null +++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,252 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; +using namespace mlir::index; +using namespace mlir::intrange; + +//===----------------------------------------------------------------------===// +// Constants +//===----------------------------------------------------------------------===// + +void ConstantOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const APInt &value = getValue(); + setResultRange(getResult(), ConstantIntRanges::constant(value)); +} + +void BoolConstantOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + bool value = getValue(); + APInt asInt(/*numBits=*/1, value); + setResultRange(getResult(), ConstantIntRanges::constant(asInt)); +} + +//===----------------------------------------------------------------------===// +// Arithmec operations. All of these operations will have their results inferred +// using both the 64-bit values and truncated 32-bit values of their inputs, +// with the results being the union of those inferences, except where the +// truncation of the 64-bit result is equal to the 32-bit result (at which time +// we take the 64-bit result). +//===----------------------------------------------------------------------===// + +void AddOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both)); +} + +void SubOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both)); +} + +void MulOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both)); +} + +void DivUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned)); +} + +void DivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferDivS, argRanges, CmpMode::Signed)); +} + +void CeilDivUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned)); +} + +void CeilDivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed)); +} + +void FloorDivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + return setResultRange( + getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed)); +} + +void RemSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferRemS, argRanges, CmpMode::Signed)); +} + +void RemUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned)); +} + +void MaxSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMaxS, argRanges, CmpMode::Signed)); +} + +void MaxUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned)); +} + +void MinSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMinS, argRanges, CmpMode::Signed)); +} + +void MinUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned)); +} + +void ShlOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both)); +} + +void ShrSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferShrS, argRanges, CmpMode::Signed)); +} + +void ShrUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned)); +} + +void AndOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned)); +} + +void OrOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferOr, argRanges, CmpMode::Unsigned)); +} + +void XOrOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferXor, argRanges, CmpMode::Unsigned)); +} + +//===----------------------------------------------------------------------===// +// Casts +//===----------------------------------------------------------------------===// + +static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, + unsigned srcWidth, unsigned destWidth, + bool isSigned) { + if (srcWidth < destWidth) + return isSigned ? extSIRange(range, destWidth) + : extUIRange(range, destWidth); + if (srcWidth > destWidth) + return truncRange(range, destWidth); + return range; +} + +// When casting to `index`, we will take the union of the possible fixed-width +// casts. +static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, + Type sourceType, Type destType, + bool isSigned) { + unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + if (sourceType.isIndex()) + return makeLikeDest(range, srcWidth, destWidth, isSigned); + // We are casting to indexs, so use the union of the 32-bit and 64-bit casts + ConstantIntRanges storageRange = + makeLikeDest(range, srcWidth, destWidth, isSigned); + ConstantIntRanges minWidthRange = + makeLikeDest(range, srcWidth, indexMinWidth, isSigned); + ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth); + ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt); + return ret; +} + +void CastSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, + /*isSigned=*/true)); +} + +void CastUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, + /*isSigned=*/false)); +} + +//===----------------------------------------------------------------------===// +// CmpOp +//===----------------------------------------------------------------------===// + +void CmpOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + index::IndexCmpPredicate indexPred = getPred(); + intrange::CmpPredicate pred = static_cast(indexPred); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + APInt min = APInt::getZero(1); + APInt max = APInt::getAllOnesValue(1); + + Optional truthValue64 = intrange::evaluatePred(pred, lhs, rhs); + + ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth), + rhsTrunc = truncRange(rhs, indexMinWidth); + Optional truthValue32 = + intrange::evaluatePred(pred, lhsTrunc, rhsTrunc); + + if (truthValue64 == truthValue32) { + if (truthValue64.has_value() && *truthValue64) + min = max; + else if (truthValue64.has_value() && !(*truthValue64)) + max = min; + } + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +} + +//===----------------------------------------------------------------------===// +// SizeOf, which is bounded between the two supported bitwidth (32 and 64). +//===----------------------------------------------------------------------===// + +void SizeOfOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + unsigned storageWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + APInt min(/*numBits=*/storageWidth, indexMinWidth); + APInt max(/*numBits=*/storageWidth, indexMaxWidth); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index a7cdbb5b3a6fec..38ad0e4a2231c7 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -51,3 +51,5 @@ add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) + +add_subdirectory(Utils) diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt new file mode 100644 index 00000000000000..ece6c8e46ffea9 --- /dev/null +++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_library(MLIRInferIntRangeCommon + InferIntRangeCommon.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils + + DEPENDS + MLIRInferIntRangeInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRInferIntRangeInterface + MLIRIR +) diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp new file mode 100644 index 00000000000000..c81f004ecf5f92 --- /dev/null +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -0,0 +1,663 @@ +//===- InferIntRangeCommon.cpp - Inference for common ops ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains implementations of range inference for operations that are +// common to both the `arith` and `index` dialects to facilitate reuse. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +#include "mlir/Interfaces/InferIntRangeInterface.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" + +#include "llvm/Support/Debug.h" + +#include +#include + +using namespace mlir; + +#define DEBUG_TYPE "int-range-analysis" + +//===----------------------------------------------------------------------===// +// General utilities +//===----------------------------------------------------------------------===// + +/// Function that evaluates the result of doing something on arithmetic +/// constants and returns std::nullopt on overflow. +using ConstArithFn = + function_ref(const APInt &, const APInt &)>; + +/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, +/// If either computation overflows, make the result unbounded. +static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, + const APInt &minRight, + const APInt &maxLeft, + const APInt &maxRight, bool isSigned) { + std::optional maybeMin = op(minLeft, minRight); + std::optional maybeMax = op(maxLeft, maxRight); + if (maybeMin && maybeMax) + return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); + return ConstantIntRanges::maxRange(minLeft.getBitWidth()); +} + +/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, +/// ignoring unbounded values. Returns the maximal range if `op` overflows. +static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, + ArrayRef rhs, bool isSigned) { + unsigned width = lhs[0].getBitWidth(); + APInt min = + isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); + APInt max = + isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); + for (const APInt &left : lhs) { + for (const APInt &right : rhs) { + std::optional maybeThisResult = op(left, right); + if (!maybeThisResult) + return ConstantIntRanges::maxRange(width); + APInt result = std::move(*maybeThisResult); + min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; + max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; + } + } + return ConstantIntRanges::range(min, max, isSigned); +} + +//===----------------------------------------------------------------------===// +// Ext, trunc, index op handling +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferIndexOp(InferRangeFn inferFn, + ArrayRef argRanges, + intrange::CmpMode mode) { + ConstantIntRanges sixtyFour = inferFn(argRanges); + SmallVector truncated; + llvm::transform(argRanges, std::back_inserter(truncated), + [](const ConstantIntRanges &range) { + return truncRange(range, /*destWidth=*/indexMinWidth); + }); + ConstantIntRanges thirtyTwo = inferFn(truncated); + ConstantIntRanges thirtyTwoAsSixtyFour = + extRange(thirtyTwo, /*destWidth=*/indexMaxWidth); + ConstantIntRanges sixtyFourAsThirtyTwo = + truncRange(sixtyFour, /*destWidth=*/indexMinWidth); + + LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour + << " 32-bit = " << thirtyTwo << "\n"); + bool truncEqual = false; + switch (mode) { + case intrange::CmpMode::Both: + truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); + break; + case intrange::CmpMode::Signed: + truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && + thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); + break; + case intrange::CmpMode::Unsigned: + truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && + thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); + break; + } + if (truncEqual) + // Returing the 64-bit result preserves more information. + return sixtyFour; + ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); + return merged; +} + +ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, + unsigned int destWidth) { + APInt umin = range.umin().zext(destWidth); + APInt umax = range.umax().zext(destWidth); + APInt smin = range.smin().sext(destWidth); + APInt smax = range.smax().sext(destWidth); + return {umin, umax, smin, smax}; +} + +ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, + unsigned destWidth) { + APInt umin = range.umin().zext(destWidth); + APInt umax = range.umax().zext(destWidth); + return ConstantIntRanges::fromUnsigned(umin, umax); +} + +ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, + unsigned destWidth) { + APInt smin = range.smin().sext(destWidth); + APInt smax = range.smax().sext(destWidth); + return ConstantIntRanges::fromSigned(smin, smax); +} + +ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, + unsigned int destWidth) { + // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], + // the range of the resulting value is not contiguous ind includes 0. + // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], + // but you can't truncate [255, 257] similarly. + bool hasUnsignedRollover = + range.umin().lshr(destWidth) != range.umax().lshr(destWidth); + APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) + : range.umin().trunc(destWidth); + APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) + : range.umax().trunc(destWidth); + + // Signed post-truncation rollover will not occur when either: + // - The high parts of the min and max, plus the sign bit, are the same + // - The high halves + sign bit of the min and max are either all 1s or all 0s + // and you won't create a [positive, negative] range by truncating. + // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 + // but not [255, 257]_i16 to a range of i8s. You can also truncate + // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. + // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) + // will truncate to 0x7e, which is greater than 0 + APInt sminHighPart = range.smin().ashr(destWidth - 1); + APInt smaxHighPart = range.smax().ashr(destWidth - 1); + bool hasSignedOverflow = + (sminHighPart != smaxHighPart) && + !(sminHighPart.isAllOnes() && + (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && + !(sminHighPart.isZero() && smaxHighPart.isZero()); + APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) + : range.smin().trunc(destWidth); + APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) + : range.smax().trunc(destWidth); + return {umin, umax, smin, smax}; +} + +//===----------------------------------------------------------------------===// +// Addition +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferAdd(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn uadd = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.uadd_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn sadd = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.sadd_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + + ConstantIntRanges urange = computeBoundsBy( + uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); + return urange.intersection(srange); +} + +//===----------------------------------------------------------------------===// +// Subtraction +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferSub(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn usub = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.usub_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn ssub = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.ssub_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstantIntRanges urange = computeBoundsBy( + usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); + return urange.intersection(srange); +} + +//===----------------------------------------------------------------------===// +// Multiplication +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferMul(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn umul = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.umul_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn smul = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.smul_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + + ConstantIntRanges urange = + minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, + /*isSigned=*/true); + return urange.intersection(srange); +} + +//===----------------------------------------------------------------------===// +// DivU, CeilDivU (Unsigned division) +//===----------------------------------------------------------------------===// + +/// Fix up division results (ex. for ceiling and floor), returning an APInt +/// if there has been no overflow +using DivisionFixupFn = function_ref( + const APInt &lhs, const APInt &rhs, const APInt &result)>; + +static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), + &rhsMax = rhs.umax(); + + if (!rhsMin.isZero()) { + auto udiv = [&fixup](const APInt &a, + const APInt &b) -> std::optional { + return fixup(a, b, a.udiv(b)); + }; + return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/false); + } + // Otherwise, it's possible we might divide by 0. + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + +ConstantIntRanges +mlir::intrange::inferDivU(ArrayRef argRanges) { + return inferDivURange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; }); +} + +ConstantIntRanges +mlir::intrange::inferCeilDivU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivUIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.urem(rhs).isZero()) { + bool overflowed = false; + APInt corrected = + result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + return inferDivURange(lhs, rhs, ceilDivUIFix); +} + +//===----------------------------------------------------------------------===// +// DivS, CeilDivS, FloorDivS (Signed division) +//===----------------------------------------------------------------------===// + +static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); + + if (canDivide) { + auto sdiv = [&fixup](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.sdiv_ov(b, overflowed); + return overflowed ? std::optional() : fixup(a, b, result); + }; + return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/true); + } + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + +ConstantIntRanges +mlir::intrange::inferDivS(ArrayRef argRanges) { + return inferDivSRange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; }); +} + +ConstantIntRanges +mlir::intrange::inferCeilDivS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivSIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + return inferDivSRange(lhs, rhs, ceilDivSIFix); +} + +ConstantIntRanges +mlir::intrange::inferFloorDivS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn floorDivSIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + return inferDivSRange(lhs, rhs, floorDivSIFix); +} + +//===----------------------------------------------------------------------===// +// Signed remainder (RemS) +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferRemS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + + unsigned width = rhsMax.getBitWidth(); + APInt smin = APInt::getSignedMinValue(width); + APInt smax = APInt::getSignedMaxValue(width); + // No bounds if zero could be a divisor. + bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); + if (canBound) { + APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); + bool canNegativeDividend = lhsMin.isNegative(); + bool canPositiveDividend = lhsMax.isStrictlyPositive(); + APInt zero = APInt::getZero(maxDivisor.getBitWidth()); + APInt maxPositiveResult = maxDivisor - 1; + APInt minNegativeResult = -maxPositiveResult; + smin = canNegativeDividend ? minNegativeResult : zero; + smax = canPositiveDividend ? maxPositiveResult : zero; + // Special case: sweeping out a contiguous range in N/[modulus]. + if (rhsMin == rhsMax) { + if ((lhsMax - lhsMin).ult(maxDivisor)) { + APInt minRem = lhsMin.srem(maxDivisor); + APInt maxRem = lhsMax.srem(maxDivisor); + if (minRem.sle(maxRem)) { + smin = minRem; + smax = maxRem; + } + } + } + } + return ConstantIntRanges::fromSigned(smin, smax); +} + +//===----------------------------------------------------------------------===// +// Unsigned remainder (RemU) +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferRemU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); + + unsigned width = rhsMin.getBitWidth(); + APInt umin = APInt::getZero(width); + APInt umax = APInt::getMaxValue(width); + + if (!rhsMin.isZero()) { + umax = rhsMax - 1; + // Special case: sweeping out a contiguous range in N/[modulus] + if (rhsMin == rhsMax) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); + if ((lhsMax - lhsMin).ult(rhsMax)) { + APInt minRem = lhsMin.urem(rhsMax); + APInt maxRem = lhsMax.urem(rhsMax); + if (minRem.ule(maxRem)) { + umin = minRem; + umax = maxRem; + } + } + } + } + return ConstantIntRanges::fromUnsigned(umin, umax); +} + +//===----------------------------------------------------------------------===// +// Max and min (MaxS, MaxU, MinS, MinU) +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferMaxS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); + return ConstantIntRanges::fromSigned(smin, smax); +} + +ConstantIntRanges +mlir::intrange::inferMaxU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); + return ConstantIntRanges::fromUnsigned(umin, umax); +} + +ConstantIntRanges +mlir::intrange::inferMinS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); + return ConstantIntRanges::fromSigned(smin, smax); +} + +ConstantIntRanges +mlir::intrange::inferMinU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); + return ConstantIntRanges::fromUnsigned(umin, umax); +} + +//===----------------------------------------------------------------------===// +// Bitwise operators (And, Or, Xor) +//===----------------------------------------------------------------------===// + +/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, +/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits +/// that both bonuds have in common. This gives us a consertive approximation +/// for what values can be passed to bitwise operations. +static std::tuple +widenBitwiseBounds(const ConstantIntRanges &bound) { + APInt leftVal = bound.umin(), rightVal = bound.umax(); + unsigned bitwidth = leftVal.getBitWidth(); + unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); + leftVal.clearLowBits(differingBits); + rightVal.setLowBits(differingBits); + return std::make_tuple(std::move(leftVal), std::move(rightVal)); +} + +ConstantIntRanges +mlir::intrange::inferAnd(ArrayRef argRanges) { + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto andi = [](const APInt &a, const APInt &b) -> std::optional { + return a & b; + }; + return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false); +} + +ConstantIntRanges +mlir::intrange::inferOr(ArrayRef argRanges) { + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto ori = [](const APInt &a, const APInt &b) -> std::optional { + return a | b; + }; + return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false); +} + +ConstantIntRanges +mlir::intrange::inferXor(ArrayRef argRanges) { + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto xori = [](const APInt &a, const APInt &b) -> std::optional { + return a ^ b; + }; + return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false); +} + +//===----------------------------------------------------------------------===// +// Shifts (Shl, ShrS, ShrU) +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferShl(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn shl = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); + }; + ConstantIntRanges urange = + minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/true); + return urange.intersection(srange); +} + +ConstantIntRanges +mlir::intrange::inferShrS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn ashr = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); + }; + + return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/true); +} + +ConstantIntRanges +mlir::intrange::inferShrU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn lshr = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); + }; + return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); +} + +//===----------------------------------------------------------------------===// +// Comparisons (Cmp) +//===----------------------------------------------------------------------===// + +static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { + switch (pred) { + case intrange::CmpPredicate::eq: + return intrange::CmpPredicate::ne; + case intrange::CmpPredicate::ne: + return intrange::CmpPredicate::eq; + case intrange::CmpPredicate::slt: + return intrange::CmpPredicate::sge; + case intrange::CmpPredicate::sle: + return intrange::CmpPredicate::sgt; + case intrange::CmpPredicate::sgt: + return intrange::CmpPredicate::sle; + case intrange::CmpPredicate::sge: + return intrange::CmpPredicate::slt; + case intrange::CmpPredicate::ult: + return intrange::CmpPredicate::uge; + case intrange::CmpPredicate::ule: + return intrange::CmpPredicate::ugt; + case intrange::CmpPredicate::ugt: + return intrange::CmpPredicate::ule; + case intrange::CmpPredicate::uge: + return intrange::CmpPredicate::ult; + } + llvm_unreachable("unknown cmp predicate value"); +} + +static bool isStaticallyTrue(intrange::CmpPredicate pred, + const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs) { + switch (pred) { + case intrange::CmpPredicate::sle: + return lhs.smax().sle(rhs.smin()); + case intrange::CmpPredicate::slt: + return lhs.smax().slt(rhs.smin()); + case intrange::CmpPredicate::ule: + return lhs.umax().ule(rhs.umin()); + case intrange::CmpPredicate::ult: + return lhs.umax().ult(rhs.umin()); + case intrange::CmpPredicate::sge: + return lhs.smin().sge(rhs.smax()); + case intrange::CmpPredicate::sgt: + return lhs.smin().sgt(rhs.smax()); + case intrange::CmpPredicate::uge: + return lhs.umin().uge(rhs.umax()); + case intrange::CmpPredicate::ugt: + return lhs.umin().ugt(rhs.umax()); + case intrange::CmpPredicate::eq: { + std::optional lhsConst = lhs.getConstantValue(); + std::optional rhsConst = rhs.getConstantValue(); + return lhsConst && rhsConst && lhsConst == rhsConst; + } + case intrange::CmpPredicate::ne: { + // While equality requires that there is an interpration of the preceeding + // computations that produces equal constants, whether that be signed or + // unsigned, statically determining inequality requires that neither + // interpretation produce potentially overlapping ranges. + bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) || + isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs); + bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) || + isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs); + return sne && une; + } + } + return false; +} + +std::optional mlir::intrange::evaluatePred(CmpPredicate pred, + const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs) { + if (isStaticallyTrue(pred, lhs, rhs)) + return true; + if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) + return false; + return std::nullopt; +} diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir new file mode 100644 index 00000000000000..2784d5fd5cf701 --- /dev/null +++ b/mlir/test/Dialect/Index/int-range-inference.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s + +// Most operations are covered by the `arith` tests, which use the same code +// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling +// code is operating as expected. + +// CHECK-LABEL: func @add_same_for_both +// CHECK: %[[true:.*]] = index.bool.constant true +// CHECK: return %[[true]] +func.func @add_same_for_both(%arg0 : index) -> i1 { + %c1 = index.constant 1 + %calmostBig = index.constant 0xfffffffe + %0 = index.minu %arg0, %calmostBig + %1 = index.add %0, %c1 + %2 = index.cmp uge(%1, %c1) + func.return %2 : i1 +} + +// CHECK-LABEL: func @add_unsigned_ov +// CHECK: %[[uge:.*]] = index.cmp uge +// CHECK: return %[[uge]] +func.func @add_unsigned_ov(%arg0 : index) -> i1 { + %c1 = index.constant 1 + %cu32_max = index.constant 0xffffffff + %0 = index.minu %arg0, %cu32_max + %1 = index.add %0, %c1 + // On 32-bit, the add could wrap, so the result doesn't have to be >= 1 + %2 = index.cmp uge(%1, %c1) + func.return %2 : i1 +} + +// CHECK-LABEL: func @add_signed_ov +// CHECK: %[[sge:.*]] = index.cmp sge +// CHECK: return %[[sge]] +func.func @add_signed_ov(%arg0 : index) -> i1 { + %c0 = index.constant 0 + %c1 = index.constant 1 + %ci32_max = index.constant 0x7fffffff + %0 = index.minu %arg0, %ci32_max + %1 = index.add %0, %c1 + // On 32-bit, the add could wrap, so the result doesn't have to be positive + %2 = index.cmp sge(%1, %c0) + func.return %2 : i1 +} + +// CHECK-LABEL: func @add_big +// CHECK: %[[true:.*]] = index.bool.constant true +// CHECK: return %[[true]] +func.func @add_big(%arg0 : index) -> i1 { + %c1 = index.constant 1 + %cmin = index.constant 0x300000000 + %cmax = index.constant 0x30000ffff + // Note: the order of the clamps matters. + // If you go max, then min, you infer the ranges [0x300...0, 0xff..ff] + // and then [0x30...0000, 0x30...ffff] + // If you switch the order of the below operations, you instead first infer + // the range [0,0x3...ffff]. Then, the min inference can't constraint + // this intermediate, since in the 32-bit case we could have, for example + // trunc(%arg0 = 0x2ffffffff) = 0xffffffff > trunc(0x30000ffff) = 0x0000ffff + // which means we can't do any inference. + %0 = index.maxu %arg0, %cmin + %1 = index.minu %0, %cmax + %2 = index.add %1, %c1 + %3 = index.cmp uge(%1, %cmin) + func.return %3 : i1 +} From af05e818fb0519f027120e160ca036a3ab678584 Mon Sep 17 00:00:00 2001 From: David Carlier Date: Thu, 19 Jan 2023 18:01:03 +0000 Subject: [PATCH 37/61] [Sanitizers] GetMemoryProfile implementation for FreeBSD. Reviewers: dvyukov Reviewed-By: dvyukov Differental Revision: https://reviews.llvm.org/D140688 --- .../sanitizer_common/sanitizer_procmaps_bsd.cpp | 16 ++++++++++++++++ .../sanitizer_procmaps_common.cpp | 2 ++ 2 files changed, 18 insertions(+) diff --git a/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_bsd.cpp b/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_bsd.cpp index 1f489b71ad9988..36a82c4ac966ba 100644 --- a/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_bsd.cpp +++ b/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_bsd.cpp @@ -39,6 +39,22 @@ namespace __sanitizer { +#if SANITIZER_FREEBSD +void GetMemoryProfile(fill_profile_f cb, uptr *stats) { + const int Mib[] = { + CTL_KERN, + KERN_PROC, + KERN_PROC_PID, + getpid() + }; + + struct kinfo_proc InfoProc; + uptr Len = sizeof(InfoProc); + CHECK_EQ(internal_sysctl(Mib, ARRAY_SIZE(Mib), nullptr, (uptr *)&InfoProc, &Len, 0), 0); + cb(0, InfoProc.ki_rssize * GetPageSizeCached(), false, stats); +} +#endif + void ReadProcMaps(ProcSelfMapsBuff *proc_maps) { const int Mib[] = { #if SANITIZER_FREEBSD diff --git a/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_common.cpp b/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_common.cpp index eb351b0f06fd0b..a7805ad1b083b0 100644 --- a/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_common.cpp +++ b/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_common.cpp @@ -145,6 +145,7 @@ void MemoryMappingLayout::DumpListOfModules( } } +#if SANITIZER_LINUX || SANITIZER_ANDROID || SANITIZER_SOLARIS || SANITIZER_NETBSD void GetMemoryProfile(fill_profile_f cb, uptr *stats) { char *smaps = nullptr; uptr smaps_cap = 0; @@ -184,6 +185,7 @@ void ParseUnixMemoryProfile(fill_profile_f cb, uptr *stats, char *smaps, while (*pos++ != '\n') {} } } +#endif } // namespace __sanitizer From 1d98861a7896236895d467f7e7ab4eadf7dffd82 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 19 Jan 2023 19:01:22 +0100 Subject: [PATCH 38/61] [mlir][SCF] Fix crash in loop peeling Upper bound and step size should be symbols instead of dims. Differential Revision: https://reviews.llvm.org/D142136 --- .../SCF/Utils/AffineCanonicalizationUtils.cpp | 3 ++- mlir/test/Dialect/SCF/for-loop-peeling.mlir | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp index 0b09dc5374d1b6..4ee27e4d00343c 100644 --- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp @@ -169,7 +169,8 @@ LogicalResult scf::rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op, Value iv, Value ub, Value step, bool insideLoop) { FlatAffineValueConstraints constraints; - constraints.appendDimVar({iv, ub, step}); + constraints.appendDimVar({iv}); + constraints.appendSymbolVar({ub, step}); if (auto constUb = getConstantIntValue(ub)) constraints.addBound(IntegerPolyhedron::EQ, 1, *constUb); if (auto constStep = getConstantIntValue(step)) diff --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir index b1cf1f2414de0c..9a6d1c8c0a14cd 100644 --- a/mlir/test/Dialect/SCF/for-loop-peeling.mlir +++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir @@ -275,3 +275,17 @@ func.func @nested_loops(%lb0: index, %lb1 : index, %ub0: index, %ub1: index, } return %r0 : i32 } + +// ----- + +// CHECK-LABEL: func @regression +func.func @regression(%arg0: memref, %arg1: index) { + %c0 = arith.constant 0 : index + %0 = affine.apply affine_map<()[s0] -> (s0 * s0)>()[%arg1] + scf.for %arg2 = %c0 to %0 step %arg1 { + %1 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + s0 * s0)>(%arg2)[%arg1] + %2 = arith.index_cast %0 : index to i64 + memref.store %2, %arg0[] : memref + } + return +} From d0907ce7ed9f159562ca3f4cfd8d87e89e93febe Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Thu, 19 Jan 2023 18:10:51 +0000 Subject: [PATCH 39/61] [LoopUnroll] Directly update DT instead of DTU. The scope of DT updates are very limited when unrolling loops: the DT should only need updating for * new blocks added * exiting blocks we simplified branches This can be done manually without too much extra work. MergeBlockIntoPredecessor also needs to be updated to support direct DT updates. This fixes excessive time spent in DTU for same cases. In an internal example, time spent in LoopUnroll with this patch goes from ~200s to 2s. It also is slightly positive for CTMark: * NewPM-O3: -0.13% * NewPM-ReleaseThinLTO: -0.11% * NewPM-ReleaseLTO-g: -0.13% Notable improvements are mafft (~ -0.50%) and lencod (~ -0.30%), with no workload regressed. https://llvm-compile-time-tracker.com/compare.php?from=78a9ee7834331fb4360457cc565fa36f5452f7e0&to=687e08d011b0dc6d3edd223612761e44225c7537&stat=instructions:u Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D141487 --- .../llvm/Transforms/Utils/BasicBlockUtils.h | 5 +- llvm/lib/Transforms/Utils/BasicBlockUtils.cpp | 20 ++++++- llvm/lib/Transforms/Utils/LoopUnroll.cpp | 53 +++++++++++++++---- 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h index 7db466ecdd8c6e..c97baaf4afc2f5 100644 --- a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h +++ b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h @@ -90,11 +90,14 @@ bool DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI = nullptr, /// if BB's Pred has a branch to BB and to AnotherBB, and BB has a single /// successor Sing. In this case the branch will be updated with Sing instead of /// BB, and BB will still be merged into its predecessor and removed. +/// If \p DT is not nullptr, update it directly; in that case, DTU must be +/// nullptr. bool MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU = nullptr, LoopInfo *LI = nullptr, MemorySSAUpdater *MSSAU = nullptr, MemoryDependenceResults *MemDep = nullptr, - bool PredecessorWithTwoSuccessors = false); + bool PredecessorWithTwoSuccessors = false, + DominatorTree *DT = nullptr); /// Merge block(s) sucessors, if possible. Return true if at least two /// of the blocks were merged together. diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index d14e5b82245003..8e49edb4dac167 100644 --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -179,7 +179,8 @@ bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI, bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, LoopInfo *LI, MemorySSAUpdater *MSSAU, MemoryDependenceResults *MemDep, - bool PredecessorWithTwoSuccessors) { + bool PredecessorWithTwoSuccessors, + DominatorTree *DT) { if (BB->hasAddressTaken()) return false; @@ -232,10 +233,21 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, FoldSingleEntryPHINodes(BB, MemDep); } + if (DT) { + assert(!DTU && "cannot use both DT and DTU for updates"); + DomTreeNode *PredNode = DT->getNode(PredBB); + DomTreeNode *BBNode = DT->getNode(BB); + if (PredNode) { + assert(BBNode && "PredNode unreachable but BBNode reachable?"); + for (DomTreeNode *C : to_vector(BBNode->children())) + C->setIDom(PredNode); + } + } // DTU update: Collect all the edges that exit BB. // These dominator edges will be redirected from Pred. std::vector Updates; if (DTU) { + assert(!DT && "cannot use both DT and DTU for updates"); // To avoid processing the same predecessor more than once. SmallPtrSet SeenSuccs; SmallPtrSet SuccsOfPredBB(succ_begin(PredBB), @@ -311,6 +323,12 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, if (DTU) DTU->applyUpdates(Updates); + if (DT) { + assert(succ_empty(BB) && + "successors should have been transferred to PredBB"); + DT->eraseNode(BB); + } + // Finally, erase the old block and update dominator info. DeleteDeadBlock(BB, DTU); diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp index cc96d15d5e5966..f8251ee9d6458b 100644 --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -321,6 +321,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, unsigned TripMultiple; unsigned BreakoutTrip; bool ExitOnTrue; + BasicBlock *FirstExitingBlock = nullptr; SmallVector ExitingBlocks; }; DenseMap ExitInfos; @@ -680,8 +681,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, assert(!UnrollVerifyDomtree || DT->verify(DominatorTree::VerificationLevel::Fast)); - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - + SmallVector DTUpdates; auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) { auto *Term = cast(Src->getTerminator()); const unsigned Idx = ExitOnTrue ^ WillExit; @@ -695,7 +695,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, BranchInst::Create(Dest, Term); Term->eraseFromParent(); - DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}}); + DTUpdates.emplace_back(DominatorTree::Delete, Src, DeadSucc); }; auto WillExit = [&](const ExitInfo &Info, unsigned i, unsigned j, @@ -733,28 +733,56 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, // Fold branches for iterations where we know that they will exit or not // exit. - for (const auto &Pair : ExitInfos) { - const ExitInfo &Info = Pair.second; + for (auto &Pair : ExitInfos) { + ExitInfo &Info = Pair.second; for (unsigned i = 0, e = Info.ExitingBlocks.size(); i != e; ++i) { // The branch destination. unsigned j = (i + 1) % e; bool IsLatch = Pair.first == LatchBlock; std::optional KnownWillExit = WillExit(Info, i, j, IsLatch); - if (!KnownWillExit) + if (!KnownWillExit) { + if (!Info.FirstExitingBlock) + Info.FirstExitingBlock = Info.ExitingBlocks[i]; continue; + } // We don't fold known-exiting branches for non-latch exits here, // because this ensures that both all loop blocks and all exit blocks // remain reachable in the CFG. // TODO: We could fold these branches, but it would require much more // sophisticated updates to LoopInfo. - if (*KnownWillExit && !IsLatch) + if (*KnownWillExit && !IsLatch) { + if (!Info.FirstExitingBlock) + Info.FirstExitingBlock = Info.ExitingBlocks[i]; continue; + } SetDest(Info.ExitingBlocks[i], *KnownWillExit, Info.ExitOnTrue); } } + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + DomTreeUpdater *DTUToUse = &DTU; + if (ExitingBlocks.size() == 1) { + // Manually update the DT if there's a single exiting node. In that case + // there's a single exit node and it is sufficient to update the nodes + // immediately dominated by the original exiting block. They will become + // dominated by the first exiting block that leaves the loop after + // unrolling. Note that the CFG inside the loop does not change, so there's + // no need to update the DT inside the unrolled loop. + DTUToUse = nullptr; + auto &[OriginalExit, Info] = *ExitInfos.begin(); + if (!Info.FirstExitingBlock) + Info.FirstExitingBlock = Info.ExitingBlocks.back(); + for (auto *C : to_vector(DT->getNode(OriginalExit)->children())) { + if (L->contains(C->getBlock())) + continue; + C->setIDom(DT->getNode(Info.FirstExitingBlock)); + } + } else { + DTU.applyUpdates(DTUpdates); + } + // When completely unrolling, the last latch becomes unreachable. if (!LatchIsExiting && CompletelyUnroll) { // There is no need to update the DT here, because there must be a unique @@ -774,16 +802,21 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, if (Term && Term->isUnconditional()) { BasicBlock *Dest = Term->getSuccessor(0); BasicBlock *Fold = Dest->getUniquePredecessor(); - if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) { + if (MergeBlockIntoPredecessor(Dest, /*DTU=*/DTUToUse, LI, + /*MSSAU=*/nullptr, /*MemDep=*/nullptr, + /*PredecessorWithTwoSuccessors=*/false, + DTUToUse ? nullptr : DT)) { // Dest has been folded into Fold. Update our worklists accordingly. std::replace(Latches.begin(), Latches.end(), Dest, Fold); llvm::erase_value(UnrolledLoopBlocks, Dest); } } } - // Apply updates to the DomTree. - DT = &DTU.getDomTree(); + if (DTUToUse) { + // Apply updates to the DomTree. + DT = &DTU.getDomTree(); + } assert(!UnrollVerifyDomtree || DT->verify(DominatorTree::VerificationLevel::Fast)); From b5d9f00b2096290653fcb6e8de38d5c352af63a0 Mon Sep 17 00:00:00 2001 From: Erich Keane Date: Tue, 17 Jan 2023 11:29:04 -0800 Subject: [PATCH 40/61] Forbid implicit conversion of constraint expression to bool As reported in https://github.com/llvm/llvm-project/issues/54524, and later in https://github.com/llvm/llvm-project/issues/60038, we were not properly implmenting temp.constr.atomic P3. This patch stops implicitly converting constraints to bool, and ensures the Rvalue conversion takes place as needed. Differential Revision: https://reviews.llvm.org/D141954 --- clang/docs/ReleaseNotes.rst | 3 + clang/lib/Sema/SemaConcept.cpp | 25 ++++--- .../constrant-satisfaction-conversions.cpp | 67 +++++++++++++++++++ 3 files changed, 87 insertions(+), 8 deletions(-) create mode 100644 clang/test/CXX/temp/temp.constr/temp.constr.atomic/constrant-satisfaction-conversions.cpp diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index 8de179cb896174..97ce8866d3ccf3 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -774,6 +774,9 @@ C++20 Feature Support and `P1975R0: `_, which allows parenthesized aggregate-initialization. +- Fixed an issue with concept requirement evaluation, where we incorrectly allowed implicit + conversions to bool for a requirement. This fixes `GH54524 `_. + C++2b Feature Support ^^^^^^^^^^^^^^^^^^^^^ diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp index f20e4751f41a36..4d4b2482d046e1 100644 --- a/clang/lib/Sema/SemaConcept.cpp +++ b/clang/lib/Sema/SemaConcept.cpp @@ -329,14 +329,7 @@ static ExprResult calculateConstraintSatisfaction( Sema::SFINAETrap Trap(S); SubstitutedExpression = S.SubstConstraintExpr(const_cast(AtomicExpr), MLTAL); - // Substitution might have stripped off a contextual conversion to - // bool if this is the operand of an '&&' or '||'. For example, we - // might lose an lvalue-to-rvalue conversion here. If so, put it back - // before we try to evaluate. - if (SubstitutedExpression.isUsable() && - !SubstitutedExpression.isInvalid()) - SubstitutedExpression = - S.PerformContextuallyConvertToBool(SubstitutedExpression.get()); + if (SubstitutedExpression.isInvalid() || Trap.hasErrorOccurred()) { // C++2a [temp.constr.atomic]p1 // ...If substitution results in an invalid type or expression, the @@ -373,6 +366,22 @@ static ExprResult calculateConstraintSatisfaction( if (!S.CheckConstraintExpression(SubstitutedExpression.get())) return ExprError(); + // [temp.constr.atomic]p3: To determine if an atomic constraint is + // satisfied, the parameter mapping and template arguments are first + // substituted into its expression. If substitution results in an + // invalid type or expression, the constraint is not satisfied. + // Otherwise, the lvalue-to-rvalue conversion is performed if necessary, + // and E shall be a constant expression of type bool. + // + // Perform the L to R Value conversion if necessary. We do so for all + // non-PRValue categories, else we fail to extend the lifetime of + // temporaries, and that fails the constant expression check. + if (!SubstitutedExpression.get()->isPRValue()) + SubstitutedExpression = ImplicitCastExpr::Create( + S.Context, SubstitutedExpression.get()->getType(), + CK_LValueToRValue, SubstitutedExpression.get(), + /*BasePath=*/nullptr, VK_PRValue, FPOptionsOverride()); + return SubstitutedExpression; }); } diff --git a/clang/test/CXX/temp/temp.constr/temp.constr.atomic/constrant-satisfaction-conversions.cpp b/clang/test/CXX/temp/temp.constr/temp.constr.atomic/constrant-satisfaction-conversions.cpp new file mode 100644 index 00000000000000..ba8e2dc372e984 --- /dev/null +++ b/clang/test/CXX/temp/temp.constr/temp.constr.atomic/constrant-satisfaction-conversions.cpp @@ -0,0 +1,67 @@ +// RUN: %clang_cc1 -std=c++20 -x c++ -Wno-constant-logical-operand -verify %s + +template concept C = +sizeof(T) == 4 && !true; // requires atomic constraints sizeof(T) == 4 and !true + +template concept C2 = sizeof(T); // expected-error{{atomic constraint must be of type 'bool' (found }} + +template struct S { + constexpr operator bool() const { return true; } +}; + +// expected-error@+3{{atomic constraint must be of type 'bool' (found 'S')}} +// expected-note@#FINST{{while checking constraint satisfaction}} +// expected-note@#FINST{{in instantiation of function template specialization}} +template requires (S{}) +void f(T); +void f(int); + +// Ensure this applies to operator && as well. +// expected-error@+3{{atomic constraint must be of type 'bool' (found 'S')}} +// expected-note@#F2INST{{while checking constraint satisfaction}} +// expected-note@#F2INST{{in instantiation of function template specialization}} +template requires (S{} && true) +void f2(T); +void f2(int); + +template requires requires { + requires S{}; + // expected-error@-1{{atomic constraint must be of type 'bool' (found 'S')}} + // expected-note@-2{{while checking the satisfaction}} + // expected-note@-3{{in instantiation of requirement}} + // expected-note@-4{{while checking the satisfaction}} + // expected-note@-6{{while substituting template arguments}} + // expected-note@#F3INST{{while checking constraint satisfaction}} + // expected-note@#F3INST{{in instantiation of function template specialization}} + // +} +void f3(T); +void f3(int); + +// Doesn't diagnose, since this is no longer a compound requirement. +template requires (bool(1 && 2)) +void f4(T); +void f4(int); + +void g() { + f(0); // #FINST + f2(0); // #F2INST + f3(0); // #F3INST + f4(0); +} + +template +auto Nullptr = nullptr; + +template concept NullTy = Nullptr; +// expected-error@-1{{atomic constraint must be of type 'bool' (found }} +// expected-note@+1{{while checking the satisfaction}} +static_assert(NullTy); + +template +auto Struct = S{}; + +template concept StructTy = Struct; +// expected-error@-1{{atomic constraint must be of type 'bool' (found 'S')}} +// expected-note@+1{{while checking the satisfaction}} +static_assert(StructTy); From 7f66e1833f62d6f7269adc60ac18bbaa820f64ae Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 19 Jan 2023 13:18:22 -0500 Subject: [PATCH 41/61] [MLIR] Add InferTypeOpInterface to scf.if op Differential Revision: https://reviews.llvm.org/D142049 --- mlir/include/mlir/Dialect/SCF/IR/SCF.h | 1 + mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 12 +++--- mlir/lib/Dialect/SCF/IR/SCF.cpp | 40 ++++++++++++++----- .../llvm-project-overlay/mlir/BUILD.bazel | 2 + 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h index c1e7bc33b4ef4f..5453f3862e7449 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 9e1752b69174ea..05adc85434778d 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/RegionKindInterface.td" include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -610,12 +611,11 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [ // IfOp //===----------------------------------------------------------------------===// -def IfOp : SCF_Op<"if", - [DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects, - NoRegionArguments]> { +def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects, + NoRegionArguments]> { let summary = "if-then-else operation"; let description = [{ The `scf.if` operation represents an if-then-else construct for diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 8699f1d7b162db..af2adb994145a5 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1467,6 +1467,23 @@ bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) { return false; } +LogicalResult +IfOp::inferReturnTypes(MLIRContext *ctx, std::optional loc, + ValueRange operands, DictionaryAttr attrs, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (regions.empty()) + return failure(); + Region *r = regions.front(); + assert(!r->empty()); + Block &b = r->front(); + auto yieldOp = llvm::dyn_cast(b.getTerminator()); + TypeRange types = yieldOp.getOperandTypes(); + inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(), + types.end()); + return success(); +} + void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, bool withElseRegion) { build(builder, result, /*resultTypes=*/std::nullopt, cond, withElseRegion); @@ -1516,19 +1533,24 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, // Build then region. OpBuilder::InsertionGuard guard(builder); Region *thenRegion = result.addRegion(); - Block *thenBlock = builder.createBlock(thenRegion); + builder.createBlock(thenRegion); thenBuilder(builder, result.location); - // Infer types if there are any. - if (auto yieldOp = llvm::dyn_cast(thenBlock->getTerminator())) - result.addTypes(yieldOp.getOperandTypes()); - // Build else region. Region *elseRegion = result.addRegion(); - if (!elseBuilder) - return; - builder.createBlock(elseRegion); - elseBuilder(builder, result.location); + if (elseBuilder) { + builder.createBlock(elseRegion); + elseBuilder(builder, result.location); + } + + // Infer result types. + SmallVector inferredReturnTypes; + MLIRContext *ctx = builder.getContext(); + auto attrDict = DictionaryAttr::get(ctx, result.attributes); + if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict, + result.regions, inferredReturnTypes))) { + result.addTypes(inferredReturnTypes); + } } LogicalResult IfOp::verify() { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 31e5bd21537f34..02588325708827 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1900,6 +1900,7 @@ td_library( includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":LoopLikeInterfaceTdFiles", ":ParallelCombiningOpInterfaceTdFiles", ":SideEffectInterfacesTdFiles", @@ -2929,6 +2930,7 @@ cc_library( ":ControlFlowInterfaces", ":FuncDialect", ":IR", + ":InferTypeOpInterface", ":LoopLikeInterface", ":MemRefDialect", ":ParallelCombiningOpInterface", From 1f3f3c0ea724335e7eb518a8fe30990c9245875b Mon Sep 17 00:00:00 2001 From: Arthur Eubanks Date: Thu, 19 Jan 2023 10:19:27 -0800 Subject: [PATCH 42/61] Revert "Reland [pgo] Avoid introducing relocations by using private alias" This reverts commit da5a8d14b8cc6cea16ee0929413c0672b47c93d9. Causes more duplicate symbol errors, see https://bugs.chromium.org/p/chromium/issues/detail?id=1408161. --- .../profile/instrprof-discarded-comdat.cpp | 51 ----------- .../Instrumentation/InstrProfiling.cpp | 60 +------------ llvm/test/Transforms/PGOProfile/comdat.ll | 31 ------- .../PGOProfile/profdata_priv_alias.ll | 84 ------------------- 4 files changed, 3 insertions(+), 223 deletions(-) delete mode 100644 compiler-rt/test/profile/instrprof-discarded-comdat.cpp delete mode 100644 llvm/test/Transforms/PGOProfile/profdata_priv_alias.ll diff --git a/compiler-rt/test/profile/instrprof-discarded-comdat.cpp b/compiler-rt/test/profile/instrprof-discarded-comdat.cpp deleted file mode 100644 index feaa7c1083c14e..00000000000000 --- a/compiler-rt/test/profile/instrprof-discarded-comdat.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// Check that instrprof does not introduce references to discarded sections when -// using comdats. -// -// Occasionally, it is possible that the same function can be compiled in -// different TUs with slightly different linkages, e.g., due to different -// compiler options. However, if these are comdat functions, a single -// implementation will be chosen at link time. we want to ensure that the -// profiling data does not contain a reference to the discarded section. - -// UNSUPPORTED: target={{.*windows.*}} - -// RUN: mkdir -p %t.d -// RUN: %clangxx_pgogen -O2 -fPIC -ffunction-sections -fdata-sections -c %s -o %t.d/a1.o -DOBJECT_1 -mllvm -disable-preinline -// RUN: %clangxx_pgogen -O2 -fPIC -ffunction-sections -fdata-sections -c %s -o %t.d/a2.o -// RUN: %clangxx_pgogen -fPIC -shared -o %t.d/liba.so %t.d/a1.o %t.d/a2.o 2>&1 | FileCheck %s --allow-empty - -// Ensure that we don't get an error when linking -// CHECK-NOT: relocation refers to a discarded section: .text._ZN1CIiE1fEi - -template struct C { - void f(T x); - int g(T x) { - f(x); - return v; - } - int v; -}; - -template -#ifdef OBJECT_1 -__attribute__((weak)) -#else -__attribute__((noinline)) -#endif -void C::f(T x) { - v += x; -} - -#ifdef OBJECT_1 -int foo() { - C c; - c.f(1); - return c.g(2); -} -#else -int bar() { - C c; - c.f(3); - return c.g(4); -} -#endif diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index e4a7edad8c724d..c0409206216e52 100644 --- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -823,62 +823,6 @@ static inline bool shouldRecordFunctionAddr(Function *F) { return F->hasAddressTaken() || F->hasLinkOnceLinkage(); } -static inline bool shouldUsePublicSymbol(Function *Fn) { - // It isn't legal to make an alias of this function at all - if (Fn->isDeclarationForLinker()) - return true; - - // Symbols with local linkage can just use the symbol directly without - // introducing relocations - if (Fn->hasLocalLinkage()) - return true; - - // For comdat functions, an alias would need the same linkage as the original - // function and hidden visibility, and there is not point in adding an alias - // with identical linkage an visibility to avoid introducing relocations. - // This caused duplicate symbols to be introduced under the combination of - // PGO + ThinLTO + CFI, due to some unfavorable interaction between the new - // alias, and alias related transforms in GlobalOpt and LowerTypeTests. - if (Fn->hasComdat() && - (Fn->getVisibility() == GlobalValue::VisibilityTypes::HiddenVisibility)) - return true; - - // its OK to use an alias - return false; -} - -static inline Constant *getFuncAddrForProfData(Function *Fn) { - auto *Int8PtrTy = Type::getInt8PtrTy(Fn->getContext()); - // Store a nullptr in __llvm_profd, if we shouldn't use a real address - if (!shouldRecordFunctionAddr(Fn)) - return ConstantPointerNull::get(Int8PtrTy); - - // If we can't use an alias, we must use the public symbol, even though this - // may require a symbolic relocation. - if (shouldUsePublicSymbol(Fn)) - return ConstantExpr::getBitCast(Fn, Int8PtrTy); - - // When possible use a private alias to avoid symbolic relocations. - auto *GA = GlobalAlias::create(GlobalValue::LinkageTypes::PrivateLinkage, - Fn->getName() + ".local", Fn); - - // When the instrumented function is a COMDAT function, we cannot use a - // private alias. If we did, we would create reference to a local label in - // this function's section. If this version of the function isn't selected by - // the linker, then the metadata would introduce a reference to a discarded - // section. So, for COMDAT functions, we need to adjust the linkage of the - // alias. Using hidden visibility avoids a dynamic relocation and an entry in - // the dynamic symbol table. - // - // Note that this handles COMDAT functions with visibility other than Hidden, - // since that case is covered in shouldUsePublicSymbol() - if (Fn->hasComdat()) { - GA->setLinkage(Fn->getLinkage()); - GA->setVisibility(GlobalValue::VisibilityTypes::HiddenVisibility); - } - return ConstantExpr::getBitCast(GA, Int8PtrTy); -} - static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) { // Don't do this for Darwin. compiler-rt uses linker magic. if (TT.isOSDarwin()) @@ -1070,7 +1014,9 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { }; auto *DataTy = StructType::get(Ctx, ArrayRef(DataTypes)); - Constant *FunctionAddr = getFuncAddrForProfData(Fn); + Constant *FunctionAddr = shouldRecordFunctionAddr(Fn) + ? ConstantExpr::getBitCast(Fn, Int8PtrTy) + : ConstantPointerNull::get(Int8PtrTy); Constant *Int16ArrayVals[IPVK_Last + 1]; for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) diff --git a/llvm/test/Transforms/PGOProfile/comdat.ll b/llvm/test/Transforms/PGOProfile/comdat.ll index 765a77538a9b1e..9f5c0ee848ca52 100644 --- a/llvm/test/Transforms/PGOProfile/comdat.ll +++ b/llvm/test/Transforms/PGOProfile/comdat.ll @@ -4,8 +4,6 @@ $linkonceodr = comdat any $weakodr = comdat any -$weak = comdat any -$linkonce = comdat any ;; profc/profd have hash suffixes. This definition doesn't have value profiling, ;; so definitions with the same name in other modules must have the same CFG and @@ -29,32 +27,3 @@ define linkonce_odr void @linkonceodr() comdat { define weak_odr void @weakodr() comdat { ret void } - -;; weak in a comdat is not renamed. There is no guarantee that definitions in -;; other modules don't have value profiling. profd should be conservatively -;; non-private to prevent a caller from referencing a non-prevailing profd, -;; causing a linker error. -; ELF: @__profc_weak = weak hidden global {{.*}} comdat, align 8 -; ELF: @__profd_weak = weak hidden global {{.*}} comdat($__profc_weak), align 8 -; COFF: @__profc_weak = weak hidden global {{.*}} comdat, align 8 -; COFF: @__profd_weak = weak hidden global {{.*}} comdat, align 8 -define weak void @weak() comdat { - ret void -} - -;; profc/profd have hash suffixes. This definition doesn't have value profiling, -;; so definitions with the same name in other modules must have the same CFG and -;; cannot have value profiling, either. profd can be made private for ELF. -; ELF: @__profc_linkonce.[[#]] = linkonce hidden global {{.*}} comdat, align 8 -; ELF: @__profd_linkonce.[[#]] = private global {{.*}} comdat($__profc_linkonce.[[#]]), align 8 -; COFF: @__profc_linkonce.[[#]] = linkonce hidden global {{.*}} comdat, align 8 -; COFF: @__profd_linkonce.[[#]] = linkonce hidden global {{.*}} comdat, align 8 -define linkonce void @linkonce() comdat { - ret void -} - -; Check that comdat aliases are hidden for all linkage types -; ELF: @linkonceodr.local = linkonce_odr hidden alias void (), ptr @linkonceodr -; ELF: @weakodr.local = weak_odr hidden alias void (), ptr @weakodr -; ELF: @weak.local = weak hidden alias void (), ptr @weak -; ELF: @linkonce.local = linkonce hidden alias void (), ptr @linkonce diff --git a/llvm/test/Transforms/PGOProfile/profdata_priv_alias.ll b/llvm/test/Transforms/PGOProfile/profdata_priv_alias.ll deleted file mode 100644 index 806a2ab9b82d3c..00000000000000 --- a/llvm/test/Transforms/PGOProfile/profdata_priv_alias.ll +++ /dev/null @@ -1,84 +0,0 @@ -; RUN: opt -S -passes=pgo-instr-gen,instrprof < %s | FileCheck %s - -target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" -target triple = "x86_64-unknown-linux-gnu" - -;; Test that we use private aliases to reference function addresses inside profile data -; CHECK: @__profd_foo = private global {{.*}} ptr @foo.local -; CHECK-NOT: @__profd_foo = private global {{.*}} ptr @foo - -; CHECK: @__profd_weak = private global {{.*}} ptr @weak.local -; CHECK: @__profd_linkonce = private global {{.*}} ptr @linkonce.local -; CHECK: @__profd_weakodr = private global {{.*}} ptr @weakodr.local -; CHECK: @__profd_linkonceodr = private global {{.*}} ptr @linkonceodr.local - -; available_externally shouldn't have an alias, so make sure it doesn't appear here -; CHECK: @__profc_available_externally.[[HASH:[#0-9]+]] -; CHECK-NOT: @__profd_available_externally.[[HASH]] = {{.*}}ptr @available_externally.[[HASH]].local - -;; Ensure when not instrumenting a non-comdat function, then if we generate an -;; alias, then it is private. We check comdat versions in comdat.ll -; CHECK: @foo.local = private alias i32 (i32), ptr @foo -; CHECK: @weak.local = private alias void (), ptr @weak -; CHECK: @linkonce.local = private alias void (), ptr @linkonce -; CHECK: @weakodr.local = private alias void (), ptr @weakodr -; CHECK: @linkonceodr.local = private alias void (), ptr @linkonceodr - -;; We should never generate an alias for available_externally functions -; CHECK-NOT: @available_externally{{.*}} = private alias void (), ptr @available_externally - -define i32 @foo(i32 %0) { -; CHECK-LABEL: @foo( -; CHECK-NEXT: entry: -; CHECK-NEXT: [[PGOCOUNT:%.*]] = load i64, ptr @__profc_foo, align 8 -; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[PGOCOUNT]], 1 -; CHECK-NEXT: store i64 [[TMP1]], ptr @__profc_foo, align 8 -; CHECK-NEXT: ret i32 0 -entry: - ret i32 0 -} - -define weak void @weak() { -; CHECK-LABEL: @weak( -; CHECK-NEXT: [[PGOCOUNT:%.*]] = load i64, ptr @__profc_weak, align 8 -; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[PGOCOUNT]], 1 -; CHECK-NEXT: store i64 [[TMP1]], ptr @__profc_weak, align 8 -; CHECK-NEXT: ret void - ret void -} - -define linkonce void @linkonce() { -; CHECK-LABEL: @linkonce( -; CHECK-NEXT: [[PGOCOUNT:%.*]] = load i64, ptr @__profc_linkonce, align 8 -; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[PGOCOUNT]], 1 -; CHECK-NEXT: store i64 [[TMP1]], ptr @__profc_linkonce, align 8 -; CHECK-NEXT: ret void - ret void -} - -define weak_odr void @weakodr() { -; CHECK-LABEL: @weakodr( -; CHECK-NEXT: [[PGOCOUNT:%.*]] = load i64, ptr @__profc_weakodr, align 8 -; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[PGOCOUNT]], 1 -; CHECK-NEXT: store i64 [[TMP1]], ptr @__profc_weakodr, align 8 -; CHECK-NEXT: ret void - ret void -} - -define linkonce_odr void @linkonceodr() { -; CHECK-LABEL: @linkonceodr( -; CHECK-NEXT: [[PGOCOUNT:%.*]] = load i64, ptr @__profc_linkonceodr, align 8 -; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[PGOCOUNT]], 1 -; CHECK-NEXT: store i64 [[TMP1]], ptr @__profc_linkonceodr, align 8 -; CHECK-NEXT: ret void - ret void -} - -define available_externally void @available_externally(){ -; CHECK-LABEL: @available_externally( -; CHECK-NEXT: [[PGOCOUNT:%.*]] = load i64, ptr @__profc_available_externally.[[HASH]], align 8 -; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[PGOCOUNT]], 1 -; CHECK-NEXT: store i64 [[TMP1]], ptr @__profc_available_externally.[[HASH]], align 8 -; CHECK-NEXT: ret void - ret void -} From 117ecdd5789e6efd28505a8b84680020fd935cd1 Mon Sep 17 00:00:00 2001 From: Gulfem Savrun Yeniceri Date: Sat, 14 Jan 2023 00:48:32 +0000 Subject: [PATCH 43/61] [IRLinker] Replace CallInstr with CallBase This patch replaces CallInstr with CallBase to cover InvokeInstr besides CallInstr while removing nocallback attribute on a call site. It also extends drop-attribute.ll test to include a case for an invoke instruction. Differential Revision: https://reviews.llvm.org/D141740 --- llvm/lib/Linker/IRMover.cpp | 2 +- llvm/test/Linker/drop-attribute.ll | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Linker/IRMover.cpp b/llvm/lib/Linker/IRMover.cpp index 322afb929e6538..517e2dc8ebe030 100644 --- a/llvm/lib/Linker/IRMover.cpp +++ b/llvm/lib/Linker/IRMover.cpp @@ -1555,7 +1555,7 @@ void IRLinker::updateAttributes(GlobalValue &GV) { // Remove nocallback attribute when it is on a call-site. for (BasicBlock &BB : *F) for (Instruction &I : BB) - if (CallInst *CI = dyn_cast(&I)) + if (CallBase *CI = dyn_cast(&I)) CI->removeFnAttr(Attribute::NoCallback); } } diff --git a/llvm/test/Linker/drop-attribute.ll b/llvm/test/Linker/drop-attribute.ll index 8e74252e524599..9be95a89109b4e 100644 --- a/llvm/test/Linker/drop-attribute.ll +++ b/llvm/test/Linker/drop-attribute.ll @@ -5,18 +5,26 @@ ; CHECK: define i32 @main() ; CHECK-NEXT: entry: ; CHECK-NEXT: call void @test_nocallback_definition() -; Test that checks that nocallback attribute on a call-site is dropped. +; Test that checks that nocallback attribute on a call-site in a call instruction is dropped. ; CHECK-NEXT: call void @test_nocallback_call_site(){{$}} ; CHECK-NEXT: %0 = call float @llvm.sqrt.f32(float undef) ; CHECK-NEXT: call void @test_nocallback_declaration_definition_not_linked_in() ; CHECK-NEXT: call void @test_nocallback_declaration_definition_linked_in() -define i32 @main() { +; Test that checks that nocallback attribute on a call-site in an invoke instruction is dropped. +; CHECK-NEXT: invoke void @test_nocallback_call_site(){{$}} +define i32 @main() personality i8 0 { entry: call void @test_nocallback_definition() call void @test_nocallback_call_site() nocallback call float @llvm.sqrt.f32(float undef) call void @test_nocallback_declaration_definition_not_linked_in() call void @test_nocallback_declaration_definition_linked_in() + invoke void @test_nocallback_call_site() nocallback + to label %ret unwind label %unw +unw: + %tmp = landingpad i8 cleanup + br label %ret +ret: ret i32 0 } @@ -26,7 +34,7 @@ define void @test_nocallback_definition() nocallback { ret void } -; Test that checks that nocallback attribute on a declaration when a definition is linked in is dropped. +; Test that checks that nocallback attribute on a call site is dropped. ; CHECK: declare void @test_nocallback_call_site(){{$}} declare void @test_nocallback_call_site() From 9f59affa244c1b15d9980c9abcaff433514a9d85 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Thu, 19 Jan 2023 18:41:00 +0000 Subject: [PATCH 44/61] Revert "[mlir][Index] Implement InferIntRangeInterface" This reverts commit 455305624884cf9237143e2ba0635fcc5ba5206a. Linker error, unbreak build while I work out how to fix it. Differential Revision: https://reviews.llvm.org/D142142 --- mlir/include/mlir/Dialect/Index/IR/IndexOps.h | 1 - .../include/mlir/Dialect/Index/IR/IndexOps.td | 4 +- .../Interfaces/Utils/InferIntRangeCommon.h | 126 ---- mlir/lib/Dialect/Arith/IR/CMakeLists.txt | 1 - .../Arith/IR/InferIntRangeInterfaceImpls.cpp | 481 +++++++++++-- mlir/lib/Dialect/Index/IR/CMakeLists.txt | 2 - .../Index/IR/InferIntRangeInterfaceImpls.cpp | 252 ------- mlir/lib/Interfaces/CMakeLists.txt | 2 - mlir/lib/Interfaces/Utils/CMakeLists.txt | 13 - .../Interfaces/Utils/InferIntRangeCommon.cpp | 663 ------------------ .../Dialect/Index/int-range-inference.mlir | 66 -- 11 files changed, 441 insertions(+), 1170 deletions(-) delete mode 100644 mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h delete mode 100644 mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp delete mode 100644 mlir/lib/Interfaces/Utils/CMakeLists.txt delete mode 100644 mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp delete mode 100644 mlir/test/Dialect/Index/int-range-inference.mlir diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h index d8debfb731323c..85a0549edd4dd5 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -13,7 +13,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" -#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td index 8fbccc4ba94fca..76008a17364f93 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -12,7 +12,6 @@ include "mlir/Dialect/Index/IR/IndexDialect.td" include "mlir/Dialect/Index/IR/IndexEnums.td" include "mlir/Interfaces/CastInterfaces.td" -include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" @@ -24,8 +23,7 @@ include "mlir/IR/OpBase.td" /// Base class for Index dialect operations. class IndexOp traits = []> - : Op] # traits>; + : Op; //===----------------------------------------------------------------------===// // IndexBinaryOp diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h deleted file mode 100644 index 7ee059cf342ce1..00000000000000 --- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h +++ /dev/null @@ -1,126 +0,0 @@ -//===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares implementations of range inference for operations that are -// common to both the `arith` and `index` dialects to facilitate reuse. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H -#define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H - -#include "mlir/Interfaces/InferIntRangeInterface.h" -#include "llvm/ADT/ArrayRef.h" - -namespace mlir { -namespace intrange { -/// Function that performs inference on an array of `ConstantIntRanges`, -/// abstracted away here to permit writing the function that handles both -/// 64- and 32-bit index types. -using InferRangeFn = - function_ref)>; - -static constexpr unsigned indexMinWidth = 32; -static constexpr unsigned indexMaxWidth = 64; - -enum class CmpMode : uint32_t { Both, Signed, Unsigned }; - -/// Compute `inferFn` on `ranges`, whose size should be the index storage -/// bitwidth. Then, compute the function on `argRanges` again after truncating -/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is -/// equal to the 32-bit result, use it (to preserve compatibility with folders -/// and inference precision), and take the union of the results otherwise. -/// -/// The `mode` argument specifies if the unsigned, signed, or both results of -/// the inference computation should be used when comparing the results. -ConstantIntRanges inferIndexOp(InferRangeFn inferFn, - ArrayRef argRanges, - CmpMode mode); - -/// Independently zero-extend the unsigned values and sign-extend the signed -/// values in `range` to `destWidth` bits, returning the resulting range. -ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth); - -/// Use the unsigned values in `range` to zero-extend it to `destWidth`. -ConstantIntRanges extUIRange(const ConstantIntRanges &range, - unsigned destWidth); - -/// Use the signed values in `range` to sign-extend it to `destWidth`. -ConstantIntRanges extSIRange(const ConstantIntRanges &range, - unsigned destWidth); - -/// Truncate `range` to `destWidth` bits, taking care to handle cases such as -/// the truncation of [255, 256] to i8 not being a uniform range. -ConstantIntRanges truncRange(const ConstantIntRanges &range, - unsigned destWidth); - -ConstantIntRanges inferAdd(ArrayRef argRanges); - -ConstantIntRanges inferSub(ArrayRef argRanges); - -ConstantIntRanges inferMul(ArrayRef argRanges); - -ConstantIntRanges inferDivS(ArrayRef argRanges); - -ConstantIntRanges inferDivU(ArrayRef argRanges); - -ConstantIntRanges inferCeilDivS(ArrayRef argRanges); - -ConstantIntRanges inferCeilDivU(ArrayRef argRanges); - -ConstantIntRanges inferFloorDivS(ArrayRef argRanges); - -ConstantIntRanges inferRemS(ArrayRef argRanges); - -ConstantIntRanges inferRemU(ArrayRef argRanges); - -ConstantIntRanges inferMaxS(ArrayRef argRanges); - -ConstantIntRanges inferMaxU(ArrayRef argRanges); - -ConstantIntRanges inferMinS(ArrayRef argRanges); - -ConstantIntRanges inferMinU(ArrayRef argRanges); - -ConstantIntRanges inferAnd(ArrayRef argRanges); - -ConstantIntRanges inferOr(ArrayRef argRanges); - -ConstantIntRanges inferXor(ArrayRef argRanges); - -ConstantIntRanges inferShl(ArrayRef argRanges); - -ConstantIntRanges inferShrS(ArrayRef argRanges); - -ConstantIntRanges inferShrU(ArrayRef argRanges); - -/// Copy of the enum from `arith` and `index` to allow the common integer range -/// infrastructure to not depend on either dialect. -enum class CmpPredicate : uint64_t { - eq, - ne, - slt, - sle, - sgt, - sge, - ult, - ule, - ugt, - uge, -}; - -/// Returns a boolean value if `pred` is statically true or false for -/// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the -/// value of the predicate cannot be determined. -Optional evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs); - -} // namespace intrange -} // namespace mlir - -#endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt index ffbe80105911ed..0de17bbfbd12a2 100644 --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -16,7 +16,6 @@ add_mlir_dialect_library(MLIRArithDialect LINK_LIBS PUBLIC MLIRDialect - MLIRInferIntRangeCommon MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp index 971477fa94cb9f..10d6ef29756c6f 100644 --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Interfaces/InferIntRangeInterface.h" -#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/Support/Debug.h" #include @@ -17,7 +16,48 @@ using namespace mlir; using namespace mlir::arith; -using namespace mlir::intrange; + +/// Function that evaluates the result of doing something on arithmetic +/// constants and returns std::nullopt on overflow. +using ConstArithFn = + function_ref(const APInt &, const APInt &)>; + +/// Return the maxmially wide signed or unsigned range for a given bitwidth. + +/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, +/// If either computation overflows, make the result unbounded. +static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, + const APInt &minRight, + const APInt &maxLeft, + const APInt &maxRight, bool isSigned) { + std::optional maybeMin = op(minLeft, minRight); + std::optional maybeMax = op(maxLeft, maxRight); + if (maybeMin && maybeMax) + return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); + return ConstantIntRanges::maxRange(minLeft.getBitWidth()); +} + +/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, +/// ignoring unbounded values. Returns the maximal range if `op` overflows. +static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, + ArrayRef rhs, bool isSigned) { + unsigned width = lhs[0].getBitWidth(); + APInt min = + isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); + APInt max = + isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); + for (const APInt &left : lhs) { + for (const APInt &right : rhs) { + std::optional maybeThisResult = op(left, right); + if (!maybeThisResult) + return ConstantIntRanges::maxRange(width); + APInt result = std::move(*maybeThisResult); + min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; + max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; + } + } + return ConstantIntRanges::range(min, max, isSigned); +} //===----------------------------------------------------------------------===// // ConstantOp @@ -38,7 +78,25 @@ void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, void arith::AddIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferAdd(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn uadd = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.uadd_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn sadd = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.sadd_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + + ConstantIntRanges urange = computeBoundsBy( + uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); + setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// @@ -47,7 +105,25 @@ void arith::AddIOp::inferResultRanges(ArrayRef argRanges, void arith::SubIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferSub(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn usub = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.usub_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn ssub = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.ssub_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstantIntRanges urange = computeBoundsBy( + usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); + setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// @@ -56,25 +132,96 @@ void arith::SubIOp::inferResultRanges(ArrayRef argRanges, void arith::MulIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMul(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn umul = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.umul_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn smul = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.smul_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + + ConstantIntRanges urange = + minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, + /*isSigned=*/true); + + setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// +/// Fix up division results (ex. for ceiling and floor), returning an APInt +/// if there has been no overflow +using DivisionFixupFn = function_ref( + const APInt &lhs, const APInt &rhs, const APInt &result)>; + +static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), + &rhsMax = rhs.umax(); + + if (!rhsMin.isZero()) { + auto udiv = [&fixup](const APInt &a, + const APInt &b) -> std::optional { + return fixup(a, b, a.udiv(b)); + }; + return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/false); + } + // Otherwise, it's possible we might divide by 0. + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferDivU(argRanges)); + setResultRange(getResult(), + inferDivUIRange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; })); } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// +static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); + + if (canDivide) { + auto sdiv = [&fixup](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.sdiv_ov(b, overflowed); + return overflowed ? std::optional() : fixup(a, b, result); + }; + return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/true); + } + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferDivS(argRanges)); + setResultRange(getResult(), + inferDivSIRange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; })); } //===----------------------------------------------------------------------===// @@ -83,7 +230,20 @@ void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, void arith::CeilDivUIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferCeilDivU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivUIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.urem(rhs).isZero()) { + bool overflowed = false; + APInt corrected = + result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix)); } //===----------------------------------------------------------------------===// @@ -92,7 +252,20 @@ void arith::CeilDivUIOp::inferResultRanges( void arith::CeilDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferCeilDivS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivSIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix)); } //===----------------------------------------------------------------------===// @@ -101,7 +274,20 @@ void arith::CeilDivSIOp::inferResultRanges( void arith::FloorDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - return setResultRange(getResult(), inferFloorDivS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn floorDivSIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix)); } //===----------------------------------------------------------------------===// @@ -110,7 +296,29 @@ void arith::FloorDivSIOp::inferResultRanges( void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferRemU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); + + unsigned width = rhsMin.getBitWidth(); + APInt umin = APInt::getZero(width); + APInt umax = APInt::getMaxValue(width); + + if (!rhsMin.isZero()) { + umax = rhsMax - 1; + // Special case: sweeping out a contiguous range in N/[modulus] + if (rhsMin == rhsMax) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); + if ((lhsMax - lhsMin).ult(rhsMax)) { + APInt minRem = lhsMin.urem(rhsMax); + APInt maxRem = lhsMax.urem(rhsMax); + if (minRem.ule(maxRem)) { + umin = minRem; + umax = maxRem; + } + } + } + } + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// @@ -119,16 +327,67 @@ void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferRemS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + + unsigned width = rhsMax.getBitWidth(); + APInt smin = APInt::getSignedMinValue(width); + APInt smax = APInt::getSignedMaxValue(width); + // No bounds if zero could be a divisor. + bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); + if (canBound) { + APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); + bool canNegativeDividend = lhsMin.isNegative(); + bool canPositiveDividend = lhsMax.isStrictlyPositive(); + APInt zero = APInt::getZero(maxDivisor.getBitWidth()); + APInt maxPositiveResult = maxDivisor - 1; + APInt minNegativeResult = -maxPositiveResult; + smin = canNegativeDividend ? minNegativeResult : zero; + smax = canPositiveDividend ? maxPositiveResult : zero; + // Special case: sweeping out a contiguous range in N/[modulus]. + if (rhsMin == rhsMax) { + if ((lhsMax - lhsMin).ult(maxDivisor)) { + APInt minRem = lhsMin.srem(maxDivisor); + APInt maxRem = lhsMax.srem(maxDivisor); + if (minRem.sle(maxRem)) { + smin = minRem; + smax = maxRem; + } + } + } + } + setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// +/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, +/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits +/// that both bonuds have in common. This gives us a consertive approximation +/// for what values can be passed to bitwise operations. +static std::tuple +widenBitwiseBounds(const ConstantIntRanges &bound) { + APInt leftVal = bound.umin(), rightVal = bound.umax(); + unsigned bitwidth = leftVal.getBitWidth(); + unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); + leftVal.clearLowBits(differingBits); + rightVal.setLowBits(differingBits); + return std::make_tuple(std::move(leftVal), std::move(rightVal)); +} + void arith::AndIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferAnd(argRanges)); + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto andi = [](const APInt &a, const APInt &b) -> std::optional { + return a & b; + }; + setResultRange(getResult(), + minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false)); } //===----------------------------------------------------------------------===// @@ -137,7 +396,14 @@ void arith::AndIOp::inferResultRanges(ArrayRef argRanges, void arith::OrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferOr(argRanges)); + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto ori = [](const APInt &a, const APInt &b) -> std::optional { + return a | b; + }; + setResultRange(getResult(), + minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false)); } //===----------------------------------------------------------------------===// @@ -146,7 +412,14 @@ void arith::OrIOp::inferResultRanges(ArrayRef argRanges, void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferXor(argRanges)); + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto xori = [](const APInt &a, const APInt &b) -> std::optional { + return a ^ b; + }; + setResultRange(getResult(), + minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false)); } //===----------------------------------------------------------------------===// @@ -155,7 +428,11 @@ void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMaxS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); + setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); } //===----------------------------------------------------------------------===// @@ -164,7 +441,11 @@ void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMaxU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// @@ -173,7 +454,11 @@ void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMinS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); + setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); } //===----------------------------------------------------------------------===// @@ -182,40 +467,94 @@ void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMinU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// +static ConstantIntRanges extUIRange(const ConstantIntRanges &range, + Type destType) { + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + APInt umin = range.umin().zext(destWidth); + APInt umax = range.umax().zext(destWidth); + return ConstantIntRanges::fromUnsigned(umin, umax); +} + void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - unsigned destWidth = - ConstantIntRanges::getStorageBitwidth(getResult().getType()); - setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); + Type destType = getResult().getType(); + setResultRange(getResult(), extUIRange(argRanges[0], destType)); } //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// +static ConstantIntRanges extSIRange(const ConstantIntRanges &range, + Type destType) { + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + APInt smin = range.smin().sext(destWidth); + APInt smax = range.smax().sext(destWidth); + return ConstantIntRanges::fromSigned(smin, smax); +} + void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - unsigned destWidth = - ConstantIntRanges::getStorageBitwidth(getResult().getType()); - setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); + Type destType = getResult().getType(); + setResultRange(getResult(), extSIRange(argRanges[0], destType)); } //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// +static ConstantIntRanges truncIRange(const ConstantIntRanges &range, + Type destType) { + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], + // the range of the resulting value is not contiguous ind includes 0. + // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], + // but you can't truncate [255, 257] similarly. + bool hasUnsignedRollover = + range.umin().lshr(destWidth) != range.umax().lshr(destWidth); + APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) + : range.umin().trunc(destWidth); + APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) + : range.umax().trunc(destWidth); + + // Signed post-truncation rollover will not occur when either: + // - The high parts of the min and max, plus the sign bit, are the same + // - The high halves + sign bit of the min and max are either all 1s or all 0s + // and you won't create a [positive, negative] range by truncating. + // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 + // but not [255, 257]_i16 to a range of i8s. You can also truncate + // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. + // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) + // will truncate to 0x7e, which is greater than 0 + APInt sminHighPart = range.smin().ashr(destWidth - 1); + APInt smaxHighPart = range.smax().ashr(destWidth - 1); + bool hasSignedOverflow = + (sminHighPart != smaxHighPart) && + !(sminHighPart.isAllOnes() && + (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && + !(sminHighPart.isZero() && smaxHighPart.isZero()); + APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) + : range.smin().trunc(destWidth); + APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) + : range.smax().trunc(destWidth); + return {umin, umax, smin, smax}; +} + void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - unsigned destWidth = - ConstantIntRanges::getStorageBitwidth(getResult().getType()); - setResultRange(getResult(), truncRange(argRanges[0], destWidth)); + Type destType = getResult().getType(); + setResultRange(getResult(), truncIRange(argRanges[0], destType)); } //===----------------------------------------------------------------------===// @@ -230,9 +569,9 @@ void arith::IndexCastOp::inferResultRanges( unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); + setResultRange(getResult(), extSIRange(argRanges[0], destType)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncRange(argRanges[0], destWidth)); + setResultRange(getResult(), truncIRange(argRanges[0], destType)); else setResultRange(getResult(), argRanges[0]); } @@ -249,9 +588,9 @@ void arith::IndexCastUIOp::inferResultRanges( unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); + setResultRange(getResult(), extUIRange(argRanges[0], destType)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncRange(argRanges[0], destWidth)); + setResultRange(getResult(), truncIRange(argRanges[0], destType)); else setResultRange(getResult(), argRanges[0]); } @@ -260,19 +599,51 @@ void arith::IndexCastUIOp::inferResultRanges( // CmpIOp //===----------------------------------------------------------------------===// +bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs) { + switch (pred) { + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::slt: + return (applyCmpPredicate(pred, lhs.smax(), rhs.smin())); + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: + return applyCmpPredicate(pred, lhs.umax(), rhs.umin()); + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::sgt: + return applyCmpPredicate(pred, lhs.smin(), rhs.smax()); + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + return applyCmpPredicate(pred, lhs.umin(), rhs.umax()); + case arith::CmpIPredicate::eq: { + std::optional lhsConst = lhs.getConstantValue(); + std::optional rhsConst = rhs.getConstantValue(); + return lhsConst && rhsConst && lhsConst == rhsConst; + } + case arith::CmpIPredicate::ne: { + // While equality requires that there is an interpration of the preceeding + // computations that produces equal constants, whether that be signed or + // unsigned, statically determining inequality requires that neither + // interpretation produce potentially overlapping ranges. + bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); + bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); + return sne && une; + } + } + return false; +} + void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - arith::CmpIPredicate arithPred = getPredicate(); - intrange::CmpPredicate pred = static_cast(arithPred); + arith::CmpIPredicate pred = getPredicate(); const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; APInt min = APInt::getZero(1); APInt max = APInt::getAllOnesValue(1); - - Optional truthValue = intrange::evaluatePred(pred, lhs, rhs); - if (truthValue.has_value() && *truthValue) + if (isStaticallyTrue(pred, lhs, rhs)) min = max; - else if (truthValue.has_value() && !(*truthValue)) + else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) max = min; setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); @@ -302,7 +673,18 @@ void arith::SelectOp::inferResultRanges(ArrayRef argRanges, void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferShl(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn shl = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); + }; + ConstantIntRanges urange = + minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/true); + setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// @@ -311,7 +693,15 @@ void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferShrU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn lshr = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); + }; + setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()}, + {rhs.umin(), rhs.umax()}, + /*isSigned=*/false)); } //===----------------------------------------------------------------------===// @@ -320,5 +710,14 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferShrS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn ashr = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); + }; + + setResultRange(getResult(), + minMaxBy(ashr, {lhs.smin(), lhs.smax()}, + {rhs.umin(), rhs.umax()}, /*isSigned=*/true)); } diff --git a/mlir/lib/Dialect/Index/IR/CMakeLists.txt b/mlir/lib/Dialect/Index/IR/CMakeLists.txt index e820eececa483c..53321f1ea3f25c 100644 --- a/mlir/lib/Dialect/Index/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Index/IR/CMakeLists.txt @@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRIndexDialect IndexAttrs.cpp IndexDialect.cpp IndexOps.cpp - InferIntRangeInterfaceImpls.cpp DEPENDS MLIRIndexOpsIncGen @@ -11,7 +10,6 @@ add_mlir_dialect_library(MLIRIndexDialect MLIRDialect MLIRIR MLIRCastInterfaces - MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp deleted file mode 100644 index 6daa7640b017e4..00000000000000 --- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp +++ /dev/null @@ -1,252 +0,0 @@ -//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Interfaces/InferIntRangeInterface.h" -#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" - -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "int-range-analysis" - -using namespace mlir; -using namespace mlir::index; -using namespace mlir::intrange; - -//===----------------------------------------------------------------------===// -// Constants -//===----------------------------------------------------------------------===// - -void ConstantOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - const APInt &value = getValue(); - setResultRange(getResult(), ConstantIntRanges::constant(value)); -} - -void BoolConstantOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - bool value = getValue(); - APInt asInt(/*numBits=*/1, value); - setResultRange(getResult(), ConstantIntRanges::constant(asInt)); -} - -//===----------------------------------------------------------------------===// -// Arithmec operations. All of these operations will have their results inferred -// using both the 64-bit values and truncated 32-bit values of their inputs, -// with the results being the union of those inferences, except where the -// truncation of the 64-bit result is equal to the 32-bit result (at which time -// we take the 64-bit result). -//===----------------------------------------------------------------------===// - -void AddOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both)); -} - -void SubOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both)); -} - -void MulOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both)); -} - -void DivUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned)); -} - -void DivSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferDivS, argRanges, CmpMode::Signed)); -} - -void CeilDivUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned)); -} - -void CeilDivSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed)); -} - -void FloorDivSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - return setResultRange( - getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed)); -} - -void RemSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferRemS, argRanges, CmpMode::Signed)); -} - -void RemUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned)); -} - -void MaxSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferMaxS, argRanges, CmpMode::Signed)); -} - -void MaxUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned)); -} - -void MinSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferMinS, argRanges, CmpMode::Signed)); -} - -void MinUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned)); -} - -void ShlOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both)); -} - -void ShrSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferShrS, argRanges, CmpMode::Signed)); -} - -void ShrUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned)); -} - -void AndOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned)); -} - -void OrOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferOr, argRanges, CmpMode::Unsigned)); -} - -void XOrOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferXor, argRanges, CmpMode::Unsigned)); -} - -//===----------------------------------------------------------------------===// -// Casts -//===----------------------------------------------------------------------===// - -static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, - unsigned srcWidth, unsigned destWidth, - bool isSigned) { - if (srcWidth < destWidth) - return isSigned ? extSIRange(range, destWidth) - : extUIRange(range, destWidth); - if (srcWidth > destWidth) - return truncRange(range, destWidth); - return range; -} - -// When casting to `index`, we will take the union of the possible fixed-width -// casts. -static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, - Type sourceType, Type destType, - bool isSigned) { - unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - if (sourceType.isIndex()) - return makeLikeDest(range, srcWidth, destWidth, isSigned); - // We are casting to indexs, so use the union of the 32-bit and 64-bit casts - ConstantIntRanges storageRange = - makeLikeDest(range, srcWidth, destWidth, isSigned); - ConstantIntRanges minWidthRange = - makeLikeDest(range, srcWidth, indexMinWidth, isSigned); - ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth); - ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt); - return ret; -} - -void CastSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - Type sourceType = getOperand().getType(); - Type destType = getResult().getType(); - setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, - /*isSigned=*/true)); -} - -void CastUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - Type sourceType = getOperand().getType(); - Type destType = getResult().getType(); - setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, - /*isSigned=*/false)); -} - -//===----------------------------------------------------------------------===// -// CmpOp -//===----------------------------------------------------------------------===// - -void CmpOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - index::IndexCmpPredicate indexPred = getPred(); - intrange::CmpPredicate pred = static_cast(indexPred); - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - APInt min = APInt::getZero(1); - APInt max = APInt::getAllOnesValue(1); - - Optional truthValue64 = intrange::evaluatePred(pred, lhs, rhs); - - ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth), - rhsTrunc = truncRange(rhs, indexMinWidth); - Optional truthValue32 = - intrange::evaluatePred(pred, lhsTrunc, rhsTrunc); - - if (truthValue64 == truthValue32) { - if (truthValue64.has_value() && *truthValue64) - min = max; - else if (truthValue64.has_value() && !(*truthValue64)) - max = min; - } - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); -} - -//===----------------------------------------------------------------------===// -// SizeOf, which is bounded between the two supported bitwidth (32 and 64). -//===----------------------------------------------------------------------===// - -void SizeOfOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - unsigned storageWidth = - ConstantIntRanges::getStorageBitwidth(getResult().getType()); - APInt min(/*numBits=*/storageWidth, indexMinWidth); - APInt max(/*numBits=*/storageWidth, indexMaxWidth); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); -} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index 38ad0e4a2231c7..a7cdbb5b3a6fec 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -51,5 +51,3 @@ add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) - -add_subdirectory(Utils) diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt deleted file mode 100644 index ece6c8e46ffea9..00000000000000 --- a/mlir/lib/Interfaces/Utils/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_mlir_library(MLIRInferIntRangeCommon - InferIntRangeCommon.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils - - DEPENDS - MLIRInferIntRangeInterfaceIncGen - - LINK_LIBS PUBLIC - MLIRInferIntRangeInterface - MLIRIR -) diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp deleted file mode 100644 index c81f004ecf5f92..00000000000000 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ /dev/null @@ -1,663 +0,0 @@ -//===- InferIntRangeCommon.cpp - Inference for common ops ------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains implementations of range inference for operations that are -// common to both the `arith` and `index` dialects to facilitate reuse. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" - -#include "mlir/Interfaces/InferIntRangeInterface.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" - -#include "llvm/Support/Debug.h" - -#include -#include - -using namespace mlir; - -#define DEBUG_TYPE "int-range-analysis" - -//===----------------------------------------------------------------------===// -// General utilities -//===----------------------------------------------------------------------===// - -/// Function that evaluates the result of doing something on arithmetic -/// constants and returns std::nullopt on overflow. -using ConstArithFn = - function_ref(const APInt &, const APInt &)>; - -/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, -/// If either computation overflows, make the result unbounded. -static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, - const APInt &minRight, - const APInt &maxLeft, - const APInt &maxRight, bool isSigned) { - std::optional maybeMin = op(minLeft, minRight); - std::optional maybeMax = op(maxLeft, maxRight); - if (maybeMin && maybeMax) - return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); - return ConstantIntRanges::maxRange(minLeft.getBitWidth()); -} - -/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, -/// ignoring unbounded values. Returns the maximal range if `op` overflows. -static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, - ArrayRef rhs, bool isSigned) { - unsigned width = lhs[0].getBitWidth(); - APInt min = - isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); - APInt max = - isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); - for (const APInt &left : lhs) { - for (const APInt &right : rhs) { - std::optional maybeThisResult = op(left, right); - if (!maybeThisResult) - return ConstantIntRanges::maxRange(width); - APInt result = std::move(*maybeThisResult); - min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; - max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; - } - } - return ConstantIntRanges::range(min, max, isSigned); -} - -//===----------------------------------------------------------------------===// -// Ext, trunc, index op handling -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferIndexOp(InferRangeFn inferFn, - ArrayRef argRanges, - intrange::CmpMode mode) { - ConstantIntRanges sixtyFour = inferFn(argRanges); - SmallVector truncated; - llvm::transform(argRanges, std::back_inserter(truncated), - [](const ConstantIntRanges &range) { - return truncRange(range, /*destWidth=*/indexMinWidth); - }); - ConstantIntRanges thirtyTwo = inferFn(truncated); - ConstantIntRanges thirtyTwoAsSixtyFour = - extRange(thirtyTwo, /*destWidth=*/indexMaxWidth); - ConstantIntRanges sixtyFourAsThirtyTwo = - truncRange(sixtyFour, /*destWidth=*/indexMinWidth); - - LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour - << " 32-bit = " << thirtyTwo << "\n"); - bool truncEqual = false; - switch (mode) { - case intrange::CmpMode::Both: - truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); - break; - case intrange::CmpMode::Signed: - truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && - thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); - break; - case intrange::CmpMode::Unsigned: - truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && - thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); - break; - } - if (truncEqual) - // Returing the 64-bit result preserves more information. - return sixtyFour; - ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); - return merged; -} - -ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, - unsigned int destWidth) { - APInt umin = range.umin().zext(destWidth); - APInt umax = range.umax().zext(destWidth); - APInt smin = range.smin().sext(destWidth); - APInt smax = range.smax().sext(destWidth); - return {umin, umax, smin, smax}; -} - -ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, - unsigned destWidth) { - APInt umin = range.umin().zext(destWidth); - APInt umax = range.umax().zext(destWidth); - return ConstantIntRanges::fromUnsigned(umin, umax); -} - -ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, - unsigned destWidth) { - APInt smin = range.smin().sext(destWidth); - APInt smax = range.smax().sext(destWidth); - return ConstantIntRanges::fromSigned(smin, smax); -} - -ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, - unsigned int destWidth) { - // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], - // the range of the resulting value is not contiguous ind includes 0. - // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], - // but you can't truncate [255, 257] similarly. - bool hasUnsignedRollover = - range.umin().lshr(destWidth) != range.umax().lshr(destWidth); - APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) - : range.umin().trunc(destWidth); - APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) - : range.umax().trunc(destWidth); - - // Signed post-truncation rollover will not occur when either: - // - The high parts of the min and max, plus the sign bit, are the same - // - The high halves + sign bit of the min and max are either all 1s or all 0s - // and you won't create a [positive, negative] range by truncating. - // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 - // but not [255, 257]_i16 to a range of i8s. You can also truncate - // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. - // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) - // will truncate to 0x7e, which is greater than 0 - APInt sminHighPart = range.smin().ashr(destWidth - 1); - APInt smaxHighPart = range.smax().ashr(destWidth - 1); - bool hasSignedOverflow = - (sminHighPart != smaxHighPart) && - !(sminHighPart.isAllOnes() && - (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && - !(sminHighPart.isZero() && smaxHighPart.isZero()); - APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) - : range.smin().trunc(destWidth); - APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) - : range.smax().trunc(destWidth); - return {umin, umax, smin, smax}; -} - -//===----------------------------------------------------------------------===// -// Addition -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferAdd(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn uadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.uadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn sadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = computeBoundsBy( - uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); - return urange.intersection(srange); -} - -//===----------------------------------------------------------------------===// -// Subtraction -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferSub(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn usub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.usub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn ssub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.ssub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstantIntRanges urange = computeBoundsBy( - usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); - return urange.intersection(srange); -} - -//===----------------------------------------------------------------------===// -// Multiplication -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferMul(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn umul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.umul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn smul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.smul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = - minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, - /*isSigned=*/true); - return urange.intersection(srange); -} - -//===----------------------------------------------------------------------===// -// DivU, CeilDivU (Unsigned division) -//===----------------------------------------------------------------------===// - -/// Fix up division results (ex. for ceiling and floor), returning an APInt -/// if there has been no overflow -using DivisionFixupFn = function_ref( - const APInt &lhs, const APInt &rhs, const APInt &result)>; - -static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), - &rhsMax = rhs.umax(); - - if (!rhsMin.isZero()) { - auto udiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - return fixup(a, b, a.udiv(b)); - }; - return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/false); - } - // Otherwise, it's possible we might divide by 0. - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - -ConstantIntRanges -mlir::intrange::inferDivU(ArrayRef argRanges) { - return inferDivURange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; }); -} - -ConstantIntRanges -mlir::intrange::inferCeilDivU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivUIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.urem(rhs).isZero()) { - bool overflowed = false; - APInt corrected = - result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - return inferDivURange(lhs, rhs, ceilDivUIFix); -} - -//===----------------------------------------------------------------------===// -// DivS, CeilDivS, FloorDivS (Signed division) -//===----------------------------------------------------------------------===// - -static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); - - if (canDivide) { - auto sdiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sdiv_ov(b, overflowed); - return overflowed ? std::optional() : fixup(a, b, result); - }; - return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/true); - } - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - -ConstantIntRanges -mlir::intrange::inferDivS(ArrayRef argRanges) { - return inferDivSRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; }); -} - -ConstantIntRanges -mlir::intrange::inferCeilDivS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - return inferDivSRange(lhs, rhs, ceilDivSIFix); -} - -ConstantIntRanges -mlir::intrange::inferFloorDivS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn floorDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - return inferDivSRange(lhs, rhs, floorDivSIFix); -} - -//===----------------------------------------------------------------------===// -// Signed remainder (RemS) -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferRemS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - - unsigned width = rhsMax.getBitWidth(); - APInt smin = APInt::getSignedMinValue(width); - APInt smax = APInt::getSignedMaxValue(width); - // No bounds if zero could be a divisor. - bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); - if (canBound) { - APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); - bool canNegativeDividend = lhsMin.isNegative(); - bool canPositiveDividend = lhsMax.isStrictlyPositive(); - APInt zero = APInt::getZero(maxDivisor.getBitWidth()); - APInt maxPositiveResult = maxDivisor - 1; - APInt minNegativeResult = -maxPositiveResult; - smin = canNegativeDividend ? minNegativeResult : zero; - smax = canPositiveDividend ? maxPositiveResult : zero; - // Special case: sweeping out a contiguous range in N/[modulus]. - if (rhsMin == rhsMax) { - if ((lhsMax - lhsMin).ult(maxDivisor)) { - APInt minRem = lhsMin.srem(maxDivisor); - APInt maxRem = lhsMax.srem(maxDivisor); - if (minRem.sle(maxRem)) { - smin = minRem; - smax = maxRem; - } - } - } - } - return ConstantIntRanges::fromSigned(smin, smax); -} - -//===----------------------------------------------------------------------===// -// Unsigned remainder (RemU) -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferRemU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); - - unsigned width = rhsMin.getBitWidth(); - APInt umin = APInt::getZero(width); - APInt umax = APInt::getMaxValue(width); - - if (!rhsMin.isZero()) { - umax = rhsMax - 1; - // Special case: sweeping out a contiguous range in N/[modulus] - if (rhsMin == rhsMax) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); - if ((lhsMax - lhsMin).ult(rhsMax)) { - APInt minRem = lhsMin.urem(rhsMax); - APInt maxRem = lhsMax.urem(rhsMax); - if (minRem.ule(maxRem)) { - umin = minRem; - umax = maxRem; - } - } - } - } - return ConstantIntRanges::fromUnsigned(umin, umax); -} - -//===----------------------------------------------------------------------===// -// Max and min (MaxS, MaxU, MinS, MinU) -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferMaxS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); - return ConstantIntRanges::fromSigned(smin, smax); -} - -ConstantIntRanges -mlir::intrange::inferMaxU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); - return ConstantIntRanges::fromUnsigned(umin, umax); -} - -ConstantIntRanges -mlir::intrange::inferMinS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); - return ConstantIntRanges::fromSigned(smin, smax); -} - -ConstantIntRanges -mlir::intrange::inferMinU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); - return ConstantIntRanges::fromUnsigned(umin, umax); -} - -//===----------------------------------------------------------------------===// -// Bitwise operators (And, Or, Xor) -//===----------------------------------------------------------------------===// - -/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, -/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits -/// that both bonuds have in common. This gives us a consertive approximation -/// for what values can be passed to bitwise operations. -static std::tuple -widenBitwiseBounds(const ConstantIntRanges &bound) { - APInt leftVal = bound.umin(), rightVal = bound.umax(); - unsigned bitwidth = leftVal.getBitWidth(); - unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); - leftVal.clearLowBits(differingBits); - rightVal.setLowBits(differingBits); - return std::make_tuple(std::move(leftVal), std::move(rightVal)); -} - -ConstantIntRanges -mlir::intrange::inferAnd(ArrayRef argRanges) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto andi = [](const APInt &a, const APInt &b) -> std::optional { - return a & b; - }; - return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false); -} - -ConstantIntRanges -mlir::intrange::inferOr(ArrayRef argRanges) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto ori = [](const APInt &a, const APInt &b) -> std::optional { - return a | b; - }; - return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false); -} - -ConstantIntRanges -mlir::intrange::inferXor(ArrayRef argRanges) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto xori = [](const APInt &a, const APInt &b) -> std::optional { - return a ^ b; - }; - return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false); -} - -//===----------------------------------------------------------------------===// -// Shifts (Shl, ShrS, ShrU) -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferShl(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn shl = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); - }; - ConstantIntRanges urange = - minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/true); - return urange.intersection(srange); -} - -ConstantIntRanges -mlir::intrange::inferShrS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn ashr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); - }; - - return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/true); -} - -ConstantIntRanges -mlir::intrange::inferShrU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn lshr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); - }; - return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); -} - -//===----------------------------------------------------------------------===// -// Comparisons (Cmp) -//===----------------------------------------------------------------------===// - -static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { - switch (pred) { - case intrange::CmpPredicate::eq: - return intrange::CmpPredicate::ne; - case intrange::CmpPredicate::ne: - return intrange::CmpPredicate::eq; - case intrange::CmpPredicate::slt: - return intrange::CmpPredicate::sge; - case intrange::CmpPredicate::sle: - return intrange::CmpPredicate::sgt; - case intrange::CmpPredicate::sgt: - return intrange::CmpPredicate::sle; - case intrange::CmpPredicate::sge: - return intrange::CmpPredicate::slt; - case intrange::CmpPredicate::ult: - return intrange::CmpPredicate::uge; - case intrange::CmpPredicate::ule: - return intrange::CmpPredicate::ugt; - case intrange::CmpPredicate::ugt: - return intrange::CmpPredicate::ule; - case intrange::CmpPredicate::uge: - return intrange::CmpPredicate::ult; - } - llvm_unreachable("unknown cmp predicate value"); -} - -static bool isStaticallyTrue(intrange::CmpPredicate pred, - const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs) { - switch (pred) { - case intrange::CmpPredicate::sle: - return lhs.smax().sle(rhs.smin()); - case intrange::CmpPredicate::slt: - return lhs.smax().slt(rhs.smin()); - case intrange::CmpPredicate::ule: - return lhs.umax().ule(rhs.umin()); - case intrange::CmpPredicate::ult: - return lhs.umax().ult(rhs.umin()); - case intrange::CmpPredicate::sge: - return lhs.smin().sge(rhs.smax()); - case intrange::CmpPredicate::sgt: - return lhs.smin().sgt(rhs.smax()); - case intrange::CmpPredicate::uge: - return lhs.umin().uge(rhs.umax()); - case intrange::CmpPredicate::ugt: - return lhs.umin().ugt(rhs.umax()); - case intrange::CmpPredicate::eq: { - std::optional lhsConst = lhs.getConstantValue(); - std::optional rhsConst = rhs.getConstantValue(); - return lhsConst && rhsConst && lhsConst == rhsConst; - } - case intrange::CmpPredicate::ne: { - // While equality requires that there is an interpration of the preceeding - // computations that produces equal constants, whether that be signed or - // unsigned, statically determining inequality requires that neither - // interpretation produce potentially overlapping ranges. - bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) || - isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs); - bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) || - isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs); - return sne && une; - } - } - return false; -} - -std::optional mlir::intrange::evaluatePred(CmpPredicate pred, - const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs) { - if (isStaticallyTrue(pred, lhs, rhs)) - return true; - if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) - return false; - return std::nullopt; -} diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir deleted file mode 100644 index 2784d5fd5cf701..00000000000000 --- a/mlir/test/Dialect/Index/int-range-inference.mlir +++ /dev/null @@ -1,66 +0,0 @@ -// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s - -// Most operations are covered by the `arith` tests, which use the same code -// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling -// code is operating as expected. - -// CHECK-LABEL: func @add_same_for_both -// CHECK: %[[true:.*]] = index.bool.constant true -// CHECK: return %[[true]] -func.func @add_same_for_both(%arg0 : index) -> i1 { - %c1 = index.constant 1 - %calmostBig = index.constant 0xfffffffe - %0 = index.minu %arg0, %calmostBig - %1 = index.add %0, %c1 - %2 = index.cmp uge(%1, %c1) - func.return %2 : i1 -} - -// CHECK-LABEL: func @add_unsigned_ov -// CHECK: %[[uge:.*]] = index.cmp uge -// CHECK: return %[[uge]] -func.func @add_unsigned_ov(%arg0 : index) -> i1 { - %c1 = index.constant 1 - %cu32_max = index.constant 0xffffffff - %0 = index.minu %arg0, %cu32_max - %1 = index.add %0, %c1 - // On 32-bit, the add could wrap, so the result doesn't have to be >= 1 - %2 = index.cmp uge(%1, %c1) - func.return %2 : i1 -} - -// CHECK-LABEL: func @add_signed_ov -// CHECK: %[[sge:.*]] = index.cmp sge -// CHECK: return %[[sge]] -func.func @add_signed_ov(%arg0 : index) -> i1 { - %c0 = index.constant 0 - %c1 = index.constant 1 - %ci32_max = index.constant 0x7fffffff - %0 = index.minu %arg0, %ci32_max - %1 = index.add %0, %c1 - // On 32-bit, the add could wrap, so the result doesn't have to be positive - %2 = index.cmp sge(%1, %c0) - func.return %2 : i1 -} - -// CHECK-LABEL: func @add_big -// CHECK: %[[true:.*]] = index.bool.constant true -// CHECK: return %[[true]] -func.func @add_big(%arg0 : index) -> i1 { - %c1 = index.constant 1 - %cmin = index.constant 0x300000000 - %cmax = index.constant 0x30000ffff - // Note: the order of the clamps matters. - // If you go max, then min, you infer the ranges [0x300...0, 0xff..ff] - // and then [0x30...0000, 0x30...ffff] - // If you switch the order of the below operations, you instead first infer - // the range [0,0x3...ffff]. Then, the min inference can't constraint - // this intermediate, since in the 32-bit case we could have, for example - // trunc(%arg0 = 0x2ffffffff) = 0xffffffff > trunc(0x30000ffff) = 0x0000ffff - // which means we can't do any inference. - %0 = index.maxu %arg0, %cmin - %1 = index.minu %0, %cmax - %2 = index.add %1, %c1 - %3 = index.cmp uge(%1, %cmin) - func.return %3 : i1 -} From bec49b1d803ced901e14aca154ffc106939100e5 Mon Sep 17 00:00:00 2001 From: Joseph Huber Date: Thu, 19 Jan 2023 11:20:07 -0600 Subject: [PATCH 45/61] [LinkerWrapper] Use `clang` to perform the device linking Right now in the linker wrapper we manually invoke a lot of the toolchain programs. This reproduces a lot of logic that is already handled in clang. Since D140158 we can now target all supported toolchains directly via cross-compilation. This patch changes the linker wrapper to consolidate all the alternate linking and assembler steps into a generic call to `clang` and let clang handle the argument handling. This heavily simplifies the interface. Reviewed By: tra, JonChesterfield Differential Revision: https://reviews.llvm.org/D142133 --- clang/test/Driver/linker-wrapper.c | 76 ++--- .../ClangLinkerWrapper.cpp | 272 +++--------------- 2 files changed, 58 insertions(+), 290 deletions(-) diff --git a/clang/test/Driver/linker-wrapper.c b/clang/test/Driver/linker-wrapper.c index b052c077b92cdc..d515e316da0f40 100644 --- a/clang/test/Driver/linker-wrapper.c +++ b/clang/test/Driver/linker-wrapper.c @@ -12,17 +12,11 @@ // RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -o %t.o -fembed-offload-object=%t.out // RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run \ // RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=NVPTX-LINK - -// NVPTX-LINK: nvlink{{.*}}-m64 -o {{.*}}.img -arch sm_70 {{.*}}.o {{.*}}.o - -// RUN: clang-offload-packager -o %t.out \ -// RUN: --image=file=%t.elf.o,kind=openmp,triple=nvptx64-nvidia-cuda,arch=sm_70 \ -// RUN: --image=file=%t.elf.o,kind=openmp,triple=nvptx64-nvidia-cuda,arch=sm_70 // RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-llvm-bc -o %t.o -fembed-offload-object=%t.out // RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run \ -// RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=HOST-BC +// RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=NVPTX-LINK -// HOST-BC: nvlink{{.*}}-m64 -o {{.*}}.img -arch sm_70 {{.*}}.o {{.*}}.o +// NVPTX-LINK: clang{{.*}} -o {{.*}}.img --target=nvptx64-nvidia-cuda -march=sm_70 -O2 -Wl,--no-undefined {{.*}}.o {{.*}}.o // RUN: clang-offload-packager -o %t.out \ // RUN: --image=file=%t.elf.o,kind=openmp,triple=nvptx64-nvidia-cuda,arch=sm_70 \ @@ -31,16 +25,7 @@ // RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run --device-debug -O0 \ // RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=NVPTX-LINK-DEBUG -// NVPTX-LINK-DEBUG: nvlink{{.*}}-m64 -g -o {{.*}}.img -arch sm_70 {{.*}}.o {{.*}}.o - -// RUN: clang-offload-packager -o %t.out \ -// RUN: --image=file=%t.nvptx.bc,kind=openmp,triple=nvptx64-nvidia-cuda,arch=sm_70 \ -// RUN: --image=file=%t.nvptx.bc,kind=openmp,triple=nvptx64-nvidia-cuda,arch=sm_70 -// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -o %t.o -fembed-offload-object=%t.out -// RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run --device-debug -O2 \ -// RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=NVPTX-LINK-DEBUG-LTO - -// NVPTX-LINK-DEBUG-LTO: ptxas{{.*}}-m64 -o {{.*}}.cubin -O2 --gpu-name sm_70 -lineinfo {{.*}}.s +// NVPTX-LINK-DEBUG: clang{{.*}} -o {{.*}}.img --target=nvptx64-nvidia-cuda -march=sm_70 -O2 -Wl,--no-undefined -g {{.*}}.o {{.*}}.o // RUN: clang-offload-packager -o %t.out \ // RUN: --image=file=%t.elf.o,kind=openmp,triple=amdgcn-amd-amdhsa,arch=gfx908 \ @@ -49,7 +34,7 @@ // RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run \ // RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=AMDGPU-LINK -// AMDGPU-LINK: lld{{.*}}-flavor gnu --no-undefined -shared -plugin-opt=-amdgpu-internalize-symbols -plugin-opt=mcpu=gfx908 -o {{.*}}.img {{.*}}.o {{.*}}.o +// AMDGPU-LINK: clang{{.*}} -o {{.*}}.img --target=amdgcn-amd-amdhsa -mcpu=gfx908 -O2 -Wl,--no-undefined {{.*}}.o {{.*}}.o // RUN: clang-offload-packager -o %t.out \ // RUN: --image=file=%t.amdgpu.bc,kind=openmp,triple=amdgcn-amd-amdhsa,arch=gfx1030 \ @@ -58,17 +43,7 @@ // RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run --save-temps -O2 \ // RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=AMDGPU-LTO-TEMPS -// AMDGPU-LTO-TEMPS: clang{{.*}}-o [[OBJ:.+]] -fPIC -c --target=amdgcn-amd-amdhsa -O2 -mcpu=gfx1030 {{.*}}.s -// AMDGPU-LTO-TEMPS: lld{{.*}}-flavor gnu --no-undefined -shared -plugin-opt=-amdgpu-internalize-symbols -plugin-opt=mcpu=gfx1030 -o {{.*}}.img {{.*}}.o - -// RUN: clang-offload-packager -o %t.out \ -// RUN: --image=file=%t.amdgpu.bc,kind=openmp,triple=amdgcn-amd-amdhsa,arch=gfx908 \ -// RUN: --image=file=%t.amdgpu.bc,kind=openmp,triple=amdgcn-amd-amdhsa,arch=gfx908 -// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -o %t.o -fembed-offload-object=%t.out -// RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run \ -// RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=AMDGPU-LINK-LTO - -// AMDGPU-LINK-LTO: lld{{.*}}-flavor gnu --no-undefined -shared -plugin-opt=-amdgpu-internalize-symbols -plugin-opt=mcpu=gfx908 -o {{.*}}.img {{.*}}.o +// AMDGPU-LTO-TEMPS: clang{{.*}} -o {{.*}}.img --target=amdgcn-amd-amdhsa -mcpu=gfx1030 -O2 -Wl,--no-undefined -save-temps {{.*}}.s // RUN: clang-offload-packager -o %t.out \ // RUN: --image=file=%t.elf.o,kind=openmp,triple=x86_64-unknown-linux-gnu \ @@ -77,7 +52,7 @@ // RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run \ // RUN: --linker-path=/usr/bin/ld.lld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=CPU-LINK -// CPU-LINK: ld.lld{{.*}}-m elf_x86_64 -shared -Bsymbolic -o {{.*}}.img {{.*}}.o {{.*}}.o +// CPU-LINK: clang{{.*}} -o {{.*}}.img --target=x86_64-unknown-linux-gnu -march=native -O2 -Wl,--no-undefined -Bsymbolic -shared {{.*}}.o {{.*}}.o // RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -o %t.o // RUN: clang-linker-wrapper --dry-run --host-triple=x86_64-unknown-linux-gnu -mllvm -openmp-opt-disable \ @@ -86,25 +61,6 @@ // HOST-LINK: ld.lld{{.*}}-a -b -c {{.*}}.o -o a.out // HOST-LINK-NOT: ld.lld{{.*}}-abc -// RUN: clang-offload-packager -o %t.out \ -// RUN: --image=file=%t.nvptx.bc,kind=openmp,triple=nvptx64-nvidia-cuda,arch=sm_70 \ -// RUN: --image=file=%t.nvptx.bc,kind=openmp,triple=nvptx64-nvidia-cuda,arch=sm_70 -// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -o %t.o -fembed-offload-object=%t.out -// RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run \ -// RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=LTO - -// LTO: ptxas{{.*}}-m64 -o {{.*}}.cubin -O2 --gpu-name sm_70 {{.*}}.s -// LTO-NOT: nvlink - -// RUN: clang-offload-packager -o %t.out \ -// RUN: --image=file=%t.elf.o,kind=openmp,triple=nvptx64-nvidia-cuda,arch=sm_70 \ -// RUN: --image=file=%t.elf.o,kind=cuda,triple=nvptx64-nvidia-cuda,arch=sm_70 -// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -o %t.o -fembed-offload-object=%t.out -// RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run \ -// RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=CUDA-OMP-LINK - -// CUDA-OMP-LINK: nvlink{{.*}}-m64 -o {{.*}}.img -arch sm_70 {{.*}}.o {{.*}}.o - // RUN: clang-offload-packager -o %t-lib.out \ // RUN: --image=file=%t.elf.o,kind=openmp,triple=nvptx64-nvidia-cuda,arch=sm_70 \ // RUN: --image=file=%t.elf.o,kind=cuda,triple=nvptx64-nvidia-cuda,arch=sm_52 @@ -116,8 +72,8 @@ // RUN: clang-linker-wrapper --host-triple=x86_64-unknown-linux-gnu --dry-run \ // RUN: --linker-path=/usr/bin/ld -- %t.a %t-obj.o -o a.out 2>&1 | FileCheck %s --check-prefix=STATIC-LIBRARY -// STATIC-LIBRARY: nvlink{{.*}} -arch sm_70 -// STATIC-LIBRARY-NOT: nvlink{{.*}} -arch sm_50 +// STATIC-LIBRARY: clang{{.*}} -march=sm_70 +// STATIC-LIBRARY-NOT: clang{{.*}} -march=sm_50 // RUN: clang-offload-packager -o %t.out \ // RUN: --image=file=%t.elf.o,kind=cuda,triple=nvptx64-nvidia-cuda,arch=sm_70 \ @@ -128,9 +84,9 @@ // RUN: clang-linker-wrapper --dry-run --host-triple=x86_64-unknown-linux-gnu \ // RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=CUDA -// CUDA: nvlink{{.*}}-m64 -o {{.*}}.img -arch sm_52 {{.*}}.o -// CUDA: nvlink{{.*}}-m64 -o {{.*}}.img -arch sm_70 {{.*}}.o {{.*}}.o -// CUDA: fatbinary{{.*}}-64 --create {{.*}}.fatbin --image=profile=sm_70,file={{.*}}.img --image=profile=sm_52,file={{.*}}.img +// CUDA: clang{{.*}} -o [[IMG_SM52:.+]] --target=nvptx64-nvidia-cuda -march=sm_52 +// CUDA: clang{{.*}} -o [[IMG_SM70:.+]] --target=nvptx64-nvidia-cuda -march=sm_70 +// CUDA: fatbinary{{.*}}-64 --create {{.*}}.fatbin --image=profile=sm_70,file=[[IMG_SM70]] --image=profile=sm_52,file=[[IMG_SM52]] // RUN: clang-offload-packager -o %t.out \ // RUN: --image=file=%t.elf.o,kind=cuda,triple=nvptx64-nvidia-cuda,arch=sm_80 \ @@ -153,9 +109,9 @@ // RUN: clang-linker-wrapper --dry-run --host-triple=x86_64-unknown-linux-gnu \ // RUN: --linker-path=/usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=HIP -// HIP: lld{{.*}}-flavor gnu --no-undefined -shared -plugin-opt=-amdgpu-internalize-symbols -plugin-opt=mcpu=gfx908 -o {{.*}}.img {{.*}}.o -// HIP: lld{{.*}}-flavor gnu --no-undefined -shared -plugin-opt=-amdgpu-internalize-symbols -plugin-opt=mcpu=gfx90a -o {{.*}}.img {{.*}}.o -// HIP: clang-offload-bundler{{.*}}-type=o -bundle-align=4096 -targets=host-x86_64-unknown-linux,hipv4-amdgcn-amd-amdhsa--gfx90a,hipv4-amdgcn-amd-amdhsa--gfx908 -input=/dev/null -input={{.*}}.img -input={{.*}}.img -output={{.*}}.hipfb +// HIP: clang{{.*}} -o [[IMG_GFX908:.+]] --target=amdgcn-amd-amdhsa -mcpu=gfx908 +// HIP: clang{{.*}} -o [[IMG_GFX90A:.+]] --target=amdgcn-amd-amdhsa -mcpu=gfx90a +// HIP: clang-offload-bundler{{.*}}-type=o -bundle-align=4096 -targets=host-x86_64-unknown-linux,hipv4-amdgcn-amd-amdhsa--gfx90a,hipv4-amdgcn-amd-amdhsa--gfx908 -input=/dev/null -input=[[IMG_GFX90A]] -input=[[IMG_GFX908]] -output={{.*}}.hipfb // RUN: clang-offload-packager -o %t.out \ // RUN: --image=file=%t.elf.o,kind=openmp,triple=amdgcn-amd-amdhsa,arch=gfx908 \ @@ -166,8 +122,8 @@ // RUN: --linker-path=/usr/bin/ld --device-linker=a --device-linker=nvptx64-nvidia-cuda=b -- \ // RUN: %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=LINKER-ARGS -// LINKER-ARGS: lld{{.*}}-flavor gnu --no-undefined -shared -plugin-opt=-amdgpu-internalize-symbols -plugin-opt=mcpu=gfx908 -o {{.*}}.img {{.*}}.o a -// LINKER-ARGS: nvlink{{.*}}-m64 -o {{.*}}.img -arch sm_70 {{.*}}.o a b +// LINKER-ARGS: clang{{.*}}--target=amdgcn-amd-amdhsa{{.*}}-Wl,a +// LINKER-ARGS: clang{{.*}}--target=nvptx64-nvidia-cuda{{.*}}-Wl,a -Wl,b // RUN: not clang-linker-wrapper --dry-run --host-triple=x86_64-unknown-linux-gnu -ldummy \ // RUN: --linker-path=/usr/bin/ld --device-linker=a --device-linker=nvptx64-nvidia-cuda=b -- \ diff --git a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp index 1daa58f20fd5b7..5cae73ef543eb2 100644 --- a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp +++ b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp @@ -271,99 +271,13 @@ void printVersion(raw_ostream &OS) { } namespace nvptx { -Expected assemble(StringRef InputFile, const ArgList &Args, - bool RDC = true) { - llvm::TimeTraceScope TimeScope("NVPTX Assembler"); - // NVPTX uses the ptxas binary to create device object files. - Expected PtxasPath = findProgram("ptxas", {CudaBinaryPath}); - if (!PtxasPath) - return PtxasPath.takeError(); - - const llvm::Triple Triple(Args.getLastArgValue(OPT_triple_EQ)); - StringRef Arch = Args.getLastArgValue(OPT_arch_EQ); - // Create a new file to write the linked device image to. Assume that the - // input filename already has the device and architecture. - auto TempFileOrErr = createOutputFile(sys::path::stem(InputFile), "cubin"); - if (!TempFileOrErr) - return TempFileOrErr.takeError(); - - SmallVector CmdArgs; - StringRef OptLevel = Args.getLastArgValue(OPT_opt_level, "O2"); - CmdArgs.push_back(*PtxasPath); - CmdArgs.push_back(Triple.isArch64Bit() ? "-m64" : "-m32"); - if (Verbose) - CmdArgs.push_back("-v"); - for (StringRef Arg : Args.getAllArgValues(OPT_ptxas_arg)) - CmdArgs.push_back(Args.MakeArgString(Arg)); - CmdArgs.push_back("-o"); - CmdArgs.push_back(*TempFileOrErr); - CmdArgs.push_back(Args.MakeArgString("-" + OptLevel)); - CmdArgs.push_back("--gpu-name"); - CmdArgs.push_back(Arch); - if (Args.hasArg(OPT_debug) && OptLevel[1] == '0') - CmdArgs.push_back("-g"); - else if (Args.hasArg(OPT_debug)) - CmdArgs.push_back("-lineinfo"); - if (RDC) - CmdArgs.push_back("-c"); - - CmdArgs.push_back(InputFile); - - if (Error Err = executeCommands(*PtxasPath, CmdArgs)) - return std::move(Err); - - return *TempFileOrErr; -} - -Expected link(ArrayRef InputFiles, const ArgList &Args) { - llvm::TimeTraceScope TimeScope("NVPTX linker"); - // NVPTX uses the nvlink binary to link device object files. - Expected NvlinkPath = findProgram("nvlink", {CudaBinaryPath}); - if (!NvlinkPath) - return NvlinkPath.takeError(); - - const llvm::Triple Triple(Args.getLastArgValue(OPT_triple_EQ)); - StringRef Arch = Args.getLastArgValue(OPT_arch_EQ); - - // Create a new file to write the linked device image to. - auto TempFileOrErr = - createOutputFile(sys::path::filename(ExecutableName) + "." + - Triple.getArchName() + "." + Arch, - "img"); - if (!TempFileOrErr) - return TempFileOrErr.takeError(); - - SmallVector CmdArgs; - CmdArgs.push_back(*NvlinkPath); - CmdArgs.push_back(Triple.isArch64Bit() ? "-m64" : "-m32"); - if (Args.hasArg(OPT_debug)) - CmdArgs.push_back("-g"); - if (Verbose) - CmdArgs.push_back("-v"); - CmdArgs.push_back("-o"); - CmdArgs.push_back(*TempFileOrErr); - CmdArgs.push_back("-arch"); - CmdArgs.push_back(Arch); - - // Add extracted input files. - for (StringRef Input : InputFiles) - CmdArgs.push_back(Input); - - for (StringRef Arg : Args.getAllArgValues(OPT_linker_arg_EQ)) - CmdArgs.push_back(Args.MakeArgString(Arg)); - if (Error Err = executeCommands(*NvlinkPath, CmdArgs)) - return std::move(Err); - - return *TempFileOrErr; -} - Expected fatbinary(ArrayRef> InputFiles, const ArgList &Args) { llvm::TimeTraceScope TimeScope("NVPTX fatbinary"); // NVPTX uses the fatbinary program to bundle the linked images. Expected FatBinaryPath = - findProgram("fatbinary", {CudaBinaryPath}); + findProgram("fatbinary", {CudaBinaryPath + "/bin"}); if (!FatBinaryPath) return FatBinaryPath.takeError(); @@ -393,49 +307,6 @@ fatbinary(ArrayRef> InputFiles, } // namespace nvptx namespace amdgcn { -Expected link(ArrayRef InputFiles, const ArgList &Args) { - llvm::TimeTraceScope TimeScope("AMDGPU linker"); - // AMDGPU uses lld to link device object files. - Expected LLDPath = - findProgram("lld", {getMainExecutable("lld")}); - if (!LLDPath) - return LLDPath.takeError(); - - const llvm::Triple Triple(Args.getLastArgValue(OPT_triple_EQ)); - StringRef Arch = Args.getLastArgValue(OPT_arch_EQ); - - // Create a new file to write the linked device image to. - auto TempFileOrErr = - createOutputFile(sys::path::filename(ExecutableName) + "." + - Triple.getArchName() + "." + Arch, - "img"); - if (!TempFileOrErr) - return TempFileOrErr.takeError(); - std::string ArchArg = ("-plugin-opt=mcpu=" + Arch).str(); - - SmallVector CmdArgs; - CmdArgs.push_back(*LLDPath); - CmdArgs.push_back("-flavor"); - CmdArgs.push_back("gnu"); - CmdArgs.push_back("--no-undefined"); - CmdArgs.push_back("-shared"); - CmdArgs.push_back("-plugin-opt=-amdgpu-internalize-symbols"); - CmdArgs.push_back(ArchArg); - CmdArgs.push_back("-o"); - CmdArgs.push_back(*TempFileOrErr); - - // Add extracted input files. - for (StringRef Input : InputFiles) - CmdArgs.push_back(Input); - - for (StringRef Arg : Args.getAllArgValues(OPT_linker_arg_EQ)) - CmdArgs.push_back(Args.MakeArgString(Arg)); - if (Error Err = executeCommands(*LLDPath, CmdArgs)) - return std::move(Err); - - return *TempFileOrErr; -} - Expected fatbinary(ArrayRef> InputFiles, const ArgList &Args) { @@ -483,35 +354,9 @@ fatbinary(ArrayRef> InputFiles, } // namespace amdgcn namespace generic { - -const char *getLDMOption(const llvm::Triple &T) { - switch (T.getArch()) { - case llvm::Triple::x86: - if (T.isOSIAMCU()) - return "elf_iamcu"; - return "elf_i386"; - case llvm::Triple::aarch64: - return "aarch64linux"; - case llvm::Triple::aarch64_be: - return "aarch64linuxb"; - case llvm::Triple::ppc64: - return "elf64ppc"; - case llvm::Triple::ppc64le: - return "elf64lppc"; - case llvm::Triple::x86_64: - if (T.isX32()) - return "elf32_x86_64"; - return "elf_x86_64"; - case llvm::Triple::ve: - return "elf64ve"; - default: - return nullptr; - } -} - -Expected assemble(StringRef InputFile, const ArgList &Args) { - llvm::TimeTraceScope TimeScope("Clang Assembler"); - // Use `clang` to invoke the generic assembler. +Expected clang(ArrayRef InputFiles, const ArgList &Args) { + llvm::TimeTraceScope TimeScope("Clang"); + // Use `clang` to invoke the appropriate device tools. Expected ClangPath = findProgram("clang", {getMainExecutable("clang")}); if (!ClangPath) @@ -519,9 +364,14 @@ Expected assemble(StringRef InputFile, const ArgList &Args) { const llvm::Triple Triple(Args.getLastArgValue(OPT_triple_EQ)); StringRef Arch = Args.getLastArgValue(OPT_arch_EQ); + if (Arch.empty()) + Arch = "native"; // Create a new file to write the linked device image to. Assume that the // input filename already has the device and architecture. - auto TempFileOrErr = createOutputFile(sys::path::stem(InputFile), "o"); + auto TempFileOrErr = + createOutputFile(sys::path::filename(ExecutableName) + "." + + Triple.getArchName() + "." + Arch, + "img"); if (!TempFileOrErr) return TempFileOrErr.takeError(); @@ -530,65 +380,47 @@ Expected assemble(StringRef InputFile, const ArgList &Args) { *ClangPath, "-o", *TempFileOrErr, - "-fPIC", - "-c", Args.MakeArgString("--target=" + Triple.getTriple()), - Args.MakeArgString("-" + OptLevel), Triple.isAMDGPU() ? Args.MakeArgString("-mcpu=" + Arch) : Args.MakeArgString("-march=" + Arch), - InputFile, + Args.MakeArgString("-" + OptLevel), + "-Wl,--no-undefined", }; - if (Error Err = executeCommands(*ClangPath, CmdArgs)) - return std::move(Err); + // If this is CPU offloading we copy the input libraries. + if (!Triple.isAMDGPU() && !Triple.isNVPTX()) { + CmdArgs.push_back("-Bsymbolic"); + CmdArgs.push_back("-shared"); + ArgStringList LinkerArgs; + for (const opt::Arg *Arg : + Args.filtered(OPT_library, OPT_rpath, OPT_library_path)) + Arg->render(Args, LinkerArgs); + llvm::copy(LinkerArgs, std::back_inserter(CmdArgs)); + } - return *TempFileOrErr; -} + if (Args.hasArg(OPT_debug)) + CmdArgs.push_back("-g"); -Expected link(ArrayRef InputFiles, const ArgList &Args) { - llvm::TimeTraceScope TimeScope("Generic linker"); - const llvm::Triple Triple(Args.getLastArgValue(OPT_triple_EQ)); - StringRef Arch = Args.getLastArgValue(OPT_arch_EQ); + if (SaveTemps) + CmdArgs.push_back("-save-temps"); - // Create a new file to write the linked device image to. - auto TempFileOrErr = - createOutputFile(sys::path::filename(ExecutableName) + "." + - Triple.getArchName() + "." + Arch, - "img"); - if (!TempFileOrErr) - return TempFileOrErr.takeError(); + if (Verbose) + CmdArgs.push_back("-v"); - // Use the host linker to perform generic offloading. Use the same libraries - // and paths as the host application does. - SmallVector CmdArgs; - CmdArgs.push_back(Args.getLastArgValue(OPT_linker_path_EQ)); - CmdArgs.push_back("-m"); - CmdArgs.push_back(getLDMOption(Triple)); - CmdArgs.push_back("-shared"); + if (!CudaBinaryPath.empty()) + CmdArgs.push_back(Args.MakeArgString("--cuda-path=" + CudaBinaryPath)); - ArgStringList LinkerArgs; - for (const opt::Arg *Arg : Args) { - auto Op = Arg->getOption(); - if (Op.matches(OPT_library) || Op.matches(OPT_library_path) || - Op.matches(OPT_as_needed) || Op.matches(OPT_no_as_needed) || - Op.matches(OPT_rpath) || Op.matches(OPT_dynamic_linker)) - Arg->render(Args, LinkerArgs); - } - for (StringRef Arg : LinkerArgs) - CmdArgs.push_back(Arg); + for (StringRef Arg : Args.getAllArgValues(OPT_ptxas_arg)) + llvm::copy(SmallVector({"-Xcuda-ptxas", Arg}), + std::back_inserter(CmdArgs)); - CmdArgs.push_back("-Bsymbolic"); - CmdArgs.push_back("-o"); - CmdArgs.push_back(*TempFileOrErr); + for (StringRef Arg : Args.getAllArgValues(OPT_linker_arg_EQ)) + CmdArgs.push_back(Args.MakeArgString("-Wl," + Arg)); - // Add extracted input files. - for (StringRef Input : InputFiles) - CmdArgs.push_back(Input); + for (StringRef InputFile : InputFiles) + CmdArgs.push_back(InputFile); - for (StringRef Arg : Args.getAllArgValues(OPT_linker_arg_EQ)) - CmdArgs.push_back(Args.MakeArgString(Arg)); - if (Error Err = - executeCommands(Args.getLastArgValue(OPT_linker_path_EQ), CmdArgs)) + if (Error Err = executeCommands(*ClangPath, CmdArgs)) return std::move(Err); return *TempFileOrErr; @@ -601,16 +433,14 @@ Expected linkDevice(ArrayRef InputFiles, switch (Triple.getArch()) { case Triple::nvptx: case Triple::nvptx64: - return nvptx::link(InputFiles, Args); case Triple::amdgcn: - return amdgcn::link(InputFiles, Args); case Triple::x86: case Triple::x86_64: case Triple::aarch64: case Triple::aarch64_be: case Triple::ppc64: case Triple::ppc64le: - return generic::link(InputFiles, Args); + return generic::clang(InputFiles, Args); default: return createStringError(inconvertibleErrorCode(), Triple.getArchName() + @@ -933,19 +763,6 @@ Error linkBitcodeFiles(SmallVectorImpl &InputFiles, return Error::success(); } - // Is we are compiling for NVPTX we need to run the assembler first. - if (Triple.isNVPTX() || SaveTemps) { - for (StringRef &File : Files) { - - auto FileOrErr = Triple.isNVPTX() - ? nvptx::assemble(File, Args, !SingleOutput) - : generic::assemble(File, Args); - if (!FileOrErr) - return FileOrErr.takeError(); - File = *FileOrErr; - } - } - // Append the new inputs to the device linker input. for (StringRef File : Files) OutputFiles.push_back(File); @@ -1226,12 +1043,9 @@ linkAndWrapDeviceFiles(SmallVectorImpl &LinkerInputFiles, } // Link the remaining device files using the device linker. - llvm::Triple Triple(LinkerArgs.getLastArgValue(OPT_triple_EQ)); - bool RequiresLinking = - !Args.hasArg(OPT_embed_bitcode) && - !(Input.empty() && InputFiles.size() == 1 && Triple.isNVPTX()); - auto OutputOrErr = RequiresLinking ? linkDevice(InputFiles, LinkerArgs) - : InputFiles.front(); + auto OutputOrErr = !Args.hasArg(OPT_embed_bitcode) + ? linkDevice(InputFiles, LinkerArgs) + : InputFiles.front(); if (!OutputOrErr) return OutputOrErr.takeError(); @@ -1457,8 +1271,6 @@ int main(int Argc, char **Argv) { SaveTemps = Args.hasArg(OPT_save_temps); ExecutableName = Args.getLastArgValue(OPT_o, "a.out"); CudaBinaryPath = Args.getLastArgValue(OPT_cuda_path_EQ).str(); - if (!CudaBinaryPath.empty()) - CudaBinaryPath = CudaBinaryPath + "/bin"; parallel::strategy = hardware_concurrency(1); if (auto *Arg = Args.getLastArg(OPT_wrapper_jobs)) { From 679c352ff42d45de6dcaca7022a7e664cd03323c Mon Sep 17 00:00:00 2001 From: Joseph Huber Date: Thu, 19 Jan 2023 12:48:28 -0600 Subject: [PATCH 46/61] [Clang][NFC] Tweak error message for GPU architecture tools Summary: There shouldn't be an extra newline in these messages. --- clang/tools/amdgpu-arch/AMDGPUArch.cpp | 2 +- clang/tools/nvptx-arch/NVPTXArch.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/clang/tools/amdgpu-arch/AMDGPUArch.cpp b/clang/tools/amdgpu-arch/AMDGPUArch.cpp index 308bdeca243aa1..2fdd398c9c673c 100644 --- a/clang/tools/amdgpu-arch/AMDGPUArch.cpp +++ b/clang/tools/amdgpu-arch/AMDGPUArch.cpp @@ -50,7 +50,7 @@ llvm::Error loadHSA() { llvm::sys::DynamicLibrary::getPermanentLibrary(DynamicHSAPath, &ErrMsg)); if (!DynlibHandle->isValid()) { return llvm::createStringError(llvm::inconvertibleErrorCode(), - "Failed to 'dlopen' %s\n", DynamicHSAPath); + "Failed to 'dlopen' %s", DynamicHSAPath); } #define DYNAMIC_INIT(SYMBOL) \ { \ diff --git a/clang/tools/nvptx-arch/NVPTXArch.cpp b/clang/tools/nvptx-arch/NVPTXArch.cpp index 6509e42fee4e6a..91723324c28e2d 100644 --- a/clang/tools/nvptx-arch/NVPTXArch.cpp +++ b/clang/tools/nvptx-arch/NVPTXArch.cpp @@ -44,7 +44,7 @@ llvm::Error loadCUDA() { llvm::sys::DynamicLibrary::getPermanentLibrary(DynamicCudaPath, &ErrMsg)); if (!DynlibHandle->isValid()) { return llvm::createStringError(llvm::inconvertibleErrorCode(), - "Failed to 'dlopen' %s\n", DynamicCudaPath); + "Failed to 'dlopen' %s", DynamicCudaPath); } #define DYNAMIC_INIT(SYMBOL) \ { \ From e7f080b3598d7d73456954554d173146b2744953 Mon Sep 17 00:00:00 2001 From: Stanislav Mekhanoshin Date: Wed, 18 Jan 2023 11:58:50 -0800 Subject: [PATCH 47/61] [AMDGPU] Introduce separate register limit bias in scheduler Current implementation abuses ErrorMargin to apply an additional bias to VGPR and SGPR limits under a high register pressure. The ErrorMargin exists to account for inaccuracies of the RP tracker and not to tackle an excess pressure. Introduce separate bias for this purpose and also make it different for SGPRs and VGPRs as we may want to use different values in the future. This is supposed to be NFC, however there is a subtle difference when subtracting a margin overflows the limit. Doing two subtractions makes it less probable, although manifests only in mir tests with an artificially small register budget. Differential Revision: https://reviews.llvm.org/D142051 --- llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp | 13 ++++++++++--- llvm/lib/Target/AMDGPU/GCNSchedStrategy.h | 12 +++++++++--- .../partial-regcopy-and-spill-missed-at-regalloc.ll | 4 ++-- llvm/test/CodeGen/AMDGPU/spill-vector-superclass.ll | 2 +- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp index 5e55186eaada95..ae119e0b17025e 100644 --- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp +++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp @@ -73,12 +73,18 @@ void GCNSchedStrategy::initialize(ScheduleDAGMI *DAG) { VGPRCriticalLimit = std::min(ST.getMaxNumVGPRs(TargetOccupancy), VGPRExcessLimit); - // Subtract error margin from register limits and avoid overflow. + // Subtract error margin and bias from register limits and avoid overflow. + SGPRCriticalLimit = + std::min(SGPRCriticalLimit - SGPRLimitBias, SGPRCriticalLimit); SGPRCriticalLimit = std::min(SGPRCriticalLimit - ErrorMargin, SGPRCriticalLimit); + VGPRCriticalLimit = + std::min(VGPRCriticalLimit - VGPRLimitBias, VGPRCriticalLimit); VGPRCriticalLimit = std::min(VGPRCriticalLimit - ErrorMargin, VGPRCriticalLimit); + SGPRExcessLimit = std::min(SGPRExcessLimit - SGPRLimitBias, SGPRExcessLimit); SGPRExcessLimit = std::min(SGPRExcessLimit - ErrorMargin, SGPRExcessLimit); + VGPRExcessLimit = std::min(VGPRExcessLimit - VGPRLimitBias, VGPRExcessLimit); VGPRExcessLimit = std::min(VGPRExcessLimit - ErrorMargin, VGPRExcessLimit); } @@ -670,7 +676,8 @@ bool UnclusteredHighRPStage::initGCNSchedStage() { InitialOccupancy = DAG.MinOccupancy; // Aggressivly try to reduce register pressure in the unclustered high RP // stage. Temporarily increase occupancy target in the region. - S.ErrorMargin = S.HighRPErrorMargin; + S.SGPRLimitBias = S.HighRPSGPRBias; + S.VGPRLimitBias = S.HighRPVGPRBias; if (MFI.getMaxWavesPerEU() > DAG.MinOccupancy) MFI.increaseOccupancy(MF, ++DAG.MinOccupancy); @@ -735,7 +742,7 @@ void GCNSchedStage::finalizeGCNSchedStage() { void UnclusteredHighRPStage::finalizeGCNSchedStage() { SavedMutations.swap(DAG.Mutations); - S.ErrorMargin = S.DefaultErrorMargin; + S.SGPRLimitBias = S.VGPRLimitBias = 0; if (DAG.MinOccupancy > InitialOccupancy) { for (unsigned IDX = 0; IDX < DAG.Pressure.size(); ++IDX) DAG.RegionsWithMinOcc[IDX] = diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h index d7b161ccbf07c7..30a62f61553758 100644 --- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h +++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h @@ -78,16 +78,22 @@ class GCNSchedStrategy : public GenericScheduler { // An error margin is necessary because of poor performance of the generic RP // tracker and can be adjusted up for tuning heuristics to try and more // aggressively reduce register pressure. - const unsigned DefaultErrorMargin = 3; + unsigned ErrorMargin = 3; - const unsigned HighRPErrorMargin = 10; + // Bias for SGPR limits under a high register pressure. + const unsigned HighRPSGPRBias = 7; - unsigned ErrorMargin = DefaultErrorMargin; + // Bias for VGPR limits under a high register pressure. + const unsigned HighRPVGPRBias = 7; unsigned SGPRCriticalLimit; unsigned VGPRCriticalLimit; + unsigned SGPRLimitBias = 0; + + unsigned VGPRLimitBias = 0; + GCNSchedStrategy(const MachineSchedContext *C); SUnit *pickNode(bool &IsTopNode) override; diff --git a/llvm/test/CodeGen/AMDGPU/partial-regcopy-and-spill-missed-at-regalloc.ll b/llvm/test/CodeGen/AMDGPU/partial-regcopy-and-spill-missed-at-regalloc.ll index 29876a2d307d6a..785f8ce151e82f 100644 --- a/llvm/test/CodeGen/AMDGPU/partial-regcopy-and-spill-missed-at-regalloc.ll +++ b/llvm/test/CodeGen/AMDGPU/partial-regcopy-and-spill-missed-at-regalloc.ll @@ -18,8 +18,8 @@ define amdgpu_kernel void @partial_copy(<4 x i32> %arg) #0 { ; REGALLOC-GFX908-NEXT: [[COPY1:%[0-9]+]]:vreg_128 = COPY [[COPY]] ; REGALLOC-GFX908-NEXT: GLOBAL_STORE_DWORDX4 undef %14:vreg_64, [[COPY1]], 0, 0, implicit $exec :: (volatile store (s128) into `ptr addrspace(1) undef`, addrspace 1) ; REGALLOC-GFX908-NEXT: renamable $sgpr0_sgpr1_sgpr2_sgpr3 = S_LOAD_DWORDX4_IMM killed renamable $sgpr4_sgpr5, 0, 0 :: (dereferenceable invariant load (s128) from %ir.arg.kernarg.offset1, addrspace 4) - ; REGALLOC-GFX908-NEXT: [[V_MOV_B32_e32_:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 1, implicit $exec ; REGALLOC-GFX908-NEXT: [[COPY2:%[0-9]+]]:areg_128 = COPY killed renamable $sgpr0_sgpr1_sgpr2_sgpr3 + ; REGALLOC-GFX908-NEXT: [[V_MOV_B32_e32_:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 1, implicit $exec ; REGALLOC-GFX908-NEXT: [[V_MOV_B32_e32_1:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 2, implicit $exec ; REGALLOC-GFX908-NEXT: [[V_MFMA_I32_4X4X4I8_e64_:%[0-9]+]]:areg_128 = V_MFMA_I32_4X4X4I8_e64 [[V_MOV_B32_e32_]], [[V_MOV_B32_e32_1]], [[COPY2]], 0, 0, 0, implicit $mode, implicit $exec ; REGALLOC-GFX908-NEXT: [[SI_SPILL_V64_RESTORE:%[0-9]+]]:vreg_64 = SI_SPILL_V64_RESTORE %stack.0, $sgpr32, 0, implicit $exec :: (load (s64) from %stack.0, align 4, addrspace 5) @@ -43,8 +43,8 @@ define amdgpu_kernel void @partial_copy(<4 x i32> %arg) #0 { ; PEI-GFX908-NEXT: renamable $vgpr0_vgpr1_vgpr2_vgpr3 = COPY killed renamable $agpr0_agpr1_agpr2_agpr3, implicit $exec ; PEI-GFX908-NEXT: GLOBAL_STORE_DWORDX4 undef renamable $vgpr0_vgpr1, killed renamable $vgpr0_vgpr1_vgpr2_vgpr3, 0, 0, implicit $exec :: (volatile store (s128) into `ptr addrspace(1) undef`, addrspace 1) ; PEI-GFX908-NEXT: renamable $sgpr0_sgpr1_sgpr2_sgpr3 = S_LOAD_DWORDX4_IMM killed renamable $sgpr4_sgpr5, 0, 0 :: (dereferenceable invariant load (s128) from %ir.arg.kernarg.offset1, addrspace 4) - ; PEI-GFX908-NEXT: renamable $vgpr0 = V_MOV_B32_e32 1, implicit $exec ; PEI-GFX908-NEXT: renamable $agpr0_agpr1_agpr2_agpr3 = COPY killed renamable $sgpr0_sgpr1_sgpr2_sgpr3, implicit $exec + ; PEI-GFX908-NEXT: renamable $vgpr0 = V_MOV_B32_e32 1, implicit $exec ; PEI-GFX908-NEXT: renamable $vgpr1 = V_MOV_B32_e32 2, implicit $exec ; PEI-GFX908-NEXT: renamable $agpr0_agpr1_agpr2_agpr3 = V_MFMA_I32_4X4X4I8_e64 killed $vgpr0, killed $vgpr1, killed $agpr0_agpr1_agpr2_agpr3, 0, 0, 0, implicit $mode, implicit $exec ; PEI-GFX908-NEXT: $vgpr0 = BUFFER_LOAD_DWORD_OFFSET $sgpr8_sgpr9_sgpr10_sgpr11, 0, 4, 0, 0, implicit $exec, implicit-def $vgpr0_vgpr1 :: (load (s32) from %stack.0, addrspace 5) diff --git a/llvm/test/CodeGen/AMDGPU/spill-vector-superclass.ll b/llvm/test/CodeGen/AMDGPU/spill-vector-superclass.ll index e0e8dc88bdac58..7939798dab5580 100644 --- a/llvm/test/CodeGen/AMDGPU/spill-vector-superclass.ll +++ b/llvm/test/CodeGen/AMDGPU/spill-vector-superclass.ll @@ -8,9 +8,9 @@ define amdgpu_kernel void @test_spill_av_class(<4 x i32> %arg) #0 { ; GCN-NEXT: liveins: $sgpr4_sgpr5 ; GCN-NEXT: {{ $}} ; GCN-NEXT: renamable $sgpr0_sgpr1_sgpr2_sgpr3 = S_LOAD_DWORDX4_IMM killed renamable $sgpr4_sgpr5, 0, 0 :: (dereferenceable invariant load (s128) from %ir.arg.kernarg.offset1, addrspace 4) + ; GCN-NEXT: [[COPY:%[0-9]+]]:areg_128 = COPY killed renamable $sgpr0_sgpr1_sgpr2_sgpr3 ; GCN-NEXT: [[V_MOV_B32_e32_:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 1, implicit $exec ; GCN-NEXT: [[V_MOV_B32_e32_1:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 2, implicit $exec - ; GCN-NEXT: [[COPY:%[0-9]+]]:areg_128 = COPY killed renamable $sgpr0_sgpr1_sgpr2_sgpr3 ; GCN-NEXT: [[V_MFMA_I32_4X4X4I8_e64_:%[0-9]+]]:areg_128 = V_MFMA_I32_4X4X4I8_e64 [[V_MOV_B32_e32_]], [[V_MOV_B32_e32_1]], [[COPY]], 0, 0, 0, implicit $mode, implicit $exec ; GCN-NEXT: INLINEASM &"; def $0", 1 /* sideeffect attdialect */, 1769482 /* regdef:VGPR_32 */, def undef %22.sub0 ; GCN-NEXT: undef %24.sub0:av_64 = COPY %22.sub0 From 63e7e9c8756aeaa6dccd4620cba710c04e215934 Mon Sep 17 00:00:00 2001 From: Stanislav Mekhanoshin Date: Wed, 18 Jan 2023 15:20:36 -0800 Subject: [PATCH 48/61] [AMDGPU] Treat WMMA the same as MFMA for sched_barrier MFMA and WMMA essentially the same thing, but apear on different ASICs. Differential Revision: https://reviews.llvm.org/D142062 --- llvm/include/llvm/IR/IntrinsicsAMDGPU.td | 2 +- llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp | 8 +- llvm/lib/Target/AMDGPU/SIInstrInfo.h | 4 + .../llvm.amdgcn.sched.group.barrier.gfx11.ll | 403 ++++++++++++++++++ 4 files changed, 412 insertions(+), 5 deletions(-) create mode 100644 llvm/test/CodeGen/AMDGPU/llvm.amdgcn.sched.group.barrier.gfx11.ll diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td index e621bcd4c8419c..365e51c1bd224d 100644 --- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td +++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td @@ -232,7 +232,7 @@ def int_amdgcn_wave_barrier : ClangBuiltin<"__builtin_amdgcn_wave_barrier">, // scheduled across SCHED_BARRIER, i.e. allow ALU instructions to pass. // MASK = 0x0000 0002: VALU instructions may be scheduled across SCHED_BARRIER. // MASK = 0x0000 0004: SALU instructions may be scheduled across SCHED_BARRIER. -// MASK = 0x0000 0008: MFMA instructions may be scheduled across SCHED_BARRIER. +// MASK = 0x0000 0008: MFMA/WMMA instructions may be scheduled across SCHED_BARRIER. // MASK = 0x0000 0010: ALL VMEM instructions may be scheduled across SCHED_BARRIER. // MASK = 0x0000 0020: VMEM read instructions may be scheduled across SCHED_BARRIER. // MASK = 0x0000 0040: VMEM write instructions may be scheduled across SCHED_BARRIER. diff --git a/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp b/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp index 806e96441b055e..fc0df61952e48a 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp @@ -765,7 +765,7 @@ void MFMASmallGemmOpt::applyIGLPStrategy( // Count the number of MFMA instructions. unsigned MFMACount = 0; for (const MachineInstr &I : *DAG) - if (TII->isMFMA(I)) + if (TII->isMFMAorWMMA(I)) ++MFMACount; const unsigned PipelineSyncID = 0; @@ -848,11 +848,11 @@ bool SchedGroup::canAddMI(const MachineInstr &MI) const { Result = false; else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) && - (TII->isVALU(MI) || TII->isMFMA(MI) || TII->isSALU(MI))) + (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI))) Result = true; else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) && - TII->isVALU(MI) && !TII->isMFMA(MI)) + TII->isVALU(MI) && !TII->isMFMAorWMMA(MI)) Result = true; else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) && @@ -860,7 +860,7 @@ bool SchedGroup::canAddMI(const MachineInstr &MI) const { Result = true; else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) && - TII->isMFMA(MI)) + TII->isMFMAorWMMA(MI)) Result = true; else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) && diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h index 48d21d719747a9..6cbc02ab1dbc57 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h @@ -683,6 +683,10 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo { return get(Opcode).TSFlags & SIInstrFlags::IsWMMA; } + static bool isMFMAorWMMA(const MachineInstr &MI) { + return isMFMA(MI) || isWMMA(MI); + } + bool isDOT(uint16_t Opcode) const { return get(Opcode).TSFlags & SIInstrFlags::IsDOT; } diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.sched.group.barrier.gfx11.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.sched.group.barrier.gfx11.ll new file mode 100644 index 00000000000000..990a4e92170f01 --- /dev/null +++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.sched.group.barrier.gfx11.ll @@ -0,0 +1,403 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -march=amdgcn -mcpu=gfx1100 -verify-machineinstrs -misched-cluster=0 < %s | FileCheck -check-prefix=GCN %s +; RUN: llc -march=amdgcn -mcpu=gfx1100 -verify-machineinstrs -misched-cluster=0 -amdgpu-igrouplp-exact-solver-max-branches=250000 < %s | FileCheck -check-prefix=EXACTCUTOFF %s + +define amdgpu_kernel void @test_sched_group_barrier_pipeline_WMMA_cluster(ptr addrspace(3) noalias %in, ptr addrspace(3) noalias %out) #0 { +; GCN-LABEL: test_sched_group_barrier_pipeline_WMMA_cluster: +; GCN: ; %bb.0: ; %entry +; GCN-NEXT: s_load_b64 s[0:1], s[0:1], 0x24 +; GCN-NEXT: v_lshlrev_b32_e32 v40, 5, v0 +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: s_delay_alu instid0(VALU_DEP_1) +; GCN-NEXT: v_add_nc_u32_e32 v32, s0, v40 +; GCN-NEXT: v_dual_mov_b32 v81, s1 :: v_dual_add_nc_u32 v80, s1, v40 +; GCN-NEXT: ds_load_b128 v[4:7], v32 offset:16 +; GCN-NEXT: ds_load_b128 v[12:15], v32 offset:2064 +; GCN-NEXT: ds_load_b128 v[20:23], v32 offset:6160 +; GCN-NEXT: ds_load_b128 v[28:31], v32 offset:12304 +; GCN-NEXT: ds_load_b128 v[36:39], v32 offset:20496 +; GCN-NEXT: ds_load_b128 v[0:3], v32 +; GCN-NEXT: ds_load_b128 v[8:11], v32 offset:2048 +; GCN-NEXT: ds_load_b128 v[16:19], v32 offset:6144 +; GCN-NEXT: ds_load_b128 v[24:27], v32 offset:12288 +; GCN-NEXT: ds_load_b128 v[32:35], v32 offset:20480 +; GCN-NEXT: ; sched_group_barrier mask(0x00000100) size(10) SyncID(0) +; GCN-NEXT: s_waitcnt lgkmcnt(4) +; GCN-NEXT: v_mov_b32_e32 v47, v7 +; GCN-NEXT: s_waitcnt lgkmcnt(3) +; GCN-NEXT: v_mov_b32_e32 v55, v15 +; GCN-NEXT: s_waitcnt lgkmcnt(2) +; GCN-NEXT: v_mov_b32_e32 v63, v23 +; GCN-NEXT: s_waitcnt lgkmcnt(1) +; GCN-NEXT: v_mov_b32_e32 v71, v31 +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_dual_mov_b32 v79, v39 :: v_dual_mov_b32 v46, v6 +; GCN-NEXT: v_dual_mov_b32 v45, v5 :: v_dual_mov_b32 v44, v4 +; GCN-NEXT: v_dual_mov_b32 v43, v3 :: v_dual_mov_b32 v42, v2 +; GCN-NEXT: v_dual_mov_b32 v41, v1 :: v_dual_mov_b32 v40, v0 +; GCN-NEXT: v_dual_mov_b32 v54, v14 :: v_dual_mov_b32 v53, v13 +; GCN-NEXT: v_dual_mov_b32 v52, v12 :: v_dual_mov_b32 v51, v11 +; GCN-NEXT: v_dual_mov_b32 v50, v10 :: v_dual_mov_b32 v49, v9 +; GCN-NEXT: v_mov_b32_e32 v48, v8 +; GCN-NEXT: v_dual_mov_b32 v62, v22 :: v_dual_mov_b32 v61, v21 +; GCN-NEXT: v_dual_mov_b32 v60, v20 :: v_dual_mov_b32 v59, v19 +; GCN-NEXT: v_dual_mov_b32 v58, v18 :: v_dual_mov_b32 v57, v17 +; GCN-NEXT: v_mov_b32_e32 v56, v16 +; GCN-NEXT: v_dual_mov_b32 v70, v30 :: v_dual_mov_b32 v69, v29 +; GCN-NEXT: v_dual_mov_b32 v68, v28 :: v_dual_mov_b32 v67, v27 +; GCN-NEXT: v_dual_mov_b32 v66, v26 :: v_dual_mov_b32 v65, v25 +; GCN-NEXT: v_mov_b32_e32 v64, v24 +; GCN-NEXT: v_dual_mov_b32 v78, v38 :: v_dual_mov_b32 v77, v37 +; GCN-NEXT: v_dual_mov_b32 v76, v36 :: v_dual_mov_b32 v75, v35 +; GCN-NEXT: v_dual_mov_b32 v74, v34 :: v_dual_mov_b32 v73, v33 +; GCN-NEXT: v_mov_b32_e32 v72, v32 +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[40:47], v[0:7], v[0:7], v[40:47] +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[48:55], v[8:15], v[8:15], v[48:55] +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[56:63], v[16:23], v[16:23], v[56:63] +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[64:71], v[24:31], v[24:31], v[64:71] +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[72:79], v[32:39], v[32:39], v[72:79] +; GCN-NEXT: ds_store_b128 v80, v[44:47] offset:16 +; GCN-NEXT: ds_store_b128 v80, v[40:43] +; GCN-NEXT: ds_store_b128 v81, v[52:55] offset:2064 +; GCN-NEXT: ds_store_b128 v81, v[48:51] offset:2048 +; GCN-NEXT: ds_store_b128 v81, v[60:63] offset:4112 +; GCN-NEXT: ds_store_b128 v81, v[56:59] offset:4096 +; GCN-NEXT: ds_store_b128 v81, v[68:71] offset:6160 +; GCN-NEXT: ds_store_b128 v81, v[64:67] offset:6144 +; GCN-NEXT: ds_store_b128 v81, v[76:79] offset:8208 +; GCN-NEXT: ds_store_b128 v81, v[72:75] offset:8192 +; GCN-NEXT: ; sched_group_barrier mask(0x00000008) size(5) SyncID(0) +; GCN-NEXT: ; sched_group_barrier mask(0x00000200) size(10) SyncID(0) +; GCN-NEXT: s_endpgm +; +; EXACTCUTOFF-LABEL: test_sched_group_barrier_pipeline_WMMA_cluster: +; EXACTCUTOFF: ; %bb.0: ; %entry +; EXACTCUTOFF-NEXT: s_load_b64 s[0:1], s[0:1], 0x24 +; EXACTCUTOFF-NEXT: v_lshlrev_b32_e32 v40, 5, v0 +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(0) +; EXACTCUTOFF-NEXT: s_delay_alu instid0(VALU_DEP_1) +; EXACTCUTOFF-NEXT: v_add_nc_u32_e32 v32, s0, v40 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v81, s1 :: v_dual_add_nc_u32 v80, s1, v40 +; EXACTCUTOFF-NEXT: ds_load_b128 v[4:7], v32 offset:16 +; EXACTCUTOFF-NEXT: ds_load_b128 v[12:15], v32 offset:2064 +; EXACTCUTOFF-NEXT: ds_load_b128 v[20:23], v32 offset:6160 +; EXACTCUTOFF-NEXT: ds_load_b128 v[28:31], v32 offset:12304 +; EXACTCUTOFF-NEXT: ds_load_b128 v[36:39], v32 offset:20496 +; EXACTCUTOFF-NEXT: ds_load_b128 v[0:3], v32 +; EXACTCUTOFF-NEXT: ds_load_b128 v[8:11], v32 offset:2048 +; EXACTCUTOFF-NEXT: ds_load_b128 v[16:19], v32 offset:6144 +; EXACTCUTOFF-NEXT: ds_load_b128 v[24:27], v32 offset:12288 +; EXACTCUTOFF-NEXT: ds_load_b128 v[32:35], v32 offset:20480 +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000100) size(10) SyncID(0) +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(4) +; EXACTCUTOFF-NEXT: v_mov_b32_e32 v47, v7 +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(3) +; EXACTCUTOFF-NEXT: v_mov_b32_e32 v55, v15 +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(2) +; EXACTCUTOFF-NEXT: v_mov_b32_e32 v63, v23 +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(1) +; EXACTCUTOFF-NEXT: v_mov_b32_e32 v71, v31 +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(0) +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v79, v39 :: v_dual_mov_b32 v46, v6 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v45, v5 :: v_dual_mov_b32 v44, v4 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v43, v3 :: v_dual_mov_b32 v42, v2 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v41, v1 :: v_dual_mov_b32 v40, v0 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v54, v14 :: v_dual_mov_b32 v53, v13 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v52, v12 :: v_dual_mov_b32 v51, v11 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v50, v10 :: v_dual_mov_b32 v49, v9 +; EXACTCUTOFF-NEXT: v_mov_b32_e32 v48, v8 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v62, v22 :: v_dual_mov_b32 v61, v21 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v60, v20 :: v_dual_mov_b32 v59, v19 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v58, v18 :: v_dual_mov_b32 v57, v17 +; EXACTCUTOFF-NEXT: v_mov_b32_e32 v56, v16 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v70, v30 :: v_dual_mov_b32 v69, v29 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v68, v28 :: v_dual_mov_b32 v67, v27 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v66, v26 :: v_dual_mov_b32 v65, v25 +; EXACTCUTOFF-NEXT: v_mov_b32_e32 v64, v24 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v78, v38 :: v_dual_mov_b32 v77, v37 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v76, v36 :: v_dual_mov_b32 v75, v35 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v74, v34 :: v_dual_mov_b32 v73, v33 +; EXACTCUTOFF-NEXT: v_mov_b32_e32 v72, v32 +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[40:47], v[0:7], v[0:7], v[40:47] +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[48:55], v[8:15], v[8:15], v[48:55] +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[56:63], v[16:23], v[16:23], v[56:63] +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[64:71], v[24:31], v[24:31], v[64:71] +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[72:79], v[32:39], v[32:39], v[72:79] +; EXACTCUTOFF-NEXT: ds_store_b128 v80, v[44:47] offset:16 +; EXACTCUTOFF-NEXT: ds_store_b128 v80, v[40:43] +; EXACTCUTOFF-NEXT: ds_store_b128 v81, v[52:55] offset:2064 +; EXACTCUTOFF-NEXT: ds_store_b128 v81, v[48:51] offset:2048 +; EXACTCUTOFF-NEXT: ds_store_b128 v81, v[60:63] offset:4112 +; EXACTCUTOFF-NEXT: ds_store_b128 v81, v[56:59] offset:4096 +; EXACTCUTOFF-NEXT: ds_store_b128 v81, v[68:71] offset:6160 +; EXACTCUTOFF-NEXT: ds_store_b128 v81, v[64:67] offset:6144 +; EXACTCUTOFF-NEXT: ds_store_b128 v81, v[76:79] offset:8208 +; EXACTCUTOFF-NEXT: ds_store_b128 v81, v[72:75] offset:8192 +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000008) size(5) SyncID(0) +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000200) size(10) SyncID(0) +; EXACTCUTOFF-NEXT: s_endpgm +entry: + %idx = call i32 @llvm.amdgcn.workitem.id.x() + %load.0.addr = getelementptr <16 x half>, ptr addrspace(3) %in, i32 %idx + %load.0 = load <16 x half>, ptr addrspace(3) %load.0.addr + %load.1.addr = getelementptr <16 x half>, ptr addrspace(3) %load.0.addr, i32 64 + %load.1 = load <16 x half>, ptr addrspace(3) %load.1.addr + %load.2.addr = getelementptr <16 x half>, ptr addrspace(3) %load.1.addr, i32 128 + %load.2 = load <16 x half>, ptr addrspace(3) %load.2.addr + %load.3.addr = getelementptr <16 x half>, ptr addrspace(3) %load.2.addr, i32 192 + %load.3 = load <16 x half>, ptr addrspace(3) %load.3.addr + %load.4.addr = getelementptr <16 x half>, ptr addrspace(3) %load.3.addr, i32 256 + %load.4 = load <16 x half>, ptr addrspace(3) %load.4.addr + %mai.0 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.0, <16 x half> %load.0, <16 x half> %load.0, i1 0) + %mai.1 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.1, <16 x half> %load.1, <16 x half> %load.1, i1 0) + %mai.2 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.2, <16 x half> %load.2, <16 x half> %load.2, i1 0) + %mai.3 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.3, <16 x half> %load.3, <16 x half> %load.3, i1 0) + %mai.4 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.4, <16 x half> %load.4, <16 x half> %load.4, i1 0) + %store.0.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 %idx + store <16 x half> %mai.0, ptr addrspace(3) %store.0.addr + %store.1.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 64 + store <16 x half> %mai.1, ptr addrspace(3) %store.1.addr + %store.2.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 128 + store <16 x half> %mai.2, ptr addrspace(3) %store.2.addr + %store.3.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 192 + store <16 x half> %mai.3, ptr addrspace(3) %store.3.addr + %store.4.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 256 + store <16 x half> %mai.4, ptr addrspace(3) %store.4.addr + ; 10 DS read + call void @llvm.amdgcn.sched.group.barrier(i32 256, i32 10, i32 0) + ; 5 WMMA + call void @llvm.amdgcn.sched.group.barrier(i32 8, i32 5, i32 0) + ; 10 DS write + call void @llvm.amdgcn.sched.group.barrier(i32 512, i32 10, i32 0) + ret void +} + +define amdgpu_kernel void @test_sched_group_barrier_pipeline_WMMA_interleave(ptr addrspace(3) noalias %in, ptr addrspace(3) noalias %out) #0 { +; GCN-LABEL: test_sched_group_barrier_pipeline_WMMA_interleave: +; GCN: ; %bb.0: ; %entry +; GCN-NEXT: s_load_b64 s[0:1], s[0:1], 0x24 +; GCN-NEXT: v_lshlrev_b32_e32 v16, 5, v0 +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: s_delay_alu instid0(VALU_DEP_1) +; GCN-NEXT: v_add_nc_u32_e32 v17, s0, v16 +; GCN-NEXT: v_add_nc_u32_e32 v16, s1, v16 +; GCN-NEXT: ds_load_b128 v[4:7], v17 offset:16 +; GCN-NEXT: ds_load_b128 v[0:3], v17 +; GCN-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; GCN-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; GCN-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; GCN-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; GCN-NEXT: s_delay_alu instid0(VALU_DEP_1) +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; GCN-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; GCN-NEXT: ds_store_b128 v16, v[12:15] offset:16 +; GCN-NEXT: ds_store_b128 v16, v[8:11] +; GCN-NEXT: ds_load_b128 v[4:7], v17 offset:2064 +; GCN-NEXT: ds_load_b128 v[0:3], v17 offset:2048 +; GCN-NEXT: v_mov_b32_e32 v16, s1 +; GCN-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; GCN-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; GCN-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; GCN-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; GCN-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; GCN-NEXT: s_delay_alu instid0(VALU_DEP_1) +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; GCN-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; GCN-NEXT: ds_store_b128 v16, v[12:15] offset:2064 +; GCN-NEXT: ds_store_b128 v16, v[8:11] offset:2048 +; GCN-NEXT: ds_load_b128 v[4:7], v17 offset:6160 +; GCN-NEXT: ds_load_b128 v[0:3], v17 offset:6144 +; GCN-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; GCN-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; GCN-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; GCN-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; GCN-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; GCN-NEXT: s_delay_alu instid0(VALU_DEP_1) +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; GCN-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; GCN-NEXT: ds_store_b128 v16, v[12:15] offset:4112 +; GCN-NEXT: ds_store_b128 v16, v[8:11] offset:4096 +; GCN-NEXT: ds_load_b128 v[4:7], v17 offset:12304 +; GCN-NEXT: ds_load_b128 v[0:3], v17 offset:12288 +; GCN-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; GCN-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; GCN-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; GCN-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; GCN-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; GCN-NEXT: s_delay_alu instid0(VALU_DEP_1) +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; GCN-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; GCN-NEXT: ds_store_b128 v16, v[12:15] offset:6160 +; GCN-NEXT: ds_store_b128 v16, v[8:11] offset:6144 +; GCN-NEXT: ds_load_b128 v[4:7], v17 offset:20496 +; GCN-NEXT: ds_load_b128 v[0:3], v17 offset:20480 +; GCN-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; GCN-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; GCN-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; GCN-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; GCN-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; GCN-NEXT: s_delay_alu instid0(VALU_DEP_1) +; GCN-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; GCN-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; GCN-NEXT: ds_store_b128 v16, v[12:15] offset:8208 +; GCN-NEXT: ds_store_b128 v16, v[8:11] offset:8192 +; GCN-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; GCN-NEXT: s_endpgm +; +; EXACTCUTOFF-LABEL: test_sched_group_barrier_pipeline_WMMA_interleave: +; EXACTCUTOFF: ; %bb.0: ; %entry +; EXACTCUTOFF-NEXT: s_load_b64 s[0:1], s[0:1], 0x24 +; EXACTCUTOFF-NEXT: v_lshlrev_b32_e32 v16, 5, v0 +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(0) +; EXACTCUTOFF-NEXT: s_delay_alu instid0(VALU_DEP_1) +; EXACTCUTOFF-NEXT: v_add_nc_u32_e32 v17, s0, v16 +; EXACTCUTOFF-NEXT: v_add_nc_u32_e32 v16, s1, v16 +; EXACTCUTOFF-NEXT: ds_load_b128 v[4:7], v17 offset:16 +; EXACTCUTOFF-NEXT: ds_load_b128 v[0:3], v17 +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(0) +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; EXACTCUTOFF-NEXT: s_delay_alu instid0(VALU_DEP_1) +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[12:15] offset:16 +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[8:11] +; EXACTCUTOFF-NEXT: ds_load_b128 v[4:7], v17 offset:2064 +; EXACTCUTOFF-NEXT: ds_load_b128 v[0:3], v17 offset:2048 +; EXACTCUTOFF-NEXT: v_mov_b32_e32 v16, s1 +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(0) +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; EXACTCUTOFF-NEXT: s_delay_alu instid0(VALU_DEP_1) +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[12:15] offset:2064 +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[8:11] offset:2048 +; EXACTCUTOFF-NEXT: ds_load_b128 v[4:7], v17 offset:6160 +; EXACTCUTOFF-NEXT: ds_load_b128 v[0:3], v17 offset:6144 +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(0) +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; EXACTCUTOFF-NEXT: s_delay_alu instid0(VALU_DEP_1) +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[12:15] offset:4112 +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[8:11] offset:4096 +; EXACTCUTOFF-NEXT: ds_load_b128 v[4:7], v17 offset:12304 +; EXACTCUTOFF-NEXT: ds_load_b128 v[0:3], v17 offset:12288 +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(0) +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; EXACTCUTOFF-NEXT: s_delay_alu instid0(VALU_DEP_1) +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[12:15] offset:6160 +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[8:11] offset:6144 +; EXACTCUTOFF-NEXT: ds_load_b128 v[4:7], v17 offset:20496 +; EXACTCUTOFF-NEXT: ds_load_b128 v[0:3], v17 offset:20480 +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000100) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: s_waitcnt lgkmcnt(0) +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v15, v7 :: v_dual_mov_b32 v14, v6 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v13, v5 :: v_dual_mov_b32 v12, v4 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v11, v3 :: v_dual_mov_b32 v10, v2 +; EXACTCUTOFF-NEXT: v_dual_mov_b32 v9, v1 :: v_dual_mov_b32 v8, v0 +; EXACTCUTOFF-NEXT: s_delay_alu instid0(VALU_DEP_1) +; EXACTCUTOFF-NEXT: v_wmma_f16_16x16x16_f16 v[8:15], v[0:7], v[0:7], v[8:15] +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000008) size(1) SyncID(0) +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[12:15] offset:8208 +; EXACTCUTOFF-NEXT: ds_store_b128 v16, v[8:11] offset:8192 +; EXACTCUTOFF-NEXT: ; sched_group_barrier mask(0x00000200) size(2) SyncID(0) +; EXACTCUTOFF-NEXT: s_endpgm +entry: + %idx = call i32 @llvm.amdgcn.workitem.id.x() + %load.0.addr = getelementptr <16 x half>, ptr addrspace(3) %in, i32 %idx + %load.0 = load <16 x half>, ptr addrspace(3) %load.0.addr + %load.1.addr = getelementptr <16 x half>, ptr addrspace(3) %load.0.addr, i32 64 + %load.1 = load <16 x half>, ptr addrspace(3) %load.1.addr + %load.2.addr = getelementptr <16 x half>, ptr addrspace(3) %load.1.addr, i32 128 + %load.2 = load <16 x half>, ptr addrspace(3) %load.2.addr + %load.3.addr = getelementptr <16 x half>, ptr addrspace(3) %load.2.addr, i32 192 + %load.3 = load <16 x half>, ptr addrspace(3) %load.3.addr + %load.4.addr = getelementptr <16 x half>, ptr addrspace(3) %load.3.addr, i32 256 + %load.4 = load <16 x half>, ptr addrspace(3) %load.4.addr + %mai.0 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.0, <16 x half> %load.0, <16 x half> %load.0, i1 0) + %mai.1 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.1, <16 x half> %load.1, <16 x half> %load.1, i1 0) + %mai.2 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.2, <16 x half> %load.2, <16 x half> %load.2, i1 0) + %mai.3 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.3, <16 x half> %load.3, <16 x half> %load.3, i1 0) + %mai.4 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %load.4, <16 x half> %load.4, <16 x half> %load.4, i1 0) + %store.0.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 %idx + store <16 x half> %mai.0, ptr addrspace(3) %store.0.addr + %store.1.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 64 + store <16 x half> %mai.1, ptr addrspace(3) %store.1.addr + %store.2.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 128 + store <16 x half> %mai.2, ptr addrspace(3) %store.2.addr + %store.3.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 192 + store <16 x half> %mai.3, ptr addrspace(3) %store.3.addr + %store.4.addr = getelementptr <16 x half>, ptr addrspace(3) %out, i32 256 + store <16 x half> %mai.4, ptr addrspace(3) %store.4.addr + ; 2 DS read + call void @llvm.amdgcn.sched.group.barrier(i32 256, i32 2, i32 0) + ; 1 WMMA + call void @llvm.amdgcn.sched.group.barrier(i32 8, i32 1, i32 0) + ; 2 DS write + call void @llvm.amdgcn.sched.group.barrier(i32 512, i32 2, i32 0) + ; 2 DS read + call void @llvm.amdgcn.sched.group.barrier(i32 256, i32 2, i32 0) + ; 1 WMMA + call void @llvm.amdgcn.sched.group.barrier(i32 8, i32 1, i32 0) + ; 2 DS write + call void @llvm.amdgcn.sched.group.barrier(i32 512, i32 2, i32 0) + ; 2 DS read + call void @llvm.amdgcn.sched.group.barrier(i32 256, i32 2, i32 0) + ; 1 WMMA + call void @llvm.amdgcn.sched.group.barrier(i32 8, i32 1, i32 0) + ; 2 DS write + call void @llvm.amdgcn.sched.group.barrier(i32 512, i32 2, i32 0) + ; 2 DS read + call void @llvm.amdgcn.sched.group.barrier(i32 256, i32 2, i32 0) + ; 1 WMMA + call void @llvm.amdgcn.sched.group.barrier(i32 8, i32 1, i32 0) + ; 2 DS write + call void @llvm.amdgcn.sched.group.barrier(i32 512, i32 2, i32 0) + ; 2 DS read + call void @llvm.amdgcn.sched.group.barrier(i32 256, i32 2, i32 0) + ; 1 WMMA + call void @llvm.amdgcn.sched.group.barrier(i32 8, i32 1, i32 0) + ; 2 DS write + call void @llvm.amdgcn.sched.group.barrier(i32 512, i32 2, i32 0) + ret void +} + +declare i32 @llvm.amdgcn.workitem.id.x() #2 +declare void @llvm.amdgcn.sched.group.barrier(i32, i32, i32) #1 +declare <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half>, <16 x half> , <16 x half>, i1 immarg) #1 + +attributes #0 = { nounwind "amdgpu-flat-work-group-size"="1,32" } +attributes #1 = { nounwind } +attributes #2 = { nounwind readnone speculatable } From c90801457f7cbbaee97821a06a893f4146ab1b2e Mon Sep 17 00:00:00 2001 From: Nikolas Klauser Date: Sun, 20 Nov 2022 23:16:20 +0100 Subject: [PATCH 49/61] [libc++] Refactor deque::iterator algorithm optimizations This has multiple benefits: - The optimizations are also performed for the `ranges::` versions of the algorithms - Code duplication is reduced - it is simpler to add this optimization for other segmented iterators, like `ranges::join_view::iterator` - Algorithm code is removed from `` Reviewed By: ldionne, huixie90, #libc Spies: mstorsjo, sstefan1, EricWF, libcxx-commits, mgorny Differential Revision: https://reviews.llvm.org/D132505 --- libcxx/benchmarks/CMakeLists.txt | 1 + libcxx/benchmarks/deque_iterator.bench.cpp | 232 +++++++ libcxx/docs/ReleaseNotes.rst | 2 + libcxx/include/CMakeLists.txt | 1 + libcxx/include/__algorithm/copy.h | 71 ++- libcxx/include/__algorithm/copy_backward.h | 73 ++- libcxx/include/__algorithm/move.h | 73 ++- libcxx/include/__algorithm/move_backward.h | 83 ++- libcxx/include/__iterator/reverse_iterator.h | 1 + .../include/__iterator/segmented_iterator.h | 79 +++ libcxx/include/deque | 572 +----------------- libcxx/include/module.modulemap.in | 1 + .../nothrow_forward_iterator.compile.pass.cpp | 14 +- .../nothrow_forward_range.compile.pass.cpp | 12 - libcxx/test/libcxx/private_headers.verify.cpp | 1 + .../alg.copy/ranges.copy.pass.cpp | 50 +- .../alg.copy/ranges.copy.segmented.pass.cpp | 51 ++ .../alg.copy/ranges.copy_backward.pass.cpp | 91 ++- .../alg.move/ranges.move.pass.cpp | 119 +++- .../alg.move/ranges.move_backward.pass.cpp | 93 ++- libcxx/test/support/test_iterators.h | 15 + 21 files changed, 953 insertions(+), 682 deletions(-) create mode 100644 libcxx/benchmarks/deque_iterator.bench.cpp create mode 100644 libcxx/include/__iterator/segmented_iterator.h create mode 100644 libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy.segmented.pass.cpp diff --git a/libcxx/benchmarks/CMakeLists.txt b/libcxx/benchmarks/CMakeLists.txt index a7922302e189b2..fbc5144db3219b 100644 --- a/libcxx/benchmarks/CMakeLists.txt +++ b/libcxx/benchmarks/CMakeLists.txt @@ -176,6 +176,7 @@ set(BENCHMARK_TESTS algorithms/stable_sort.bench.cpp allocation.bench.cpp deque.bench.cpp + deque_iterator.bench.cpp filesystem.bench.cpp format_to_n.bench.cpp format_to.bench.cpp diff --git a/libcxx/benchmarks/deque_iterator.bench.cpp b/libcxx/benchmarks/deque_iterator.bench.cpp new file mode 100644 index 00000000000000..0eb23f2f3bea55 --- /dev/null +++ b/libcxx/benchmarks/deque_iterator.bench.cpp @@ -0,0 +1,232 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "benchmark/benchmark.h" + +namespace { +void run_sizes(auto benchmark) { + benchmark->Arg(0) + ->Arg(1) + ->Arg(2) + ->Arg(64) + ->Arg(512) + ->Arg(1024) + ->Arg(4000) + ->Arg(4096) + ->Arg(5500) + ->Arg(64000) + ->Arg(65536) + ->Arg(70000); +} + +template +void benchmark_containers(benchmark::State& state, FromContainer& d, ToContainer& v, Func&& func) { + for (auto _ : state) { + benchmark::DoNotOptimize(v); + benchmark::DoNotOptimize(d); + func(d.begin(), d.end(), v.begin()); + } +} + +template +void benchmark_deque_vector(benchmark::State& state, Func&& func) { + auto size = state.range(0); + std::deque d; + d.resize(size); + std::ranges::fill(d, 10); + std::vector v; + v.resize(size); + benchmark_containers(state, d, v, func); +} + +template +void benchmark_deque_deque(benchmark::State& state, Func&& func) { + auto size = state.range(0); + std::deque d; + d.resize(size); + std::ranges::fill(d, 10); + std::deque v; + v.resize(size); + benchmark_containers(state, d, v, func); +} + +template +void benchmark_vector_deque(benchmark::State& state, Func&& func) { + auto size = state.range(0); + std::vector d; + d.resize(size); + std::ranges::fill(d, 10); + std::deque v; + v.resize(size); + benchmark_containers(state, d, v, func); +} + +template +void benchmark_containers_backward(benchmark::State& state, FromContainer& d, ToContainer& v, Func&& func) { + for (auto _ : state) { + benchmark::DoNotOptimize(v); + benchmark::DoNotOptimize(d); + func(d.begin(), d.end(), v.end()); + } +} + +template +void benchmark_deque_vector_backward(benchmark::State& state, Func&& func) { + auto size = state.range(0); + std::deque d; + d.resize(size); + std::ranges::fill(d, 10); + std::vector v; + v.resize(size); + benchmark_containers_backward(state, d, v, func); +} + +template +void benchmark_deque_deque_backward(benchmark::State& state, Func&& func) { + auto size = state.range(0); + std::deque d; + d.resize(size); + std::ranges::fill(d, 10); + std::deque v; + v.resize(size); + benchmark_containers_backward(state, d, v, func); +} + +template +void benchmark_vector_deque_backward(benchmark::State& state, Func&& func) { + auto size = state.range(0); + std::vector d; + d.resize(size); + std::ranges::fill(d, 10); + std::deque v; + v.resize(size); + benchmark_containers_backward(state, d, v, func); +} + +struct CopyFunctor { + template + auto operator()(Args... args) const { + std::copy(std::forward(args)...); + } +} copy; + +struct MoveFunctor { + template + auto operator()(Args... args) const { + std::move(std::forward(args)...); + } +} move; + +struct CopyBackwardFunctor { + template + auto operator()(Args... args) const { + std::copy_backward(std::forward(args)...); + } +} copy_backward; + +struct MoveBackwardFunctor { + template + auto operator()(Args... args) const { + std::move_backward(std::forward(args)...); + } +} move_backward; + +// copy +void BM_deque_vector_copy(benchmark::State& state) { benchmark_deque_vector(state, copy); } +BENCHMARK(BM_deque_vector_copy)->Apply(run_sizes); + +void BM_deque_vector_ranges_copy(benchmark::State& state) { benchmark_deque_vector(state, std::ranges::copy); } +BENCHMARK(BM_deque_vector_ranges_copy)->Apply(run_sizes); + +void BM_deque_deque_copy(benchmark::State& state) { benchmark_deque_deque(state, copy); } +BENCHMARK(BM_deque_deque_copy)->Apply(run_sizes); + +void BM_deque_deque_ranges_copy(benchmark::State& state) { benchmark_deque_deque(state, std::ranges::copy); } +BENCHMARK(BM_deque_deque_ranges_copy)->Apply(run_sizes); + +void BM_vector_deque_copy(benchmark::State& state) { benchmark_vector_deque(state, copy); } +BENCHMARK(BM_vector_deque_copy)->Apply(run_sizes); + +void BM_vector_deque_ranges_copy(benchmark::State& state) { benchmark_vector_deque(state, std::ranges::copy); } +BENCHMARK(BM_vector_deque_ranges_copy)->Apply(run_sizes); + +// move +void BM_deque_vector_move(benchmark::State& state) { benchmark_deque_vector(state, move); } +BENCHMARK(BM_deque_vector_move)->Apply(run_sizes); + +void BM_deque_vector_ranges_move(benchmark::State& state) { benchmark_deque_vector(state, std::ranges::move); } +BENCHMARK(BM_deque_vector_ranges_move)->Apply(run_sizes); + +void BM_deque_deque_move(benchmark::State& state) { benchmark_deque_deque(state, move); } +BENCHMARK(BM_deque_deque_move)->Apply(run_sizes); + +void BM_deque_deque_ranges_move(benchmark::State& state) { benchmark_deque_deque(state, std::ranges::move); } +BENCHMARK(BM_deque_deque_ranges_move)->Apply(run_sizes); + +void BM_vector_deque_move(benchmark::State& state) { benchmark_vector_deque(state, move); } +BENCHMARK(BM_vector_deque_move)->Apply(run_sizes); + +void BM_vector_deque_ranges_move(benchmark::State& state) { benchmark_vector_deque(state, std::ranges::move); } +BENCHMARK(BM_vector_deque_ranges_move)->Apply(run_sizes); + +// copy_backward +void BM_deque_vector_copy_backward(benchmark::State& state) { benchmark_deque_vector_backward(state, copy_backward); } +BENCHMARK(BM_deque_vector_copy_backward)->Apply(run_sizes); + +void BM_deque_vector_ranges_copy_backward(benchmark::State& state) { + benchmark_deque_vector_backward(state, std::ranges::copy_backward); +} +BENCHMARK(BM_deque_vector_ranges_copy_backward)->Apply(run_sizes); + +void BM_deque_deque_copy_backward(benchmark::State& state) { benchmark_deque_deque_backward(state, copy_backward); } +BENCHMARK(BM_deque_deque_copy_backward)->Apply(run_sizes); + +void BM_deque_deque_ranges_copy_backward(benchmark::State& state) { + benchmark_deque_deque_backward(state, std::ranges::copy_backward); +} +BENCHMARK(BM_deque_deque_ranges_copy_backward)->Apply(run_sizes); + +void BM_vector_deque_copy_backward(benchmark::State& state) { benchmark_vector_deque_backward(state, copy_backward); } +BENCHMARK(BM_vector_deque_copy_backward)->Apply(run_sizes); + +void BM_vector_deque_ranges_copy_backward(benchmark::State& state) { + benchmark_vector_deque_backward(state, std::ranges::copy_backward); +} +BENCHMARK(BM_vector_deque_ranges_copy_backward)->Apply(run_sizes); + +// move_backward +void BM_deque_vector_move_backward(benchmark::State& state) { benchmark_deque_vector_backward(state, move_backward); } +BENCHMARK(BM_deque_vector_move_backward)->Apply(run_sizes); + +void BM_deque_vector_ranges_move_backward(benchmark::State& state) { + benchmark_deque_vector_backward(state, std::ranges::move_backward); +} +BENCHMARK(BM_deque_vector_ranges_move_backward)->Apply(run_sizes); + +void BM_deque_deque_move_backward(benchmark::State& state) { benchmark_deque_deque_backward(state, move_backward); } +BENCHMARK(BM_deque_deque_move_backward)->Apply(run_sizes); + +void BM_deque_deque_ranges_move_backward(benchmark::State& state) { + benchmark_deque_deque_backward(state, std::ranges::move_backward); +} +BENCHMARK(BM_deque_deque_ranges_move_backward)->Apply(run_sizes); + +void BM_vector_deque_move_backward(benchmark::State& state) { benchmark_vector_deque_backward(state, move_backward); } +BENCHMARK(BM_vector_deque_move_backward)->Apply(run_sizes); + +void BM_vector_deque_ranges_move_backward(benchmark::State& state) { + benchmark_vector_deque_backward(state, std::ranges::move_backward); +} +BENCHMARK(BM_vector_deque_ranges_move_backward)->Apply(run_sizes); + +} // namespace + +BENCHMARK_MAIN(); diff --git a/libcxx/docs/ReleaseNotes.rst b/libcxx/docs/ReleaseNotes.rst index e4679590d7ffe9..eeacec14289fdb 100644 --- a/libcxx/docs/ReleaseNotes.rst +++ b/libcxx/docs/ReleaseNotes.rst @@ -76,6 +76,8 @@ Improvements and New Features the C library. - Implemented ```` header from C++17 - `D122780 `_ Improved the performance of std::sort +- The ``ranges`` versions of ``copy``, ``move``, ``copy_backward`` and ``move_backward`` are now also optimized for + ``std::deque<>::iterator``, which can lead to up to 20x performance improvements on certain algorithms. Deprecations and Removals ------------------------- diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt index 2d202444894f3d..8c372a3b9b353a 100644 --- a/libcxx/include/CMakeLists.txt +++ b/libcxx/include/CMakeLists.txt @@ -404,6 +404,7 @@ set(files __iterator/readable_traits.h __iterator/reverse_access.h __iterator/reverse_iterator.h + __iterator/segmented_iterator.h __iterator/size.h __iterator/sortable.h __iterator/unreachable_sentinel.h diff --git a/libcxx/include/__algorithm/copy.h b/libcxx/include/__algorithm/copy.h index f33d7feda067ff..193a6df3165640 100644 --- a/libcxx/include/__algorithm/copy.h +++ b/libcxx/include/__algorithm/copy.h @@ -11,7 +11,10 @@ #include <__algorithm/copy_move_common.h> #include <__algorithm/iterator_operations.h> +#include <__algorithm/min.h> #include <__config> +#include <__iterator/segmented_iterator.h> +#include <__type_traits/common_type.h> #include <__utility/move.h> #include <__utility/pair.h> @@ -19,8 +22,15 @@ # pragma GCC system_header #endif +_LIBCPP_PUSH_MACROS +#include <__undef_macros> + _LIBCPP_BEGIN_NAMESPACE_STD +template +inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> __copy(_InIter, _Sent, _OutIter); + +template struct __copy_loop { template _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> @@ -33,6 +43,57 @@ struct __copy_loop { return std::make_pair(std::move(__first), std::move(__result)); } + + template ::value, int> = 0> + _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> + operator()(_InIter __first, _InIter __last, _OutIter __result) const { + using _Traits = __segmented_iterator_traits<_InIter>; + auto __sfirst = _Traits::__segment(__first); + auto __slast = _Traits::__segment(__last); + if (__sfirst == __slast) { + auto __iters = std::__copy<_AlgPolicy>(_Traits::__local(__first), _Traits::__local(__last), std::move(__result)); + return std::make_pair(__last, std::move(__iters.second)); + } + + __result = std::__copy<_AlgPolicy>(_Traits::__local(__first), _Traits::__end(__sfirst), std::move(__result)).second; + ++__sfirst; + while (__sfirst != __slast) { + __result = + std::__copy<_AlgPolicy>(_Traits::__begin(__sfirst), _Traits::__end(__sfirst), std::move(__result)).second; + ++__sfirst; + } + __result = + std::__copy<_AlgPolicy>(_Traits::__begin(__sfirst), _Traits::__local(__last), std::move(__result)).second; + return std::make_pair(__last, std::move(__result)); + } + + template ::value && + !__is_segmented_iterator<_InIter>::value && __is_segmented_iterator<_OutIter>::value, + int> = 0> + _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> + operator()(_InIter __first, _InIter __last, _OutIter __result) { + using _Traits = __segmented_iterator_traits<_OutIter>; + using _DiffT = typename common_type<__iter_diff_t<_InIter>, __iter_diff_t<_OutIter> >::type; + + if (__first == __last) + return std::make_pair(std::move(__first), std::move(__result)); + + auto __local_first = _Traits::__local(__result); + auto __segment_iterator = _Traits::__segment(__result); + while (true) { + auto __local_last = _Traits::__end(__segment_iterator); + auto __size = std::min<_DiffT>(__local_last - __local_first, __last - __first); + auto __iters = std::__copy<_AlgPolicy>(__first, __first + __size, __local_first); + __first = std::move(__iters.first); + + if (__first == __last) + return std::make_pair(std::move(__first), _Traits::__compose(__segment_iterator, std::move(__iters.second))); + + __local_first = _Traits::__begin(++__segment_iterator); + } + } }; struct __copy_trivial { @@ -46,20 +107,20 @@ struct __copy_trivial { }; template -pair<_InIter, _OutIter> -inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 +pair<_InIter, _OutIter> inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 __copy(_InIter __first, _Sent __last, _OutIter __result) { - return std::__dispatch_copy_or_move<_AlgPolicy, __copy_loop, __copy_trivial>( + return std::__dispatch_copy_or_move<_AlgPolicy, __copy_loop<_AlgPolicy>, __copy_trivial>( std::move(__first), std::move(__last), std::move(__result)); } template -inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_SINCE_CXX20 -_OutputIterator +inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_SINCE_CXX20 _OutputIterator copy(_InputIterator __first, _InputIterator __last, _OutputIterator __result) { return std::__copy<_ClassicAlgPolicy>(__first, __last, __result).second; } _LIBCPP_END_NAMESPACE_STD +_LIBCPP_POP_MACROS + #endif // _LIBCPP___ALGORITHM_COPY_H diff --git a/libcxx/include/__algorithm/copy_backward.h b/libcxx/include/__algorithm/copy_backward.h index be8c1ae9a4ba87..bb2a4328781cb0 100644 --- a/libcxx/include/__algorithm/copy_backward.h +++ b/libcxx/include/__algorithm/copy_backward.h @@ -11,7 +11,10 @@ #include <__algorithm/copy_move_common.h> #include <__algorithm/iterator_operations.h> +#include <__algorithm/min.h> #include <__config> +#include <__iterator/segmented_iterator.h> +#include <__type_traits/common_type.h> #include <__type_traits/is_copy_constructible.h> #include <__utility/move.h> #include <__utility/pair.h> @@ -20,8 +23,15 @@ # pragma GCC system_header #endif +_LIBCPP_PUSH_MACROS +#include <__undef_macros> + _LIBCPP_BEGIN_NAMESPACE_STD +template +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InIter, _OutIter> +__copy_backward(_InIter __first, _Sent __last, _OutIter __result); + template struct __copy_backward_loop { template @@ -36,6 +46,64 @@ struct __copy_backward_loop { return std::make_pair(std::move(__original_last_iter), std::move(__result)); } + + template ::value, int> = 0> + _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> + operator()(_InIter __first, _InIter __last, _OutIter __result) const { + using _Traits = __segmented_iterator_traits<_InIter>; + auto __sfirst = _Traits::__segment(__first); + auto __slast = _Traits::__segment(__last); + if (__sfirst == __slast) { + auto __iters = + std::__copy_backward<_AlgPolicy>(_Traits::__local(__first), _Traits::__local(__last), std::move(__result)); + return std::make_pair(__last, __iters.second); + } + + __result = + std::__copy_backward<_AlgPolicy>(_Traits::__begin(__slast), _Traits::__local(__last), std::move(__result)) + .second; + --__slast; + while (__sfirst != __slast) { + __result = + std::__copy_backward<_AlgPolicy>(_Traits::__begin(__slast), _Traits::__end(__slast), std::move(__result)) + .second; + --__slast; + } + __result = std::__copy_backward<_AlgPolicy>(_Traits::__local(__first), _Traits::__end(__slast), std::move(__result)) + .second; + return std::make_pair(__last, std::move(__result)); + } + + template ::value && + !__is_segmented_iterator<_InIter>::value && __is_segmented_iterator<_OutIter>::value, + int> = 0> + _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> + operator()(_InIter __first, _InIter __last, _OutIter __result) { + using _Traits = __segmented_iterator_traits<_OutIter>; + auto __orig_last = __last; + auto __segment_iterator = _Traits::__segment(__result); + + // When the range contains no elements, __result might not be a valid iterator + if (__first == __last) + return std::make_pair(__first, __result); + + auto __local_last = _Traits::__local(__result); + while (true) { + using _DiffT = typename common_type<__iter_diff_t<_InIter>, __iter_diff_t<_OutIter> >::type; + + auto __local_first = _Traits::__begin(__segment_iterator); + auto __size = std::min<_DiffT>(__local_last - __local_first, __last - __first); + auto __iter = std::__copy_backward<_AlgPolicy>(__last - __size, __last, __local_last).second; + __last -= __size; + + if (__first == __last) + return std::make_pair(std::move(__orig_last), _Traits::__compose(__segment_iterator, std::move(__iter))); + --__segment_iterator; + __local_last = _Traits::__end(__segment_iterator); + } + } }; struct __copy_backward_trivial { @@ -49,8 +117,7 @@ struct __copy_backward_trivial { }; template -_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 -pair<_BidirectionalIterator1, _BidirectionalIterator2> +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_BidirectionalIterator1, _BidirectionalIterator2> __copy_backward(_BidirectionalIterator1 __first, _Sentinel __last, _BidirectionalIterator2 __result) { return std::__dispatch_copy_or_move<_AlgPolicy, __copy_backward_loop<_AlgPolicy>, __copy_backward_trivial>( std::move(__first), std::move(__last), std::move(__result)); @@ -71,4 +138,6 @@ copy_backward(_BidirectionalIterator1 __first, _BidirectionalIterator1 __last, _LIBCPP_END_NAMESPACE_STD +_LIBCPP_POP_MACROS + #endif // _LIBCPP___ALGORITHM_COPY_BACKWARD_H diff --git a/libcxx/include/__algorithm/move.h b/libcxx/include/__algorithm/move.h index 2581a417b55810..ac95bda7b61f77 100644 --- a/libcxx/include/__algorithm/move.h +++ b/libcxx/include/__algorithm/move.h @@ -11,7 +11,10 @@ #include <__algorithm/copy_move_common.h> #include <__algorithm/iterator_operations.h> +#include <__algorithm/min.h> #include <__config> +#include <__iterator/segmented_iterator.h> +#include <__type_traits/common_type.h> #include <__type_traits/is_copy_constructible.h> #include <__utility/move.h> #include <__utility/pair.h> @@ -20,8 +23,15 @@ # pragma GCC system_header #endif +_LIBCPP_PUSH_MACROS +#include <__undef_macros> + _LIBCPP_BEGIN_NAMESPACE_STD +template +inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> +__move(_InIter __first, _Sent __last, _OutIter __result); + template struct __move_loop { template @@ -34,6 +44,57 @@ struct __move_loop { } return std::make_pair(std::move(__first), std::move(__result)); } + + template ::value, int> = 0> + _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> + operator()(_InIter __first, _InIter __last, _OutIter __result) const { + using _Traits = __segmented_iterator_traits<_InIter>; + auto __sfirst = _Traits::__segment(__first); + auto __slast = _Traits::__segment(__last); + if (__sfirst == __slast) { + auto __iters = std::__move<_AlgPolicy>(_Traits::__local(__first), _Traits::__local(__last), std::move(__result)); + return std::make_pair(__last, std::move(__iters.second)); + } + + __result = std::__move<_AlgPolicy>(_Traits::__local(__first), _Traits::__end(__sfirst), std::move(__result)).second; + ++__sfirst; + while (__sfirst != __slast) { + __result = + std::__move<_AlgPolicy>(_Traits::__begin(__sfirst), _Traits::__end(__sfirst), std::move(__result)).second; + ++__sfirst; + } + __result = + std::__move<_AlgPolicy>(_Traits::__begin(__sfirst), _Traits::__local(__last), std::move(__result)).second; + return std::make_pair(__last, std::move(__result)); + } + + template ::value && + !__is_segmented_iterator<_InIter>::value && __is_segmented_iterator<_OutIter>::value, + int> = 0> + _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> + operator()(_InIter __first, _InIter __last, _OutIter __result) { + using _Traits = __segmented_iterator_traits<_OutIter>; + using _DiffT = typename common_type<__iter_diff_t<_InIter>, __iter_diff_t<_OutIter> >::type; + + if (__first == __last) + return std::make_pair(std::move(__first), std::move(__result)); + + auto __local_first = _Traits::__local(__result); + auto __segment_iterator = _Traits::__segment(__result); + while (true) { + auto __local_last = _Traits::__end(__segment_iterator); + auto __size = std::min<_DiffT>(__local_last - __local_first, __last - __first); + auto __iters = std::__move<_AlgPolicy>(__first, __first + __size, __local_first); + __first = std::move(__iters.first); + + if (__first == __last) + return std::make_pair(std::move(__first), _Traits::__compose(__segment_iterator, std::move(__iters.second))); + + __local_first = _Traits::__begin(++__segment_iterator); + } + } }; struct __move_trivial { @@ -47,23 +108,23 @@ struct __move_trivial { }; template -inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 -pair<_InIter, _OutIter> +inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> __move(_InIter __first, _Sent __last, _OutIter __result) { return std::__dispatch_copy_or_move<_AlgPolicy, __move_loop<_AlgPolicy>, __move_trivial>( std::move(__first), std::move(__last), std::move(__result)); } template -inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 -_OutputIterator move(_InputIterator __first, _InputIterator __last, _OutputIterator __result) { +inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _OutputIterator +move(_InputIterator __first, _InputIterator __last, _OutputIterator __result) { static_assert(is_copy_constructible<_InputIterator>::value, "Iterators has to be copy constructible."); static_assert(is_copy_constructible<_OutputIterator>::value, "The output iterator has to be copy constructible."); - return std::__move<_ClassicAlgPolicy>( - std::move(__first), std::move(__last), std::move(__result)).second; + return std::__move<_ClassicAlgPolicy>(std::move(__first), std::move(__last), std::move(__result)).second; } _LIBCPP_END_NAMESPACE_STD +_LIBCPP_POP_MACROS + #endif // _LIBCPP___ALGORITHM_MOVE_H diff --git a/libcxx/include/__algorithm/move_backward.h b/libcxx/include/__algorithm/move_backward.h index 6636ca6667cefa..d4f013be6807f8 100644 --- a/libcxx/include/__algorithm/move_backward.h +++ b/libcxx/include/__algorithm/move_backward.h @@ -11,7 +11,10 @@ #include <__algorithm/copy_move_common.h> #include <__algorithm/iterator_operations.h> +#include <__algorithm/min.h> #include <__config> +#include <__iterator/segmented_iterator.h> +#include <__type_traits/common_type.h> #include <__type_traits/is_copy_constructible.h> #include <__utility/move.h> #include <__utility/pair.h> @@ -20,8 +23,15 @@ # pragma GCC system_header #endif +_LIBCPP_PUSH_MACROS +#include <__undef_macros> + _LIBCPP_BEGIN_NAMESPACE_STD +template +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_BidirectionalIterator1, _BidirectionalIterator2> +__move_backward(_BidirectionalIterator1 __first, _Sentinel __last, _BidirectionalIterator2 __result); + template struct __move_backward_loop { template @@ -36,6 +46,64 @@ struct __move_backward_loop { return std::make_pair(std::move(__original_last_iter), std::move(__result)); } + + template ::value, int> = 0> + _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> + operator()(_InIter __first, _InIter __last, _OutIter __result) const { + using _Traits = __segmented_iterator_traits<_InIter>; + auto __sfirst = _Traits::__segment(__first); + auto __slast = _Traits::__segment(__last); + if (__sfirst == __slast) { + auto __iters = + std::__move_backward<_AlgPolicy>(_Traits::__local(__first), _Traits::__local(__last), std::move(__result)); + return std::make_pair(__last, __iters.second); + } + + __result = + std::__move_backward<_AlgPolicy>(_Traits::__begin(__slast), _Traits::__local(__last), std::move(__result)) + .second; + --__slast; + while (__sfirst != __slast) { + __result = + std::__move_backward<_AlgPolicy>(_Traits::__begin(__slast), _Traits::__end(__slast), std::move(__result)) + .second; + --__slast; + } + __result = std::__move_backward<_AlgPolicy>(_Traits::__local(__first), _Traits::__end(__slast), std::move(__result)) + .second; + return std::make_pair(__last, std::move(__result)); + } + + template ::value && + !__is_segmented_iterator<_InIter>::value && __is_segmented_iterator<_OutIter>::value, + int> = 0> + _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_InIter, _OutIter> + operator()(_InIter __first, _InIter __last, _OutIter __result) { + using _Traits = __segmented_iterator_traits<_OutIter>; + using _DiffT = typename common_type<__iter_diff_t<_InIter>, __iter_diff_t<_OutIter> >::type; + + // When the range contains no elements, __result might not be a valid iterator + if (__first == __last) + return std::make_pair(__first, __result); + + auto __orig_last = __last; + + auto __local_last = _Traits::__local(__result); + auto __segment_iterator = _Traits::__segment(__result); + while (true) { + auto __local_first = _Traits::__begin(__segment_iterator); + auto __size = std::min<_DiffT>(__local_last - __local_first, __last - __first); + auto __iter = std::__move_backward<_AlgPolicy>(__last - __size, __last, __local_last).second; + __last -= __size; + + if (__first == __last) + return std::make_pair(std::move(__orig_last), _Traits::__compose(__segment_iterator, std::move(__iter))); + + __local_last = _Traits::__end(--__segment_iterator); + } + } }; struct __move_backward_trivial { @@ -49,8 +117,7 @@ struct __move_backward_trivial { }; template -_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 -pair<_BidirectionalIterator1, _BidirectionalIterator2> +_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_BidirectionalIterator1, _BidirectionalIterator2> __move_backward(_BidirectionalIterator1 __first, _Sentinel __last, _BidirectionalIterator2 __result) { static_assert(std::is_copy_constructible<_BidirectionalIterator1>::value && std::is_copy_constructible<_BidirectionalIterator1>::value, "Iterators must be copy constructible."); @@ -60,15 +127,13 @@ __move_backward(_BidirectionalIterator1 __first, _Sentinel __last, _Bidirectiona } template -inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_SINCE_CXX20 -_BidirectionalIterator2 -move_backward(_BidirectionalIterator1 __first, _BidirectionalIterator1 __last, - _BidirectionalIterator2 __result) -{ - return std::__move_backward<_ClassicAlgPolicy>( - std::move(__first), std::move(__last), std::move(__result)).second; +inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_SINCE_CXX20 _BidirectionalIterator2 +move_backward(_BidirectionalIterator1 __first, _BidirectionalIterator1 __last, _BidirectionalIterator2 __result) { + return std::__move_backward<_ClassicAlgPolicy>(std::move(__first), std::move(__last), std::move(__result)).second; } _LIBCPP_END_NAMESPACE_STD +_LIBCPP_POP_MACROS + #endif // _LIBCPP___ALGORITHM_MOVE_BACKWARD_H diff --git a/libcxx/include/__iterator/reverse_iterator.h b/libcxx/include/__iterator/reverse_iterator.h index 942235a580b845..f272e03c17a629 100644 --- a/libcxx/include/__iterator/reverse_iterator.h +++ b/libcxx/include/__iterator/reverse_iterator.h @@ -25,6 +25,7 @@ #include <__iterator/next.h> #include <__iterator/prev.h> #include <__iterator/readable_traits.h> +#include <__iterator/segmented_iterator.h> #include <__memory/addressof.h> #include <__ranges/access.h> #include <__ranges/concepts.h> diff --git a/libcxx/include/__iterator/segmented_iterator.h b/libcxx/include/__iterator/segmented_iterator.h new file mode 100644 index 00000000000000..f3cd1e5fa1f5da --- /dev/null +++ b/libcxx/include/__iterator/segmented_iterator.h @@ -0,0 +1,79 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef _LIBCPP___SEGMENTED_ITERATOR_H +#define _LIBCPP___SEGMENTED_ITERATOR_H + +// Segmented iterators are iterators over (not necessarily contiguous) sub-ranges. +// +// For example, std::deque stores its data into multiple blocks of contiguous memory, +// which are not stored contiguously themselves. The concept of segmented iterators +// allows algorithms to operate over these multi-level iterators natively, opening the +// door to various optimizations. See http://lafstern.org/matt/segmented.pdf for details. +// +// If __segmented_iterator_traits can be instantiated, the following functions and associated types must be provided: +// - Traits::__local_iterator +// The type of iterators used to iterate inside a segment. +// +// - Traits::__segment_iterator +// The type of iterators used to iterate over segments. +// Segment iterators can be forward iterators or bidirectional iterators, depending on the +// underlying data structure. +// +// - static __segment_iterator Traits::__segment(It __it) +// Returns an iterator to the segment that the provided iterator is in. +// +// - static __local_iterator Traits::__local(It __it) +// Returns the local iterator pointing to the element that the provided iterator points to. +// +// - static __local_iterator Traits::__begin(__segment_iterator __it) +// Returns the local iterator to the beginning of the segment that the provided iterator is pointing into. +// +// - static __local_iterator Traits::__end(__segment_iterator __it) +// Returns the one-past-the-end local iterator to the segment that the provided iterator is pointing into. +// +// - static It Traits::__compose(__segment_iterator, __local_iterator) +// Returns the iterator composed of the segment iterator and local iterator. + +#include <__config> +#include <__type_traits/integral_constant.h> +#include + +#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) +# pragma GCC system_header +#endif + +_LIBCPP_BEGIN_NAMESPACE_STD + +template +struct __segmented_iterator_traits; +/* exposition-only: +{ + using __segment_iterator = ...; + using __local_iterator = ...; + + static __segment_iterator __segment(_Iterator); + static __local_iterator __local(_Iterator); + static __local_iterator __begin(__segment_iterator); + static __local_iterator __end(__segment_iterator); + static _Iterator __compose(__segment_iterator, __local_iterator); +}; +*/ + +template +struct __has_specialization : false_type {}; + +template +struct __has_specialization<_Tp, sizeof(_Tp) * 0> : true_type {}; + +template +using __is_segmented_iterator = __has_specialization<__segmented_iterator_traits<_Iterator> >; + +_LIBCPP_END_NAMESPACE_STD + +#endif // _LIBCPP___SEGMENTED_ITERATOR_H diff --git a/libcxx/include/deque b/libcxx/include/deque index 8445883397451b..f2b8076c0af10e 100644 --- a/libcxx/include/deque +++ b/libcxx/include/deque @@ -176,6 +176,7 @@ template #include <__iterator/next.h> #include <__iterator/prev.h> #include <__iterator/reverse_iterator.h> +#include <__iterator/segmented_iterator.h> #include <__memory/allocator_destructor.h> #include <__memory/pointer_traits.h> #include <__memory/temp_value.h> @@ -216,98 +217,6 @@ _LIBCPP_BEGIN_NAMESPACE_STD template > class _LIBCPP_TEMPLATE_VIS deque; -template -class _LIBCPP_TEMPLATE_VIS __deque_iterator; - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -copy(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type* = 0); - -template -_LIBCPP_HIDE_FROM_ABI _OutputIterator -copy(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r); - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -copy(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r); - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -copy_backward(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type* = 0); - -template -_LIBCPP_HIDE_FROM_ABI _OutputIterator -copy_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r); - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -copy_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r); - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -move(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type* = 0); - -template -_LIBCPP_HIDE_FROM_ABI _OutputIterator -move(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r); - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -move(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r); - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -move_backward(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type* = 0); - -template -_LIBCPP_HIDE_FROM_ABI _OutputIterator -move_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r); - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -move_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r); - template struct __deque_block_size { static const _DiffType value = sizeof(_ValueType) < 256 ? 4096 / sizeof(_ValueType) : 16; @@ -478,464 +387,43 @@ private: template friend class _LIBCPP_TEMPLATE_VIS __deque_iterator; - template - friend - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> - copy(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type*); - - template - friend - _OutputIterator - copy(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r); - - template - friend - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> - copy(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r); - - template - friend - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> - copy_backward(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type*); - - template - friend - _OutputIterator - copy_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r); - - template - friend - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> - copy_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r); - - template - friend - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> - move(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type*); - - template - friend - _OutputIterator - move(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r); - - template - friend - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> - move(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r); - - template - friend - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> - move_backward(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type*); - - template - friend - _OutputIterator - move_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r); - - template - friend - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> - move_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r); + template + friend struct __segmented_iterator_traits; }; -template -const _DiffType __deque_iterator<_ValueType, _Pointer, _Reference, _MapPointer, - _DiffType, _BlockSize>::__block_size = - __deque_block_size<_ValueType, _DiffType>::value; - -// copy - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -copy(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type*) -{ - typedef typename __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::difference_type difference_type; - typedef typename __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::pointer pointer; - const difference_type __block_size = __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::__block_size; - while (__f != __l) - { - pointer __rb = __r.__ptr_; - pointer __re = *__r.__m_iter_ + __block_size; - difference_type __bs = __re - __rb; - difference_type __n = __l - __f; - _RAIter __m = __l; - if (__n > __bs) - { - __n = __bs; - __m = __f + __n; - } - _VSTD::copy(__f, __m, __rb); - __f = __m; - __r += __n; - } - return __r; -} - -template -_LIBCPP_HIDE_FROM_ABI _OutputIterator -copy(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r) -{ - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::difference_type difference_type; - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::pointer pointer; - const difference_type __block_size = __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::__block_size; - difference_type __n = __l - __f; - while (__n > 0) - { - pointer __fb = __f.__ptr_; - pointer __fe = *__f.__m_iter_ + __block_size; - difference_type __bs = __fe - __fb; - if (__bs > __n) - { - __bs = __n; - __fe = __fb + __bs; - } - __r = _VSTD::copy(__fb, __fe, __r); - __n -= __bs; - __f += __bs; - } - return __r; -} - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -copy(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r) -{ - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::difference_type difference_type; - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::pointer pointer; - const difference_type __block_size = __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::__block_size; - difference_type __n = __l - __f; - while (__n > 0) - { - pointer __fb = __f.__ptr_; - pointer __fe = *__f.__m_iter_ + __block_size; - difference_type __bs = __fe - __fb; - if (__bs > __n) - { - __bs = __n; - __fe = __fb + __bs; - } - __r = _VSTD::copy(__fb, __fe, __r); - __n -= __bs; - __f += __bs; - } - return __r; -} - -// copy_backward - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -copy_backward(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type*) -{ - typedef typename __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::difference_type difference_type; - typedef typename __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::pointer pointer; - while (__f != __l) - { - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __rp = _VSTD::prev(__r); - pointer __rb = *__rp.__m_iter_; - pointer __re = __rp.__ptr_ + 1; - difference_type __bs = __re - __rb; - difference_type __n = __l - __f; - _RAIter __m = __f; - if (__n > __bs) - { - __n = __bs; - __m = __l - __n; - } - _VSTD::copy_backward(__m, __l, __re); - __l = __m; - __r -= __n; - } - return __r; -} - -template -_LIBCPP_HIDE_FROM_ABI _OutputIterator -copy_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r) -{ - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::difference_type difference_type; - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::pointer pointer; - difference_type __n = __l - __f; - while (__n > 0) - { - --__l; - pointer __lb = *__l.__m_iter_; - pointer __le = __l.__ptr_ + 1; - difference_type __bs = __le - __lb; - if (__bs > __n) - { - __bs = __n; - __lb = __le - __bs; - } - __r = _VSTD::copy_backward(__lb, __le, __r); - __n -= __bs; - __l -= __bs - 1; - } - return __r; -} - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -copy_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r) -{ - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::difference_type difference_type; - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::pointer pointer; - difference_type __n = __l - __f; - while (__n > 0) - { - --__l; - pointer __lb = *__l.__m_iter_; - pointer __le = __l.__ptr_ + 1; - difference_type __bs = __le - __lb; - if (__bs > __n) - { - __bs = __n; - __lb = __le - __bs; - } - __r = _VSTD::copy_backward(__lb, __le, __r); - __n -= __bs; - __l -= __bs - 1; - } - return __r; -} - -// move - -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -move(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type*) -{ - typedef typename __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::difference_type difference_type; - typedef typename __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::pointer pointer; - const difference_type __block_size = __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::__block_size; - while (__f != __l) - { - pointer __rb = __r.__ptr_; - pointer __re = *__r.__m_iter_ + __block_size; - difference_type __bs = __re - __rb; - difference_type __n = __l - __f; - _RAIter __m = __l; - if (__n > __bs) - { - __n = __bs; - __m = __f + __n; - } - _VSTD::move(__f, __m, __rb); - __f = __m; - __r += __n; - } - return __r; -} - -template -_LIBCPP_HIDE_FROM_ABI _OutputIterator -move(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r) -{ - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::difference_type difference_type; - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::pointer pointer; - const difference_type __block_size = __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::__block_size; - difference_type __n = __l - __f; - while (__n > 0) - { - pointer __fb = __f.__ptr_; - pointer __fe = *__f.__m_iter_ + __block_size; - difference_type __bs = __fe - __fb; - if (__bs > __n) - { - __bs = __n; - __fe = __fb + __bs; - } - __r = _VSTD::move(__fb, __fe, __r); - __n -= __bs; - __f += __bs; - } - return __r; -} +template +struct __segmented_iterator_traits< + __deque_iterator<_ValueType, _Pointer, _Reference, _MapPointer, _DiffType, _BlockSize> > { +private: + using _Iterator = __deque_iterator<_ValueType, _Pointer, _Reference, _MapPointer, _DiffType, _BlockSize>; -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -move(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r) -{ - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::difference_type difference_type; - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::pointer pointer; - const difference_type __block_size = __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::__block_size; - difference_type __n = __l - __f; - while (__n > 0) - { - pointer __fb = __f.__ptr_; - pointer __fe = *__f.__m_iter_ + __block_size; - difference_type __bs = __fe - __fb; - if (__bs > __n) - { - __bs = __n; - __fe = __fb + __bs; - } - __r = _VSTD::move(__fb, __fe, __r); - __n -= __bs; - __f += __bs; - } - return __r; -} +public: + using __is_segmented_iterator = true_type; + using __segment_iterator = _MapPointer; + using __local_iterator = _Pointer; -// move_backward + static _LIBCPP_HIDE_FROM_ABI __segment_iterator __segment(_Iterator __iter) { return __iter.__m_iter_; } + static _LIBCPP_HIDE_FROM_ABI __local_iterator __local(_Iterator __iter) { return __iter.__ptr_; } + static _LIBCPP_HIDE_FROM_ABI __local_iterator __begin(__segment_iterator __iter) { return *__iter; } -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -move_backward(_RAIter __f, - _RAIter __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r, - typename enable_if<__is_cpp17_random_access_iterator<_RAIter>::value>::type*) -{ - typedef typename __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::difference_type difference_type; - typedef typename __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2>::pointer pointer; - while (__f != __l) - { - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __rp = _VSTD::prev(__r); - pointer __rb = *__rp.__m_iter_; - pointer __re = __rp.__ptr_ + 1; - difference_type __bs = __re - __rb; - difference_type __n = __l - __f; - _RAIter __m = __f; - if (__n > __bs) - { - __n = __bs; - __m = __l - __n; - } - _VSTD::move_backward(__m, __l, __re); - __l = __m; - __r -= __n; - } - return __r; -} + static _LIBCPP_HIDE_FROM_ABI __local_iterator __end(__segment_iterator __iter) { + return *__iter + _Iterator::__block_size; + } -template -_LIBCPP_HIDE_FROM_ABI _OutputIterator -move_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - _OutputIterator __r) -{ - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::difference_type difference_type; - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::pointer pointer; - difference_type __n = __l - __f; - while (__n > 0) - { - --__l; - pointer __lb = *__l.__m_iter_; - pointer __le = __l.__ptr_ + 1; - difference_type __bs = __le - __lb; - if (__bs > __n) - { - __bs = __n; - __lb = __le - __bs; + static _LIBCPP_HIDE_FROM_ABI _Iterator __compose(__segment_iterator __segment, __local_iterator __local) { + if (__local == __end(__segment)) { + ++__segment; + return _Iterator(__segment, *__segment); } - __r = _VSTD::move_backward(__lb, __le, __r); - __n -= __bs; - __l -= __bs - 1; - } - return __r; -} + return _Iterator(__segment, __local); + } +}; -template -_LIBCPP_HIDE_FROM_ABI __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> -move_backward(__deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __f, - __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1> __l, - __deque_iterator<_V2, _P2, _R2, _M2, _D2, _B2> __r) -{ - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::difference_type difference_type; - typedef typename __deque_iterator<_V1, _P1, _R1, _M1, _D1, _B1>::pointer pointer; - difference_type __n = __l - __f; - while (__n > 0) - { - --__l; - pointer __lb = *__l.__m_iter_; - pointer __le = __l.__ptr_ + 1; - difference_type __bs = __le - __lb; - if (__bs > __n) - { - __bs = __n; - __lb = __le - __bs; - } - __r = _VSTD::move_backward(__lb, __le, __r); - __n -= __bs; - __l -= __bs - 1; - } - return __r; -} +template +const _DiffType __deque_iterator<_ValueType, _Pointer, _Reference, _MapPointer, + _DiffType, _BlockSize>::__block_size = + __deque_block_size<_ValueType, _DiffType>::value; template */> class _LIBCPP_TEMPLATE_VIS deque diff --git a/libcxx/include/module.modulemap.in b/libcxx/include/module.modulemap.in index aa4da4dadc42ff..a6521a937a04ee 100644 --- a/libcxx/include/module.modulemap.in +++ b/libcxx/include/module.modulemap.in @@ -1014,6 +1014,7 @@ module std [system] { module readable_traits { private header "__iterator/readable_traits.h" } module reverse_access { private header "__iterator/reverse_access.h" } module reverse_iterator { private header "__iterator/reverse_iterator.h" } + module segmented_iterator { private header "__iterator/segmented_iterator.h" } module size { private header "__iterator/size.h" } module sortable { private header "__iterator/sortable.h" diff --git a/libcxx/test/libcxx/algorithms/specialized.algorithms/special.mem.concepts/nothrow_forward_iterator.compile.pass.cpp b/libcxx/test/libcxx/algorithms/specialized.algorithms/special.mem.concepts/nothrow_forward_iterator.compile.pass.cpp index c3da6f07303e88..9c488c255465f2 100644 --- a/libcxx/test/libcxx/algorithms/specialized.algorithms/special.mem.concepts/nothrow_forward_iterator.compile.pass.cpp +++ b/libcxx/test/libcxx/algorithms/specialized.algorithms/special.mem.concepts/nothrow_forward_iterator.compile.pass.cpp @@ -15,19 +15,9 @@ #include "test_iterators.h" -struct ForwardProxyIterator { - using value_type = int; - using difference_type = int; - ForwardProxyIterator& operator++(); - ForwardProxyIterator operator++(int); - bool operator==(const ForwardProxyIterator&) const; - - int operator*() const; -}; - static_assert(std::ranges::__nothrow_forward_iterator>); -static_assert(std::forward_iterator); -static_assert(!std::ranges::__nothrow_forward_iterator); +static_assert(std::forward_iterator>); +static_assert(!std::ranges::__nothrow_forward_iterator>); constexpr bool forward_subsumes_input(std::ranges::__nothrow_forward_iterator auto) { return true; diff --git a/libcxx/test/libcxx/algorithms/specialized.algorithms/special.mem.concepts/nothrow_forward_range.compile.pass.cpp b/libcxx/test/libcxx/algorithms/specialized.algorithms/special.mem.concepts/nothrow_forward_range.compile.pass.cpp index 06af5370ddb25b..2ddfdf66362e1e 100644 --- a/libcxx/test/libcxx/algorithms/specialized.algorithms/special.mem.concepts/nothrow_forward_range.compile.pass.cpp +++ b/libcxx/test/libcxx/algorithms/specialized.algorithms/special.mem.concepts/nothrow_forward_range.compile.pass.cpp @@ -16,18 +16,6 @@ #include "test_iterators.h" #include "test_range.h" -// Has to be a template to work with `test_range`. -template -struct ForwardProxyIterator { - using value_type = int; - using difference_type = int; - ForwardProxyIterator& operator++(); - ForwardProxyIterator operator++(int); - bool operator==(const ForwardProxyIterator&) const; - - int operator*() const; -}; - static_assert(std::ranges::__nothrow_forward_range>); static_assert(!std::ranges::__nothrow_forward_range>); static_assert(std::ranges::forward_range>); diff --git a/libcxx/test/libcxx/private_headers.verify.cpp b/libcxx/test/libcxx/private_headers.verify.cpp index e61eeddd2b4661..339324ca561cfd 100644 --- a/libcxx/test/libcxx/private_headers.verify.cpp +++ b/libcxx/test/libcxx/private_headers.verify.cpp @@ -435,6 +435,7 @@ END-SCRIPT #include <__iterator/readable_traits.h> // expected-error@*:* {{use of private header from outside its module: '__iterator/readable_traits.h'}} #include <__iterator/reverse_access.h> // expected-error@*:* {{use of private header from outside its module: '__iterator/reverse_access.h'}} #include <__iterator/reverse_iterator.h> // expected-error@*:* {{use of private header from outside its module: '__iterator/reverse_iterator.h'}} +#include <__iterator/segmented_iterator.h> // expected-error@*:* {{use of private header from outside its module: '__iterator/segmented_iterator.h'}} #include <__iterator/size.h> // expected-error@*:* {{use of private header from outside its module: '__iterator/size.h'}} #include <__iterator/sortable.h> // expected-error@*:* {{use of private header from outside its module: '__iterator/sortable.h'}} #include <__iterator/unreachable_sentinel.h> // expected-error@*:* {{use of private header from outside its module: '__iterator/unreachable_sentinel.h'}} diff --git a/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy.pass.cpp b/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy.pass.cpp index c2410568b3e3b8..7a5c0a7531727c 100644 --- a/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy.pass.cpp +++ b/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy.pass.cpp @@ -20,7 +20,9 @@ #include #include #include +#include #include +#include #include "almost_satisfies_types.h" #include "test_iterators.h" @@ -53,20 +55,21 @@ static_assert(!HasCopyR); static_assert(std::is_same_v, std::ranges::in_out_result>); +// clang-format off template constexpr void test_iterators() { { // simple test { - std::array in {1, 2, 3, 4}; + std::array in{1, 2, 3, 4}; std::array out; std::same_as> auto ret = - std::ranges::copy(In(in.data()), Sent(In(in.data() + in.size())), Out(out.data())); + std::ranges::copy(In(in.data()), Sent(In(in.data() + in.size())), Out(out.data())); assert(in == out); assert(base(ret.in) == in.data() + in.size()); assert(base(ret.out) == out.data() + out.size()); } { - std::array in {1, 2, 3, 4}; + std::array in{1, 2, 3, 4}; std::array out; auto range = std::ranges::subrange(In(in.data()), Sent(In(in.data() + in.size()))); std::same_as> auto ret = std::ranges::copy(range, Out(out.data())); @@ -88,12 +91,13 @@ constexpr void test_iterators() { std::array in; std::array out; auto range = std::ranges::subrange(In(in.data()), Sent(In(in.data() + in.size()))); - auto ret = std::ranges::copy(range, Out(out.data())); + auto ret = std::ranges::copy(range, Out(out.data())); assert(base(ret.in) == in.data()); assert(base(ret.out) == out.data()); } } } +// clang-format on constexpr bool test() { meta::for_each(meta::forward_iterator_list{}, []() { @@ -122,7 +126,7 @@ constexpr bool test() { } { // check that an iterator is returned with a borrowing range - std::array in {1, 2, 3, 4}; + std::array in{1, 2, 3, 4}; std::array out; std::same_as> auto ret = std::ranges::copy(std::views::all(in), out.data()); assert(ret.in == in.data() + 4); @@ -132,8 +136,8 @@ constexpr bool test() { { // check that every element is copied exactly once struct CopyOnce { - bool copied = false; - constexpr CopyOnce() = default; + bool copied = false; + constexpr CopyOnce() = default; constexpr CopyOnce(const CopyOnce& other) = delete; constexpr CopyOnce& operator=(const CopyOnce& other) { assert(!other.copied); @@ -142,16 +146,16 @@ constexpr bool test() { } }; { - std::array in {}; - std::array out {}; + std::array in{}; + std::array out{}; auto ret = std::ranges::copy(in.begin(), in.end(), out.begin()); assert(ret.in == in.end()); assert(ret.out == out.end()); assert(std::all_of(out.begin(), out.end(), [](const auto& e) { return e.copied; })); } { - std::array in {}; - std::array out {}; + std::array in{}; + std::array out{}; auto ret = std::ranges::copy(in, out.begin()); assert(ret.in == in.end()); assert(ret.out == out.end()); @@ -162,8 +166,8 @@ constexpr bool test() { { // check that the range is copied forwards struct OnlyForwardsCopyable { OnlyForwardsCopyable* next = nullptr; - bool canCopy = false; - OnlyForwardsCopyable() = default; + bool canCopy = false; + OnlyForwardsCopyable() = default; constexpr OnlyForwardsCopyable& operator=(const OnlyForwardsCopyable&) { assert(canCopy); if (next != nullptr) @@ -172,12 +176,12 @@ constexpr bool test() { } }; { - std::array in {}; - std::array out {}; - out[0].next = &out[1]; - out[1].next = &out[2]; + std::array in{}; + std::array out{}; + out[0].next = &out[1]; + out[1].next = &out[2]; out[0].canCopy = true; - auto ret = std::ranges::copy(in.begin(), in.end(), out.begin()); + auto ret = std::ranges::copy(in.begin(), in.end(), out.begin()); assert(ret.in == in.end()); assert(ret.out == out.end()); assert(out[0].canCopy); @@ -185,12 +189,12 @@ constexpr bool test() { assert(out[2].canCopy); } { - std::array in {}; - std::array out {}; - out[0].next = &out[1]; - out[1].next = &out[2]; + std::array in{}; + std::array out{}; + out[0].next = &out[1]; + out[1].next = &out[2]; out[0].canCopy = true; - auto ret = std::ranges::copy(in, out.begin()); + auto ret = std::ranges::copy(in, out.begin()); assert(ret.in == in.end()); assert(ret.out == out.end()); assert(out[0].canCopy); diff --git a/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy.segmented.pass.cpp b/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy.segmented.pass.cpp new file mode 100644 index 00000000000000..e8f8301ac4d6a3 --- /dev/null +++ b/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy.segmented.pass.cpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17 + +#include +#include +#include +#include +#include + +template +constexpr void test_containers() { + using InIter = typename InContainer::iterator; + using OutIter = typename OutContainer::iterator; + + { + InContainer in{1, 2, 3, 4}; + OutContainer out(4); + + std::same_as> auto ret = + std::ranges::copy(in.begin(), in.end(), out.begin()); + assert(std::ranges::equal(in, out)); + assert(ret.in == in.end()); + assert(ret.out == out.end()); + } + { + InContainer in{1, 2, 3, 4}; + OutContainer out(4); + std::same_as> auto ret = std::ranges::copy(in, out.begin()); + assert(std::ranges::equal(in, out)); + assert(ret.in == in.end()); + assert(ret.out == out.end()); + } +} + +int main(int, char**) { + if (!std::is_constant_evaluated()) { + test_containers, std::deque>(); + test_containers, std::vector>(); + test_containers, std::deque>(); + test_containers, std::vector>(); + } + + return 0; +} diff --git a/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy_backward.pass.cpp b/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy_backward.pass.cpp index a18ba9d6c344cf..3762948e4ed1f6 100644 --- a/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy_backward.pass.cpp +++ b/libcxx/test/std/algorithms/alg.modifying.operations/alg.copy/ranges.copy_backward.pass.cpp @@ -23,7 +23,9 @@ #include #include #include +#include #include +#include #include "almost_satisfies_types.h" #include "test_iterators.h" @@ -99,37 +101,84 @@ constexpr void test_iterators() { } } -template +template +constexpr void test_containers() { + { + InContainer in {1, 2, 3, 4}; + OutContainer out(4); + std::same_as> auto ret = + std::ranges::copy_backward(In(in.begin()), Sent(In(in.end())), Out(out.end())); + assert(std::ranges::equal(in, out)); + assert(base(ret.in) == in.end()); + assert(base(ret.out) == out.begin()); + } + { + InContainer in {1, 2, 3, 4}; + OutContainer out(4); + auto range = std::ranges::subrange(In(in.begin()), Sent(In(in.end()))); + std::same_as> auto ret = std::ranges::copy_backward(range, Out(out.end())); + assert(std::ranges::equal(in, out)); + assert(base(ret.in) == in.end()); + assert(base(ret.out) == out.begin()); + } +} + +template