Skip to content

Conversation

@chichunchen
Copy link
Contributor

@chichunchen chichunchen commented Jan 6, 2026

Extend OpenMP device clause lowering for target data, target enter data,
target exit data, and target update to accept non-constant values.
Previously, only constant device IDs could be lowered to LLVM IR.

Add Flang tests to validate device clause handling and mark the feature
as supported in the OpenMPSupport documentation. New tests cover:

  • target teams
  • target teams distribute
  • target teams distribute parallel do
  • target teams distribute parallel do simd
  • target data

Tests for target update and target enter/exit were
already present in Flang.

@llvmbot llvmbot added mlir:llvm mlir flang Flang issues not falling into any other category mlir:openmp flang:openmp labels Jan 6, 2026
@llvmbot
Copy link
Member

llvmbot commented Jan 6, 2026

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir-llvm

Author: Chi-Chun, Chen (chichunchen)

Changes

Allow the OpenMP device clause on target data/enter/exit/update to be lowered from variables instead of constatns only.


Patch is 22.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174665.diff

4 Files Affected:

  • (modified) flang/docs/OpenMPSupport.md (+2-2)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+21-27)
  • (modified) mlir/test/Target/LLVMIR/omptarget-device.mlir (+210-4)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+4-2)
diff --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index 21966c5489108..1b25fc4890847 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -37,9 +37,9 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
 | simd construct                                             | P      | Implicit linearization is skipped if iv is a pointer or allocatable|
 | declare simd construct                                     | N      | |
 | do simd construct                                          | P      | linear clause is not supported |
-| target data construct                                      | P      | device clause not supported |
+| target data construct                                      | Y      | |
 | target construct                                           | Y      | |
-| target update construct                                    | P      | device clause not supported |
+| target update construct                                    | Y      | |
 | declare target directive                                   | Y      | |
 | teams construct                                            | Y      | |
 | distribute construct                                       | Y      | |
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 614f06017a324..2698a4b4e89db 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -456,11 +456,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
             omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
       .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
           [&](auto op) { checkDepend(op, result); })
-      .Case<omp::TargetUpdateOp>([&](auto op) {
-        checkDepend(op, result);
-        checkDevice(op, result);
-      })
-      .Case<omp::TargetDataOp>([&](auto op) { checkDevice(op, result); })
+      .Case<omp::TargetUpdateOp>([&](auto op) { checkDepend(op, result); })
       .Case([&](omp::TargetOp op) {
         checkAllocate(op, result);
         checkBare(op, result);
@@ -5051,7 +5047,7 @@ static LogicalResult
 convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *ifCond = nullptr;
-  int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
+  llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
   SmallVector<Value> mapVars;
   SmallVector<Value> useDevicePtrVars;
   SmallVector<Value> useDeviceAddrVars;
@@ -5067,6 +5063,11 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
   bool isOffloadEntry =
       isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
 
+  auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
+    llvm::Value *v = moduleTranslation.lookupValue(dev);
+    return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
+  };
+
   LogicalResult result =
       llvm::TypeSwitch<Operation *, LogicalResult>(op)
           .Case([&](omp::TargetDataOp dataOp) {
@@ -5076,10 +5077,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = dataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = dataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = dataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             mapVars = dataOp.getMapVars();
             useDevicePtrVars = dataOp.getUseDevicePtrVars();
@@ -5093,10 +5092,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = enterDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = enterDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = enterDataOp.getDevice())
+              deviceID = getDeviceID(devId);
+
             RTLFn =
                 enterDataOp.getNowait()
                     ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
@@ -5112,10 +5110,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = exitDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = exitDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = exitDataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             RTLFn = exitDataOp.getNowait()
                         ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
@@ -5131,10 +5127,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = updateDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = updateDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = updateDataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             RTLFn =
                 updateDataOp.getNowait()
@@ -5287,13 +5281,13 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
     if (isa<omp::TargetDataOp>(op))
       return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
-                                          builder.getInt64(deviceID), ifCond,
-                                          info, genMapInfoCB, customMapperCB,
+                                          deviceID, ifCond, info, genMapInfoCB,
+                                          customMapperCB,
                                           /*MapperFunc=*/nullptr, bodyGenCB,
                                           /*DeviceAddrCB=*/nullptr);
-    return ompBuilder->createTargetData(
-        ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
-        info, genMapInfoCB, customMapperCB, &RTLFn);
+    return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
+                                        deviceID, ifCond, info, genMapInfoCB,
+                                        customMapperCB, &RTLFn);
   }();
 
   if (failed(handleError(afterIP, *op)))
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
index b4c9744cc0c87..64f488a7bb44a 100644
--- a/mlir/test/Target/LLVMIR/omptarget-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
 module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
-  llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+  llvm.func @_QPopenmp_target(%d16 : i16, %d32 : i32, %d64 : i64) {
     %x  = llvm.mlir.constant(0 : i32) : i32
 
     // Constant i16 -> i64 in the runtime call.
@@ -47,7 +47,7 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-
   }
 }
 
-// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK-LABEL: define void @_QPopenmp_target(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
 // CHECK: br label %entry
 // CHECK: entry:
 
@@ -65,4 +65,210 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-
 // CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
 
 // Variable i64
-// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @_QPopenmp_target_data(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %one = llvm.mlir.constant(1 : i64) : i64
+    %buf = llvm.alloca %one x i32 : (i64) -> !llvm.ptr
+    %map = omp.map.info var_ptr(%buf : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_data device(%c1_i16 : i16) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_data device(%c2_i32 : i32) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_data device(%c3_i64 : i64) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i16 -> cast to i64.
+    omp.target_data device(%d16 : i16) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i32 -> cast to i64.
+    omp.target_data device(%d32 : i32) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i64 stays i64.
+    omp.target_data device(%d64 : i64) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @_QPopenmp_target_data(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 1, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 1, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 2, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 2, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 3, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 3, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %{{.*}}, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %{{.*}}, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @_QPomp_target_enter_exit(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %c1 = llvm.mlir.constant(1 : i64) : i64
+    %var = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
+
+    %m_to = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = "var"}
+    %m_from = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = "var"}
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_enter_data device(%c1_i16 : i16) map_entries(%m_to : !llvm.ptr)
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_enter_data device(%c2_i32 : i32) map_entries(%m_to : !llvm.ptr)
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_enter_data device(%c3_i64 : i64) map_entries(%m_to : !llvm.ptr)
+
+    // ---- Variable cases (enter) ----
+    omp.target_enter_data device(%d16 : i16) map_entries(%m_to : !llvm.ptr)
+    omp.target_enter_data device(%d32 : i32) map_entries(%m_to : !llvm.ptr)
+    omp.target_enter_data device(%d64 : i64) map_entries(%m_to : !llvm.ptr)
+
+    // ---- Constant cases (exit) ----
+    omp.target_exit_data device(%c1_i16 : i16) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%c2_i32 : i32) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%c3_i64 : i64) map_entries(%m_from : !llvm.ptr)
+
+    // ---- Variable cases (exit) ----
+    omp.target_exit_data device(%d16 : i16) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%d32 : i32) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%d64 : i64) map_entries(%m_from : !llvm.ptr)
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @_QPomp_target_enter_exit(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant enter cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable enter cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64_BEGIN:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D16_I64_BEGIN]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64_BEGIN:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D32_I64_BEGIN]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Constant exit cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable exit cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64_END:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D16_I64_END]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64_END:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D32_I64_END]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// CHECK: ret void
+// CHECK: }
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @target_update_dev_clause(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %c1 = llvm.mlir.constant(1 : i64) : i64
+    %var = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
+    %m = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = "var"}
+
+    // ---- Constant cases ----
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_update device(%c1_i16 : i16) map_entries(%m : !llvm.ptr)
+
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_update device(%c2_i32 : i32) map_entries(%m : !llvm.ptr)
+
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_update device(%c3_i64 : i64) map_entries(%m : !llvm.ptr)
+
+    // ---- Variable cases ----
+    omp.target_update device(%d16 : i16) map_entries(%m : !llvm.ptr)
+    omp.target_update device(%d32 : i32) map_entries(%m : !llvm.ptr)
+    omp.target_update device(%d64 : i64) map_entries(%m : !llvm.ptr)
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @target_update_dev_clause(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// CHECK: ret void
+// CHECK: }
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index e289d5d013eaa..0b4d63125f82f 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -156,6 +156,7 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
 // CHECK:         %[[VAL_8:.*]] = load i32, ptr %[[VAL_7]], align 4
 // CHECK:         %[[VAL_9:.*]] = icmp slt i32 %[[VAL_8]], 10
 // CHECK:         %[[VAL_10:.*]] = load i32, ptr %[[VAL_6]], align 4
+// CHECK:         %[[DEV_I64_BEGIN:.*]] = sext i32 %[[VAL_10:.*]] to i64
 // CHECK:         br label %[[VAL_11:.*]]
 // CHECK:       entry:                                            ; preds = %[[VAL_12:.*]]
 // CHECK:         br i1 %[[VAL_9]], label %[[VAL_13:.*]], label %[[VAL_14:.*]]
@@ -176,7 +177,7 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
 // CHECK:  ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 6, 2026

@llvm/pr-subscribers-mlir-openmp

Author: Chi-Chun, Chen (chichunchen)

Changes

Allow the OpenMP device clause on target data/enter/exit/update to be lowered from variables instead of constatns only.


Patch is 22.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174665.diff

4 Files Affected:

  • (modified) flang/docs/OpenMPSupport.md (+2-2)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+21-27)
  • (modified) mlir/test/Target/LLVMIR/omptarget-device.mlir (+210-4)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+4-2)
diff --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index 21966c5489108..1b25fc4890847 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -37,9 +37,9 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
 | simd construct                                             | P      | Implicit linearization is skipped if iv is a pointer or allocatable|
 | declare simd construct                                     | N      | |
 | do simd construct                                          | P      | linear clause is not supported |
-| target data construct                                      | P      | device clause not supported |
+| target data construct                                      | Y      | |
 | target construct                                           | Y      | |
-| target update construct                                    | P      | device clause not supported |
+| target update construct                                    | Y      | |
 | declare target directive                                   | Y      | |
 | teams construct                                            | Y      | |
 | distribute construct                                       | Y      | |
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 614f06017a324..2698a4b4e89db 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -456,11 +456,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
             omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
       .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
           [&](auto op) { checkDepend(op, result); })
-      .Case<omp::TargetUpdateOp>([&](auto op) {
-        checkDepend(op, result);
-        checkDevice(op, result);
-      })
-      .Case<omp::TargetDataOp>([&](auto op) { checkDevice(op, result); })
+      .Case<omp::TargetUpdateOp>([&](auto op) { checkDepend(op, result); })
       .Case([&](omp::TargetOp op) {
         checkAllocate(op, result);
         checkBare(op, result);
@@ -5051,7 +5047,7 @@ static LogicalResult
 convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *ifCond = nullptr;
-  int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
+  llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
   SmallVector<Value> mapVars;
   SmallVector<Value> useDevicePtrVars;
   SmallVector<Value> useDeviceAddrVars;
@@ -5067,6 +5063,11 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
   bool isOffloadEntry =
       isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
 
+  auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
+    llvm::Value *v = moduleTranslation.lookupValue(dev);
+    return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
+  };
+
   LogicalResult result =
       llvm::TypeSwitch<Operation *, LogicalResult>(op)
           .Case([&](omp::TargetDataOp dataOp) {
@@ -5076,10 +5077,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = dataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = dataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = dataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             mapVars = dataOp.getMapVars();
             useDevicePtrVars = dataOp.getUseDevicePtrVars();
@@ -5093,10 +5092,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = enterDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = enterDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = enterDataOp.getDevice())
+              deviceID = getDeviceID(devId);
+
             RTLFn =
                 enterDataOp.getNowait()
                     ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
@@ -5112,10 +5110,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = exitDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = exitDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = exitDataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             RTLFn = exitDataOp.getNowait()
                         ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
@@ -5131,10 +5127,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = updateDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = updateDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = updateDataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             RTLFn =
                 updateDataOp.getNowait()
@@ -5287,13 +5281,13 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
     if (isa<omp::TargetDataOp>(op))
       return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
-                                          builder.getInt64(deviceID), ifCond,
-                                          info, genMapInfoCB, customMapperCB,
+                                          deviceID, ifCond, info, genMapInfoCB,
+                                          customMapperCB,
                                           /*MapperFunc=*/nullptr, bodyGenCB,
                                           /*DeviceAddrCB=*/nullptr);
-    return ompBuilder->createTargetData(
-        ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
-        info, genMapInfoCB, customMapperCB, &RTLFn);
+    return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
+                                        deviceID, ifCond, info, genMapInfoCB,
+                                        customMapperCB, &RTLFn);
   }();
 
   if (failed(handleError(afterIP, *op)))
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
index b4c9744cc0c87..64f488a7bb44a 100644
--- a/mlir/test/Target/LLVMIR/omptarget-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
 module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
-  llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+  llvm.func @_QPopenmp_target(%d16 : i16, %d32 : i32, %d64 : i64) {
     %x  = llvm.mlir.constant(0 : i32) : i32
 
     // Constant i16 -> i64 in the runtime call.
@@ -47,7 +47,7 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-
   }
 }
 
-// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK-LABEL: define void @_QPopenmp_target(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
 // CHECK: br label %entry
 // CHECK: entry:
 
@@ -65,4 +65,210 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-
 // CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
 
 // Variable i64
-// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @_QPopenmp_target_data(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %one = llvm.mlir.constant(1 : i64) : i64
+    %buf = llvm.alloca %one x i32 : (i64) -> !llvm.ptr
+    %map = omp.map.info var_ptr(%buf : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_data device(%c1_i16 : i16) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_data device(%c2_i32 : i32) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_data device(%c3_i64 : i64) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i16 -> cast to i64.
+    omp.target_data device(%d16 : i16) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i32 -> cast to i64.
+    omp.target_data device(%d32 : i32) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i64 stays i64.
+    omp.target_data device(%d64 : i64) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @_QPopenmp_target_data(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 1, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 1, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 2, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 2, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 3, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 3, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %{{.*}}, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %{{.*}}, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @_QPomp_target_enter_exit(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %c1 = llvm.mlir.constant(1 : i64) : i64
+    %var = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
+
+    %m_to = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = "var"}
+    %m_from = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = "var"}
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_enter_data device(%c1_i16 : i16) map_entries(%m_to : !llvm.ptr)
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_enter_data device(%c2_i32 : i32) map_entries(%m_to : !llvm.ptr)
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_enter_data device(%c3_i64 : i64) map_entries(%m_to : !llvm.ptr)
+
+    // ---- Variable cases (enter) ----
+    omp.target_enter_data device(%d16 : i16) map_entries(%m_to : !llvm.ptr)
+    omp.target_enter_data device(%d32 : i32) map_entries(%m_to : !llvm.ptr)
+    omp.target_enter_data device(%d64 : i64) map_entries(%m_to : !llvm.ptr)
+
+    // ---- Constant cases (exit) ----
+    omp.target_exit_data device(%c1_i16 : i16) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%c2_i32 : i32) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%c3_i64 : i64) map_entries(%m_from : !llvm.ptr)
+
+    // ---- Variable cases (exit) ----
+    omp.target_exit_data device(%d16 : i16) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%d32 : i32) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%d64 : i64) map_entries(%m_from : !llvm.ptr)
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @_QPomp_target_enter_exit(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant enter cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable enter cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64_BEGIN:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D16_I64_BEGIN]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64_BEGIN:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D32_I64_BEGIN]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Constant exit cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable exit cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64_END:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D16_I64_END]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64_END:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D32_I64_END]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// CHECK: ret void
+// CHECK: }
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @target_update_dev_clause(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %c1 = llvm.mlir.constant(1 : i64) : i64
+    %var = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
+    %m = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = "var"}
+
+    // ---- Constant cases ----
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_update device(%c1_i16 : i16) map_entries(%m : !llvm.ptr)
+
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_update device(%c2_i32 : i32) map_entries(%m : !llvm.ptr)
+
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_update device(%c3_i64 : i64) map_entries(%m : !llvm.ptr)
+
+    // ---- Variable cases ----
+    omp.target_update device(%d16 : i16) map_entries(%m : !llvm.ptr)
+    omp.target_update device(%d32 : i32) map_entries(%m : !llvm.ptr)
+    omp.target_update device(%d64 : i64) map_entries(%m : !llvm.ptr)
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @target_update_dev_clause(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// CHECK: ret void
+// CHECK: }
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index e289d5d013eaa..0b4d63125f82f 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -156,6 +156,7 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
 // CHECK:         %[[VAL_8:.*]] = load i32, ptr %[[VAL_7]], align 4
 // CHECK:         %[[VAL_9:.*]] = icmp slt i32 %[[VAL_8]], 10
 // CHECK:         %[[VAL_10:.*]] = load i32, ptr %[[VAL_6]], align 4
+// CHECK:         %[[DEV_I64_BEGIN:.*]] = sext i32 %[[VAL_10:.*]] to i64
 // CHECK:         br label %[[VAL_11:.*]]
 // CHECK:       entry:                                            ; preds = %[[VAL_12:.*]]
 // CHECK:         br i1 %[[VAL_9]], label %[[VAL_13:.*]], label %[[VAL_14:.*]]
@@ -176,7 +177,7 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
 // CHECK:  ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 6, 2026

@llvm/pr-subscribers-mlir

Author: Chi-Chun, Chen (chichunchen)

Changes

Allow the OpenMP device clause on target data/enter/exit/update to be lowered from variables instead of constatns only.


Patch is 22.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174665.diff

4 Files Affected:

  • (modified) flang/docs/OpenMPSupport.md (+2-2)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+21-27)
  • (modified) mlir/test/Target/LLVMIR/omptarget-device.mlir (+210-4)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+4-2)
diff --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index 21966c5489108..1b25fc4890847 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -37,9 +37,9 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
 | simd construct                                             | P      | Implicit linearization is skipped if iv is a pointer or allocatable|
 | declare simd construct                                     | N      | |
 | do simd construct                                          | P      | linear clause is not supported |
-| target data construct                                      | P      | device clause not supported |
+| target data construct                                      | Y      | |
 | target construct                                           | Y      | |
-| target update construct                                    | P      | device clause not supported |
+| target update construct                                    | Y      | |
 | declare target directive                                   | Y      | |
 | teams construct                                            | Y      | |
 | distribute construct                                       | Y      | |
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 614f06017a324..2698a4b4e89db 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -456,11 +456,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
             omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
       .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
           [&](auto op) { checkDepend(op, result); })
-      .Case<omp::TargetUpdateOp>([&](auto op) {
-        checkDepend(op, result);
-        checkDevice(op, result);
-      })
-      .Case<omp::TargetDataOp>([&](auto op) { checkDevice(op, result); })
+      .Case<omp::TargetUpdateOp>([&](auto op) { checkDepend(op, result); })
       .Case([&](omp::TargetOp op) {
         checkAllocate(op, result);
         checkBare(op, result);
@@ -5051,7 +5047,7 @@ static LogicalResult
 convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *ifCond = nullptr;
-  int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
+  llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
   SmallVector<Value> mapVars;
   SmallVector<Value> useDevicePtrVars;
   SmallVector<Value> useDeviceAddrVars;
@@ -5067,6 +5063,11 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
   bool isOffloadEntry =
       isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
 
+  auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
+    llvm::Value *v = moduleTranslation.lookupValue(dev);
+    return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
+  };
+
   LogicalResult result =
       llvm::TypeSwitch<Operation *, LogicalResult>(op)
           .Case([&](omp::TargetDataOp dataOp) {
@@ -5076,10 +5077,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = dataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = dataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = dataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             mapVars = dataOp.getMapVars();
             useDevicePtrVars = dataOp.getUseDevicePtrVars();
@@ -5093,10 +5092,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = enterDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = enterDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = enterDataOp.getDevice())
+              deviceID = getDeviceID(devId);
+
             RTLFn =
                 enterDataOp.getNowait()
                     ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
@@ -5112,10 +5110,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = exitDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = exitDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = exitDataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             RTLFn = exitDataOp.getNowait()
                         ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
@@ -5131,10 +5127,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = updateDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = updateDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = updateDataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             RTLFn =
                 updateDataOp.getNowait()
@@ -5287,13 +5281,13 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
     if (isa<omp::TargetDataOp>(op))
       return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
-                                          builder.getInt64(deviceID), ifCond,
-                                          info, genMapInfoCB, customMapperCB,
+                                          deviceID, ifCond, info, genMapInfoCB,
+                                          customMapperCB,
                                           /*MapperFunc=*/nullptr, bodyGenCB,
                                           /*DeviceAddrCB=*/nullptr);
-    return ompBuilder->createTargetData(
-        ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
-        info, genMapInfoCB, customMapperCB, &RTLFn);
+    return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
+                                        deviceID, ifCond, info, genMapInfoCB,
+                                        customMapperCB, &RTLFn);
   }();
 
   if (failed(handleError(afterIP, *op)))
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
index b4c9744cc0c87..64f488a7bb44a 100644
--- a/mlir/test/Target/LLVMIR/omptarget-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
 module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
-  llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+  llvm.func @_QPopenmp_target(%d16 : i16, %d32 : i32, %d64 : i64) {
     %x  = llvm.mlir.constant(0 : i32) : i32
 
     // Constant i16 -> i64 in the runtime call.
@@ -47,7 +47,7 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-
   }
 }
 
-// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK-LABEL: define void @_QPopenmp_target(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
 // CHECK: br label %entry
 // CHECK: entry:
 
@@ -65,4 +65,210 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-
 // CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
 
 // Variable i64
-// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @_QPopenmp_target_data(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %one = llvm.mlir.constant(1 : i64) : i64
+    %buf = llvm.alloca %one x i32 : (i64) -> !llvm.ptr
+    %map = omp.map.info var_ptr(%buf : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_data device(%c1_i16 : i16) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_data device(%c2_i32 : i32) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_data device(%c3_i64 : i64) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i16 -> cast to i64.
+    omp.target_data device(%d16 : i16) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i32 -> cast to i64.
+    omp.target_data device(%d32 : i32) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i64 stays i64.
+    omp.target_data device(%d64 : i64) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @_QPopenmp_target_data(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 1, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 1, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 2, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 2, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 3, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 3, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %{{.*}}, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %{{.*}}, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @_QPomp_target_enter_exit(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %c1 = llvm.mlir.constant(1 : i64) : i64
+    %var = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
+
+    %m_to = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = "var"}
+    %m_from = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = "var"}
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_enter_data device(%c1_i16 : i16) map_entries(%m_to : !llvm.ptr)
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_enter_data device(%c2_i32 : i32) map_entries(%m_to : !llvm.ptr)
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_enter_data device(%c3_i64 : i64) map_entries(%m_to : !llvm.ptr)
+
+    // ---- Variable cases (enter) ----
+    omp.target_enter_data device(%d16 : i16) map_entries(%m_to : !llvm.ptr)
+    omp.target_enter_data device(%d32 : i32) map_entries(%m_to : !llvm.ptr)
+    omp.target_enter_data device(%d64 : i64) map_entries(%m_to : !llvm.ptr)
+
+    // ---- Constant cases (exit) ----
+    omp.target_exit_data device(%c1_i16 : i16) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%c2_i32 : i32) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%c3_i64 : i64) map_entries(%m_from : !llvm.ptr)
+
+    // ---- Variable cases (exit) ----
+    omp.target_exit_data device(%d16 : i16) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%d32 : i32) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%d64 : i64) map_entries(%m_from : !llvm.ptr)
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @_QPomp_target_enter_exit(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant enter cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable enter cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64_BEGIN:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D16_I64_BEGIN]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64_BEGIN:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D32_I64_BEGIN]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Constant exit cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable exit cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64_END:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D16_I64_END]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64_END:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D32_I64_END]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// CHECK: ret void
+// CHECK: }
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @target_update_dev_clause(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %c1 = llvm.mlir.constant(1 : i64) : i64
+    %var = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
+    %m = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = "var"}
+
+    // ---- Constant cases ----
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_update device(%c1_i16 : i16) map_entries(%m : !llvm.ptr)
+
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_update device(%c2_i32 : i32) map_entries(%m : !llvm.ptr)
+
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_update device(%c3_i64 : i64) map_entries(%m : !llvm.ptr)
+
+    // ---- Variable cases ----
+    omp.target_update device(%d16 : i16) map_entries(%m : !llvm.ptr)
+    omp.target_update device(%d32 : i32) map_entries(%m : !llvm.ptr)
+    omp.target_update device(%d64 : i64) map_entries(%m : !llvm.ptr)
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @target_update_dev_clause(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// CHECK: ret void
+// CHECK: }
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index e289d5d013eaa..0b4d63125f82f 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -156,6 +156,7 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
 // CHECK:         %[[VAL_8:.*]] = load i32, ptr %[[VAL_7]], align 4
 // CHECK:         %[[VAL_9:.*]] = icmp slt i32 %[[VAL_8]], 10
 // CHECK:         %[[VAL_10:.*]] = load i32, ptr %[[VAL_6]], align 4
+// CHECK:         %[[DEV_I64_BEGIN:.*]] = sext i32 %[[VAL_10:.*]] to i64
 // CHECK:         br label %[[VAL_11:.*]]
 // CHECK:       entry:                                            ; preds = %[[VAL_12:.*]]
 // CHECK:         br i1 %[[VAL_9]], label %[[VAL_13:.*]], label %[[VAL_14:.*]]
@@ -176,7 +177,7 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
 // CHECK:  ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 6, 2026

@llvm/pr-subscribers-flang-openmp

Author: Chi-Chun, Chen (chichunchen)

Changes

Allow the OpenMP device clause on target data/enter/exit/update to be lowered from variables instead of constatns only.


Patch is 22.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174665.diff

4 Files Affected:

  • (modified) flang/docs/OpenMPSupport.md (+2-2)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+21-27)
  • (modified) mlir/test/Target/LLVMIR/omptarget-device.mlir (+210-4)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+4-2)
diff --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index 21966c5489108..1b25fc4890847 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -37,9 +37,9 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
 | simd construct                                             | P      | Implicit linearization is skipped if iv is a pointer or allocatable|
 | declare simd construct                                     | N      | |
 | do simd construct                                          | P      | linear clause is not supported |
-| target data construct                                      | P      | device clause not supported |
+| target data construct                                      | Y      | |
 | target construct                                           | Y      | |
-| target update construct                                    | P      | device clause not supported |
+| target update construct                                    | Y      | |
 | declare target directive                                   | Y      | |
 | teams construct                                            | Y      | |
 | distribute construct                                       | Y      | |
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 614f06017a324..2698a4b4e89db 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -456,11 +456,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
             omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
       .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp>(
           [&](auto op) { checkDepend(op, result); })
-      .Case<omp::TargetUpdateOp>([&](auto op) {
-        checkDepend(op, result);
-        checkDevice(op, result);
-      })
-      .Case<omp::TargetDataOp>([&](auto op) { checkDevice(op, result); })
+      .Case<omp::TargetUpdateOp>([&](auto op) { checkDepend(op, result); })
       .Case([&](omp::TargetOp op) {
         checkAllocate(op, result);
         checkBare(op, result);
@@ -5051,7 +5047,7 @@ static LogicalResult
 convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Value *ifCond = nullptr;
-  int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
+  llvm::Value *deviceID = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
   SmallVector<Value> mapVars;
   SmallVector<Value> useDevicePtrVars;
   SmallVector<Value> useDeviceAddrVars;
@@ -5067,6 +5063,11 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
   bool isOffloadEntry =
       isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
 
+  auto getDeviceID = [&](mlir::Value dev) -> llvm::Value * {
+    llvm::Value *v = moduleTranslation.lookupValue(dev);
+    return builder.CreateIntCast(v, builder.getInt64Ty(), /*isSigned=*/true);
+  };
+
   LogicalResult result =
       llvm::TypeSwitch<Operation *, LogicalResult>(op)
           .Case([&](omp::TargetDataOp dataOp) {
@@ -5076,10 +5077,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = dataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = dataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = dataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             mapVars = dataOp.getMapVars();
             useDevicePtrVars = dataOp.getUseDevicePtrVars();
@@ -5093,10 +5092,9 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = enterDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = enterDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = enterDataOp.getDevice())
+              deviceID = getDeviceID(devId);
+
             RTLFn =
                 enterDataOp.getNowait()
                     ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
@@ -5112,10 +5110,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = exitDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = exitDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = exitDataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             RTLFn = exitDataOp.getNowait()
                         ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
@@ -5131,10 +5127,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             if (auto ifVar = updateDataOp.getIfExpr())
               ifCond = moduleTranslation.lookupValue(ifVar);
 
-            if (auto devId = updateDataOp.getDevice())
-              if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
-                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
-                  deviceID = intAttr.getInt();
+            if (mlir::Value devId = updateDataOp.getDevice())
+              deviceID = getDeviceID(devId);
 
             RTLFn =
                 updateDataOp.getNowait()
@@ -5287,13 +5281,13 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
     if (isa<omp::TargetDataOp>(op))
       return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
-                                          builder.getInt64(deviceID), ifCond,
-                                          info, genMapInfoCB, customMapperCB,
+                                          deviceID, ifCond, info, genMapInfoCB,
+                                          customMapperCB,
                                           /*MapperFunc=*/nullptr, bodyGenCB,
                                           /*DeviceAddrCB=*/nullptr);
-    return ompBuilder->createTargetData(
-        ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
-        info, genMapInfoCB, customMapperCB, &RTLFn);
+    return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
+                                        deviceID, ifCond, info, genMapInfoCB,
+                                        customMapperCB, &RTLFn);
   }();
 
   if (failed(handleError(afterIP, *op)))
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
index b4c9744cc0c87..64f488a7bb44a 100644
--- a/mlir/test/Target/LLVMIR/omptarget-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
 module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
-  llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+  llvm.func @_QPopenmp_target(%d16 : i16, %d32 : i32, %d64 : i64) {
     %x  = llvm.mlir.constant(0 : i32) : i32
 
     // Constant i16 -> i64 in the runtime call.
@@ -47,7 +47,7 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-
   }
 }
 
-// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK-LABEL: define void @_QPopenmp_target(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
 // CHECK: br label %entry
 // CHECK: entry:
 
@@ -65,4 +65,210 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-
 // CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
 
 // Variable i64
-// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @_QPopenmp_target_data(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %one = llvm.mlir.constant(1 : i64) : i64
+    %buf = llvm.alloca %one x i32 : (i64) -> !llvm.ptr
+    %map = omp.map.info var_ptr(%buf : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_data device(%c1_i16 : i16) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_data device(%c2_i32 : i32) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_data device(%c3_i64 : i64) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i16 -> cast to i64.
+    omp.target_data device(%d16 : i16) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i32 -> cast to i64.
+    omp.target_data device(%d32 : i32) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    // Variable i64 stays i64.
+    omp.target_data device(%d64 : i64) map_entries(%map : !llvm.ptr) {
+      omp.terminator
+    }
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @_QPopenmp_target_data(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 1, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 1, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 2, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 2, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 3, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 3, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// Variable i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %{{.*}}, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %{{.*}}, i32 1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr null)
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @_QPomp_target_enter_exit(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %c1 = llvm.mlir.constant(1 : i64) : i64
+    %var = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
+
+    %m_to = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = "var"}
+    %m_from = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = "var"}
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_enter_data device(%c1_i16 : i16) map_entries(%m_to : !llvm.ptr)
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_enter_data device(%c2_i32 : i32) map_entries(%m_to : !llvm.ptr)
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_enter_data device(%c3_i64 : i64) map_entries(%m_to : !llvm.ptr)
+
+    // ---- Variable cases (enter) ----
+    omp.target_enter_data device(%d16 : i16) map_entries(%m_to : !llvm.ptr)
+    omp.target_enter_data device(%d32 : i32) map_entries(%m_to : !llvm.ptr)
+    omp.target_enter_data device(%d64 : i64) map_entries(%m_to : !llvm.ptr)
+
+    // ---- Constant cases (exit) ----
+    omp.target_exit_data device(%c1_i16 : i16) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%c2_i32 : i32) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%c3_i64 : i64) map_entries(%m_from : !llvm.ptr)
+
+    // ---- Variable cases (exit) ----
+    omp.target_exit_data device(%d16 : i16) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%d32 : i32) map_entries(%m_from : !llvm.ptr)
+    omp.target_exit_data device(%d64 : i64) map_entries(%m_from : !llvm.ptr)
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @_QPomp_target_enter_exit(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant enter cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable enter cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64_BEGIN:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D16_I64_BEGIN]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64_BEGIN:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %[[D32_I64_BEGIN]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Constant exit cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable exit cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64_END:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D16_I64_END]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64_END:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %[[D32_I64_END]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_end_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// CHECK: ret void
+// CHECK: }
+
+// -----
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @target_update_dev_clause(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %c1 = llvm.mlir.constant(1 : i64) : i64
+    %var = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
+    %m = omp.map.info var_ptr(%var : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = "var"}
+
+    // ---- Constant cases ----
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target_update device(%c1_i16 : i16) map_entries(%m : !llvm.ptr)
+
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target_update device(%c2_i32 : i32) map_entries(%m : !llvm.ptr)
+
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target_update device(%c3_i64 : i64) map_entries(%m : !llvm.ptr)
+
+    // ---- Variable cases ----
+    omp.target_update device(%d16 : i16) map_entries(%m : !llvm.ptr)
+    omp.target_update device(%d32 : i32) map_entries(%m : !llvm.ptr)
+    omp.target_update device(%d64 : i64) map_entries(%m : !llvm.ptr)
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @target_update_dev_clause(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 1, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 2, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 3, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// ---- Variable cases ----
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64 stays i64
+// CHECK: call void @__tgt_target_data_update_mapper(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// CHECK: ret void
+// CHECK: }
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index e289d5d013eaa..0b4d63125f82f 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -156,6 +156,7 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
 // CHECK:         %[[VAL_8:.*]] = load i32, ptr %[[VAL_7]], align 4
 // CHECK:         %[[VAL_9:.*]] = icmp slt i32 %[[VAL_8]], 10
 // CHECK:         %[[VAL_10:.*]] = load i32, ptr %[[VAL_6]], align 4
+// CHECK:         %[[DEV_I64_BEGIN:.*]] = sext i32 %[[VAL_10:.*]] to i64
 // CHECK:         br label %[[VAL_11:.*]]
 // CHECK:       entry:                                            ; preds = %[[VAL_12:.*]]
 // CHECK:         br i1 %[[VAL_9]], label %[[VAL_13:.*]], label %[[VAL_14:.*]]
@@ -176,7 +177,7 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
 // CHECK:  ...
[truncated]

@github-actions
Copy link

github-actions bot commented Jan 6, 2026

🐧 Linux x64 Test Results

  • 7548 tests passed
  • 598 tests skipped

✅ The build succeeded and all tests passed.

@chichunchen chichunchen force-pushed the cchen/flang/enter_exit_device branch from 7bb1214 to c844398 Compare January 7, 2026 02:40
Extend OpenMP device clause lowering for target data, target enter data,
target exit data, and target update to accept non-constant values.
Previously, only constant device IDs could be lowered to LLVM IR.

Add Flang tests to validate device clause handling and mark the feature
as supported in the OpenMPSupport documentation. New tests cover:
- target teams
- target teams distribute
- target teams distribute parallel do
- target teams distribute parallel do simd
- target data

Tests for target update and target enter/exit were
already present in Flang.
@chichunchen chichunchen force-pushed the cchen/flang/enter_exit_device branch from c844398 to e02a664 Compare January 7, 2026 02:44
Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@chichunchen chichunchen merged commit 5fb4383 into llvm:main Jan 7, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants