@@ -340,13 +340,29 @@ def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
340340// Any type from the given list
341341class AnyTypeOf<list<Type> allowedTypes, string summary = "",
342342 string cppClassName = "::mlir::Type"> : Type<
343- // Satisfy any of the allowed type's condition
343+ // Satisfy any of the allowed types' conditions.
344344 Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
345345 !if(!eq(summary, ""),
346346 !interleave(!foreach(t, allowedTypes, t.summary), " or "),
347347 summary),
348348 cppClassName>;
349349
350+ // A type that satisfies the constraints of all given types.
351+ class AllOfType<list<Type> allowedTypes, string summary = "",
352+ string cppClassName = "::mlir::Type"> : Type<
353+ // Satisfy all of the allowedf types' conditions.
354+ And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
355+ !if(!eq(summary, ""),
356+ !interleave(!foreach(t, allowedTypes, t.summary), " and "),
357+ summary),
358+ cppClassName>;
359+
360+ // A type that satisfies additional predicates.
361+ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
362+ string cppClassName = "::mlir::Type"> : Type<
363+ And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
364+ summary, cppClassName>;
365+
350366// Integer types.
351367
352368// Any integer type irrespective of its width and signedness semantics.
@@ -475,22 +491,21 @@ def F128 : F<128>;
475491def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
476492 BuildableType<"$_builder.getBF16Type()">;
477493
494+ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
495+ "complex-type", "::mlir::ComplexType">;
496+
478497class Complex<Type type>
479- : Type<And<[
480- CPred<"$_self.isa<::mlir::ComplexType>()">,
498+ : ConfinedType<AnyComplex, [
481499 SubstLeaves<"$_self",
482500 "$_self.cast<::mlir::ComplexType>().getElementType()",
483- type.predicate>]> ,
501+ type.predicate>],
484502 "complex type with " # type.summary # " elements",
485503 "::mlir::ComplexType">,
486504 SameBuildabilityAs<type, "::mlir::ComplexType::get($_builder.get" # type #
487505 "Type())"> {
488506 Type elementType = type;
489507}
490508
491- def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
492- "complex-type", "::mlir::ComplexType">;
493-
494509class OpaqueType<string dialect, string name, string summary>
495510 : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
496511 summary, "::mlir::OpaqueType">,
@@ -572,9 +587,8 @@ class VectorOfRank<list<int> allowedRanks> : Type<
572587// Any vector where the rank is from the given `allowedRanks` list and the type
573588// is from the given `allowedTypes` list
574589class VectorOfRankAndType<list<int> allowedRanks,
575- list<Type> allowedTypes> : Type<
576- And<[VectorOf<allowedTypes>.predicate,
577- VectorOfRank<allowedRanks>.predicate]>,
590+ list<Type> allowedTypes> : AllOfType<
591+ [VectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
578592 VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
579593 "::mlir::VectorType">;
580594
@@ -630,28 +644,25 @@ class ScalableVectorOfLength<list<int> allowedLengths> : Type<
630644// `allowedLengths` list and the type is from the given `allowedTypes`
631645// list
632646class VectorOfLengthAndType<list<int> allowedLengths,
633- list<Type> allowedTypes> : Type<
634- And<[VectorOf<allowedTypes>.predicate,
635- VectorOfLength<allowedLengths>.predicate]>,
647+ list<Type> allowedTypes> : AllOfType<
648+ [VectorOf<allowedTypes>, VectorOfLength<allowedLengths>],
636649 VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
637650 "::mlir::VectorType">;
638651
639652// Any fixed-length vector where the number of elements is from the given
640653// `allowedLengths` list and the type is from the given `allowedTypes` list
641654class FixedVectorOfLengthAndType<list<int> allowedLengths,
642- list<Type> allowedTypes> : Type<
643- And<[FixedVectorOf<allowedTypes>.predicate,
644- FixedVectorOfLength<allowedLengths>.predicate]>,
655+ list<Type> allowedTypes> : AllOfType<
656+ [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
645657 FixedVectorOf<allowedTypes>.summary #
646658 FixedVectorOfLength<allowedLengths>.summary,
647659 "::mlir::VectorType">;
648660
649661// Any scalable vector where the number of elements is from the given
650662// `allowedLengths` list and the type is from the given `allowedTypes` list
651663class ScalableVectorOfLengthAndType<list<int> allowedLengths,
652- list<Type> allowedTypes> : Type<
653- And<[ScalableVectorOf<allowedTypes>.predicate,
654- ScalableVectorOfLength<allowedLengths>.predicate]>,
664+ list<Type> allowedTypes> : AllOfType<
665+ [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
655666 ScalableVectorOf<allowedTypes>.summary #
656667 ScalableVectorOfLength<allowedLengths>.summary,
657668 "::mlir::VectorType">;
@@ -768,34 +779,33 @@ def F64MemRef : MemRefOf<[F64]>;
768779
769780// TODO: Have an easy way to add another constraint to a type.
770781class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
771- Type<And<[ MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]> ,
782+ ConfinedType< MemRefOf<allowedTypes>, [ HasAnyRankOfPred<ranks>],
772783 !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
773784 MemRefOf<allowedTypes>.summary,
774785 "::mlir::MemRefType">;
775786
776- class StaticShapeMemRefOf<list<Type> allowedTypes>
777- : Type<And<[ MemRefOf<allowedTypes>.predicate, HasStaticShapePred]> ,
778- "statically shaped " # MemRefOf<allowedTypes>.summary,
779- "::mlir::MemRefType">;
787+ class StaticShapeMemRefOf<list<Type> allowedTypes> :
788+ ConfinedType< MemRefOf<allowedTypes>, [ HasStaticShapePred],
789+ "statically shaped " # MemRefOf<allowedTypes>.summary,
790+ "::mlir::MemRefType">;
780791
781792def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
782793
783794// For a MemRefType, verify that it has strides.
784795def HasStridesPred : CPred<[{ isStrided($_self.cast<::mlir::MemRefType>()) }]>;
785796
786- class StridedMemRefOf<list<Type> allowedTypes>
787- : Type<And<[ MemRefOf<allowedTypes>.predicate, HasStridesPred]> ,
788- "strided " # MemRefOf<allowedTypes>.summary>;
797+ class StridedMemRefOf<list<Type> allowedTypes> :
798+ ConfinedType< MemRefOf<allowedTypes>, [ HasStridesPred],
799+ "strided " # MemRefOf<allowedTypes>.summary>;
789800
790801def AnyStridedMemRef : StridedMemRefOf<[AnyType]>;
791802
792803class AnyStridedMemRefOfRank<int rank> :
793- Type<And<[AnyStridedMemRef.predicate,
794- MemRefRankOf<[AnyType], [rank]>.predicate]>,
804+ AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>],
795805 AnyStridedMemRef.summary # " of rank " # rank>;
796806
797807class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
798- Type<And<[ MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]> ,
808+ ConfinedType< MemRefOf<allowedTypes>, [ HasAnyRankOfPred<ranks>],
799809 !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
800810 MemRefOf<allowedTypes>.summary>;
801811
0 commit comments