From 0bfc7226638e6fa2cd44d664facdc689e0a21d83 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 21 Dec 2024 14:20:59 -0500 Subject: [PATCH] Regenerate MLIR Bindings (#410) Co-authored-by: mofeing <15837247+mofeing@users.noreply.github.com> --- src/mlir/Dialects/Affine.jl | 107 ++++++++++++++++++++++++++++++--- src/mlir/Dialects/CHLO.jl | 84 ++++++++++++++++++++++++++ src/mlir/Dialects/Enzyme.jl | 4 ++ src/mlir/Dialects/EnzymeXLA.jl | 59 ++++++++++++++++++ src/mlir/Dialects/Func.jl | 9 ++- src/mlir/libMLIR_h.jl | 61 +++++++++++++++---- 6 files changed, 303 insertions(+), 21 deletions(-) create mode 100644 src/mlir/Dialects/EnzymeXLA.jl diff --git a/src/mlir/Dialects/Affine.jl b/src/mlir/Dialects/Affine.jl index 73332b437..c3774d385 100755 --- a/src/mlir/Dialects/Affine.jl +++ b/src/mlir/Dialects/Affine.jl @@ -82,19 +82,38 @@ In the above example, `%indices:3` conceptually holds the following: %indices_1 = affine.apply #map1()[%linear_index] %indices_2 = affine.apply #map2()[%linear_index] ``` + +The basis may either contain `N` or `N-1` elements, where `N` is the number of results. +If there are N basis elements, the first one will not be used during computations, +but may be used during analysis and canonicalization to eliminate terms from +the `affine.delinearize_index` or to enable conclusions about the total size of +`%linear_index`. + +If the basis is fully provided, the delinearize_index operation is said to \"have +an outer bound\". The builders assume that an `affine.delinearize_index` has +an outer bound by default, as this is how the operation was initially defined. + +That is, the example above could also have been written +```mlir +%0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index +``` + +Note that, due to the constraints of affine maps, all the basis elements must +be strictly positive. A dynamic basis element being 0 or negative causes +undefined behavior. """ function delinearize_index( linear_index::Value, - basis::Vector{Value}; - multi_index=nothing::Union{Nothing,Vector{IR.Type}}, + dynamic_basis::Vector{Value}; + multi_index::Vector{IR.Type}, + static_basis, location=Location(), ) - op_ty_results = IR.Type[] - operands = Value[linear_index, basis...] + op_ty_results = IR.Type[multi_index...,] + operands = Value[linear_index, dynamic_basis...] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[] - !isnothing(multi_index) && push!(op_ty_results, multi_index...) + attributes = NamedAttribute[namedattribute("static_basis", static_basis),] return create_operation( "affine.delinearize_index", @@ -103,8 +122,8 @@ function delinearize_index( owned_regions, successors, attributes, - results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + results=op_ty_results, + result_inference=false, ) end @@ -327,6 +346,7 @@ func.func @pad_edges(%I : memref<10x10xf32>) -> (memref<12x12xf32) { function if_( operand_0::Vector{Value}; results::Vector{IR.Type}, + condition, thenRegion::Region, elseRegion::Region, location=Location(), @@ -335,7 +355,7 @@ function if_( operands = Value[operand_0...,] owned_regions = Region[thenRegion, elseRegion] successors = Block[] - attributes = NamedAttribute[] + attributes = NamedAttribute[namedattribute("condition", condition),] return create_operation( "affine.if", @@ -349,6 +369,75 @@ function if_( ) end +""" +`linearize_index` + +The `affine.linearize_index` operation takes a sequence of index values and a +basis of the same length and linearizes the indices using that basis. + +That is, for indices `%idx_0` to `%idx_{N-1}` and basis elements `b_0` +(or `b_1`) up to `b_{N-1}` it computes + +``` +sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j +``` + +The basis may either have `N` or `N-1` elements, where `N` is the number of +inputs to linearize_index. If `N` inputs are provided, the first one is not used +in computation, but may be used during analysis or canonicalization as a bound +on `%idx_0`. + +If all `N` basis elements are provided, the linearize_index operation is said to +\"have an outer bound\". + +If the `disjoint` property is present, this is an optimization hint that, +for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index, +except that `%idx_0` may be negative to make the index as a whole negative. + +Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`. + +# Example + +```mlir +%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (2, 3, 5) : index +// Same effect +%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (3, 5) : index +``` + +In the above example, `%linear_index` conceptually holds the following: + +```mlir +#map = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)> +%linear_index = affine.apply #map()[%index_0, %index_1, %index_2] +``` +""" +function linearize_index( + multi_index::Vector{Value}, + dynamic_basis::Vector{Value}; + linear_index=nothing::Union{Nothing,IR.Type}, + static_basis, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[multi_index..., dynamic_basis...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("static_basis", static_basis),] + push!(attributes, operandsegmentsizes([length(multi_index), length(dynamic_basis)])) + !isnothing(linear_index) && push!(op_ty_results, linear_index) + + return create_operation( + "affine.linearize_index", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + """ `load` diff --git a/src/mlir/Dialects/CHLO.jl b/src/mlir/Dialects/CHLO.jl index 61faed2e5..7696a6556 100755 --- a/src/mlir/Dialects/CHLO.jl +++ b/src/mlir/Dialects/CHLO.jl @@ -1397,6 +1397,60 @@ function polygamma( ) end +""" +`ragged_dot` + + +This operation takes three tensor args---lhs, rhs, and group_sizes---and +a \"ragged_dot_dimension_numbers\" attribute. Like dot_general, the lhs and +rhs are allowed arbitrary batch and contracting dimensions. Additionally, +the lhs is required to have one ragged dimension, and the rhs may have at +most one group dimension. The op has three modes, depending on the kind of +the lhs ragged dimension. + +In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [g] -> [b,m,n]`. +Here the ragged dimension is an lhs non-contracting dimension (`m`). The +dimensions `b` and `k` represent batch and contracting dimensions +respectively. The rhs is required to have a group dimension (`g`). + +In mode 2, the shape-signature is `[b,m,k], [b,k,n], [g] -> [g,b,m,n]`. +Here the ragged dimension is an lhs/rhs contracting dimension (`k`). + +In mode 3, the shape-signature is `[b,m,k], [b,k,n], [g] -> [b,m,n]`. Here +the ragged dimension is an lhs/rhs batch dimension (`b`). +""" +function ragged_dot( + lhs::Value, + rhs::Value, + group_sizes::Value; + result=nothing::Union{Nothing,IR.Type}, + ragged_dot_dimension_numbers, + precision_config=nothing, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[lhs, rhs, group_sizes] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute( + "ragged_dot_dimension_numbers", ragged_dot_dimension_numbers + ),] + !isnothing(result) && push!(op_ty_results, result) + !isnothing(precision_config) && + push!(attributes, namedattribute("precision_config", precision_config)) + + return create_operation( + "chlo.ragged_dot", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + """ `sinh` @@ -1427,6 +1481,36 @@ function sinh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L ) end +""" +`square` + +Returns `Square(operand)` element-wise. + +\$\$ +\\square(x) = complex((x.real - x.imag) * (x.real + x.imag), x.real * x.imag * 2) if x is a complex number + = x * x otherwise +\$\$ +""" +function square(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) + op_ty_results = IR.Type[] + operands = Value[operand,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(result) && push!(op_ty_results, result) + + return create_operation( + "chlo.square", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + """ `tan` diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index f705808c6..9ebd8211b 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -43,6 +43,7 @@ function autodiff( fn, activity, ret_activity, + width=nothing, location=Location(), ) op_ty_results = IR.Type[outputs...,] @@ -54,6 +55,7 @@ function autodiff( namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] + !isnothing(width) && push!(attributes, namedattribute("width", width)) return create_operation( "enzyme.autodiff", @@ -96,6 +98,7 @@ function fwddiff( fn, activity, ret_activity, + width=nothing, location=Location(), ) op_ty_results = IR.Type[outputs...,] @@ -107,6 +110,7 @@ function fwddiff( namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] + !isnothing(width) && push!(attributes, namedattribute("width", width)) return create_operation( "enzyme.fwddiff", diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl new file mode 100644 index 000000000..0ee73ded4 --- /dev/null +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -0,0 +1,59 @@ +module enzymexla +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +function kernel_call( + gridx::Value, + gridy::Value, + gridz::Value, + blockx::Value, + blocky::Value, + blockz::Value, + shmem::Value, + inputs::Vector{Value}; + result_0::Vector{IR.Type}, + fn, + backend_config=nothing, + operand_layouts=nothing, + result_layouts=nothing, + output_operand_aliases=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, shmem, inputs...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(backend_config) && + push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && + push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && + push!(attributes, namedattribute("result_layouts", result_layouts)) + !isnothing(output_operand_aliases) && + push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + + return create_operation( + "enzymexla.kernel_call", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # enzymexla diff --git a/src/mlir/Dialects/Func.jl b/src/mlir/Dialects/Func.jl index 82f5aefb3..6eb30523c 100755 --- a/src/mlir/Dialects/Func.jl +++ b/src/mlir/Dialects/Func.jl @@ -69,13 +69,18 @@ symbol reference attribute named \"callee\". ``` """ function call( - operands::Vector{Value}; result_0::Vector{IR.Type}, callee, location=Location() + operands::Vector{Value}; + result_0::Vector{IR.Type}, + callee, + no_inline=nothing, + location=Location(), ) op_ty_results = IR.Type[result_0...,] operands = Value[operands...,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("callee", callee),] + !isnothing(no_inline) && push!(attributes, namedattribute("no_inline", no_inline)) return create_operation( "func.call", @@ -174,6 +179,7 @@ function func_(; sym_visibility=nothing, arg_attrs=nothing, res_attrs=nothing, + no_inline=nothing, body::Region, location=Location(), ) @@ -188,6 +194,7 @@ function func_(; push!(attributes, namedattribute("sym_visibility", sym_visibility)) !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(no_inline) && push!(attributes, namedattribute("no_inline", no_inline)) return create_operation( "func.func", diff --git a/src/mlir/libMLIR_h.jl b/src/mlir/libMLIR_h.jl index 59e6ed3f1..5f5c0feeb 100644 --- a/src/mlir/libMLIR_h.jl +++ b/src/mlir/libMLIR_h.jl @@ -1972,6 +1972,20 @@ function mlirValueReplaceAllUsesOfWith(of, with) @ccall mlir_c.mlirValueReplaceAllUsesOfWith(of::MlirValue, with::MlirValue)::Cvoid end +""" + mlirValueReplaceAllUsesExcept(of, with, numExceptions, exceptions) + +Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use 'with' instead, except if the user is listed in 'exceptions'. The 'exceptions' parameter is an array of [`MlirOperation`](@ref) pointers with a length of 'numExceptions'. +""" +function mlirValueReplaceAllUsesExcept(of, with, numExceptions, exceptions) + @ccall mlir_c.mlirValueReplaceAllUsesExcept( + of::MlirValue, + with::MlirValue, + numExceptions::intptr_t, + exceptions::Ptr{MlirOperation}, + )::Cvoid +end + """ mlirOpOperandIsNull(opOperand) @@ -5856,9 +5870,9 @@ function mlirPassManagerRunOnOp(passManager, op) end """ - mlirPassManagerEnableIRPrinting(passManager, printBeforeAll, printAfterAll, printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure) + mlirPassManagerEnableIRPrinting(passManager, printBeforeAll, printAfterAll, printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, flags, treePrintingPath) -Enable IR printing. +Enable IR printing. The treePrintingPath argument is an optional path to a directory where the dumps will be produced. If it isn't provided then dumps are produced to stderr. """ function mlirPassManagerEnableIRPrinting( passManager, @@ -5867,6 +5881,8 @@ function mlirPassManagerEnableIRPrinting( printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, + flags, + treePrintingPath, ) @ccall mlir_c.mlirPassManagerEnableIRPrinting( passManager::MlirPassManager, @@ -5875,6 +5891,8 @@ function mlirPassManagerEnableIRPrinting( printModuleScope::Bool, printAfterOnlyOnChange::Bool, printAfterOnlyOnFailure::Bool, + flags::MlirOpPrintingFlags, + treePrintingPath::MlirStringRef, )::Cvoid end @@ -6331,6 +6349,14 @@ function mlirRegisterConversionConvertMemRefToSPIRV() @ccall mlir_c.mlirRegisterConversionConvertMemRefToSPIRV()::Cvoid end +function mlirCreateConversionConvertMeshToMPIPass() + @ccall mlir_c.mlirCreateConversionConvertMeshToMPIPass()::MlirPass +end + +function mlirRegisterConversionConvertMeshToMPIPass() + @ccall mlir_c.mlirRegisterConversionConvertMeshToMPIPass()::Cvoid +end + function mlirCreateConversionConvertNVGPUToNVVMPass() @ccall mlir_c.mlirCreateConversionConvertNVGPUToNVVMPass()::MlirPass end @@ -6873,6 +6899,10 @@ function mlirGetDialectHandle__cf__() @ccall mlir_c.mlirGetDialectHandle__cf__()::MlirDialectHandle end +function mlirGetDialectHandle__emitc__() + @ccall mlir_c.mlirGetDialectHandle__emitc__()::MlirDialectHandle +end + function mlirGetDialectHandle__func__() @ccall mlir_c.mlirGetDialectHandle__func__()::MlirDialectHandle end @@ -7103,6 +7133,15 @@ function mlirLLVMArrayTypeGet(elementType, numElements) @ccall mlir_c.mlirLLVMArrayTypeGet(elementType::MlirType, numElements::Cuint)::MlirType end +""" + mlirLLVMArrayTypeGetElementType(type) + +Returns the element type of the llvm.array type. +""" +function mlirLLVMArrayTypeGetElementType(type) + @ccall mlir_c.mlirLLVMArrayTypeGetElementType(type::MlirType)::MlirType +end + """ mlirLLVMFunctionTypeGet(resultType, nArgumentTypes, argumentTypes, isVarArg) @@ -7331,17 +7370,17 @@ function mlirLLVMComdatAttrGet(ctx, comdat) end @cenum MlirLLVMLinkage::UInt32 begin - MlirLLVMLinkagePrivate = 0x0000000000000000 - MlirLLVMLinkageInternal = 0x0000000000000001 - MlirLLVMLinkageAvailableExternally = 0x0000000000000002 - MlirLLVMLinkageLinkonce = 0x0000000000000003 + MlirLLVMLinkageExternal = 0x0000000000000000 + MlirLLVMLinkageAvailableExternally = 0x0000000000000001 + MlirLLVMLinkageLinkonce = 0x0000000000000002 + MlirLLVMLinkageLinkonceODR = 0x0000000000000003 MlirLLVMLinkageWeak = 0x0000000000000004 - MlirLLVMLinkageCommon = 0x0000000000000005 + MlirLLVMLinkageWeakODR = 0x0000000000000005 MlirLLVMLinkageAppending = 0x0000000000000006 - MlirLLVMLinkageExternWeak = 0x0000000000000007 - MlirLLVMLinkageLinkonceODR = 0x0000000000000008 - MlirLLVMLinkageWeakODR = 0x0000000000000009 - MlirLLVMLinkageExternal = 0x000000000000000a + MlirLLVMLinkageInternal = 0x0000000000000007 + MlirLLVMLinkagePrivate = 0x0000000000000008 + MlirLLVMLinkageExternWeak = 0x0000000000000009 + MlirLLVMLinkageCommon = 0x000000000000000a end """