Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal Task Shader payload #4238

Merged
merged 16 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -2280,6 +2280,9 @@ extern "C"
// Metal [[attribute]] inputs.
SLANG_PARAMETER_CATEGORY_METAL_ATTRIBUTE,

// Metal [[payload]] inputs
SLANG_PARAMETER_CATEGORY_METAL_PAYLOAD,

//
SLANG_PARAMETER_CATEGORY_COUNT,

Expand Down Expand Up @@ -2854,6 +2857,7 @@ namespace slang
MetalTexture = SLANG_PARAMETER_CATEGORY_METAL_TEXTURE,
MetalArgumentBufferElement = SLANG_PARAMETER_CATEGORY_METAL_ARGUMENT_BUFFER_ELEMENT,
MetalAttribute = SLANG_PARAMETER_CATEGORY_METAL_ATTRIBUTE,
MetalPayload = SLANG_PARAMETER_CATEGORY_METAL_PAYLOAD,

// DEPRECATED:
VertexInput = SLANG_PARAMETER_CATEGORY_VERTEX_INPUT,
Expand Down
19 changes: 19 additions & 0 deletions source/slang/slang-emit-metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ void MetalSourceEmitter::emitFuncParamLayoutImpl(IRInst* param)
case LayoutResourceKind::VaryingInput:
m_writer->emit(" [[stage_in]]");
break;
case LayoutResourceKind::MetalPayload:
m_writer->emit(" [[payload]]");
break;
}
}
if (auto sysSemanticAttr = layout->findSystemValueSemanticAttr())
Expand All @@ -191,6 +194,12 @@ void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoi
case Stage::Compute:
m_writer->emit("[[kernel]] ");
break;
case Stage::Mesh:
m_writer->emit("[[mesh]] ");
break;
case Stage::Amplification:
m_writer->emit("[[object]] ");
break;
default:
SLANG_ABORT_COMPILATION("unsupported stage.");
}
Expand Down Expand Up @@ -618,6 +627,9 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
case AddressSpace::GroupShared:
m_writer->emit(" threadgroup");
break;
case AddressSpace::MetalObjectData:
m_writer->emit(" object_data");
break;
}
m_writer->emit("*");
return;
Expand All @@ -631,6 +643,11 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit(">");
return;
}
case kIROp_MetalMeshGridPropertiesType:
{
m_writer->emit("mesh_grid_properties ");
return;
}
default:
break;
}
Expand Down Expand Up @@ -939,6 +956,8 @@ void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRI
case AddressSpace::ThreadLocal:
m_writer->emit("thread ");
break;
case AddressSpace::MetalObjectData:
m_writer->emit("object_data ");
default:
break;
}
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ INST(Nop, nop, 0, 0)
INST(PrimitivesType, Primitives, 2, HOISTABLE)
INST_RANGE(MeshOutputType, VerticesType, PrimitivesType)

/* Metal Mesh Grid Properties */
INST(MetalMeshGridPropertiesType, mesh_grid_properties, 0, HOISTABLE)

/* HLSLStructuredBufferTypeBase */
INST(HLSLStructuredBufferType, StructuredBuffer, 0, HOISTABLE)
INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, HOISTABLE)
Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -3579,6 +3579,11 @@ struct IRBuilder
return getAttributedType(baseType, attributes.getCount(), attributes.getBuffer());
}

IRMetalMeshGridPropertiesType* getMetalMeshGridPropertiesType()
{
return (IRMetalMeshGridPropertiesType*)getType(kIROp_MetalMeshGridPropertiesType);
}

IRInst* emitDebugSource(UnownedStringSlice fileName, UnownedStringSlice source);
IRInst* emitDebugLine(IRInst* source, IRIntegerValue lineStart, IRIntegerValue lineEnd, IRIntegerValue colStart, IRIntegerValue colEnd);
IRInst* emitDebugVar(IRType* type, IRInst* source, IRInst* line, IRInst* col, IRInst* argIndex = nullptr);
Expand Down
77 changes: 77 additions & 0 deletions source/slang/slang-ir-metal-legalize.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "slang-ir-metal-legalize.h"

#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
#include "slang-ir-clone.h"
Expand Down Expand Up @@ -306,15 +307,89 @@ namespace Slang
fixUpFuncType(func, structType);
}

void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint)
{
// Find out DispatchMesh function
IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts())
{
if (const auto func = as<IRGlobalValueWithCode>(globalInst))
{
if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>())
{
if (dec->getName() == "DispatchMesh")
{
SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found");
dispatchMeshFunc = func;
}
}
}
}

if (!dispatchMeshFunc)
return;

IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
builder.setInsertBefore(dispatchMeshFunc);

// We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid
traverseUses(dispatchMeshFunc, [&](const IRUse* use) {
if (const auto call = as<IRCall>(use->getUser()))
{
SLANG_ASSERT(call->getArgCount() == 4);
const auto payload = call->getArg(3);

const auto payloadPtrType = composeGetters<IRPtrType>(
payload,
&IRInst::getDataType
);
SLANG_ASSERT(payloadPtrType);
const auto payloadType = payloadPtrType->getValueType();
SLANG_ASSERT(payloadType);

builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
const auto annotatedPayloadType =
builder.getPtrType(
kIROp_RefType,
payloadPtrType->getValueType(),
AddressSpace::MetalObjectData
);
auto packedParam = builder.emitParam(annotatedPayloadType);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are more than one DispatchMesh call, this will insert duplicate parameters into the entry point. We should probably make sure we insert at most once for the referencing entry point?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also there could be more than one entry points in a single module. I feel like instead of inserting into the entry point, we should try to find the parent function of the call site, assert that it is a object shader entry point, and insert a param into the EntryPoint if it is t already there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto parentFunc = getParentFunc(call) should give you the user func, you can then builder.setInsertInto(parentFunc->getFirstBlock()->getFirstOrdinaryInst()) for the param.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we still assume that the DispatchMesh call is done directly from inside the entry point, but we make sure that the payload parameter is only generated once per entrypoint?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's do that for now. we can add a diagnostic if oarent func isnt an entrypoint.

IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());

// Add the MetalPayload resource info, so we can emit [[payload]]
varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
auto paramVarLayout = varLayoutBuilder.build();
builder.addLayoutDecoration(packedParam, paramVarLayout);

// Now we replace the call to DispatchMesh with a call to the mesh grid properties
// But first we need to create the parameter
const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType();
auto mgp = builder.emitParam(meshGridPropertiesType);

// Now we store whatever got passed as the payload to the parameter
builder.setInsertBefore(call);
builder.emitStore(packedParam, builder.emitLoad(payload));

// lastly we call the set_threadgroups_per_grid
// However it also doesnt take 3 separate arguments for the group sizes, but one uint3
const auto groupSizeType = builder.getVectorType(builder.getUIntType(), 3);
const auto groupSize = builder.emitMakeVector(groupSizeType, Slang::List(call->getArg(0), call->getArg(1), call->getArg(2)));
builder.emitCallInst(mgp->getFullType(), builder.getFuncType(Slang::List<IRType*>(groupSizeType), builder.getVoidType()), 1, &groupSize);
}
});
}
void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);

hoistEntryPointParameterFromStruct(entryPoint);
packStageInParameters(entryPoint);
wrapReturnValueInStruct(entryPoint);
legalizeDispatchMeshPayloadForMetal(entryPoint);
}


void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
{
List<EntryPointInfo> entryPoints;
Expand All @@ -337,4 +412,6 @@ namespace Slang

specializeAddressSpace(module);
}

}

4 changes: 4 additions & 0 deletions source/slang/slang-ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ enum class AddressSpace
Global = 2,
GroupShared = 3,
Uniform = 4,
// specific address space for payload data in metal
MetalObjectData = 5,
};

typedef unsigned int IROpFlags;
Expand Down Expand Up @@ -1549,6 +1551,8 @@ SIMPLE_IR_TYPE(VerticesType, MeshOutputType)
SIMPLE_IR_TYPE(IndicesType, MeshOutputType)
SIMPLE_IR_TYPE(PrimitivesType, MeshOutputType)

SIMPLE_IR_TYPE(MetalMeshGridPropertiesType, Type)

SIMPLE_IR_TYPE(GLSLInputAttachmentType, Type)
SIMPLE_IR_PARENT_TYPE(ParameterGroupType, PointerLikeType)
SIMPLE_IR_PARENT_TYPE(UniformParameterGroupType, ParameterGroupType)
Expand Down
102 changes: 102 additions & 0 deletions tests/metal/simple-task.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//TEST:SIMPLE(filecheck=CHECK): -target metal
//TEST:SIMPLE(filecheck=CHECK-ASM): -target metallib
//TEST:REFLECTION(filecheck=REFLECT):-target metal -entry main_kernel -stage task

// TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer

uniform RWStructuredBuffer<float> outputBuffer;

cbuffer Uniforms
{
float4x4 modelViewProjection;
}

// CHECK-ASM: define void @taskMain
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no //CHECK: lines in the file, so test in line 1 will fail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what those check comments do, copied them from the rasterization test file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are regular expressions used to match the result of the generated source file or buffer content.

A //CHECK line will try to match a line in the output with the regex following CHECK:, if such a match if found, it is successful, otherwise it will fail the test.


//
// Task shader
//

struct MeshPayload
{
int exponent;
};

[numthreads(1,1,1)]
void taskMain()
{
MeshPayload p;
p.exponent = 3;
DispatchMesh(1, 1, 1, p);
}

//
// Mesh shader
//

const static float2 positions[3] = {
float2(0.0, -0.5),
float2(0.5, 0.5),
float2(-0.5, 0.5)
};

const static float3 colors[3] = {
float3(1.0, 1.0, 0.0),
float3(0.0, 1.0, 1.0),
float3(1.0, 0.0, 1.0)
};

struct Vertex
{
float4 pos : SV_Position;
float3 color : Color;
int index : Index;
int value : Value;
};

const static uint MAX_VERTS = 12;
const static uint MAX_PRIMS = 4;

[outputtopology("triangle")]
[numthreads(12, 1, 1)]
void meshMain(
in uint tig: SV_GroupIndex,
in payload MeshPayload meshPayload,
// Check that we correctly generate the specific 'in payload' that HLSL
// requires:
// HLSL: , in payload MeshPayload
OutputVertices<Vertex, MAX_VERTS> verts,
OutputIndices<uint3, MAX_PRIMS> triangles)
{
const uint numVertices = 12;
const uint numPrimitives = 4;
SetMeshOutputCounts(numVertices, numPrimitives);

if (tig < numVertices)
{
const int tri = tig / 3;
verts[tig] = { float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent)) };
}

if (tig < numPrimitives)
triangles[tig] = tig * 3 + uint3(0, 1, 2);
}

//
// Fragment Shader
//

struct Fragment
{
float4 color : SV_Target;
};

Fragment fragmentMain(Vertex input)
{
outputBuffer[input.index] = input.value;

Fragment output;
output.color = float4(input.color, 1.0);
return output;
}

Loading