Skip to content

Commit

Permalink
Merge branch 'main' into emilien/effect-value
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal authored Aug 9, 2024
2 parents 9ed8888 + fab70a7 commit 3fca062
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ dev = [
"nbval<0.12",
"filecheck==1.0.0",
"lit<19.0.0",
"marimo==0.7.18",
"marimo==0.7.19",
"pre-commit==3.8.0",
"ruff==0.5.6",
"ruff==0.5.7",
"asv<0.7",
"isort==5.13.2",
"nbconvert>=7.7.2,<8.0.0",
Expand Down
31 changes: 30 additions & 1 deletion tests/filecheck/dialects/linalg/linalg_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ linalg.matmul {id} ins(%2, %3 : memref<64x9216xf32>, memref<9216x4096xf32>) outs
%fill = linalg.fill ins(%0 : f32) outs(%t3 : tensor<4x16xf32>) -> tensor<4x16xf32>
linalg.fill ins(%0 : f32) outs(%m3 : memref<4x16xf32>)

%5, %6 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>)
%7 = arith.constant 0 : i32
%8 = arith.constant 0 : i32
%9 = "test.op"() : () -> (tensor<64x4096xi32>)

linalg.quantized_matmul ins(%5, %6, %7, %8 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%9 : tensor<64x4096xi32>) -> tensor<64x4096xi32>


// CHECK: module {
// CHECK-NEXT: %{{.*}} %{{.*}} = "test.op"() : () -> (f32, memref<1x256xf32>)
// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%{{.*}} : f32) outs(%{{.*}} : memref<1x256xf32>) {
Expand All @@ -60,7 +68,12 @@ linalg.fill ins(%0 : f32) outs(%m3 : memref<4x16xf32>)
// CHECK-NEXT: linalg.matmul {"id"} ins(%{{.*}} %{{.*}} : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%{{.*}} : memref<64x4096xf32>)
// CHECK-NEXT: %{{.*}} = linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<4x16xf32>)
// CHECK-NEXT: }
// CHECK-NEXT: %5, %6 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>)
// CHECK-NEXT: %7 = arith.constant 0 : i32
// CHECK-NEXT: %8 = arith.constant 0 : i32
// CHECK-NEXT: %9 = "test.op"() : () -> tensor<64x4096xi32>
// CHECK-NEXT: linalg.quantized_matmul ins(%5, %6, %7, %8 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%9 : tensor<64x4096xi32>) -> tensor<64x4096xi32>
// CHECK-NEXT: }

// CHECK-GENERIC: "linalg.generic"(%{{.*}} %{{.*}} <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 1, 1>}> ({
// CHECK-GENERIC-NEXT: ^0(%{{.*}} f32, %{{.*}} f32):
Expand Down Expand Up @@ -122,3 +135,19 @@ linalg.fill ins(%0 : f32) outs(%m3 : memref<4x16xf32>)
// CHECK-GENERIC-NEXT: ^9(%{{.*}} : f32, %{{.*}} : f32):
// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> ()
// CHECK-GENERIC-NEXT: }) : (f32, memref<4x16xf32>) -> ()

// CHECK-GENERIC-NEXT: %{{.*}}, %{{.*}} = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>)
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{"value" = 0 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{"value" = 0 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: %{{.*}} = "test.op"() : () -> tensor<64x4096xi32>

// CHECK-GENERIC-NEXT: %{{.*}} = "linalg.quantized_matmul"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{"operandSegmentSizes" = array<i32: 4, 1>}> ({
// CHECK-GENERIC-NEXT: ^10(%36 : i8, %37 : i8, %38 : i32, %39 : i32, %40 : i32):
// CHECK-GENERIC-NEXT: %41 = "arith.extsi"(%36) : (i8) -> i32
// CHECK-GENERIC-NEXT: %42 = "arith.subi"(%41, %38) : (i32, i32) -> i32
// CHECK-GENERIC-NEXT: %43 = "arith.extsi"(%37) : (i8) -> i32
// CHECK-GENERIC-NEXT: %44 = "arith.subi"(%43, %39) : (i32, i32) -> i32
// CHECK-GENERIC-NEXT: %45 = "arith.muli"(%42, %44) : (i32, i32) -> i32
// CHECK-GENERIC-NEXT: %46 = "arith.addi"(%40, %45) : (i32, i32) -> i32
// CHECK-GENERIC-NEXT: "linalg.yield"(%46) : (i32) -> ()
// CHECK-GENERIC-NEXT: }) {"linalg.memoized_indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]} : (tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32, tensor<64x4096xi32>) -> tensor<64x4096xi32>
17 changes: 15 additions & 2 deletions tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>)

linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>)


%21, %22 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>)
%23 = arith.constant 0 : i32
%24 = arith.constant 0 : i32
%25 = "test.op"() : () -> (tensor<64x4096xi32>)

%quant_mat_mul = linalg.quantized_matmul ins(%21, %22, %23, %24 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%25 : tensor<64x4096xi32>) -> tensor<64x4096xi32>

// CHECK-NEXT: #map = affine_map<(d0, d1) -> ()>
// CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-NEXT: module {
Expand Down Expand Up @@ -91,13 +99,18 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou
// CHECK-NEXT: %13:2 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>)
// CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1]
// CHECK-NEXT: %{{.*}} = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) {
// CHECK-NEXT: ^bb0(%in: f32, %in_0: f32, %out: f32):
// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_0 : f32
// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32):
// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_1 : f32
// CHECK-NEXT: linalg.yield %{{.*}} : f32
// CHECK-NEXT: } -> tensor<2x3xf32>
// CHECK-NEXT: %{{.*}} = linalg.sub ins(%{{.*}}, %{{.*}} : tensor<2x3xf32>, tensor<2x3xf32>) outs(%{{.*}} : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %16:2 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
// CHECK-NEXT: %17 = "test.op"() : () -> memref<64x4096xf32>
// CHECK-NEXT: linalg.matmul {id} ins(%16#0, %16#1 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%17 : memref<64x4096xf32>)
// CHECK-NEXT: %18:2 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>)
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32
// CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32>
// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_0 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32>
// CHECK-NEXT: }

63 changes: 63 additions & 0 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,68 @@ def hidden_region(args: tuple[BlockArgument, ...]) -> None:
)


@irdl_op_definition
class QuantizedMatmulOp(NamedOpBase):
"""
Performs a matrix multiplication of two 2D inputs.
See https://mlir.llvm.org/docs/Dialects/Linalg/#linalgquantized_matmul-linalgquantizedmatmulop
"""

name = "linalg.quantized_matmul"

PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True

def __init__(
self,
inputs: Sequence[SSAValue],
outputs: Sequence[SSAValue] = (),
res: Sequence[Attribute] | None = None,
attributes: dict[str, Attribute] | None = None,
):
if res is None:
result_types = tuple(
cast(AnyTensorType, output_type)
for output in outputs
if isinstance(output_type := output.type, TensorType)
)
else:
result_types = res

arg_types = self.body_arg_types((*inputs, *outputs))

@Builder.implicit_region(arg_types)
def hidden_region(args: tuple[BlockArgument, ...]) -> None:
o1 = arith.ExtSIOp(args[0], IntegerType(32))
o2 = arith.Subi(o1, args[2])
o3 = arith.ExtSIOp(args[1], IntegerType(32))
o4 = arith.Subi(o3, args[3])
o5 = arith.Muli(o2, o4)
o6 = arith.Addi(args[4], o5)
YieldOp(o6)

# add linalg.memoized_indexing_maps attribute
if not attributes:
attributes = {}
if "linalg.memoized_indexing_maps" not in attributes:
attributes["linalg.memoized_indexing_maps"] = ArrayAttr(
[
AffineMapAttr(AffineMap.from_callable(lambda i, _, k: (i, k))),
AffineMapAttr(AffineMap.from_callable(lambda _, j, k: (k, j))),
AffineMapAttr(AffineMap.from_callable(lambda i, j, _: (i, j))),
]
)

super().__init__(
ins=inputs,
outs=outputs,
result_types=result_types,
attributes=attributes,
hidden_region=hidden_region,
)


class PoolingOpsBase(IRDLOperation, ABC):
"""Base class for linalg pooling operations."""

Expand Down Expand Up @@ -1111,6 +1173,7 @@ def parse(cls, parser: Parser) -> Self:
MulOp,
TransposeOp,
MatmulOp,
QuantizedMatmulOp,
PoolingNchwMaxOp,
Conv2DNchwFchwOp,
BroadcastOp,
Expand Down

0 comments on commit 3fca062

Please sign in to comment.