Skip to content

Commit a60fda5

Browse files
committed
[mlir][OpenMP] Restrict types for omp.parallel args
This patch restricts the value of `if` clause expression to an I1 value. It also restricts the value of `num_threads` clause expression to an I32 value. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D124142
1 parent c8603db commit a60fda5

File tree

4 files changed

+74
-20
lines changed

4 files changed

+74
-20
lines changed

Diff for: flang/lib/Lower/OpenMP.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,10 @@ genOMP(Fortran::lower::AbstractConverter &converter,
254254
if (const auto &ifClause =
255255
std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
256256
auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
257-
ifClauseOperand = fir::getBase(
257+
mlir::Value ifVal = fir::getBase(
258258
converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
259+
ifClauseOperand = firOpBuilder.createConvert(
260+
currentLocation, firOpBuilder.getI1Type(), ifVal);
259261
} else if (const auto &numThreadsClause =
260262
std::get_if<Fortran::parser::OmpClause::NumThreads>(
261263
&clause.u)) {

Diff for: flang/test/Lower/OpenMP/parallel.f90

+41-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@ end subroutine parallel_simple
1515
!===============================================================================
1616

1717
!FIRDialect-LABEL: func @_QPparallel_if
18-
subroutine parallel_if(alpha)
18+
subroutine parallel_if(alpha, beta, gamma)
1919
integer, intent(in) :: alpha
20+
logical, intent(in) :: beta
21+
logical(1) :: logical1
22+
logical(2) :: logical2
23+
logical(4) :: logical4
24+
logical(8) :: logical8
2025

2126
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
2227
!$omp parallel if(alpha .le. 0)
@@ -46,6 +51,41 @@ subroutine parallel_if(alpha)
4651
!OMPDialect: omp.terminator
4752
!$omp end parallel
4853

54+
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
55+
!$omp parallel if(beta)
56+
!FIRDialect: fir.call
57+
call f1()
58+
!OMPDialect: omp.terminator
59+
!$omp end parallel
60+
61+
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
62+
!$omp parallel if(logical1)
63+
!FIRDialect: fir.call
64+
call f1()
65+
!OMPDialect: omp.terminator
66+
!$omp end parallel
67+
68+
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
69+
!$omp parallel if(logical2)
70+
!FIRDialect: fir.call
71+
call f1()
72+
!OMPDialect: omp.terminator
73+
!$omp end parallel
74+
75+
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
76+
!$omp parallel if(logical4)
77+
!FIRDialect: fir.call
78+
call f1()
79+
!OMPDialect: omp.terminator
80+
!$omp end parallel
81+
82+
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
83+
!$omp parallel if(logical8)
84+
!FIRDialect: fir.call
85+
call f1()
86+
!OMPDialect: omp.terminator
87+
!$omp end parallel
88+
4989
end subroutine parallel_if
5090

5191
!===============================================================================

Diff for: mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
9999
of the parallel region.
100100
}];
101101

102-
let arguments = (ins Optional<AnyType>:$if_expr_var,
103-
Optional<AnyType>:$num_threads_var,
102+
let arguments = (ins Optional<I1>:$if_expr_var,
103+
Optional<IntLikeType>:$num_threads_var,
104104
Variadic<AnyType>:$allocate_vars,
105105
Variadic<AnyType>:$allocators_vars,
106106
Variadic<OpenMP_PointerLikeType>:$reduction_vars,

Diff for: mlir/test/Dialect/OpenMP/ops.mlir

+28-16
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ func.func @omp_terminator() -> () {
5151
omp.terminator
5252
}
5353

54-
func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32) -> () {
55-
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
54+
func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32) -> () {
55+
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
5656
"omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({
5757

5858
// test without if condition
59-
// CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
59+
// CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
6060
"omp.parallel"(%num_threads, %data_var, %data_var) ({
6161
omp.terminator
62-
}) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (si32, memref<i32>, memref<i32>) -> ()
62+
}) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (i32, memref<i32>, memref<i32>) -> ()
6363

6464
// CHECK: omp.barrier
6565
omp.barrier
@@ -71,13 +71,13 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : s
7171
}) {operand_segment_sizes = dense<[1,0,1,1,0]> : vector<5xi32>} : (i1, memref<i32>, memref<i32>) -> ()
7272

7373
// test without allocate
74-
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
74+
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
7575
"omp.parallel"(%if_cond, %num_threads) ({
7676
omp.terminator
77-
}) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, si32) -> ()
77+
}) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, i32) -> ()
7878

7979
omp.terminator
80-
}) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
80+
}) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, i32, memref<i32>, memref<i32>) -> ()
8181

8282
// test with multiple parameters for single variadic argument
8383
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
@@ -88,14 +88,26 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : s
8888
return
8989
}
9090

91-
func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32, %allocator : si32) -> () {
91+
func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32, %allocator : si32) -> () {
9292
// CHECK: omp.parallel
9393
omp.parallel {
9494
omp.terminator
9595
}
9696

97-
// CHECK: omp.parallel num_threads(%{{.*}} : si32)
98-
omp.parallel num_threads(%num_threads : si32) {
97+
// CHECK: omp.parallel num_threads(%{{.*}} : i32)
98+
omp.parallel num_threads(%num_threads : i32) {
99+
omp.terminator
100+
}
101+
102+
%n_index = arith.constant 2 : index
103+
// CHECK: omp.parallel num_threads(%{{.*}} : index)
104+
omp.parallel num_threads(%n_index : index) {
105+
omp.terminator
106+
}
107+
108+
%n_i64 = arith.constant 4 : i64
109+
// CHECK: omp.parallel num_threads(%{{.*}} : i64)
110+
omp.parallel num_threads(%n_i64 : i64) {
99111
omp.terminator
100112
}
101113

@@ -113,8 +125,8 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
113125
omp.terminator
114126
}
115127

116-
// CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref<i32>) proc_bind(close)
117-
omp.parallel num_threads(%num_threads : si32) if(%if_cond: i1) proc_bind(close) {
128+
// CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) private(%{{.*}} : memref<i32>) proc_bind(close)
129+
omp.parallel num_threads(%num_threads : i32) if(%if_cond: i1) proc_bind(close) {
118130
omp.terminator
119131
}
120132

@@ -347,14 +359,14 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
347359
}
348360

349361
// CHECK-LABEL: omp_target
350-
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
362+
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32) -> () {
351363

352364
// Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
353365
// CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
354366
"omp.target"(%if_cond, %device, %num_threads) ({
355367
// CHECK: omp.terminator
356368
omp.terminator
357-
}) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, si32 ) -> ()
369+
}) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, i32 ) -> ()
358370

359371
// CHECK: omp.barrier
360372
omp.barrier
@@ -363,14 +375,14 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> ()
363375
}
364376

365377
// CHECK-LABEL: omp_target_pretty
366-
func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
378+
func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : i32) -> () {
367379
// CHECK: omp.target if({{.*}}) device({{.*}})
368380
omp.target if(%if_cond) device(%device : si32) {
369381
omp.terminator
370382
}
371383

372384
// CHECK: omp.target if({{.*}}) device({{.*}}) nowait
373-
omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : si32) nowait {
385+
omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : i32) nowait {
374386
omp.terminator
375387
}
376388

0 commit comments

Comments
 (0)