Skip to content

Commit cb0effc

Browse files
authored
[flang][cuda] Using nvvm intrinsics for the syncthread and threadfence families of calls (#120020)
1 parent 9f3a611 commit cb0effc

File tree

4 files changed

+113
-21
lines changed

4 files changed

+113
-21
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,10 @@ struct IntrinsicLibrary {
392392
fir::ExtendedValue genSum(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
393393
void genSignalSubroutine(llvm::ArrayRef<fir::ExtendedValue>);
394394
void genSleep(llvm::ArrayRef<fir::ExtendedValue>);
395+
void genSyncThreads(llvm::ArrayRef<fir::ExtendedValue>);
396+
mlir::Value genSyncThreadsAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
397+
mlir::Value genSyncThreadsCount(mlir::Type, llvm::ArrayRef<mlir::Value>);
398+
mlir::Value genSyncThreadsOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
395399
fir::ExtendedValue genSystem(std::optional<mlir::Type>,
396400
mlir::ArrayRef<fir::ExtendedValue> args);
397401
void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);
@@ -401,6 +405,9 @@ struct IntrinsicLibrary {
401405
llvm::ArrayRef<fir::ExtendedValue>);
402406
fir::ExtendedValue genTranspose(mlir::Type,
403407
llvm::ArrayRef<fir::ExtendedValue>);
408+
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
409+
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
410+
void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
404411
fir::ExtendedValue genTrim(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
405412
fir::ExtendedValue genUbound(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
406413
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,10 @@ static constexpr IntrinsicHandler handlers[]{
646646
{"dim", asValue},
647647
{"mask", asBox, handleDynamicOptional}}},
648648
/*isElemental=*/false},
649+
{"syncthreads", &I::genSyncThreads, {}, /*isElemental=*/false},
650+
{"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
651+
{"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
652+
{"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
649653
{"system",
650654
&I::genSystem,
651655
{{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
@@ -655,6 +659,9 @@ static constexpr IntrinsicHandler handlers[]{
655659
{{{"count", asAddr}, {"count_rate", asAddr}, {"count_max", asAddr}}},
656660
/*isElemental=*/false},
657661
{"tand", &I::genTand},
662+
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
663+
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
664+
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
658665
{"trailz", &I::genTrailz},
659666
{"transfer",
660667
&I::genTransfer,
@@ -7437,6 +7444,52 @@ IntrinsicLibrary::genSum(mlir::Type resultType,
74377444
resultType, args);
74387445
}
74397446

7447+
// SYNCTHREADS
7448+
void IntrinsicLibrary::genSyncThreads(llvm::ArrayRef<fir::ExtendedValue> args) {
7449+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0";
7450+
mlir::FunctionType funcType =
7451+
mlir::FunctionType::get(builder.getContext(), {}, {});
7452+
auto funcOp = builder.createFunction(loc, funcName, funcType);
7453+
llvm::SmallVector<mlir::Value> noArgs;
7454+
builder.create<fir::CallOp>(loc, funcOp, noArgs);
7455+
}
7456+
7457+
// SYNCTHREADS_AND
7458+
mlir::Value
7459+
IntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType,
7460+
llvm::ArrayRef<mlir::Value> args) {
7461+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and";
7462+
mlir::MLIRContext *context = builder.getContext();
7463+
mlir::FunctionType ftype =
7464+
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
7465+
auto funcOp = builder.createFunction(loc, funcName, ftype);
7466+
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
7467+
}
7468+
7469+
// SYNCTHREADS_COUNT
7470+
mlir::Value
7471+
IntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType,
7472+
llvm::ArrayRef<mlir::Value> args) {
7473+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc";
7474+
mlir::MLIRContext *context = builder.getContext();
7475+
mlir::FunctionType ftype =
7476+
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
7477+
auto funcOp = builder.createFunction(loc, funcName, ftype);
7478+
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
7479+
}
7480+
7481+
// SYNCTHREADS_OR
7482+
mlir::Value
7483+
IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
7484+
llvm::ArrayRef<mlir::Value> args) {
7485+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or";
7486+
mlir::MLIRContext *context = builder.getContext();
7487+
mlir::FunctionType ftype =
7488+
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
7489+
auto funcOp = builder.createFunction(loc, funcName, ftype);
7490+
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
7491+
}
7492+
74407493
// SYSTEM
74417494
fir::ExtendedValue
74427495
IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,
@@ -7567,6 +7620,38 @@ IntrinsicLibrary::genTranspose(mlir::Type resultType,
75677620
return readAndAddCleanUp(resultMutableBox, resultType, "TRANSPOSE");
75687621
}
75697622

7623+
// THREADFENCE
7624+
void IntrinsicLibrary::genThreadFence(llvm::ArrayRef<fir::ExtendedValue> args) {
7625+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.gl";
7626+
mlir::FunctionType funcType =
7627+
mlir::FunctionType::get(builder.getContext(), {}, {});
7628+
auto funcOp = builder.createFunction(loc, funcName, funcType);
7629+
llvm::SmallVector<mlir::Value> noArgs;
7630+
builder.create<fir::CallOp>(loc, funcOp, noArgs);
7631+
}
7632+
7633+
// THREADFENCE_BLOCK
7634+
void IntrinsicLibrary::genThreadFenceBlock(
7635+
llvm::ArrayRef<fir::ExtendedValue> args) {
7636+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.cta";
7637+
mlir::FunctionType funcType =
7638+
mlir::FunctionType::get(builder.getContext(), {}, {});
7639+
auto funcOp = builder.createFunction(loc, funcName, funcType);
7640+
llvm::SmallVector<mlir::Value> noArgs;
7641+
builder.create<fir::CallOp>(loc, funcOp, noArgs);
7642+
}
7643+
7644+
// THREADFENCE_SYSTEM
7645+
void IntrinsicLibrary::genThreadFenceSystem(
7646+
llvm::ArrayRef<fir::ExtendedValue> args) {
7647+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.sys";
7648+
mlir::FunctionType funcType =
7649+
mlir::FunctionType::get(builder.getContext(), {}, {});
7650+
auto funcOp = builder.createFunction(loc, funcName, funcType);
7651+
llvm::SmallVector<mlir::Value> noArgs;
7652+
builder.create<fir::CallOp>(loc, funcOp, noArgs);
7653+
}
7654+
75707655
// TRIM
75717656
fir::ExtendedValue
75727657
IntrinsicLibrary::genTrim(mlir::Type resultType,

flang/module/cudadevice.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,27 @@ module cudadevice
1818
! Synchronization Functions
1919

2020
interface
21-
attributes(device) subroutine syncthreads() bind(c, name='__syncthreads')
21+
attributes(device) subroutine syncthreads()
2222
end subroutine
2323
end interface
2424
public :: syncthreads
2525

2626
interface
27-
attributes(device) integer function syncthreads_and(value) bind(c, name='__syncthreads_and')
27+
attributes(device) integer function syncthreads_and(value)
2828
integer :: value
2929
end function
3030
end interface
3131
public :: syncthreads_and
3232

3333
interface
34-
attributes(device) integer function syncthreads_count(value) bind(c, name='__syncthreads_count')
34+
attributes(device) integer function syncthreads_count(value)
3535
integer :: value
3636
end function
3737
end interface
3838
public :: syncthreads_count
3939

4040
interface
41-
attributes(device) integer function syncthreads_or(value) bind(c, name='__syncthreads_or')
41+
attributes(device) integer function syncthreads_or(value)
4242
integer :: value
4343
end function
4444
end interface
@@ -54,19 +54,19 @@ attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
5454
! Memory Fences
5555

5656
interface
57-
attributes(device) subroutine threadfence() bind(c, name='__threadfence')
57+
attributes(device) subroutine threadfence()
5858
end subroutine
5959
end interface
6060
public :: threadfence
6161

6262
interface
63-
attributes(device) subroutine threadfence_block() bind(c, name='__threadfence_block')
63+
attributes(device) subroutine threadfence_block()
6464
end subroutine
6565
end interface
6666
public :: threadfence_block
6767

6868
interface
69-
attributes(device) subroutine threadfence_system() bind(c, name='__threadfence_system')
69+
attributes(device) subroutine threadfence_system()
7070
end subroutine
7171
end interface
7272
public :: threadfence_system

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,20 @@ attributes(global) subroutine devsub()
1717
end
1818

1919
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
20-
! CHECK: fir.call @__syncthreads()
20+
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
2121
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
22-
! CHECK: fir.call @__threadfence()
23-
! CHECK: fir.call @__threadfence_block()
24-
! CHECK: fir.call @__threadfence_system()
25-
! CHECK: %{{.*}} = fir.call @__syncthreads_and(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
26-
! CHECK: %{{.*}} = fir.call @__syncthreads_count(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
27-
! CHECK: %{{.*}} = fir.call @__syncthreads_or(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
22+
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
23+
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
24+
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
25+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1_i32_0) fastmath<contract> : (i32) -> i32
26+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1_i32_1) fastmath<contract> : (i32) -> i32
27+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1_i32_2) fastmath<contract> : (i32) -> i32
2828

29-
! CHECK: func.func private @__syncthreads() attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncthreads", fir.proc_attrs = #fir.proc_attrs<bind_c>}
29+
! CHECK: func.func private @llvm.nvvm.barrier0()
3030
! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
31-
! CHECK: func.func private @__threadfence() attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__threadfence", fir.proc_attrs = #fir.proc_attrs<bind_c>}
32-
! CHECK: func.func private @__threadfence_block() attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__threadfence_block", fir.proc_attrs = #fir.proc_attrs<bind_c>}
33-
! CHECK: func.func private @__threadfence_system() attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__threadfence_system", fir.proc_attrs = #fir.proc_attrs<bind_c>}
34-
! CHECK: func.func private @__syncthreads_and(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) -> i32 attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncthreads_and", fir.proc_attrs = #fir.proc_attrs<bind_c>}
35-
! CHECK: func.func private @__syncthreads_count(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) -> i32 attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncthreads_count", fir.proc_attrs = #fir.proc_attrs<bind_c>}
36-
! CHECK: func.func private @__syncthreads_or(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) -> i32 attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncthreads_or", fir.proc_attrs = #fir.proc_attrs<bind_c>}
31+
! CHECK: func.func private @llvm.nvvm.membar.gl()
32+
! CHECK: func.func private @llvm.nvvm.membar.cta()
33+
! CHECK: func.func private @llvm.nvvm.membar.sys()
34+
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
35+
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
36+
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32

0 commit comments

Comments
 (0)