diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td index 2d9befe78001d..2016bea43fc8a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td @@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> { "and replacing with supported ones"; } +def SPIRVReplicatedConstantCompositePass + : Pass<"spirv-promote-to-replicated-constants", "spirv::ModuleOp"> { + let summary = "Convert splat composite constants and spec constants to " + "corresponding replicated constant composite ops defined by " + "SPV_EXT_replicated_composites"; +} + #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt index 68e0206e30a59..b947447dad46a 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ set(LLVM_OPTIONAL_SOURCES CanonicalizeGLPass.cpp + ConvertToReplicatedConstantCompositePass.cpp DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp @@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion add_mlir_dialect_library(MLIRSPIRVTransforms CanonicalizeGLPass.cpp + ConvertToReplicatedConstantCompositePass.cpp DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp new file mode 100644 index 0000000000000..dbbe23aa08b3c --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp @@ -0,0 +1,129 @@ +//===- ConvertToReplicatedConstantCompositePass.cpp -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert a splat composite spirv.Constant and +// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and +// spirv.EXT.SpecConstantCompositeReplicate respectively. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +namespace mlir::spirv { +#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS +#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" + +namespace { + +static Type getArrayElemType(Attribute attr) { + if (auto typedAttr = dyn_cast(attr)) { + return typedAttr.getType(); + } + + if (auto arrayAttr = dyn_cast(attr)) { + return ArrayType::get(getArrayElemType(arrayAttr[0]), arrayAttr.size()); + } + + return nullptr; +} + +static std::pair +getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) { + auto compositeType = dyn_cast_or_null(valueType); + if (!compositeType) + return {nullptr, 1}; + + if (auto splatAttr = dyn_cast(valueAttr)) { + return {splatAttr.getSplatValue(), splatAttr.size()}; + } + + if (auto arrayAttr = dyn_cast(valueAttr)) { + if (llvm::all_equal(arrayAttr)) { + Attribute attr = arrayAttr[0]; + uint32_t numElements = arrayAttr.size(); + + // Find the inner-most splat value for array of composites + auto [newAttr, newNumElements] = + getSplatAttrAndNumElements(attr, getArrayElemType(attr)); + if (newAttr) { + attr = newAttr; + numElements *= newNumElements; + } + return {attr, numElements}; + } + } + + return {nullptr, 1}; +} + +struct ConstantOpConversion final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spirv::ConstantOp op, + PatternRewriter &rewriter) const override { + auto [attr, numElements] = + getSplatAttrAndNumElements(op.getValue(), op.getType()); + if (!attr) + return rewriter.notifyMatchFailure(op, "composite is not splat"); + + if (numElements == 1) + return rewriter.notifyMatchFailure(op, + "composite has only one constituent"); + + rewriter.replaceOpWithNewOp( + op, op.getType(), attr); + return success(); + } +}; + +struct SpecConstantCompositeOpConversion final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op, + PatternRewriter &rewriter) const override { + auto compositeType = dyn_cast_or_null(op.getType()); + if (!compositeType) + return rewriter.notifyMatchFailure(op, "not a composite constant"); + + ArrayAttr constituents = op.getConstituents(); + if (constituents.size() == 1) + return rewriter.notifyMatchFailure(op, + "composite has only one consituent"); + + if (!llvm::all_equal(constituents)) + return rewriter.notifyMatchFailure(op, "composite is not splat"); + + auto splatConstituent = dyn_cast(constituents[0]); + if (!splatConstituent) + return rewriter.notifyMatchFailure( + op, "expected flat symbol reference for splat constituent"); + + rewriter.replaceOpWithNewOp( + op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent); + + return success(); + } +}; + +struct ConvertToReplicatedConstantCompositePass final + : spirv::impl::SPIRVReplicatedConstantCompositePassBase< + ConvertToReplicatedConstantCompositePass> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add( + context); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + } +}; + +} // namespace +} // namespace mlir::spirv diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir new file mode 100644 index 0000000000000..56e26eee83ff9 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir @@ -0,0 +1,283 @@ +// RUN: mlir-opt --spirv-promote-to-replicated-constants --split-input-file %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32> + %0 = spirv.Constant dense<2> : vector<3xi32> + spirv.ReturnValue %0 : vector<3xi32> + } + + spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32> + %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32> + spirv.ReturnValue %0 : !spirv.array<3 x i32> + } + + spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> + %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> + } + + spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> + %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> + } + + spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> + %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> + } + + spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>> + %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> + } + + spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> + %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> + } + + spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + + spirv.func @array_of_splat_array_of_non_splat_vectors_of_i32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xi32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>> + %0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>> + } + + spirv.func @array_of_one_splat_array_of_vector_of_one_i32() -> !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<1 x !spirv.array<2 x vector<1xi32> + %cst = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> + spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> + } + + spirv.func @splat_array_of_array_of_one_vector_of_one_i32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xi32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>> + %0 = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>> + } + + spirv.func @array_of_one_array_of_one_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + %0 = spirv.Constant [[dense<1> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + } + + spirv.func @splat_array_of_splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + %0 = spirv.Constant [[[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]], [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + } + + spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32> + %0 = spirv.Constant dense<2.0> : vector<3xf32> + spirv.ReturnValue %0 : vector<3xf32> + } + + spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32> + %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32> + spirv.ReturnValue %0 : !spirv.array<3 x f32> + } + + spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> + %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> + } + + spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> + %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> + } + + spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> + %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> + } + + spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>> + %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> + } + + spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> + %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> + } + + spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } + + spirv.func @array_of_splat_array_of_non_splat_vectors_of_f32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xf32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>> + %0 = spirv.Constant [[dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>> + } + + spirv.func @array_of_one_splat_array_of_vector_of_one_f32() -> !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<1 x !spirv.array<2 x vector<1xf32> + %cst = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> + spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> + } + + spirv.func @splat_array_of_array_of_one_vector_of_one_f32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xf32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>> + %0 = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>> + } + + spirv.func @array_of_one_array_of_one_splat_vector_of_f32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xf32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>> + %0 = spirv.Constant [[dense<1.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>> + } + + spirv.func @splat_array_of_splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + %0 = spirv.Constant [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]], [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + } + + spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32> + spirv.ReturnValue %0 : !spirv.array<1 x i32> + } + + spirv.func @arm_tensor_of_one_i32() -> (!spirv.arm.tensor<1xi32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<1xi32> + } + + spirv.func @non_splat_vector_of_i32() -> (vector<3xi32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant dense<[0, 1, 2]> : vector<3xi32> + spirv.ReturnValue %0 : vector<3xi32> + } + + spirv.func @non_splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 3]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> + } + + spirv.func @array_of_one_f32() -> (!spirv.array<1 x f32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [1.0 : f32] : !spirv.array<1 x f32> + spirv.ReturnValue %0 : !spirv.array<1 x f32> + } + + spirv.func @arm_tensor_of_one_f32() -> (!spirv.arm.tensor<1xf32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<1xf32> + } + + spirv.func @non_splat_vector_of_f32() -> (vector<3xf32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant dense<[0.0, 1.0, 2.0]> : vector<3xf32> + spirv.ReturnValue %0 : vector<3xf32> + } + + spirv.func @non_splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> + } + + spirv.func @array_of_one_array_of_one_non_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" { + // CHECK-NOT spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + } + + spirv.func @array_of_one_array_of_one_vector_of_one_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<1xi32>>>) "None" { + // CHECK-NOT spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [[dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>> + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + + spirv.SpecConstant @sc_i32_1 = 1 : i32 + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32> + spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)> + spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32> + spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32> + spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32> + + spirv.SpecConstant @sc_f32_1 = 1.0 : f32 + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32> + spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)> + spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32> + spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32> + spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32> + + spirv.SpecConstant @sc_i32_2 = 2 : i32 + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_array_of_one_i32 (@sc_i32_1) : !spirv.array<1 x i32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_arm_tensor_of_one_i32 (@sc_i32_1) : !spirv.arm.tensor<1xi32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_non_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_2) : vector<3 x i32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_i32 (@sc_i32_2, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32> + + spirv.SpecConstant @sc_f32_2 = 2.0 : f32 + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_array_of_one_f32 (@sc_f32_1) : !spirv.array<1 x f32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_arm_tensor_of_one_f32 (@sc_f32_1) : !spirv.arm.tensor<1xf32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_non_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_2) : vector<3 x f32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_f32 (@sc_f32_2, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_struct_of_i32_and_f32 (@sc_i32_1, @sc_i32_1, @sc_f32_1) : !spirv.struct<(i32, i32, f32)> +}