Skip to content

Commit e660757

Browse files
committed
[CIR][Dialect] Verify bitcast does not contain address space conversion
1 parent c329de7 commit e660757

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,21 @@ LogicalResult CastOp::verify() {
510510
if (isa<StructType>(srcType) || isa<StructType>(resType))
511511
return success();
512512

513+
// Handle the pointer types first.
514+
auto srcPtrTy = mlir::dyn_cast<mlir::cir::PointerType>(srcType);
515+
auto resPtrTy = mlir::dyn_cast<mlir::cir::PointerType>(resType);
516+
517+
if (srcPtrTy && resPtrTy) {
518+
if (srcPtrTy.getAddrSpace() != resPtrTy.getAddrSpace()) {
519+
return emitOpError() << "result type address space does not match the "
520+
"address space of the operand";
521+
}
522+
return success();
523+
}
524+
513525
// This is the only cast kind where we don't want vector types to decay
514526
// into the element type.
515-
if ((!mlir::isa<mlir::cir::PointerType>(getSrc().getType()) ||
516-
!mlir::isa<mlir::cir::PointerType>(getResult().getType())) &&
517-
(!mlir::isa<mlir::cir::VectorType>(getSrc().getType()) ||
527+
if ((!mlir::isa<mlir::cir::VectorType>(getSrc().getType()) ||
518528
!mlir::isa<mlir::cir::VectorType>(getResult().getType())))
519529
return emitOpError()
520530
<< "requires !cir.ptr or !cir.vector type for source and result";

clang/test/CIR/IR/invalid.cir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,3 +1298,18 @@ module {
12981298
cir.return
12991299
}
13001300
}
1301+
1302+
// -----
1303+
1304+
!s32i = !cir.int<s, 32>
1305+
1306+
module {
1307+
1308+
cir.func @test_bitcast_addrspace() {
1309+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["tmp"] {alignment = 4 : i64}
1310+
// expected-error@+1 {{'cir.cast' op result type address space does not match the address space of the operand}}
1311+
%1 = cir.cast(bitcast, %0 : !cir.ptr<!s32i>), !cir.ptr<!s32i, addrspace(offload_local)>
1312+
}
1313+
1314+
}
1315+

0 commit comments

Comments
 (0)