Skip to content

[mlir][llvm dialect] Verify element type of nested types #148975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 75 additions & 39 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Transforms/InliningUtils.h"

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
Expand Down Expand Up @@ -3187,6 +3188,18 @@ static int64_t getNumElements(Type t) {
return 1;
}

/// Determine the element type of `type`. Supported types are `VectorType`,
/// `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
static Type getElementType(Type type) {
while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
type = arrayType.getElementType();
if (auto vecType = dyn_cast<VectorType>(type))
return vecType.getElementType();
if (auto tenType = dyn_cast<TensorType>(type))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the TensorType can be dropped. It is not a compatible LLVM type (see isCompatibleOuterType) and since ConstantOp returns an LLVM type there should be no tensors here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the following fail to verify?

  %0 = llvm.mlir.constant(dense<[1.0, 1.0]> : tensor<2xf64>) : vector<2xf64>

It currently roundtrips without issue.

Maybe what is confusing here is that I'm using this method to test both the attribute's element type as well the constant op's result's element type?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have expected that we use:

%0 = llvm.mlir.constant(dense<[1.0, 1.0]> : vector<2xf64>) : vector<2xf64>

But if the tensor type shows up in the wild on the attribute then we can keep your version of the code.

return tenType.getElementType();
return type;
}

/// Check if the given type is a scalable vector type or a vector/array type
/// that contains a nested scalable vector type.
static bool hasScalableVectorType(Type t) {
Expand Down Expand Up @@ -3281,60 +3294,69 @@ LogicalResult LLVM::ConstantOp::verify() {
}
if (auto structType = dyn_cast<LLVMStructType>(getType())) {
auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
if (!arrayAttr) {
return emitOpError() << "expected array attribute for a struct constant";
}
if (!arrayAttr)
return emitOpError() << "expected array attribute for struct type";

ArrayRef<Type> elementTypes = structType.getBody();
if (arrayAttr.size() != elementTypes.size()) {
return emitOpError() << "expected array attribute of size "
<< elementTypes.size();
}
for (auto elementTy : elementTypes) {
if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) {
if (!type.isSignlessIntOrIndexOrFloat()) {
return emitOpError() << "expected struct element types to be floating "
"point type or integer type";
}
}

for (size_t i = 0; i < elementTypes.size(); ++i) {
Attribute element = arrayAttr[i];
if (!isa<IntegerAttr, FloatAttr>(element)) {
return emitOpError()
<< "expected struct element attribute types to be floating "
"point type or integer type";
if (!isa<FloatAttr, IntegerAttr>(attr)) {
return emitOpError() << "expected element of array attribute to be "
"floating point or integer";
}
auto elementType = cast<TypedAttr>(element).getType();
if (elementType != elementTypes[i]) {
if (cast<TypedAttr>(attr).getType() != type)
return emitOpError()
<< "struct element at index " << i << " is of wrong type";
}
}

return success();
}
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
return emitOpError() << "does not support target extension type.";
}

// Check that an attribute whose element type has floating point semantics
// `attributeFloatSemantics` is compatible with a type whose element type
// is `constantElementType`.
//
// Requirement is that either
// 1) They have identical floating point types.
// 2) `constantElementType` is an integer type of the same width as the float
// attribute. This is to support builtin MLIR float types without LLVM
// equivalents, see comments in getLLVMConstant for more details.
auto verifyFloatSemantics =
[this](const llvm::fltSemantics &attributeFloatSemantics,
Type constantElementType) -> LogicalResult {
if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
if (&floatType.getFloatSemantics() != &attributeFloatSemantics) {
return emitOpError()
<< "attribute and type have different float semantics";
}
return success();
}
unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics);
if (isa<IntegerType>(constantElementType)) {
if (!constantElementType.isInteger(floatWidth))
return emitOpError() << "expected integer type of width " << floatWidth;

return success();
}
return success();
};

// Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
if (isa<IntegerAttr>(getValue())) {
if (!llvm::isa<IntegerType>(getType()))
return emitOpError() << "expected integer type";
} else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
unsigned floatWidth = APFloat::getSizeInBits(sem);
if (auto floatTy = dyn_cast<FloatType>(getType())) {
if (floatTy.getWidth() != floatWidth) {
return emitOpError() << "expected float type of width " << floatWidth;
}
}
// See the comment for getLLVMConstant for more details about why 8-bit
// floats can be represented by integers.
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
return emitOpError() << "expected integer type of width " << floatWidth;
}
} else if (isa<ElementsAttr>(getValue())) {
return verifyFloatSemantics(floatAttr.getValue().getSemantics(), getType());
} else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
if (hasScalableVectorType(getType())) {
// The exact number of elements of a scalable vector is unknown, so we
// allow only splat attributes.
Expand All @@ -3346,18 +3368,32 @@ LogicalResult LLVM::ConstantOp::verify() {
}
if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
return emitOpError() << "expected vector or array type";

// The number of elements of the attribute and the type must match.
if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
int64_t attrNumElements = elementsAttr.getNumElements();
if (getNumElements(getType()) != attrNumElements)
return emitOpError()
<< "type and attribute have a different number of elements: "
<< getNumElements(getType()) << " vs. " << attrNumElements;
int64_t attrNumElements = elementsAttr.getNumElements();
if (getNumElements(getType()) != attrNumElements) {
return emitOpError()
<< "type and attribute have a different number of elements: "
<< getNumElements(getType()) << " vs. " << attrNumElements;
}

Type attrElmType = getElementType(elementsAttr.getType());
Type resultElmType = getElementType(getType());
if (auto floatType = dyn_cast<FloatType>(attrElmType))
return verifyFloatSemantics(floatType.getFloatSemantics(), resultElmType);

if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
return emitOpError(
"expected integer element type for integer elements attribute");
}
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {

// The case where the constant is LLVMStructType has already been handled.
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
if (!arrayType)
return emitOpError() << "expected array type";
return emitOpError()
<< "expected array or struct type for array attribute";

// When the attribute is an ArrayAttr, check that its nesting matches the
// corresponding ArrayType or VectorType nesting.
return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
Expand Down
111 changes: 108 additions & 3 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
// -----

llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute}}
// expected-error @+1 {{expected array attribute for struct type}}
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
Expand Down Expand Up @@ -439,6 +439,111 @@ llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
llvm.return %0 : vector<[4]xf64>
}


// -----

llvm.func @int_attr_requires_int_type() -> f32 {
// expected-error @below{{expected integer type}}
%0 = llvm.mlir.constant(1 : index) : f32
llvm.return %0 : f32
}

// -----

llvm.func @vector_int_attr_requires_int_type() -> vector<2xf32> {
// expected-error @below{{expected integer element type}}
%0 = llvm.mlir.constant(dense<[1, 2]> : vector<2xi32>) : vector<2xf32>
llvm.return %0 : vector<2xf32>
}

// -----

llvm.func @float_attr_and_type_required_same() -> f16 {
// expected-error @below{{attribute and type have different float semantics}}
%cst = llvm.mlir.constant(1.0 : bf16) : f16
llvm.return %cst : f16
}

// -----

llvm.func @vector_float_attr_and_type_required_same() -> vector<2xf16> {
// expected-error @below{{attribute and type have different float semantics}}
%cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xbf16>) : vector<2xf16>
llvm.return %cst : vector<2xf16>
}

// -----

llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
// expected-error @below{{expected integer type of width 16}}
%cst = llvm.mlir.constant(1.0 : f16) : i32
llvm.return %cst : i32
}

// -----

llvm.func @vector_incompatible_integer_type_for_float_attr() -> vector<2xi8> {
// expected-error @below{{expected integer type of width 16}}
%cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xf16>) : vector<2xi8>
llvm.return %cst : vector<2xi8>
}

// -----

llvm.func @vector_with_non_vector_type() -> f32 {
// expected-error @below{{expected vector or array type}}
%cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
llvm.return %cst : f32
}

// -----

llvm.func @array_attr_with_invalid_type() -> i32 {
// expected-error @below{{expected array or struct type for array attribute}}
%0 = llvm.mlir.constant([1 : i32]) : i32
llvm.return %0 : i32
}

// -----

llvm.func @elements_attribute_incompatible_nested_array_struct1_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @below{{expected integer element type for integer elements attribute}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
}

// -----

llvm.func @elements_attribute_incompatible_nested_array_struct3_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
// expected-error @below{{expected integer element type for integer elements attribute}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
}

// -----

llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
// expected-error @below{{expected struct element types to be floating point type or integer type}}
%0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
}

// -----

llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
// expected-error @below{{expected element of array attribute to be floating point or integer}}
%0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
// expected-error @below{{struct element at index 0 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// -----

nit: one split can be dropped


func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
Expand Down Expand Up @@ -484,13 +589,13 @@ func.func @extractvalue_invalid_type(%a : !llvm.array<4 x vector<8xf32>>) -> !ll
return %b : !llvm.array<4 x vector<8xf32>>
}


// -----

func.func @extractvalue_non_llvm_type(%a : i32, %b : tensor<*xi32>) {
// expected-error@+2 {{expected LLVM IR Dialect type}}
llvm.extractvalue %b[0] : tensor<*xi32>
}

// -----

func.func @extractvalue_struct_out_of_bounds() {
Expand Down Expand Up @@ -659,6 +764,7 @@ func.func @atomicrmw_scalable_vector(%ptr : !llvm.ptr, %f32_vec : vector<[2]xf32
%0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<[2]xf32>
llvm.return
}

// -----

func.func @atomicrmw_vector_expected_float(%ptr : !llvm.ptr, %i32_vec : vector<3xi32>) {
Expand Down Expand Up @@ -1667,7 +1773,6 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
return
}


// -----

func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
Expand Down
72 changes: 0 additions & 72 deletions mlir/test/Target/LLVMIR/llvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,78 +7,6 @@ func.func @foo() {

// -----

llvm.func @vector_with_non_vector_type() -> f32 {
// expected-error @below{{expected vector or array type}}
%cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
llvm.return %cst : f32
}

// -----

llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
}

// -----

llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
// expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
}

// -----

llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
// expected-error @below{{expected struct element types to be floating point type or integer type}}
%0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
}

// -----

llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
// expected-error @below{{expected struct element attribute types to be floating point type or integer type}}
%0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
// expected-error @below{{struct element at index 0 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}

// -----

llvm.func @integer_with_float_type() -> f32 {
// expected-error @+1 {{expected integer type}}
%0 = llvm.mlir.constant(1 : index) : f32
llvm.return %0 : f32
}

// -----

llvm.func @incompatible_float_attribute_type() -> f32 {
// expected-error @below{{expected float type of width 64}}
%cst = llvm.mlir.constant(1.0 : f64) : f32
llvm.return %cst : f32
}

// -----

llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
// expected-error @below{{expected integer type of width 16}}
%cst = llvm.mlir.constant(1.0 : f16) : i32
llvm.return %cst : i32
}

// -----

// expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}

Expand Down
Loading