From b98869748a367100fafca239b2c38d260ff02712 Mon Sep 17 00:00:00 2001 From: AmrDeveloper Date: Sun, 25 May 2025 16:38:21 +0200 Subject: [PATCH] [CIR] Backport Allow different Int types together in Vec ShiftOp --- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 32 +++++++++++++++++++------ clang/test/CIR/CodeGen/vectype-ext.cpp | 11 +++++++++ clang/test/CIR/CodeGen/vectype.cpp | 11 +++++++++ clang/test/CIR/IR/invalid.cir | 18 +++++++++++++- 4 files changed, 64 insertions(+), 8 deletions(-) diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index a5f57afc0aa8..c221f25110be 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -3939,15 +3939,33 @@ LogicalResult cir::BinOp::verify() { //===----------------------------------------------------------------------===// LogicalResult cir::ShiftOp::verify() { mlir::Operation *op = getOperation(); - mlir::Type resType = getResult().getType(); - bool isOp0Vec = mlir::isa(op->getOperand(0).getType()); - bool isOp1Vec = mlir::isa(op->getOperand(1).getType()); - if (isOp0Vec != isOp1Vec) + auto op0VecTy = mlir::dyn_cast(op->getOperand(0).getType()); + auto op1VecTy = mlir::dyn_cast(op->getOperand(1).getType()); + if (!op0VecTy ^ !op1VecTy) + return emitOpError() << "input types cannot be one vector and one scalar"; - if (isOp1Vec && op->getOperand(1).getType() != resType) { - return emitOpError() << "shift amount must have the type of the result " - << "if it is vector shift"; + + if (op0VecTy) { + if (op0VecTy.getSize() != op1VecTy.getSize()) + return emitOpError() << "input vector types must have the same size"; + + auto opResultTy = mlir::dyn_cast(getResult().getType()); + if (!opResultTy) + return emitOpError() << "the type of the result must be a vector " + << "if it is vector shift"; + + auto op0VecEleTy = mlir::cast(op0VecTy.getElementType()); + auto op1VecEleTy = mlir::cast(op1VecTy.getElementType()); + if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth()) + return emitOpError() + << "vector operands do not have the same elements sizes"; + + auto resVecEleTy = mlir::cast(opResultTy.getElementType()); + if (op0VecEleTy.getWidth() != resVecEleTy.getWidth()) + return emitOpError() << "vector operands and result type do not have the " + "same elements sizes"; } + return mlir::success(); } diff --git a/clang/test/CIR/CodeGen/vectype-ext.cpp b/clang/test/CIR/CodeGen/vectype-ext.cpp index 883562fbcc15..51aae6b33779 100644 --- a/clang/test/CIR/CodeGen/vectype-ext.cpp +++ b/clang/test/CIR/CodeGen/vectype-ext.cpp @@ -4,6 +4,7 @@ // RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM typedef int vi4 __attribute__((ext_vector_type(4))); +typedef unsigned int uvi4 __attribute__((ext_vector_type(4))); typedef int vi3 __attribute__((ext_vector_type(3))); typedef int vi2 __attribute__((ext_vector_type(2))); typedef double vd2 __attribute__((ext_vector_type(2))); @@ -535,3 +536,13 @@ void test_vec3() { // LLVM-NEXT: %[[#RES:]] = add <3 x i32> %[[#V3]], splat (i32 1) } + +void vector_integers_shifts_test() { + vi4 a = {1, 2, 3, 4}; + uvi4 b = {5u, 6u, 7u, 8u}; + + vi4 shl = a << b; + // CHECK: %{{[0-9]+}} = cir.shift(left, %{{[0-9]+}} : !cir.vector, %{{[0-9]+}} : !cir.vector) -> !cir.vector + uvi4 shr = b >> a; + // CHECK: %{{[0-9]+}} = cir.shift(right, %{{[0-9]+}} : !cir.vector, %{{[0-9]+}} : !cir.vector) -> !cir.vector +} diff --git a/clang/test/CIR/CodeGen/vectype.cpp b/clang/test/CIR/CodeGen/vectype.cpp index ac3e0681c60e..e365b4bd7b12 100644 --- a/clang/test/CIR/CodeGen/vectype.cpp +++ b/clang/test/CIR/CodeGen/vectype.cpp @@ -1,6 +1,7 @@ // RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s typedef int vi4 __attribute__((vector_size(16))); +typedef unsigned int uvi4 __attribute__((vector_size(16))); typedef double vd2 __attribute__((vector_size(16))); typedef long long vll2 __attribute__((vector_size(16))); typedef unsigned short vus2 __attribute__((vector_size(4))); @@ -198,3 +199,13 @@ void vector_double_test(int x, double y) { vus2 w = __builtin_convertvector(a, vus2); // CHECK: %{{[0-9]+}} = cir.cast(float_to_int, %{{[0-9]+}} : !cir.vector), !cir.vector } + +void vector_integers_shifts_test() { + vi4 a = {1, 2, 3, 4}; + uvi4 b = {5u, 6u, 7u, 8u}; + + vi4 shl = a << b; + // CHECK: %{{[0-9]+}} = cir.shift(left, %{{[0-9]+}} : !cir.vector, %{{[0-9]+}} : !cir.vector) -> !cir.vector + uvi4 shr = b >> a; + // CHECK: %{{[0-9]+}} = cir.shift(right, %{{[0-9]+}} : !cir.vector, %{{[0-9]+}} : !cir.vector) -> !cir.vector +} diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index 30fc93ecb979..16e2429c31ec 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -1434,11 +1434,27 @@ module { %0 = cir.alloca !cir.vector, !cir.ptr>, ["a", init] {alignment = 8 : i64} %1 = cir.load %0 : !cir.ptr>, !cir.vector %4 = cir.const #cir.const_vector<[#cir.int<12> : !s16i, #cir.int<12> : !s16i]> : !cir.vector - // expected-error@+1 {{'cir.shift' op shift amount must have the type of the result if it is vector shift}} + // expected-error@+1 {{'cir.shift' op vector operands do not have the same elements sizes}} %5 = cir.shift(left, %1 : !cir.vector, %4 : !cir.vector) -> !cir.vector cir.return } } + +// ----- + +!s32i = !cir.int +!s16i = !cir.int +module { + cir.func @test_shift_vec2() { + %0 = cir.alloca !cir.vector, !cir.ptr>, ["a", init] {alignment = 8 : i64} + %1 = cir.load %0 : !cir.ptr>, !cir.vector + %4 = cir.const #cir.const_vector<[#cir.int<12> : !s16i, #cir.int<12> : !s16i]> : !cir.vector + // expected-error@+1 {{'cir.shift' op vector operands do not have the same elements sizes}} + %5 = cir.shift(left, %4 : !cir.vector, %1 : !cir.vector) -> !cir.vector + cir.return + } +} + // ----- // Type of the attribute must be a CIR floating point type