Skip to content

Commit

Permalink
Regenerate MLIR Bindings (EnzymeAD#410)
Browse files Browse the repository at this point in the history
Co-authored-by: mofeing <15837247+mofeing@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and mofeing authored Dec 21, 2024
1 parent 52668ee commit 0bfc722
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 21 deletions.
107 changes: 98 additions & 9 deletions src/mlir/Dialects/Affine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,38 @@ In the above example, `%indices:3` conceptually holds the following:
%indices_1 = affine.apply #map1()[%linear_index]
%indices_2 = affine.apply #map2()[%linear_index]
```
The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
If there are N basis elements, the first one will not be used during computations,
but may be used during analysis and canonicalization to eliminate terms from
the `affine.delinearize_index` or to enable conclusions about the total size of
`%linear_index`.
If the basis is fully provided, the delinearize_index operation is said to \"have
an outer bound\". The builders assume that an `affine.delinearize_index` has
an outer bound by default, as this is how the operation was initially defined.
That is, the example above could also have been written
```mlir
%0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
```
Note that, due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.
"""
function delinearize_index(
linear_index::Value,
basis::Vector{Value};
multi_index=nothing::Union{Nothing,Vector{IR.Type}},
dynamic_basis::Vector{Value};
multi_index::Vector{IR.Type},
static_basis,
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[linear_index, basis...]
op_ty_results = IR.Type[multi_index...,]
operands = Value[linear_index, dynamic_basis...]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(multi_index) && push!(op_ty_results, multi_index...)
attributes = NamedAttribute[namedattribute("static_basis", static_basis),]

return create_operation(
"affine.delinearize_index",
Expand All @@ -103,8 +122,8 @@ function delinearize_index(
owned_regions,
successors,
attributes,
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
result_inference=(length(op_ty_results) == 0 ? true : false),
results=op_ty_results,
result_inference=false,
)
end

Expand Down Expand Up @@ -327,6 +346,7 @@ func.func @pad_edges(%I : memref<10x10xf32>) -> (memref<12x12xf32) {
function if_(
operand_0::Vector{Value};
results::Vector{IR.Type},
condition,
thenRegion::Region,
elseRegion::Region,
location=Location(),
Expand All @@ -335,7 +355,7 @@ function if_(
operands = Value[operand_0...,]
owned_regions = Region[thenRegion, elseRegion]
successors = Block[]
attributes = NamedAttribute[]
attributes = NamedAttribute[namedattribute("condition", condition),]

return create_operation(
"affine.if",
Expand All @@ -349,6 +369,75 @@ function if_(
)
end

"""
`linearize_index`
The `affine.linearize_index` operation takes a sequence of index values and a
basis of the same length and linearizes the indices using that basis.
That is, for indices `%idx_0` to `%idx_{N-1}` and basis elements `b_0`
(or `b_1`) up to `b_{N-1}` it computes
```
sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
```
The basis may either have `N` or `N-1` elements, where `N` is the number of
inputs to linearize_index. If `N` inputs are provided, the first one is not used
in computation, but may be used during analysis or canonicalization as a bound
on `%idx_0`.
If all `N` basis elements are provided, the linearize_index operation is said to
\"have an outer bound\".
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
# Example
```mlir
%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (2, 3, 5) : index
// Same effect
%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (3, 5) : index
```
In the above example, `%linear_index` conceptually holds the following:
```mlir
#map = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
%linear_index = affine.apply #map()[%index_0, %index_1, %index_2]
```
"""
function linearize_index(
multi_index::Vector{Value},
dynamic_basis::Vector{Value};
linear_index=nothing::Union{Nothing,IR.Type},
static_basis,
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[multi_index..., dynamic_basis...]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("static_basis", static_basis),]
push!(attributes, operandsegmentsizes([length(multi_index), length(dynamic_basis)]))
!isnothing(linear_index) && push!(op_ty_results, linear_index)

return create_operation(
"affine.linearize_index",
location;
operands,
owned_regions,
successors,
attributes,
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
result_inference=(length(op_ty_results) == 0 ? true : false),
)
end

"""
`load`
Expand Down
84 changes: 84 additions & 0 deletions src/mlir/Dialects/CHLO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,60 @@ function polygamma(
)
end

"""
`ragged_dot`
This operation takes three tensor args---lhs, rhs, and group_sizes---and
a \"ragged_dot_dimension_numbers\" attribute. Like dot_general, the lhs and
rhs are allowed arbitrary batch and contracting dimensions. Additionally,
the lhs is required to have one ragged dimension, and the rhs may have at
most one group dimension. The op has three modes, depending on the kind of
the lhs ragged dimension.
In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [g] -> [b,m,n]`.
Here the ragged dimension is an lhs non-contracting dimension (`m`). The
dimensions `b` and `k` represent batch and contracting dimensions
respectively. The rhs is required to have a group dimension (`g`).
In mode 2, the shape-signature is `[b,m,k], [b,k,n], [g] -> [g,b,m,n]`.
Here the ragged dimension is an lhs/rhs contracting dimension (`k`).
In mode 3, the shape-signature is `[b,m,k], [b,k,n], [g] -> [b,m,n]`. Here
the ragged dimension is an lhs/rhs batch dimension (`b`).
"""
function ragged_dot(
lhs::Value,
rhs::Value,
group_sizes::Value;
result=nothing::Union{Nothing,IR.Type},
ragged_dot_dimension_numbers,
precision_config=nothing,
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[lhs, rhs, group_sizes]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute(
"ragged_dot_dimension_numbers", ragged_dot_dimension_numbers
),]
!isnothing(result) && push!(op_ty_results, result)
!isnothing(precision_config) &&
push!(attributes, namedattribute("precision_config", precision_config))

return create_operation(
"chlo.ragged_dot",
location;
operands,
owned_regions,
successors,
attributes,
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
result_inference=(length(op_ty_results) == 0 ? true : false),
)
end

"""
`sinh`
Expand Down Expand Up @@ -1427,6 +1481,36 @@ function sinh(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=L
)
end

"""
`square`
Returns `Square(operand)` element-wise.
\$\$
\\square(x) = complex((x.real - x.imag) * (x.real + x.imag), x.real * x.imag * 2) if x is a complex number
= x * x otherwise
\$\$
"""
function square(operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location())
op_ty_results = IR.Type[]
operands = Value[operand,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(result) && push!(op_ty_results, result)

return create_operation(
"chlo.square",
location;
operands,
owned_regions,
successors,
attributes,
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
result_inference=(length(op_ty_results) == 0 ? true : false),
)
end

"""
`tan`
Expand Down
4 changes: 4 additions & 0 deletions src/mlir/Dialects/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ function autodiff(
fn,
activity,
ret_activity,
width=nothing,
location=Location(),
)
op_ty_results = IR.Type[outputs...,]
Expand All @@ -54,6 +55,7 @@ function autodiff(
namedattribute("activity", activity),
namedattribute("ret_activity", ret_activity),
]
!isnothing(width) && push!(attributes, namedattribute("width", width))

return create_operation(
"enzyme.autodiff",
Expand Down Expand Up @@ -96,6 +98,7 @@ function fwddiff(
fn,
activity,
ret_activity,
width=nothing,
location=Location(),
)
op_ty_results = IR.Type[outputs...,]
Expand All @@ -107,6 +110,7 @@ function fwddiff(
namedattribute("activity", activity),
namedattribute("ret_activity", ret_activity),
]
!isnothing(width) && push!(attributes, namedattribute("width", width))

return create_operation(
"enzyme.fwddiff",
Expand Down
59 changes: 59 additions & 0 deletions src/mlir/Dialects/EnzymeXLA.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module enzymexla
using ...IR
import ...IR:
NamedAttribute,
Value,
Location,
Block,
Region,
Attribute,
create_operation,
context,
IndexType
import ..Dialects: namedattribute, operandsegmentsizes
import ...API

function kernel_call(
gridx::Value,
gridy::Value,
gridz::Value,
blockx::Value,
blocky::Value,
blockz::Value,
shmem::Value,
inputs::Vector{Value};
result_0::Vector{IR.Type},
fn,
backend_config=nothing,
operand_layouts=nothing,
result_layouts=nothing,
output_operand_aliases=nothing,
location=Location(),
)
op_ty_results = IR.Type[result_0...,]
operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, shmem, inputs...]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("fn", fn),]
!isnothing(backend_config) &&
push!(attributes, namedattribute("backend_config", backend_config))
!isnothing(operand_layouts) &&
push!(attributes, namedattribute("operand_layouts", operand_layouts))
!isnothing(result_layouts) &&
push!(attributes, namedattribute("result_layouts", result_layouts))
!isnothing(output_operand_aliases) &&
push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases))

return create_operation(
"enzymexla.kernel_call",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

end # enzymexla
9 changes: 8 additions & 1 deletion src/mlir/Dialects/Func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,18 @@ symbol reference attribute named \"callee\".
```
"""
function call(
operands::Vector{Value}; result_0::Vector{IR.Type}, callee, location=Location()
operands::Vector{Value};
result_0::Vector{IR.Type},
callee,
no_inline=nothing,
location=Location(),
)
op_ty_results = IR.Type[result_0...,]
operands = Value[operands...,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("callee", callee),]
!isnothing(no_inline) && push!(attributes, namedattribute("no_inline", no_inline))

return create_operation(
"func.call",
Expand Down Expand Up @@ -174,6 +179,7 @@ function func_(;
sym_visibility=nothing,
arg_attrs=nothing,
res_attrs=nothing,
no_inline=nothing,
body::Region,
location=Location(),
)
Expand All @@ -188,6 +194,7 @@ function func_(;
push!(attributes, namedattribute("sym_visibility", sym_visibility))
!isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs))
!isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs))
!isnothing(no_inline) && push!(attributes, namedattribute("no_inline", no_inline))

return create_operation(
"func.func",
Expand Down
Loading

0 comments on commit 0bfc722

Please sign in to comment.