diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp
index fc1ce453e7..5c0f0a85b3 100644
--- a/deps/ReactantExtra/API.cpp
+++ b/deps/ReactantExtra/API.cpp
@@ -2404,3 +2404,7 @@ extern "C" void dump_operation(Operation *op, const char *filename) {
 extern "C" bool pjrt_device_is_addressable(PjRtDevice *device) {
   return device->IsAddressable();
 }
+
+extern "C" mlir::Operation *mlirGetParentOfTypeFunctionOp(mlir::Operation *op) {
+  return op->getParentOfType<mlir::FunctionOpInterface>();
+}
diff --git a/src/Ops.jl b/src/Ops.jl
index 7c4ebbd2c6..0f13fa097a 100644
--- a/src/Ops.jl
+++ b/src/Ops.jl
@@ -126,6 +126,11 @@ end
             result_inference=false,
         )
 
+        parent_func_op = MLIR.IR.get_parent_of_type_function_op(cstop)
+        if parent_func_op == C_NULL
+            error("Constant must be created inside a Function Op.")
+        end
+
         res = MLIR.IR.result(cstop)
         tres = TracedRArray{T,N}((), res, size(x))
         constants[value] = tres
@@ -201,6 +206,12 @@ for (T, mlir_func) in (
 
             splatattr = MLIR.API.$mlir_func(tt, number)
             cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
+
+            parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op)
+            if parent_func_op == C_NULL
+                error("Constant must be created inside a Function Op.")
+            end
+
             cst = MLIR.IR.result(cst_op)
             ta = TracedRArray{$T,length(shape)}((), cst, shape)
             return ta
@@ -221,6 +232,12 @@ end
     tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
     splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element))
     cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
+
+    parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op)
+    if parent_func_op == C_NULL
+        error("Constant must be created inside a Function Op.")
+    end
+
     cst = MLIR.IR.result(cst_op)
     ta = TracedRArray{T,length(shape)}((), cst, shape)
     return ta
diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl
index 32f42b6838..c1768b44a7 100644
--- a/src/mlir/IR/Operation.jl
+++ b/src/mlir/IR/Operation.jl
@@ -68,6 +68,12 @@ Gets the operation that owns this operation, returning null if the operation is
 parent_op(operation::Operation) =
     Operation(API.mlirOperationGetParentOperation(operation), false)
 
+"""
+    parent_region(op)
+Gets the region that owns this operation.
+"""
+parent_region(operation::Operation) = parent_region(block(operation))
+
 """
     rmfromparent!(op)
 
@@ -331,8 +337,21 @@ function create_operation_common(
     end
 end
 
+function create_operation_common_with_checks(args...; operands=nothing, kwargs...)
+    op = create_operation_common(args...; operands, kwargs...)
+    if !isnothing(operands)
+        parent_function_op = get_parent_of_type_function_op(op)
+        if parent_function_op != C_NULL
+            function_op_region = parent_region(parent_function_op)
+            operand_region = parent_region.(operands)
+            # TODO: add the checks
+        end
+    end
+    return op
+end
+
 function create_operation(args...; kwargs...)
-    res = create_operation_common(args...; kwargs...)
+    res = create_operation_common_with_checks(args...; kwargs...)
     if _has_block()
         push!(block(), res)
     end
@@ -340,7 +359,17 @@ function create_operation(args...; kwargs...)
 end
 
 function create_operation_at_front(args...; kwargs...)
-    res = create_operation_common(args...; kwargs...)
+    res = create_operation_common_with_checks(args...; kwargs...)
     Base.pushfirst!(block(), res)
     return res
 end
+
+function get_parent_of_type_function_op(op::Operation)
+    GC.@preserve op begin
+        funcop = @ccall API.mlir_c.mlirGetParentOfTypeFunctionOp(
+            op::API.MlirOperation
+        )::API.MlirOperation
+    end
+    funcop.ptr == C_NULL && return C_NULL
+    return Operation(funcop, false)
+end
diff --git a/src/mlir/IR/Value.jl b/src/mlir/IR/Value.jl
index a24632d934..38c877f763 100644
--- a/src/mlir/IR/Value.jl
+++ b/src/mlir/IR/Value.jl
@@ -121,3 +121,5 @@ function Base.show(io::IO, value::Value)
         API.mlirValuePrint(value, c_print_callback, ref)
     end
 end
+
+parent_region(value::Value) = parent_region(owner(value))