From 16cce5957f2d1f1641d6734f36ca39d41698e5e6 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Fri, 10 Jan 2025 23:04:30 -0500 Subject: [PATCH 01/18] Refactor to reuse common for metal and wgsl entry point legalization --- .../slang-ir-legalize-varying-params.cpp | 2369 +++++++++++++++++ .../slang/slang-ir-legalize-varying-params.h | 22 +- source/slang/slang-ir-metal-legalize.cpp | 1945 +------------- source/slang/slang-ir-wgsl-legalize.cpp | 1637 +----------- 4 files changed, 2518 insertions(+), 3455 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 33f3944fd9..65d0ec28cc 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -1,11 +1,14 @@ // slang-ir-legalize-varying-params.cpp #include "slang-ir-legalize-varying-params.h" +#include "core/slang-common.h" #include "slang-ir-clone.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-parameter-binding.h" +#include + namespace Slang { // Convert semantic name (ignores case) into equivlent `SystemValueSemanticName` @@ -1544,4 +1547,2370 @@ void depointerizeInputParams(IRFunc* entryPointFunc) } } + +struct LegalizeShaderEntryPointContext +{ + enum class LegalizeTarget + { + Metal, + WGSL, + }; + + struct SystemValueInfo + { + String systemValueName; + SystemValueSemanticName systemValueNameEnum; + + ShortList permittedTypes; + bool isUnsupported = false; + + // Only used by Metal. + bool isSpecial = false; + + SystemValueInfo() + { + // most commonly need 2 + permittedTypes.reserveOverflowBuffer(2); + } + }; + + SystemValueInfo getSystemValueInfo( + String inSemanticName, + String* optionalSemanticIndex, + IRInst* parentVar) + { + if (isTargetMetal()) + { + return getMetalSystemValueInfo(inSemanticName, optionalSemanticIndex, parentVar); + } + else + { + SLANG_ASSERT(isTargetWGSL()); + return getWGSLSystemValueInfo(inSemanticName, optionalSemanticIndex, parentVar); + } + } + + IRModule* m_module; + DiagnosticSink* m_sink; + LegalizeTarget m_target; + HashSet semanticInfoToRemove; + + bool isTargetMetal() const { return m_target == LegalizeTarget::Metal; } + bool isTargetWGSL() const { return m_target == LegalizeTarget::WGSL; } + + LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink, LegalizeTarget target) + : m_module(module), m_sink(sink), m_target(target) + { + } + + void removeSemanticLayoutsFromLegalizedStructs() + { + // Metal and WGSL does not allow duplicate attributes to appear in the same shader. + // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]` + // must be removed. + for (auto field : semanticInfoToRemove) + { + auto key = field->getKey(); + // Some decorations appear twice, destroy all found + for (;;) + { + if (auto semanticDecor = key->findDecoration()) + { + semanticDecor->removeAndDeallocate(); + continue; + } + else if (auto layoutDecor = key->findDecoration()) + { + layoutDecor->removeAndDeallocate(); + continue; + } + break; + } + } + } + + void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint) + { + // If an entry point has a input parameter with a struct type, we want to hoist out + // all the fields of the struct type to be individual parameters of the entry point. + // This will canonicalize the entry point signature, so we can handle all cases uniformly. + + // For example, given an entry point: + // ``` + // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID}; + // void main(VertexInput vin) { ... } + // ``` + // We will transform it to: + // ``` + // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { + // VertexInput vin = {pos,uv,vertexId}; + // ... + // } + // ``` + + auto func = entryPoint.entryPointFunc; + List paramsToProcess; + for (auto param : func->getParams()) + { + if (as(param->getDataType())) + { + paramsToProcess.add(param); + } + } + + IRBuilder builder(func); + builder.setInsertBefore(func); + for (auto param : paramsToProcess) + { + auto structType = as(param->getDataType()); + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto varLayout = findVarLayout(param); + + // If `param` already has a semantic, we don't want to hoist its fields out. + if (varLayout->findSystemValueSemanticAttr() != nullptr || + param->findDecoration()) + continue; + + IRStructTypeLayout* structTypeLayout = nullptr; + if (varLayout) + structTypeLayout = as(varLayout->getTypeLayout()); + Index fieldIndex = 0; + List fieldParams; + for (auto field : structType->getFields()) + { + auto fieldParam = builder.emitParam(field->getFieldType()); + IRCloneEnv cloneEnv; + cloneInstDecorationsAndChildren( + &cloneEnv, + builder.getModule(), + field->getKey(), + fieldParam); + + IRVarLayout* fieldLayout = + structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr; + if (varLayout) + { + IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout()); + varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout); + for (auto offsetAttr : fieldLayout->getOffsetAttrs()) + { + auto parentOffsetAttr = + varLayout->findOffsetAttr(offsetAttr->getResourceKind()); + UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0; + UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0; + auto resInfo = + varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind()); + resInfo->offset = parentOffset + offsetAttr->getOffset(); + resInfo->space = parentSpace + offsetAttr->getSpace(); + } + builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build()); + } + param->insertBefore(fieldParam); + fieldParams.add(fieldParam); + fieldIndex++; + } + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto reconstructedParam = + builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer()); + param->replaceUsesWith(reconstructedParam); + param->removeFromParent(); + } + fixUpFuncType(func); + } + + // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct + void flattenInputParameters(EntryPointInfo entryPoint) + { + // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members). + /* + // Assume the following code + struct NestedFragment + { + float2 p3; + }; + struct Fragment + { + float4 p1; + float3 p2; + NestedFragment p3_nested; + }; + + // Fragment flattens into + struct Fragment + { + float4 p1; + float3 p2; + float2 p3; + }; + */ + + // This is important since Metal and WGSL does not allow semantic's on a struct + /* + // Assume the following code + struct NestedFragment1 + { + float2 p3; + }; + struct Fragment1 + { + float4 p1 : SV_TARGET0; + float3 p2 : SV_TARGET1; + NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct + }; + + */ + + // Metal does allow semantics on members of a nested struct but we are avoiding this + // approach since there are senarios where legalization (and verification) is + // hard/expensive without creating a flat struct: + // 1. Entry points may share structs, semantics may be inconsistent across entry points + // 2. Multiple of the same struct may be used in a param list + // + // WGSL does NOT allow semantics on members of a nested struct. + /* + // Assume the following code + struct NestedFragment + { + float2 p3; + }; + struct Fragment + { + float4 p1 : SV_TARGET0; + NestedFragment p2 : SV_TARGET1; + NestedFragment p3 : SV_TARGET2; + }; + + // Legalized without flattening -- abandoned + struct NestedFragment1 + { + float2 p3 : SV_TARGET1; + }; + struct NestedFragment2 + { + float2 p3 : SV_TARGET2; + }; + struct Fragment + { + float4 p1 : SV_TARGET0; + NestedFragment1 p2; + NestedFragment2 p3; + }; + + // Legalized with flattening -- current approach + struct Fragment + { + float4 p1 : SV_TARGET0; + float2 p2 : SV_TARGET1; + float2 p3 : SV_TARGET2; + }; + */ + + auto func = entryPoint.entryPointFunc; + bool modified = false; + for (auto param : func->getParams()) + { + auto layout = findVarLayout(param); + if (!layout) + continue; + if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) + continue; + if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + continue; + // If we find a IRParam with a IRStructType member, we need to flatten the entire + // IRParam + if (auto structType = as(param->getDataType())) + { + IRBuilder builder(func); + MapStructToFlatStruct mapOldFieldToNewField; + + // Flatten struct if we have nested IRStructType + auto flattenedStruct = maybeFlattenNestedStructs( + builder, + structType, + mapOldFieldToNewField, + semanticInfoToRemove); + + // XXX TODO: Clean this up maybe? + if (isTargetWGSL()) + { + // Validate/rearange all semantics which overlap in our flat struct. + fixFieldSemanticsOfFlatStruct(flattenedStruct); + ensureStructHasUserSemantic( + flattenedStruct, + layout); + } + if (flattenedStruct != structType) + { + if (isTargetMetal()) + { + // Validate/rearange all semantics which overlap in our flat struct + fixFieldSemanticsOfFlatStruct(flattenedStruct); + } + + // Replace the 'old IRParam type' with a 'new IRParam type' + param->setFullType(flattenedStruct); + + // Emit a new variable at EntryPoint of 'old IRParam type' + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto dstVal = builder.emitVar(structType); + auto dstLoad = builder.emitLoad(dstVal); + param->replaceUsesWith(dstLoad); + builder.setInsertBefore(dstLoad); + // Copy the 'new IRParam type' to our 'old IRParam type' + mapOldFieldToNewField + .emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>( + builder, + dstVal, + param); + + modified = true; + } + } + } + if (modified) + fixUpFuncType(func); + } + + void packStageInParameters(EntryPointInfo entryPoint) + { + // If the entry point has any parameters whose layout contains VaryingInput, + // we need to pack those parameters into a single `struct` type, and decorate + // the fields with the appropriate `[[attribute]]` decorations. + // For other parameters that are not `VaryingInput`, we need to leave them as is. + // + // For example, given this code after `hoistEntryPointParameterFromStruct`: + // ``` + // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { + // VertexInput vin = {pos,uv,vertexId}; + // ... + // } + // ``` + // We are going to transform it into: + // ``` + // struct VertexInput { + // float3 pos [[attribute(0)]]; + // float2 uv [[attribute(1)]]; + // }; + // void main(VertexInput vin, int vertexId : SV_VertexID) { + // let pos = vin.pos; + // let uv = vin.uv; + // ... + // } + + auto func = entryPoint.entryPointFunc; + + bool isGeometryStage = false; + switch (entryPoint.entryPointDecor->getProfile().getStage()) + { + case Stage::Vertex: + case Stage::Amplification: + case Stage::Mesh: + case Stage::Geometry: + case Stage::Domain: + case Stage::Hull: + isGeometryStage = true; + break; + } + + List paramsToPack; + for (auto param : func->getParams()) + { + auto layout = findVarLayout(param); + if (!layout) + continue; + if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) + continue; + if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + continue; + paramsToPack.add(param); + } + + if (paramsToPack.getCount() == 0) + return; + + IRBuilder builder(func); + builder.setInsertBefore(func); + IRStructType* structType = builder.createStructType(); + auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); + builder.addNameHintDecoration( + structType, + (String(stageText) + toSlice("Input")).getUnownedSlice()); + List keys; + IRStructTypeLayout::Builder layoutBuilder(&builder); + for (auto param : paramsToPack) + { + auto paramVarLayout = findVarLayout(param); + auto key = builder.createStructKey(); + param->transferDecorationsTo(key); + builder.createStructField(structType, key, param->getDataType()); + if (auto varyingInOffsetAttr = + paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput)) + { + if (!key->findDecoration() && + !paramVarLayout->findAttr()) + { + // If the parameter doesn't have a semantic, we need to add one for semantic + // matching. + builder.addSemanticDecoration( + key, + toSlice("_slang_attr"), + (int)varyingInOffsetAttr->getOffset()); + } + } + + // For Metal geometric stages, we need to translate VaryingInput offsets to + // MetalAttribute offsets. + if (isGeometryStage && isTargetMetal()) + { + IRVarLayout::Builder elementVarLayoutBuilder( + &builder, + paramVarLayout->getTypeLayout()); + elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout); + for (auto offsetAttr : paramVarLayout->getOffsetAttrs()) + { + auto resourceKind = offsetAttr->getResourceKind(); + if (resourceKind == LayoutResourceKind::VaryingInput) + { + resourceKind = LayoutResourceKind::MetalAttribute; + } + auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind); + resInfo->offset = offsetAttr->getOffset(); + resInfo->space = offsetAttr->getSpace(); + } + paramVarLayout = elementVarLayoutBuilder.build(); + } + + layoutBuilder.addField(key, paramVarLayout); + builder.addLayoutDecoration(key, paramVarLayout); + keys.add(key); + } + builder.setInsertInto(func->getFirstBlock()); + auto packedParam = builder.emitParamAtHead(structType); + auto typeLayout = layoutBuilder.build(); + IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); + + // Add a VaryingInput resource info to the packed parameter layout, so that we can emit + // the needed `[[stage_in]]` attribute in Metal emitter. + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(packedParam, paramVarLayout); + + // Replace the original parameters with the packed parameter + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++) + { + auto param = paramsToPack[paramIndex]; + auto key = keys[paramIndex]; + auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key); + param->replaceUsesWith(paramField); + param->removeFromParent(); + } + fixUpFuncType(func); + } + + + void reportUnsupportedSystemAttribute(IRInst* param, String semanticName) + { + m_sink->diagnose( + param->sourceLoc, + Diagnostics::systemValueAttributeNotSupported, + semanticName); + } + + template + void ensureStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout) + { + // Ensure each field in an output struct type has either a system semantic or a user + // semantic, so that signature matching can happen correctly. + auto typeLayout = as(varLayout->getTypeLayout()); + Index index = 0; + IRBuilder builder(structType); + for (auto field : structType->getFields()) + { + auto key = field->getKey(); + if (auto semanticDecor = key->findDecoration()) + { + if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + { + auto indexAsString = String(UInt(semanticDecor->getSemanticIndex())); + auto sysValInfo = + getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field); + if (sysValInfo.isUnsupported) + { + reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName()); + } + else + { + builder.addTargetSystemValueDecoration( + key, + sysValInfo.systemValueName.getUnownedSlice()); + semanticDecor->removeAndDeallocate(); + } + } + index++; + continue; + } + typeLayout->getFieldLayout(index); + auto fieldLayout = typeLayout->getFieldLayout(index); + if (auto offsetAttr = fieldLayout->findOffsetAttr(K)) + { + UInt varOffset = 0; + if (auto varOffsetAttr = varLayout->findOffsetAttr(K)) + varOffset = varOffsetAttr->getOffset(); + varOffset += offsetAttr->getOffset(); + builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset); + } + index++; + } + } + + // Stores a hicharchy of members and children which map 'oldStruct->member' to + // 'flatStruct->member' Note: this map assumes we map to FlatStruct since it is easier/faster to + // process + struct MapStructToFlatStruct + { + /* + We need a hicharchy map to resolve dependencies for mapping + oldStruct to newStruct efficently. Example: + + MyStruct + | + / | \ + / | \ + / | \ + M0 M1 M2 + | | | + A_0 A_0 B_0 + + Without storing hicharchy information, there will be no way to tell apart + `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField + only has 1 instance of `A::A0` + */ + + enum CopyOptions : int + { + // Copy a flattened-struct into a struct + FlatStructIntoStruct = 0, + + // Copy a struct into a flattened-struct + StructIntoFlatStruct = 1, + }; + + private: + // Children of member if applicable. + Dictionary members; + + // Field correlating to MapStructToFlatStruct Node. + IRInst* node; + IRStructKey* getKey() + { + SLANG_ASSERT(as(node)); + return as(node)->getKey(); + } + IRInst* getNode() { return node; } + IRType* getFieldType() + { + SLANG_ASSERT(as(node)); + return as(node)->getFieldType(); + } + + // Whom node maps to inside target flatStruct + IRStructField* targetMapping; + + auto begin() { return members.begin(); } + auto end() { return members.end(); } + + // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to + // members in val2 using `MapStructToFlatStruct` + template + static void _emitCopy( + IRBuilder& builder, + IRInst* val1, + IRStructType* type1, + IRInst* val2, + IRStructType* type2, + MapStructToFlatStruct& node) + { + for (auto& field1Pair : node) + { + auto& field1 = field1Pair.second; + + // Get member of val1 + IRInst* fieldAddr1 = nullptr; + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey()); + } + else + { + if (as(val1)) + val1 = builder.emitLoad(val1); + fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey()); + } + + // If val1 is a struct, recurse + if (auto fieldAsStruct1 = as(field1.getFieldType())) + { + _emitCopy( + builder, + fieldAddr1, + fieldAsStruct1, + val2, + type2, + field1); + continue; + } + + // Get member of val2 which maps to val1.member + auto field2 = field1.getMapping(); + SLANG_ASSERT(field2); + IRInst* fieldAddr2 = nullptr; + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + if (as(val2)) + val2 = builder.emitLoad(val1); + fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey()); + } + else + { + fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey()); + } + + // Copy val2/val1 member into val1/val2 member + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + builder.emitStore(fieldAddr1, fieldAddr2); + } + else + { + builder.emitStore(fieldAddr2, fieldAddr1); + } + } + } + + public: + void setNode(IRInst* newNode) { node = newNode; } + // Get 'MapStructToFlatStruct' that is a child of 'parent'. + // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'. + MapStructToFlatStruct& getMember(IRStructField* member) { return members[member]; } + MapStructToFlatStruct& operator[](IRStructField* member) { return getMember(member); } + + void setMapping(IRStructField* newTargetMapping) { targetMapping = newTargetMapping; } + // Get 'MapStructToFlatStruct' that is a child of 'parent'. + // Return nullptr if no member is mapped to 'parent' + IRStructField* getMapping() { return targetMapping; } + + // Copies srcVal into dstVal using hicharchy map. + template + void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal) + { + auto dstType = dstVal->getDataType(); + if (auto dstPtrType = as(dstType)) + dstType = dstPtrType->getValueType(); + auto dstStructType = as(dstType); + SLANG_ASSERT(dstStructType); + + auto srcType = srcVal->getDataType(); + if (auto srcPtrType = as(srcType)) + srcType = srcPtrType->getValueType(); + auto srcStructType = as(srcType); + SLANG_ASSERT(srcStructType); + + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a + // struct + SLANG_ASSERT(node == dstStructType); + _emitCopy( + builder, + dstVal, + dstStructType, + srcVal, + srcStructType, + *this); + } + else + { + // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct + SLANG_ASSERT(node == srcStructType); + _emitCopy( + builder, + srcVal, + srcStructType, + dstVal, + dstStructType, + *this); + } + } + }; + + IRStructType* _flattenNestedStructs( + IRBuilder& builder, + IRStructType* dst, + IRStructType* src, + IRSemanticDecoration* parentSemanticDecoration, + IRLayoutDecoration* parentLayout, + MapStructToFlatStruct& mapFieldToField, + HashSet& varsWithSemanticInfo) + { + // For all fields ('oldField') of a struct do the following: + // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration, + // IRLayoutDecoration), store these if found. + // * Do not propagate semantic info if the current node has *any* form of semantic + // information. + // Update varsWithSemanticInfo. + // 2. If IRStructType: + // 2a. Recurse this function with 'decorations that carry semantic info' from parent. + // 3. If not IRStructType: + // 3a Metal. Emit 'newField' equal to 'oldField', add 'decorations which carry semantic + // info'. + // + // 3a WGSL. Emit 'newField' with 'newKey' equal to 'oldField' and 'oldKey', respectively, + // where 'oldKey' is the key corresponding to 'oldField'. + // Add 'decorations which carry semantic info' to 'newField', and move all decorations + // of 'oldKey' to 'newKey'. + // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is + // needed to copy between types. + for (auto oldField : src->getFields()) + { + auto& fieldMappingNode = mapFieldToField[oldField]; + fieldMappingNode.setNode(oldField); + + // step 1 + bool foundSemanticDecor = false; + auto oldKey = oldField->getKey(); + IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration; + if (auto oldSemanticDecoration = oldKey->findDecoration()) + { + foundSemanticDecor = true; + fieldSemanticDecoration = oldSemanticDecoration; + parentLayout = nullptr; + } + + IRLayoutDecoration* fieldLayout = parentLayout; + if (auto oldLayout = oldKey->findDecoration()) + { + fieldLayout = oldLayout; + if (!foundSemanticDecor) + fieldSemanticDecoration = nullptr; + } + if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout) + varsWithSemanticInfo.add(oldField); + + // step 2a + if (auto structFieldType = as(oldField->getFieldType())) + { + _flattenNestedStructs( + builder, + dst, + structFieldType, + fieldSemanticDecoration, + fieldLayout, + fieldMappingNode, + varsWithSemanticInfo); + continue; + } + + // step 3a + auto newKey = builder.createStructKey(); + if (isTargetMetal()) + { + copyNameHintAndDebugDecorations(newKey, oldKey); + } + else + { + SLANG_ASSERT(isTargetWGSL()); + oldKey->transferDecorationsTo(newKey); + } + + auto newField = builder.createStructField(dst, newKey, oldField->getFieldType()); + copyNameHintAndDebugDecorations(newField, oldField); + + if (fieldSemanticDecoration) + builder.addSemanticDecoration( + newKey, + fieldSemanticDecoration->getSemanticName(), + fieldSemanticDecoration->getSemanticIndex()); + + if (fieldLayout) + { + IRLayout* oldLayout = fieldLayout->getLayout(); + List instToCopy; + // Only copy certain decorations needed for resolving system semantics + for (UInt i = 0; i < oldLayout->getOperandCount(); i++) + { + auto operand = oldLayout->getOperand(i); + if (as(operand) || as(operand) || + as(operand) || as(operand)) + instToCopy.add(operand); + } + IRVarLayout* newLayout = builder.getVarLayout(instToCopy); + builder.addLayoutDecoration(newKey, newLayout); + } + // step 3b + fieldMappingNode.setMapping(newField); + } + + return dst; + } + + // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there + // was no struct flattening. + // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct + // `IRStructFields`s + IRStructType* maybeFlattenNestedStructs( + IRBuilder& builder, + IRStructType* src, + MapStructToFlatStruct& mapFieldToField, + HashSet& varsWithSemanticInfo) + { + // Find all values inside struct that need flattening and legalization. + bool hasStructTypeMembers = false; + for (auto field : src->getFields()) + { + if (as(field->getFieldType())) + { + hasStructTypeMembers = true; + break; + } + } + if (!hasStructTypeMembers) + return src; + + // We need to: + // 1. Make new struct 1:1 with old struct but without nestested structs (flatten) + // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be + // handled later). + // 3. Store the mapping from old to new struct fields to allow copying a old-struct to + // new-struct. + builder.setInsertAfter(src); + auto newStruct = builder.createStructType(); + copyNameHintAndDebugDecorations(newStruct, src); + mapFieldToField.setNode(src); + return _flattenNestedStructs( + builder, + newStruct, + src, + nullptr, + nullptr, + mapFieldToField, + varsWithSemanticInfo); + } + + // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'. + // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function. + template + void _replaceAllReturnInst( + IRBuilder& builder, + IRFunc* targetFunc, + IRStructType* newType, + CopyLogicFunc copyLogicFunc) + { + for (auto block : targetFunc->getBlocks()) + { + if (auto returnInst = as(block->getTerminator())) + { + builder.setInsertBefore(returnInst); + auto returnVal = returnInst->getVal(); + returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal)); + } + } + } + + UInt _returnNonOverlappingAttributeIndex(std::set& usedSemanticIndex) + { + // Find first unused semantic index of equal semantic type + // to fill any gaps in user set semantic bindings + UInt prev = 0; + for (auto i : usedSemanticIndex) + { + if (i > prev + 1) + { + break; + } + prev = i; + } + usedSemanticIndex.insert(prev + 1); + return prev + 1; + } + + template + struct AttributeParentPair + { + IRLayoutDecoration* layoutDecor; + T* attr; + }; + + IRLayoutDecoration* _replaceAttributeOfLayout( + IRBuilder& builder, + IRLayoutDecoration* parentLayoutDecor, + IRInst* instToReplace, + IRInst* instToReplaceWith) + { + // Replace `instToReplace` with a `instToReplaceWith` + + auto layout = parentLayoutDecor->getLayout(); + // Find the exact same decoration `instToReplace` in-case multiple of the same type exist + List opList; + opList.add(instToReplaceWith); + for (UInt i = 0; i < layout->getOperandCount(); i++) + { + if (layout->getOperand(i) != instToReplace) + opList.add(layout->getOperand(i)); + } + auto newLayoutDecor = builder.addLayoutDecoration( + parentLayoutDecor->getParent(), + builder.getVarLayout(opList)); + parentLayoutDecor->removeAndDeallocate(); + return newLayoutDecor; + } + + IRLayoutDecoration* _simplifyUserSemanticNames( + IRBuilder& builder, + IRLayoutDecoration* layoutDecor) + { + // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into + // ("SV_TARGET", 0) using 'IRUserSemanticAttr' This is done to ensure we can check semantic + // groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()' + SLANG_ASSERT(layoutDecor); + auto layout = layoutDecor->getLayout(); + List layoutOps; + layoutOps.reserve(3); + bool changed = false; + for (auto attr : layout->getAllAttrs()) + { + if (auto userSemantic = as(attr)) + { + UnownedStringSlice outName; + UnownedStringSlice outIndex; + bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex); + if (hasStringIndex) + { + changed = true; + auto loweredName = String(outName).toLower(); + auto loweredNameSlice = loweredName.getUnownedSlice(); + auto newDecoration = + builder.getUserSemanticAttr(loweredNameSlice, stringToInt(outIndex)); + userSemantic->replaceUsesWith(newDecoration); + userSemantic->removeAndDeallocate(); + userSemantic = newDecoration; + } + layoutOps.add(userSemantic); + continue; + } + layoutOps.add(attr); + } + if (changed) + { + auto parent = layoutDecor->parent; + layoutDecor->removeAndDeallocate(); + builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps)); + } + return layoutDecor; + } + + // Find overlapping field semantics and legalize them + void fixFieldSemanticsOfFlatStruct(IRStructType* structType) + { + // Goal is to ensure we do not have overlapping semantics for the user defined semantics: + // Note that in WGSL, the semantics can be either `builtin` without index or `location` with + // index. + /* + // Assume the following code + struct Fragment + { + float4 p0 : SV_POSITION; + float2 p1 : TEXCOORD0; + float2 p2 : TEXCOORD1; + float3 p3 : COLOR0; + float3 p4 : COLOR1; + }; + + // Translates into + struct Fragment + { + float4 p0 : BUILTIN_POSITION; + float2 p1 : LOCATION_0; + float2 p2 : LOCATION_1; + float3 p3 : LOCATION_2; + float3 p4 : LOCATION_3; + }; + */ + + // For Multi-Render-Target, the semantic index must be translated to `location` with + // the same index. Assume the following code + /* + struct Fragment + { + float4 p0 : SV_TARGET1; + float4 p1 : SV_TARGET0; + }; + + // Translates into + struct Fragment + { + float4 p0 : LOCATION_1; + float4 p1 : LOCATION_0; + }; + */ + + IRBuilder builder(this->m_module); + + List overlappingSemanticsDecor; + Dictionary>> + usedSemanticIndexSemanticDecor; + + List> overlappingVarOffset; + Dictionary>> usedSemanticIndexVarOffset; + + List> overlappingUserSemantic; + Dictionary>> + usedSemanticIndexUserSemantic; + + // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when + // legalizing we may destroy and remake a `IRLayoutDecoration*` + Dictionary oldLayoutDecorToNew; + + // Collect all "semantic info carrying decorations". Any collected decoration will + // fill up their respective 'Dictionary>' + // to keep track of in-use offsets for a semantic type. + // Example: IRSemanticDecoration with name of "SV_TARGET1". + // * This will have SEMANTIC_TYPE of "sv_target". + // * This will use up index '1' + // + // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to + // a list of 'overlapping semantic info decorations' so we can legalize this + // 'semantic info decoration' later. + // + // NOTE: this is a flat struct, all members are children of the initial + // IRStructType. + for (auto field : structType->getFields()) + { + auto key = field->getKey(); + if (auto semanticDecoration = key->findDecoration()) + { + auto semanticName = semanticDecoration->getSemanticName(); + + // sv_target is treated as a user-semantic because it should be emitted with + // @location like how the user semantics are emitted. + // For fragment shader, only sv_target will user @location, and for non-fragment + // shaders, sv_target is not valid. + bool isUserSemantic = + (semanticName.startsWithCaseInsensitive(toSlice("sv_target")) || + !semanticName.startsWithCaseInsensitive(toSlice("sv_"))); + + // Ensure names are in a uniform lowercase format so we can bunch together simmilar + // semantics. + UnownedStringSlice outName; + UnownedStringSlice outIndex; + bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex); + + if (isTargetMetal()) + { + if (hasStringIndex) + { + auto loweredName = String(outName).toLower(); + auto loweredNameSlice = loweredName.getUnownedSlice(); + auto newDecoration = builder.addSemanticDecoration( + key, + loweredNameSlice, + stringToInt(outIndex)); + semanticDecoration->replaceUsesWith(newDecoration); + semanticDecoration->removeAndDeallocate(); + semanticDecoration = newDecoration; + } + } + else + { + // user semantics gets all same semantic-name. + auto loweredName = String(outName).toLower(); + auto loweredNameSlice = isUserSemantic ? wgslContext.userSemanticName + : loweredName.getUnownedSlice(); + auto newDecoration = builder.addSemanticDecoration( + key, + loweredNameSlice, + // hasStringIndex ? stringToInt(outIndex) : 0); + hasStringIndex ? stringToInt(outIndex) + : semanticDecoration->getSemanticIndex()); + semanticDecoration->replaceUsesWith(newDecoration); + semanticDecoration->removeAndDeallocate(); + semanticDecoration = newDecoration; + } + + auto& semanticUse = + usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()]; + if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end()) + overlappingSemanticsDecor.add(semanticDecoration); + else + semanticUse.insert(semanticDecoration->getSemanticIndex()); + } + if (auto layoutDecor = key->findDecoration()) + { + // Ensure names are in a uniform lowercase format so we can bunch together simmilar + // semantics + layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor); + oldLayoutDecorToNew[layoutDecor] = layoutDecor; + auto layout = layoutDecor->getLayout(); + for (auto attr : layout->getAllAttrs()) + { + if (auto offset = as(attr)) + { + auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()]; + if (semanticUse.find(offset->getOffset()) != semanticUse.end()) + overlappingVarOffset.add({layoutDecor, offset}); + else + semanticUse.insert(offset->getOffset()); + } + else if (auto userSemantic = as(attr)) + { + auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()]; + if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end()) + overlappingUserSemantic.add({layoutDecor, userSemantic}); + else + semanticUse.insert(userSemantic->getIndex()); + } + } + } + } + + // Legalize all overlapping 'semantic info decorations' + for (auto decor : overlappingSemanticsDecor) + { + auto newOffset = _returnNonOverlappingAttributeIndex( + usedSemanticIndexSemanticDecor[decor->getSemanticName()]); + builder.addSemanticDecoration( + decor->getParent(), + decor->getSemanticName(), + (int)newOffset); + decor->removeAndDeallocate(); + } + for (auto& varOffset : overlappingVarOffset) + { + auto newOffset = _returnNonOverlappingAttributeIndex( + usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]); + auto newVarOffset = builder.getVarOffsetAttr( + varOffset.attr->getResourceKind(), + newOffset, + varOffset.attr->getSpace()); + oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout( + builder, + oldLayoutDecorToNew[varOffset.layoutDecor], + varOffset.attr, + newVarOffset); + } + for (auto& userSemantic : overlappingUserSemantic) + { + auto newOffset = _returnNonOverlappingAttributeIndex( + usedSemanticIndexUserSemantic[userSemantic.attr->getName()]); + auto newUserSemantic = + builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset); + oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout( + builder, + oldLayoutDecorToNew[userSemantic.layoutDecor], + userSemantic.attr, + newUserSemantic); + } + } + + void wrapReturnValueInStruct(EntryPointInfo entryPoint) + { + // Wrap return value into a struct if it is not already a struct. + // For example, given this entry point: + // ``` + // float4 main() : SV_Target { return float3(1,2,3); } + // ``` + // We are going to transform it into: + // ``` + // struct Output { + // float4 value : SV_Target; + // }; + // Output main() { return {float3(1,2,3)}; } + + auto func = entryPoint.entryPointFunc; + + auto returnType = func->getResultType(); + if (as(returnType)) + return; + auto entryPointLayoutDecor = func->findDecoration(); + if (!entryPointLayoutDecor) + return; + auto entryPointLayout = as(entryPointLayoutDecor->getLayout()); + if (!entryPointLayout) + return; + auto resultLayout = entryPointLayout->getResultLayout(); + + // If return type is already a struct, just make sure every field has a semantic. + if (auto returnStructType = as(returnType)) + { + IRBuilder builder(func); + MapStructToFlatStruct mapOldFieldToNewField; + // Flatten result struct type to ensure we do not have nested semantics + auto flattenedStruct = maybeFlattenNestedStructs( + builder, + returnStructType, + mapOldFieldToNewField, + semanticInfoToRemove); + if (returnStructType != flattenedStruct) + { + // Replace all return-values with the flattenedStruct we made. + _replaceAllReturnInst( + builder, + func, + flattenedStruct, + [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* + { + auto srcStructType = as(srcVal->getDataType()); + SLANG_ASSERT(srcStructType); + auto dstVal = copyBuilder.emitVar(dstType); + mapOldFieldToNewField.emitCopy<( + int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>( + copyBuilder, + dstVal, + srcVal); + return builder.emitLoad(dstVal); + }); + fixUpFuncType(func, flattenedStruct); + } + // Ensure non-overlapping semantics + fixFieldSemanticsOfFlatStruct(flattenedStruct); + ensureStructHasUserSemantic( + flattenedStruct, + resultLayout); + return; + } + + IRBuilder builder(func); + builder.setInsertBefore(func); + IRStructType* structType = builder.createStructType(); + auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); + builder.addNameHintDecoration( + structType, + (String(stageText) + toSlice("Output")).getUnownedSlice()); + auto key = builder.createStructKey(); + builder.addNameHintDecoration(key, toSlice("output")); + builder.addLayoutDecoration(key, resultLayout); + builder.createStructField(structType, key, returnType); + IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder); + structTypeLayoutBuilder.addField(key, resultLayout); + auto typeLayout = structTypeLayoutBuilder.build(); + IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); + auto varLayout = varLayoutBuilder.build(); + ensureStructHasUserSemantic(structType, varLayout); + + _replaceAllReturnInst( + builder, + func, + structType, + [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* + { return copyBuilder.emitMakeStruct(dstType, 1, &srcVal); }); + + // Assign an appropriate system value semantic for stage output + auto stage = entryPoint.entryPointDecor->getProfile().getStage(); + switch (stage) + { + case Stage::Compute: + case Stage::Fragment: + { + if (isTargetMetal()) + { + builder.addTargetSystemValueDecoration(key, toSlice("color(0)")); + } + else + { + IRInst* operands[] = { + builder.getStringValue(wgslContext.userSemanticName), + builder.getIntValue(builder.getIntType(), 0)}; + builder.addDecoration( + key, + kIROp_SemanticDecoration, + operands, + SLANG_COUNT_OF(operands)); + } + break; + } + case Stage::Vertex: + { + builder.addTargetSystemValueDecoration(key, toSlice("position")); + break; + } + default: + SLANG_ASSERT(false); + return; + } + + fixUpFuncType(func, structType); + } + + IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType) + { + auto fromType = val->getFullType(); + if (auto fromVector = as(fromType)) + { + if (auto toVector = as(toType)) + { + if (fromVector->getElementCount() != toVector->getElementCount()) + { + fromType = builder.getVectorType( + fromVector->getElementType(), + toVector->getElementCount()); + val = builder.emitVectorReshape(fromType, val); + } + } + else if (as(toType)) + { + UInt index = 0; + val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + } + else if (auto fromBasicType = as(fromType)) + { + if (fromBasicType->getOp() == kIROp_VoidType) + return nullptr; + if (!as(toType)) + return nullptr; + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + else + { + return nullptr; + } + return builder.emitCast(toType, val); + } + + struct SystemValLegalizationWorkItem + { + IRInst* var; + + // Only valid for WGSL. + IRType* varType; + + String attrName; + UInt attrIndex; + }; + + // varType is only valid for WGSL. + std::optional tryToMakeSystemValWorkItem( + IRInst* var, + IRType* varType) + { + if (auto semanticDecoration = var->findDecoration()) + { + if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + { + return { + {var, + varType, + String(semanticDecoration->getSemanticName()).toLower(), + (UInt)semanticDecoration->getSemanticIndex()}}; + } + } + + auto layoutDecor = var->findDecoration(); + if (!layoutDecor) + return {}; + auto sysValAttr = layoutDecor->findAttr(); + if (!sysValAttr) + return {}; + auto semanticName = String(sysValAttr->getName()); + auto sysAttrIndex = sysValAttr->getIndex(); + + return {{var, varType, semanticName, sysAttrIndex}}; + } + + List collectSystemValFromEntryPoint(EntryPointInfo entryPoint) + { + List systemValWorkItems; + for (auto param : entryPoint.entryPointFunc->getParams()) + { + std::optional maybeWorkItem; + + if (isTargetMetal()) + { + maybeWorkItem = tryToMakeSystemValWorkItem(param, nullptr); + } + else + { + if (auto structType = as(param->getDataType())) + { + for (auto field : structType->getFields()) + { + // Nested struct-s are flattened already by flattenInputParameters(). + SLANG_ASSERT(!as(field->getFieldType())); + + auto key = field->getKey(); + auto fieldType = field->getFieldType(); + auto maybeWorkItem = tryToMakeSystemValWorkItem(key, fieldType); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + continue; + } + maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType()); + } + + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + return systemValWorkItems; + } + + void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem) + { + IRBuilder builder(entryPoint.entryPointFunc); + + auto var = workItem.var; + + auto varType = workItem.varType; + // XXX: can remove this by also passing this to Metal SV info? + if (isTargetMetal()) + { + varType = var->getFullType(); + } + + auto semanticName = workItem.attrName; + + auto indexAsString = String(workItem.attrIndex); + SystemValueInfo info = getSystemValueInfo(semanticName, &indexAsString, var); + if (info.isSpecial) + { + SLANG_ASSERT(isTargetMetal()); + if (info.systemValueNameEnum == SystemValueSemanticName::InnerCoverage) + { + // Metal does not support conservative rasterization, so this is always false. + auto val = builder.getBoolValue(false); + var->replaceUsesWith(val); + var->removeAndDeallocate(); + } + else if (info.systemValueNameEnum == SystemValueSemanticName::GroupIndex) + { + // Ensure we have a cached "sv_groupthreadid" in our entry point + if (!metalContext.entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc)) + { + auto systemValWorkItems = collectSystemValFromEntryPoint(entryPoint); + for (auto i : systemValWorkItems) + { + auto indexAsStringGroupThreadId = String(i.attrIndex); + if (getSystemValueInfo(i.attrName, &indexAsStringGroupThreadId, i.var) + .systemValueNameEnum == SystemValueSemanticName::GroupThreadID) + { + metalContext.entryPointToGroupThreadId[entryPoint.entryPointFunc] = + i.var; + } + } + if (!metalContext.entryPointToGroupThreadId.containsKey( + entryPoint.entryPointFunc)) + { + // Add the missing groupthreadid needed to compute sv_groupindex + IRBuilder groupThreadIdBuilder(builder); + groupThreadIdBuilder.setInsertInto( + entryPoint.entryPointFunc->getFirstBlock()); + auto groupThreadId = groupThreadIdBuilder.emitParamAtHead( + getMetalGroupThreadIdType(groupThreadIdBuilder)); + metalContext.entryPointToGroupThreadId[entryPoint.entryPointFunc] = + groupThreadId; + groupThreadIdBuilder.addNameHintDecoration( + groupThreadId, + metalContext.groupThreadIDString); + + // Since "sv_groupindex" will be translated out to a global var and no + // longer be considered a system value we can reuse its layout and semantic + // info + Index foundRequiredDecorations = 0; + IRLayoutDecoration* layoutDecoration = nullptr; + UInt semanticIndex = 0; + for (auto decoration : var->getDecorations()) + { + if (auto layoutDecorationTmp = as(decoration)) + { + layoutDecoration = layoutDecorationTmp; + foundRequiredDecorations++; + } + else if (auto semanticDecoration = as(decoration)) + { + semanticIndex = semanticDecoration->getSemanticIndex(); + groupThreadIdBuilder.addSemanticDecoration( + groupThreadId, + metalContext.groupThreadIDString, + (int)semanticIndex); + foundRequiredDecorations++; + } + if (foundRequiredDecorations >= 2) + break; + } + SLANG_ASSERT(layoutDecoration); + layoutDecoration->removeFromParent(); + layoutDecoration->insertAtStart(groupThreadId); + SystemValLegalizationWorkItem newWorkItem = { + groupThreadId, + nullptr, + metalContext.groupThreadIDString, + semanticIndex}; + legalizeSystemValue(entryPoint, newWorkItem); + } + } + + IRBuilder svBuilder(builder.getModule()); + svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); + auto computeExtent = emitCalcGroupExtents( + svBuilder, + entryPoint.entryPointFunc, + builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 3))); + auto groupIndexCalc = emitCalcGroupIndex( + svBuilder, + metalContext.entryPointToGroupThreadId[entryPoint.entryPointFunc], + computeExtent); + svBuilder.addNameHintDecoration( + groupIndexCalc, + UnownedStringSlice("sv_groupindex")); + + var->replaceUsesWith(groupIndexCalc); + var->removeAndDeallocate(); + } + } + + if (info.isUnsupported) + { + reportUnsupportedSystemAttribute(var, semanticName); + return; + } + if (!info.permittedTypes.getCount()) + return; + + builder.addTargetSystemValueDecoration(var, info.systemValueName.getUnownedSlice()); + + bool varTypeIsPermitted = false; + for (auto& permittedType : info.permittedTypes) + { + varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; + } + + if (!varTypeIsPermitted) + { + // Note: we do not currently prefer any conversion + // example: + // * allowed types for semantic: `float4`, `uint4`, `int4` + // * user used, `float2` + // * Slang will equally prefer `float4` to `uint4` to `int4`. + // This means the type may lose data if slang selects `uint4` or `int4`. + bool foundAConversion = false; + for (auto permittedType : info.permittedTypes) + { + var->setFullType(permittedType); + builder.setInsertBefore( + entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + + // get uses before we `tryConvertValue` since this creates a new use + List uses; + for (auto use = var->firstUse; use; use = use->nextUse) + uses.add(use); + + auto convertedValue = tryConvertValue(builder, var, varType); + if (convertedValue == nullptr) + continue; + + foundAConversion = true; + copyNameHintAndDebugDecorations(convertedValue, var); + + for (auto use : uses) + builder.replaceOperand(use, convertedValue); + } + if (!foundAConversion) + { + // If we can't convert the value, report an error. + for (auto permittedType : info.permittedTypes) + { + StringBuilder typeNameSB; + getTypeNameHint(typeNameSB, permittedType); + m_sink->diagnose( + var->sourceLoc, + Diagnostics::systemValueTypeIncompatible, + semanticName, + typeNameSB.produceString()); + } + } + } + } + + void legalizeSystemValueParameters(EntryPointInfo entryPoint) + { + List systemValWorkItems = + collectSystemValFromEntryPoint(entryPoint); + + for (auto index = 0; index < systemValWorkItems.getCount(); index++) + { + legalizeSystemValue(entryPoint, systemValWorkItems[index]); + } + fixUpFuncType(entryPoint.entryPointFunc); + } + + void legalizeEntryPoint(EntryPointInfo entryPoint) + { + // If the entrypoint is receiving varying inputs as a pointer, turn it into a value. + depointerizeInputParams(entryPoint.entryPointFunc); + + // XXX: Enable these for WGSL + if (isTargetMetal()) + { + hoistEntryPointParameterFromStruct(entryPoint); + packStageInParameters(entryPoint); + } + + // Input Parameter Legalize + flattenInputParameters(entryPoint); + + // System Value Legalize + legalizeSystemValueParameters(entryPoint); + + // Output Value Legalize + wrapReturnValueInStruct(entryPoint); + + + // Other Legalize + switch (entryPoint.entryPointDecor->getProfile().getStage()) + { + case Stage::Amplification: + SLANG_ASSERT(isTargetMetal()); + legalizeMetalDispatchMeshPayload(entryPoint); + break; + case Stage::Mesh: + SLANG_ASSERT(isTargetMetal()); + legalizeMetalMeshEntryPoint(entryPoint); + break; + default: + break; + } + } + + void legalizeEntryPoints(List& entryPoints) + { + for (auto entryPoint : entryPoints) + legalizeEntryPoint(entryPoint); + removeSemanticLayoutsFromLegalizedStructs(); + } + + // ****************************************************************** + // Metal specific Legalization Logic + // ****************************************************************** + + struct MetalContext + { + ShortList permittedTypes_sv_target; + Dictionary entryPointToGroupThreadId; + const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid"); + } metalContext; + + IRType* getMetalGroupThreadIdType(IRBuilder& builder) + { + SLANG_ASSERT(isTargetMetal()); + + return builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3)); + } + + // Get all permitted types of "sv_target" for Metal + ShortList& getMetalPermittedTypes_sv_target(IRBuilder& builder) + { + SLANG_ASSERT(isTargetMetal()); + + metalContext.permittedTypes_sv_target.reserveOverflowBuffer(5 * 4); + if (metalContext.permittedTypes_sv_target.getCount() == 0) + { + for (auto baseType : + {BaseType::Float, + BaseType::Half, + BaseType::Int, + BaseType::UInt, + BaseType::Int16, + BaseType::UInt16}) + { + for (IRIntegerValue i = 1; i <= 4; i++) + { + metalContext.permittedTypes_sv_target.add( + builder.getVectorType(builder.getBasicType(baseType), i)); + } + } + } + return metalContext.permittedTypes_sv_target; + } + + SystemValueInfo getMetalSystemValueInfo( + String inSemanticName, + String* optionalSemanticIndex, + IRInst* parentVar) + { + SLANG_ASSERT(isTargetMetal()); + + IRBuilder builder(m_module); + SystemValueInfo result = {}; + UnownedStringSlice semanticName; + UnownedStringSlice semanticIndex; + + auto hasExplicitIndex = + splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex); + if (!hasExplicitIndex && optionalSemanticIndex) + semanticIndex = optionalSemanticIndex->getUnownedSlice(); + + result.systemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName); + + switch (result.systemValueNameEnum) + { + case SystemValueSemanticName::Position: + { + result.systemValueName = toSlice("position"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::Float), + builder.getIntValue(builder.getIntType(), 4))); + break; + } + case SystemValueSemanticName::ClipDistance: + { + result.isSpecial = true; + break; + } + case SystemValueSemanticName::CullDistance: + { + result.isSpecial = true; + break; + } + case SystemValueSemanticName::Coverage: + { + result.systemValueName = toSlice("sample_mask"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::InnerCoverage: + { + result.isSpecial = true; + break; + } + case SystemValueSemanticName::Depth: + { + result.systemValueName = toSlice("depth(any)"); + result.permittedTypes.add(builder.getBasicType(BaseType::Float)); + break; + } + case SystemValueSemanticName::DepthGreaterEqual: + { + result.systemValueName = toSlice("depth(greater)"); + result.permittedTypes.add(builder.getBasicType(BaseType::Float)); + break; + } + case SystemValueSemanticName::DepthLessEqual: + { + result.systemValueName = toSlice("depth(less)"); + result.permittedTypes.add(builder.getBasicType(BaseType::Float)); + break; + } + case SystemValueSemanticName::DispatchThreadID: + { + result.systemValueName = toSlice("thread_position_in_grid"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3))); + break; + } + case SystemValueSemanticName::DomainLocation: + { + result.systemValueName = toSlice("position_in_patch"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::Float), + builder.getIntValue(builder.getIntType(), 3))); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::Float), + builder.getIntValue(builder.getIntType(), 2))); + break; + } + case SystemValueSemanticName::GroupID: + { + result.systemValueName = toSlice("threadgroup_position_in_grid"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3))); + break; + } + case SystemValueSemanticName::GroupIndex: + { + result.isSpecial = true; + break; + } + case SystemValueSemanticName::GroupThreadID: + { + result.systemValueName = toSlice("thread_position_in_threadgroup"); + result.permittedTypes.add(getMetalGroupThreadIdType(builder)); + break; + } + case SystemValueSemanticName::GSInstanceID: + { + result.isUnsupported = true; + break; + } + case SystemValueSemanticName::InstanceID: + { + result.systemValueName = toSlice("instance_id"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::IsFrontFace: + { + result.systemValueName = toSlice("front_facing"); + result.permittedTypes.add(builder.getBasicType(BaseType::Bool)); + break; + } + case SystemValueSemanticName::OutputControlPointID: + { + // In metal, a hull shader is just a compute shader. + // This needs to be handled separately, by lowering into an ordinary buffer. + break; + } + case SystemValueSemanticName::PointSize: + { + result.systemValueName = toSlice("point_size"); + result.permittedTypes.add(builder.getBasicType(BaseType::Float)); + break; + } + case SystemValueSemanticName::PrimitiveID: + { + result.systemValueName = toSlice("primitive_id"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt16)); + break; + } + case SystemValueSemanticName::RenderTargetArrayIndex: + { + result.systemValueName = toSlice("render_target_array_index"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt16)); + break; + } + case SystemValueSemanticName::SampleIndex: + { + result.systemValueName = toSlice("sample_id"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::StencilRef: + { + result.systemValueName = toSlice("stencil"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::TessFactor: + { + // Tessellation factor outputs should be lowered into a write into a normal buffer. + break; + } + case SystemValueSemanticName::VertexID: + { + result.systemValueName = toSlice("vertex_id"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::ViewID: + { + result.isUnsupported = true; + break; + } + case SystemValueSemanticName::ViewportArrayIndex: + { + result.systemValueName = toSlice("viewport_array_index"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt16)); + break; + } + case SystemValueSemanticName::Target: + { + result.systemValueName = + (StringBuilder() + << "color(" << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0")) + << ")") + .produceString(); + result.permittedTypes = getMetalPermittedTypes_sv_target(builder); + + break; + } + case SystemValueSemanticName::StartVertexLocation: + { + result.systemValueName = toSlice("base_vertex"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::StartInstanceLocation: + { + result.systemValueName = toSlice("base_instance"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + default: + m_sink->diagnose( + parentVar, + Diagnostics::unimplementedSystemValueSemantic, + semanticName); + return result; + } + return result; + } + + void legalizeMetalDispatchMeshPayload(EntryPointInfo entryPoint) + { + SLANG_ASSERT(isTargetMetal()); + + // Find out DispatchMesh function + IRGlobalValueWithCode* dispatchMeshFunc = nullptr; + for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts()) + { + if (const auto func = as(globalInst)) + { + if (const auto dec = func->findDecoration()) + { + if (dec->getName() == "DispatchMesh") + { + SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found"); + dispatchMeshFunc = func; + } + } + } + } + + if (!dispatchMeshFunc) + return; + + IRBuilder builder{entryPoint.entryPointFunc->getModule()}; + + // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid + traverseUses( + dispatchMeshFunc, + [&](const IRUse* use) + { + if (const auto call = as(use->getUser())) + { + SLANG_ASSERT(call->getArgCount() == 4); + const auto payload = call->getArg(3); + + const auto payloadPtrType = + composeGetters(payload, &IRInst::getDataType); + SLANG_ASSERT(payloadPtrType); + const auto payloadType = payloadPtrType->getValueType(); + SLANG_ASSERT(payloadType); + + builder.setInsertBefore( + entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + const auto annotatedPayloadType = builder.getPtrType( + kIROp_RefType, + payloadPtrType->getValueType(), + AddressSpace::MetalObjectData); + auto packedParam = builder.emitParam(annotatedPayloadType); + builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload")); + IRVarLayout::Builder varLayoutBuilder( + &builder, + IRTypeLayout::Builder{&builder}.build()); + + // Add the MetalPayload resource info, so we can emit [[payload]] + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(packedParam, paramVarLayout); + + // Now we replace the call to DispatchMesh with a call to the mesh grid + // properties But first we need to create the parameter + const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType(); + auto mgp = builder.emitParam(meshGridPropertiesType); + builder.addExternCppDecoration(mgp, toSlice("_slang_mgp")); + } + }); + } + + void legalizeMetalMeshEntryPoint(EntryPointInfo entryPoint) + { + SLANG_ASSERT(isTargetMetal()); + + auto func = entryPoint.entryPointFunc; + + IRBuilder builder{func->getModule()}; + for (auto param : func->getParams()) + { + if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + { + IRVarLayout::Builder varLayoutBuilder( + &builder, + IRTypeLayout::Builder{&builder}.build()); + + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(param, paramVarLayout); + + IRPtrTypeBase* type = as(param->getDataType()); + + const auto annotatedPayloadType = builder.getPtrType( + kIROp_ConstRefType, + type->getValueType(), + AddressSpace::MetalObjectData); + + param->setFullType(annotatedPayloadType); + } + } + IROutputTopologyDecoration* outputDeco = + entryPoint.entryPointFunc->findDecoration(); + if (outputDeco == nullptr) + { + SLANG_UNEXPECTED("Mesh shader output decoration missing"); + return; + } + const auto topology = outputDeco->getTopology(); + const auto topStr = topology->getStringSlice(); + UInt topologyEnum = 0; + if (topStr.caseInsensitiveEquals(toSlice("point"))) + { + topologyEnum = 1; + } + else if (topStr.caseInsensitiveEquals(toSlice("line"))) + { + topologyEnum = 2; + } + else if (topStr.caseInsensitiveEquals(toSlice("triangle"))) + { + topologyEnum = 3; + } + else + { + SLANG_UNEXPECTED("unknown topology"); + return; + } + + IRInst* topologyConst = builder.getIntValue(builder.getIntType(), topologyEnum); + + IRType* vertexType = nullptr; + IRType* indicesType = nullptr; + IRType* primitiveType = nullptr; + + IRInst* maxVertices = nullptr; + IRInst* maxPrimitives = nullptr; + + IRInst* verticesParam = nullptr; + IRInst* indicesParam = nullptr; + IRInst* primitivesParam = nullptr; + for (auto param : func->getParams()) + { + if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + { + IRVarLayout::Builder varLayoutBuilder( + &builder, + IRTypeLayout::Builder{&builder}.build()); + + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(param, paramVarLayout); + } + if (param->findDecorationImpl(kIROp_VerticesDecoration)) + { + auto vertexRefType = as(param->getDataType()); + auto vertexOutputType = as(vertexRefType->getValueType()); + vertexType = vertexOutputType->getElementType(); + maxVertices = vertexOutputType->getMaxElementCount(); + SLANG_ASSERT(vertexType); + + verticesParam = param; + auto vertStruct = as(vertexType); + for (auto field : vertStruct->getFields()) + { + auto key = field->getKey(); + if (auto deco = key->findDecoration()) + { + if (deco->getSemanticName().caseInsensitiveEquals(toSlice("sv_position"))) + { + builder.addTargetSystemValueDecoration(key, toSlice("position")); + } + } + } + } + if (param->findDecorationImpl(kIROp_IndicesDecoration)) + { + auto indicesRefType = (IRConstRefType*)param->getDataType(); + auto indicesOutputType = (IRIndicesType*)indicesRefType->getValueType(); + indicesType = indicesOutputType->getElementType(); + maxPrimitives = indicesOutputType->getMaxElementCount(); + SLANG_ASSERT(indicesType); + + indicesParam = param; + } + if (param->findDecorationImpl(kIROp_PrimitivesDecoration)) + { + auto primitivesRefType = (IRConstRefType*)param->getDataType(); + auto primitivesOutputType = (IRPrimitivesType*)primitivesRefType->getValueType(); + primitiveType = primitivesOutputType->getElementType(); + SLANG_ASSERT(primitiveType); + + primitivesParam = param; + auto primStruct = as(primitiveType); + for (auto field : primStruct->getFields()) + { + auto key = field->getKey(); + if (auto deco = key->findDecoration()) + { + if (deco->getSemanticName().caseInsensitiveEquals( + toSlice("sv_primitiveid"))) + { + builder.addTargetSystemValueDecoration(key, toSlice("primitive_id")); + } + } + } + } + } + if (primitiveType == nullptr) + { + primitiveType = builder.getVoidType(); + } + builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + + auto meshParam = builder.emitParam(builder.getMetalMeshType( + vertexType, + primitiveType, + maxVertices, + maxPrimitives, + topologyConst)); + builder.addExternCppDecoration(meshParam, toSlice("_slang_mesh")); + + + verticesParam->removeFromParent(); + verticesParam->removeAndDeallocate(); + + indicesParam->removeFromParent(); + indicesParam->removeAndDeallocate(); + + if (primitivesParam != nullptr) + { + primitivesParam->removeFromParent(); + primitivesParam->removeAndDeallocate(); + } + } + + // ****************************************************************** + // WGSL specific Legalization Logic + // ****************************************************************** + + struct WGSLContext + { + UnownedStringSlice userSemanticName = toSlice("user_semantic"); + } wgslContext; + + SystemValueInfo getWGSLSystemValueInfo( + String inSemanticName, + String* optionalSemanticIndex, + IRInst* parentVar) + { + SLANG_ASSERT(isTargetWGSL()); + + IRBuilder builder(m_module); + SystemValueInfo result = {}; + UnownedStringSlice semanticName; + UnownedStringSlice semanticIndex; + + auto hasExplicitIndex = + splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex); + if (!hasExplicitIndex && optionalSemanticIndex) + semanticIndex = optionalSemanticIndex->getUnownedSlice(); + + result.systemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName); + + switch (result.systemValueNameEnum) + { + + case SystemValueSemanticName::CullDistance: + { + result.isUnsupported = true; + } + break; + + case SystemValueSemanticName::ClipDistance: + { + // TODO: Implement this based on the 'clip-distances' feature in WGSL + // https: // www.w3.org/TR/webgpu/#dom-gpufeaturename-clip-distances + result.isUnsupported = true; + } + break; + + case SystemValueSemanticName::Coverage: + { + result.systemValueName = toSlice("sample_mask"); + result.permittedTypes.add(builder.getUIntType()); + } + break; + + case SystemValueSemanticName::Depth: + { + result.systemValueName = toSlice("frag_depth"); + result.permittedTypes.add(builder.getBasicType(BaseType::Float)); + } + break; + + case SystemValueSemanticName::DepthGreaterEqual: + case SystemValueSemanticName::DepthLessEqual: + { + result.isUnsupported = true; + } + break; + + case SystemValueSemanticName::DispatchThreadID: + { + result.systemValueName = toSlice("global_invocation_id"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3))); + } + break; + + case SystemValueSemanticName::DomainLocation: + { + result.isUnsupported = true; + } + break; + + case SystemValueSemanticName::GroupID: + { + result.systemValueName = toSlice("workgroup_id"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3))); + } + break; + + case SystemValueSemanticName::GroupIndex: + { + result.systemValueName = toSlice("local_invocation_index"); + result.permittedTypes.add(builder.getUIntType()); + } + break; + + case SystemValueSemanticName::GroupThreadID: + { + result.systemValueName = toSlice("local_invocation_id"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3))); + } + break; + + case SystemValueSemanticName::GSInstanceID: + { + // No Geometry shaders in WGSL + result.isUnsupported = true; + } + break; + + case SystemValueSemanticName::InnerCoverage: + { + result.isUnsupported = true; + } + break; + + case SystemValueSemanticName::InstanceID: + { + result.systemValueName = toSlice("instance_index"); + result.permittedTypes.add(builder.getUIntType()); + } + break; + + case SystemValueSemanticName::IsFrontFace: + { + result.systemValueName = toSlice("front_facing"); + result.permittedTypes.add(builder.getBoolType()); + } + break; + + case SystemValueSemanticName::OutputControlPointID: + case SystemValueSemanticName::PointSize: + { + result.isUnsupported = true; + } + break; + + case SystemValueSemanticName::Position: + { + result.systemValueName = toSlice("position"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::Float), + builder.getIntValue(builder.getIntType(), 4))); + break; + } + + case SystemValueSemanticName::PrimitiveID: + case SystemValueSemanticName::RenderTargetArrayIndex: + { + result.isUnsupported = true; + break; + } + + case SystemValueSemanticName::SampleIndex: + { + result.systemValueName = toSlice("sample_index"); + result.permittedTypes.add(builder.getUIntType()); + break; + } + + case SystemValueSemanticName::StencilRef: + case SystemValueSemanticName::Target: + case SystemValueSemanticName::TessFactor: + { + result.isUnsupported = true; + break; + } + + case SystemValueSemanticName::VertexID: + { + result.systemValueName = toSlice("vertex_index"); + result.permittedTypes.add(builder.getUIntType()); + break; + } + + case SystemValueSemanticName::ViewID: + case SystemValueSemanticName::ViewportArrayIndex: + case SystemValueSemanticName::StartVertexLocation: + case SystemValueSemanticName::StartInstanceLocation: + { + result.isUnsupported = true; + break; + } + + default: + { + m_sink->diagnose( + parentVar, + Diagnostics::unimplementedSystemValueSemantic, + semanticName); + return result; + } + } + + return result; + } +}; + +void legalizeEntryPointVaryingParamsForMetal( + IRModule* module, + DiagnosticSink* sink, + List& entryPoints) +{ + LegalizeShaderEntryPointContext context( + module, + sink, + LegalizeShaderEntryPointContext::LegalizeTarget::Metal); + context.legalizeEntryPoints(entryPoints); +} + +void legalizeEntryPointVaryingParamsForWGSL( + IRModule* module, + DiagnosticSink* sink, + List& entryPoints) +{ + LegalizeShaderEntryPointContext context( + module, + sink, + LegalizeShaderEntryPointContext::LegalizeTarget::WGSL); + context.legalizeEntryPoints(entryPoints); +} + } // namespace Slang diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index efd61e87cf..e742f30936 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -14,19 +14,27 @@ struct IRVectorType; struct IRBuilder; struct IREntryPointDecoration; +struct EntryPointInfo +{ + IRFunc* entryPointFunc; + IREntryPointDecoration* entryPointDecor; +}; + void legalizeEntryPointVaryingParamsForCPU(IRModule* module, DiagnosticSink* sink); void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* sink); -void depointerizeInputParams(IRFunc* entryPoint); - -// (#4375) Once `slang-ir-metal-legalize.cpp` is merged with -// `slang-ir-legalize-varying-params.cpp`, move the following -// below into `slang-ir-legalize-varying-params.cpp` as well +void legalizeEntryPointVaryingParamsForMetal( + IRModule* module, + DiagnosticSink* sink, + List& entryPoints); -IRInst* emitCalcGroupExtents(IRBuilder& builder, IRFunc* entryPoint, IRVectorType* type); +void legalizeEntryPointVaryingParamsForWGSL( + IRModule* module, + DiagnosticSink* sink, + List& entryPoints); -IRInst* emitCalcGroupIndex(IRBuilder& builder, IRInst* groupThreadID, IRInst* groupExtents); +void depointerizeInputParams(IRFunc* entryPoint); // SystemValueSemanticName member definition macro #define SYSTEM_VALUE_SEMANTIC_NAMES(M) \ diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 5bfa62e4af..0d58bdd14c 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -7,1950 +7,10 @@ #include "slang-ir-specialize-address-space.h" #include "slang-ir-util.h" #include "slang-ir.h" -#include "slang-parameter-binding.h" - -#include namespace Slang { -struct EntryPointInfo -{ - IRFunc* entryPointFunc; - IREntryPointDecoration* entryPointDecor; -}; - -const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid"); -struct LegalizeMetalEntryPointContext -{ - ShortList permittedTypes_sv_target; - Dictionary entryPointToGroupThreadId; - HashSet semanticInfoToRemove; - - DiagnosticSink* m_sink; - IRModule* m_module; - - LegalizeMetalEntryPointContext(DiagnosticSink* sink, IRModule* module) - : m_sink(sink), m_module(module) - { - } - - void removeSemanticLayoutsFromLegalizedStructs() - { - // Metal does not allow duplicate attributes to appear in the same shader. - // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]` - // must be removed. - for (auto field : semanticInfoToRemove) - { - auto key = field->getKey(); - // Some decorations appear twice, destroy all found - for (;;) - { - if (auto semanticDecor = key->findDecoration()) - { - semanticDecor->removeAndDeallocate(); - continue; - } - else if (auto layoutDecor = key->findDecoration()) - { - layoutDecor->removeAndDeallocate(); - continue; - } - break; - } - } - } - - void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint) - { - // If an entry point has a input parameter with a struct type, we want to hoist out - // all the fields of the struct type to be individual parameters of the entry point. - // This will canonicalize the entry point signature, so we can handle all cases uniformly. - - // For example, given an entry point: - // ``` - // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID}; - // void main(VertexInput vin) { ... } - // ``` - // We will transform it to: - // ``` - // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { - // VertexInput vin = {pos,uv,vertexId}; - // ... - // } - // ``` - - auto func = entryPoint.entryPointFunc; - List paramsToProcess; - for (auto param : func->getParams()) - { - if (as(param->getDataType())) - { - paramsToProcess.add(param); - } - } - - IRBuilder builder(func); - builder.setInsertBefore(func); - for (auto param : paramsToProcess) - { - auto structType = as(param->getDataType()); - builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto varLayout = findVarLayout(param); - - // If `param` already has a semantic, we don't want to hoist its fields out. - if (varLayout->findSystemValueSemanticAttr() != nullptr || - param->findDecoration()) - continue; - - IRStructTypeLayout* structTypeLayout = nullptr; - if (varLayout) - structTypeLayout = as(varLayout->getTypeLayout()); - Index fieldIndex = 0; - List fieldParams; - for (auto field : structType->getFields()) - { - auto fieldParam = builder.emitParam(field->getFieldType()); - IRCloneEnv cloneEnv; - cloneInstDecorationsAndChildren( - &cloneEnv, - builder.getModule(), - field->getKey(), - fieldParam); - - IRVarLayout* fieldLayout = - structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr; - if (varLayout) - { - IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout()); - varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout); - for (auto offsetAttr : fieldLayout->getOffsetAttrs()) - { - auto parentOffsetAttr = - varLayout->findOffsetAttr(offsetAttr->getResourceKind()); - UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0; - UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0; - auto resInfo = - varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind()); - resInfo->offset = parentOffset + offsetAttr->getOffset(); - resInfo->space = parentSpace + offsetAttr->getSpace(); - } - builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build()); - } - param->insertBefore(fieldParam); - fieldParams.add(fieldParam); - fieldIndex++; - } - builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto reconstructedParam = - builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer()); - param->replaceUsesWith(reconstructedParam); - param->removeFromParent(); - } - fixUpFuncType(func); - } - - // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct - void flattenInputParameters(EntryPointInfo entryPoint) - { - // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members). - /* - // Assume the following code - struct NestedFragment - { - float2 p3; - }; - struct Fragment - { - float4 p1; - float3 p2; - NestedFragment p3_nested; - }; - - // Fragment flattens into - struct Fragment - { - float4 p1; - float3 p2; - float2 p3; - }; - */ - - // This is important since Metal does not allow semantic's on a struct - /* - // Assume the following code - struct NestedFragment1 - { - float2 p3; - }; - struct Fragment1 - { - float4 p1 : SV_TARGET0; - float3 p2 : SV_TARGET1; - NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct - }; - - */ - - // Metal does allow semantics on members of a nested struct but we are avoiding this - // approach since there are senarios where legalization (and verification) is - // hard/expensive without creating a flat struct: - // 1. Entry points may share structs, semantics may be inconsistent across entry points - // 2. Multiple of the same struct may be used in a param list - /* - // Assume the following code - struct NestedFragment - { - float2 p3; - }; - struct Fragment - { - float4 p1 : SV_TARGET0; - NestedFragment p2 : SV_TARGET1; - NestedFragment p3 : SV_TARGET2; - }; - - // Legalized without flattening -- abandoned - struct NestedFragment1 - { - float2 p3 : SV_TARGET1; - }; - struct NestedFragment2 - { - float2 p3 : SV_TARGET2; - }; - struct Fragment - { - float4 p1 : SV_TARGET0; - NestedFragment1 p2; - NestedFragment2 p3; - }; - - // Legalized with flattening -- current approach - struct Fragment - { - float4 p1 : SV_TARGET0; - float2 p2 : SV_TARGET1; - float2 p3 : SV_TARGET2; - }; - */ - - auto func = entryPoint.entryPointFunc; - bool modified = false; - for (auto param : func->getParams()) - { - auto layout = findVarLayout(param); - if (!layout) - continue; - if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) - continue; - if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) - continue; - // If we find a IRParam with a IRStructType member, we need to flatten the entire - // IRParam - if (auto structType = as(param->getDataType())) - { - IRBuilder builder(func); - MapStructToFlatStruct mapOldFieldToNewField; - - // Flatten struct if we have nested IRStructType - auto flattenedStruct = maybeFlattenNestedStructs( - builder, - structType, - mapOldFieldToNewField, - semanticInfoToRemove); - if (flattenedStruct != structType) - { - // Validate/rearange all semantics which overlap in our flat struct - fixFieldSemanticsOfFlatStruct(flattenedStruct); - - // Replace the 'old IRParam type' with a 'new IRParam type' - param->setFullType(flattenedStruct); - - // Emit a new variable at EntryPoint of 'old IRParam type' - builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto dstVal = builder.emitVar(structType); - auto dstLoad = builder.emitLoad(dstVal); - param->replaceUsesWith(dstLoad); - builder.setInsertBefore(dstLoad); - // Copy the 'new IRParam type' to our 'old IRParam type' - mapOldFieldToNewField - .emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>( - builder, - dstVal, - param); - - modified = true; - } - } - } - if (modified) - fixUpFuncType(func); - } - - void packStageInParameters(EntryPointInfo entryPoint) - { - // If the entry point has any parameters whose layout contains VaryingInput, - // we need to pack those parameters into a single `struct` type, and decorate - // the fields with the appropriate `[[attribute]]` decorations. - // For other parameters that are not `VaryingInput`, we need to leave them as is. - // - // For example, given this code after `hoistEntryPointParameterFromStruct`: - // ``` - // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { - // VertexInput vin = {pos,uv,vertexId}; - // ... - // } - // ``` - // We are going to transform it into: - // ``` - // struct VertexInput { - // float3 pos [[attribute(0)]]; - // float2 uv [[attribute(1)]]; - // }; - // void main(VertexInput vin, int vertexId : SV_VertexID) { - // let pos = vin.pos; - // let uv = vin.uv; - // ... - // } - - auto func = entryPoint.entryPointFunc; - - bool isGeometryStage = false; - switch (entryPoint.entryPointDecor->getProfile().getStage()) - { - case Stage::Vertex: - case Stage::Amplification: - case Stage::Mesh: - case Stage::Geometry: - case Stage::Domain: - case Stage::Hull: - isGeometryStage = true; - break; - } - - List paramsToPack; - for (auto param : func->getParams()) - { - auto layout = findVarLayout(param); - if (!layout) - continue; - if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) - continue; - if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) - continue; - paramsToPack.add(param); - } - - if (paramsToPack.getCount() == 0) - return; - - IRBuilder builder(func); - builder.setInsertBefore(func); - IRStructType* structType = builder.createStructType(); - auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); - builder.addNameHintDecoration( - structType, - (String(stageText) + toSlice("Input")).getUnownedSlice()); - List keys; - IRStructTypeLayout::Builder layoutBuilder(&builder); - for (auto param : paramsToPack) - { - auto paramVarLayout = findVarLayout(param); - auto key = builder.createStructKey(); - param->transferDecorationsTo(key); - builder.createStructField(structType, key, param->getDataType()); - if (auto varyingInOffsetAttr = - paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput)) - { - if (!key->findDecoration() && - !paramVarLayout->findAttr()) - { - // If the parameter doesn't have a semantic, we need to add one for semantic - // matching. - builder.addSemanticDecoration( - key, - toSlice("_slang_attr"), - (int)varyingInOffsetAttr->getOffset()); - } - } - if (isGeometryStage) - { - // For geometric stages, we need to translate VaryingInput offsets to MetalAttribute - // offsets. - IRVarLayout::Builder elementVarLayoutBuilder( - &builder, - paramVarLayout->getTypeLayout()); - elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout); - for (auto offsetAttr : paramVarLayout->getOffsetAttrs()) - { - auto resourceKind = offsetAttr->getResourceKind(); - if (resourceKind == LayoutResourceKind::VaryingInput) - { - resourceKind = LayoutResourceKind::MetalAttribute; - } - auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind); - resInfo->offset = offsetAttr->getOffset(); - resInfo->space = offsetAttr->getSpace(); - } - paramVarLayout = elementVarLayoutBuilder.build(); - } - layoutBuilder.addField(key, paramVarLayout); - builder.addLayoutDecoration(key, paramVarLayout); - keys.add(key); - } - builder.setInsertInto(func->getFirstBlock()); - auto packedParam = builder.emitParamAtHead(structType); - auto typeLayout = layoutBuilder.build(); - IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); - - // Add a VaryingInput resource info to the packed parameter layout, so that we can emit - // the needed `[[stage_in]]` attribute in Metal emitter. - varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput); - auto paramVarLayout = varLayoutBuilder.build(); - builder.addLayoutDecoration(packedParam, paramVarLayout); - - // Replace the original parameters with the packed parameter - builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++) - { - auto param = paramsToPack[paramIndex]; - auto key = keys[paramIndex]; - auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key); - param->replaceUsesWith(paramField); - param->removeFromParent(); - } - fixUpFuncType(func); - } - - struct MetalSystemValueInfo - { - String metalSystemValueName; - SystemValueSemanticName metalSystemValueNameEnum; - ShortList permittedTypes; - bool isUnsupported = false; - bool isSpecial = false; - MetalSystemValueInfo() - { - // most commonly need 2 - permittedTypes.reserveOverflowBuffer(2); - } - }; - - IRType* getGroupThreadIdType(IRBuilder& builder) - { - return builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3)); - } - - // Get all permitted types of "sv_target" for Metal - ShortList& getPermittedTypes_sv_target(IRBuilder& builder) - { - permittedTypes_sv_target.reserveOverflowBuffer(5 * 4); - if (permittedTypes_sv_target.getCount() == 0) - { - for (auto baseType : - {BaseType::Float, - BaseType::Half, - BaseType::Int, - BaseType::UInt, - BaseType::Int16, - BaseType::UInt16}) - { - for (IRIntegerValue i = 1; i <= 4; i++) - { - permittedTypes_sv_target.add( - builder.getVectorType(builder.getBasicType(baseType), i)); - } - } - } - return permittedTypes_sv_target; - } - - MetalSystemValueInfo getSystemValueInfo( - String inSemanticName, - String* optionalSemanticIndex, - IRInst* parentVar) - { - IRBuilder builder(m_module); - MetalSystemValueInfo result = {}; - UnownedStringSlice semanticName; - UnownedStringSlice semanticIndex; - - auto hasExplicitIndex = - splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex); - if (!hasExplicitIndex && optionalSemanticIndex) - semanticIndex = optionalSemanticIndex->getUnownedSlice(); - - result.metalSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName); - - switch (result.metalSystemValueNameEnum) - { - case SystemValueSemanticName::Position: - { - result.metalSystemValueName = toSlice("position"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::Float), - builder.getIntValue(builder.getIntType(), 4))); - break; - } - case SystemValueSemanticName::ClipDistance: - { - result.isSpecial = true; - break; - } - case SystemValueSemanticName::CullDistance: - { - result.isSpecial = true; - break; - } - case SystemValueSemanticName::Coverage: - { - result.metalSystemValueName = toSlice("sample_mask"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - break; - } - case SystemValueSemanticName::InnerCoverage: - { - result.isSpecial = true; - break; - } - case SystemValueSemanticName::Depth: - { - result.metalSystemValueName = toSlice("depth(any)"); - result.permittedTypes.add(builder.getBasicType(BaseType::Float)); - break; - } - case SystemValueSemanticName::DepthGreaterEqual: - { - result.metalSystemValueName = toSlice("depth(greater)"); - result.permittedTypes.add(builder.getBasicType(BaseType::Float)); - break; - } - case SystemValueSemanticName::DepthLessEqual: - { - result.metalSystemValueName = toSlice("depth(less)"); - result.permittedTypes.add(builder.getBasicType(BaseType::Float)); - break; - } - case SystemValueSemanticName::DispatchThreadID: - { - result.metalSystemValueName = toSlice("thread_position_in_grid"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3))); - break; - } - case SystemValueSemanticName::DomainLocation: - { - result.metalSystemValueName = toSlice("position_in_patch"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::Float), - builder.getIntValue(builder.getIntType(), 3))); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::Float), - builder.getIntValue(builder.getIntType(), 2))); - break; - } - case SystemValueSemanticName::GroupID: - { - result.metalSystemValueName = toSlice("threadgroup_position_in_grid"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3))); - break; - } - case SystemValueSemanticName::GroupIndex: - { - result.isSpecial = true; - break; - } - case SystemValueSemanticName::GroupThreadID: - { - result.metalSystemValueName = toSlice("thread_position_in_threadgroup"); - result.permittedTypes.add(getGroupThreadIdType(builder)); - break; - } - case SystemValueSemanticName::GSInstanceID: - { - result.isUnsupported = true; - break; - } - case SystemValueSemanticName::InstanceID: - { - result.metalSystemValueName = toSlice("instance_id"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - break; - } - case SystemValueSemanticName::IsFrontFace: - { - result.metalSystemValueName = toSlice("front_facing"); - result.permittedTypes.add(builder.getBasicType(BaseType::Bool)); - break; - } - case SystemValueSemanticName::OutputControlPointID: - { - // In metal, a hull shader is just a compute shader. - // This needs to be handled separately, by lowering into an ordinary buffer. - break; - } - case SystemValueSemanticName::PointSize: - { - result.metalSystemValueName = toSlice("point_size"); - result.permittedTypes.add(builder.getBasicType(BaseType::Float)); - break; - } - case SystemValueSemanticName::PrimitiveID: - { - result.metalSystemValueName = toSlice("primitive_id"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt16)); - break; - } - case SystemValueSemanticName::RenderTargetArrayIndex: - { - result.metalSystemValueName = toSlice("render_target_array_index"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt16)); - break; - } - case SystemValueSemanticName::SampleIndex: - { - result.metalSystemValueName = toSlice("sample_id"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - break; - } - case SystemValueSemanticName::StencilRef: - { - result.metalSystemValueName = toSlice("stencil"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - break; - } - case SystemValueSemanticName::TessFactor: - { - // Tessellation factor outputs should be lowered into a write into a normal buffer. - break; - } - case SystemValueSemanticName::VertexID: - { - result.metalSystemValueName = toSlice("vertex_id"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - break; - } - case SystemValueSemanticName::ViewID: - { - result.isUnsupported = true; - break; - } - case SystemValueSemanticName::ViewportArrayIndex: - { - result.metalSystemValueName = toSlice("viewport_array_index"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt16)); - break; - } - case SystemValueSemanticName::Target: - { - result.metalSystemValueName = - (StringBuilder() - << "color(" << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0")) - << ")") - .produceString(); - result.permittedTypes = getPermittedTypes_sv_target(builder); - - break; - } - case SystemValueSemanticName::StartVertexLocation: - { - result.metalSystemValueName = toSlice("base_vertex"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - break; - } - case SystemValueSemanticName::StartInstanceLocation: - { - result.metalSystemValueName = toSlice("base_instance"); - result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); - break; - } - default: - m_sink->diagnose( - parentVar, - Diagnostics::unimplementedSystemValueSemantic, - semanticName); - return result; - } - return result; - } - - void reportUnsupportedSystemAttribute(IRInst* param, String semanticName) - { - m_sink->diagnose( - param->sourceLoc, - Diagnostics::systemValueAttributeNotSupported, - semanticName); - } - - void ensureResultStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout) - { - // Ensure each field in an output struct type has either a system semantic or a user - // semantic, so that signature matching can happen correctly. - auto typeLayout = as(varLayout->getTypeLayout()); - Index index = 0; - IRBuilder builder(structType); - for (auto field : structType->getFields()) - { - auto key = field->getKey(); - if (auto semanticDecor = key->findDecoration()) - { - if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) - { - auto indexAsString = String(UInt(semanticDecor->getSemanticIndex())); - auto sysValInfo = - getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field); - if (sysValInfo.isUnsupported || sysValInfo.isSpecial) - { - reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName()); - } - else - { - builder.addTargetSystemValueDecoration( - key, - sysValInfo.metalSystemValueName.getUnownedSlice()); - semanticDecor->removeAndDeallocate(); - } - } - index++; - continue; - } - typeLayout->getFieldLayout(index); - auto fieldLayout = typeLayout->getFieldLayout(index); - if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) - { - UInt varOffset = 0; - if (auto varOffsetAttr = - varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) - varOffset = varOffsetAttr->getOffset(); - varOffset += offsetAttr->getOffset(); - builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset); - } - index++; - } - } - - // Stores a hicharchy of members and children which map 'oldStruct->member' to - // 'flatStruct->member' Note: this map assumes we map to FlatStruct since it is easier/faster to - // process - struct MapStructToFlatStruct - { - /* - We need a hicharchy map to resolve dependencies for mapping - oldStruct to newStruct efficently. Example: - - MyStruct - | - / | \ - / | \ - / | \ - M0 M1 M2 - | | | - A_0 A_0 B_0 - - Without storing hicharchy information, there will be no way to tell apart - `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField - only has 1 instance of `A::A0` - */ - - enum CopyOptions : int - { - // Copy a flattened-struct into a struct - FlatStructIntoStruct = 0, - - // Copy a struct into a flattened-struct - StructIntoFlatStruct = 1, - }; - - private: - // Children of member if applicable. - Dictionary members; - - // Field correlating to MapStructToFlatStruct Node. - IRInst* node; - IRStructKey* getKey() - { - SLANG_ASSERT(as(node)); - return as(node)->getKey(); - } - IRInst* getNode() { return node; } - IRType* getFieldType() - { - SLANG_ASSERT(as(node)); - return as(node)->getFieldType(); - } - - // Whom node maps to inside target flatStruct - IRStructField* targetMapping; - - auto begin() { return members.begin(); } - auto end() { return members.end(); } - - // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to - // members in val2 using `MapStructToFlatStruct` - template - static void _emitCopy( - IRBuilder& builder, - IRInst* val1, - IRStructType* type1, - IRInst* val2, - IRStructType* type2, - MapStructToFlatStruct& node) - { - for (auto& field1Pair : node) - { - auto& field1 = field1Pair.second; - - // Get member of val1 - IRInst* fieldAddr1 = nullptr; - if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) - { - fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey()); - } - else - { - if (as(val1)) - val1 = builder.emitLoad(val1); - fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey()); - } - - // If val1 is a struct, recurse - if (auto fieldAsStruct1 = as(field1.getFieldType())) - { - _emitCopy( - builder, - fieldAddr1, - fieldAsStruct1, - val2, - type2, - field1); - continue; - } - - // Get member of val2 which maps to val1.member - auto field2 = field1.getMapping(); - SLANG_ASSERT(field2); - IRInst* fieldAddr2 = nullptr; - if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) - { - if (as(val2)) - val2 = builder.emitLoad(val1); - fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey()); - } - else - { - fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey()); - } - - // Copy val2/val1 member into val1/val2 member - if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) - { - builder.emitStore(fieldAddr1, fieldAddr2); - } - else - { - builder.emitStore(fieldAddr2, fieldAddr1); - } - } - } - - public: - void setNode(IRInst* newNode) { node = newNode; } - // Get 'MapStructToFlatStruct' that is a child of 'parent'. - // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'. - MapStructToFlatStruct& getMember(IRStructField* member) { return members[member]; } - MapStructToFlatStruct& operator[](IRStructField* member) { return getMember(member); } - - void setMapping(IRStructField* newTargetMapping) { targetMapping = newTargetMapping; } - // Get 'MapStructToFlatStruct' that is a child of 'parent'. - // Return nullptr if no member is mapped to 'parent' - IRStructField* getMapping() { return targetMapping; } - - // Copies srcVal into dstVal using hicharchy map. - template - void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal) - { - auto dstType = dstVal->getDataType(); - if (auto dstPtrType = as(dstType)) - dstType = dstPtrType->getValueType(); - auto dstStructType = as(dstType); - SLANG_ASSERT(dstStructType); - - auto srcType = srcVal->getDataType(); - if (auto srcPtrType = as(srcType)) - srcType = srcPtrType->getValueType(); - auto srcStructType = as(srcType); - SLANG_ASSERT(srcStructType); - - if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) - { - // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a - // struct - SLANG_ASSERT(node == dstStructType); - _emitCopy( - builder, - dstVal, - dstStructType, - srcVal, - srcStructType, - *this); - } - else - { - // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct - SLANG_ASSERT(node == srcStructType); - _emitCopy( - builder, - srcVal, - srcStructType, - dstVal, - dstStructType, - *this); - } - } - }; - - IRStructType* _flattenNestedStructs( - IRBuilder& builder, - IRStructType* dst, - IRStructType* src, - IRSemanticDecoration* parentSemanticDecoration, - IRLayoutDecoration* parentLayout, - MapStructToFlatStruct& mapFieldToField, - HashSet& varsWithSemanticInfo) - { - // For all fields ('oldField') of a struct do the following: - // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration, - // IRLayoutDecoration), store these if found. - // * Do not propagate semantic info if the current node has *any* form of semantic - // information. - // Update varsWithSemanticInfo. - // 2. If IRStructType: - // 2a. Recurse this function with 'decorations that carry semantic info' from parent. - // 3. If not IRStructType: - // 3a. Emit 'newField' equal to 'oldField', add 'decorations which carry semantic info'. - // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is - // needed to copy between types. - for (auto oldField : src->getFields()) - { - auto& fieldMappingNode = mapFieldToField[oldField]; - fieldMappingNode.setNode(oldField); - - // step 1 - bool foundSemanticDecor = false; - auto oldKey = oldField->getKey(); - IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration; - if (auto oldSemanticDecoration = oldKey->findDecoration()) - { - foundSemanticDecor = true; - fieldSemanticDecoration = oldSemanticDecoration; - parentLayout = nullptr; - } - - IRLayoutDecoration* fieldLayout = parentLayout; - if (auto oldLayout = oldKey->findDecoration()) - { - fieldLayout = oldLayout; - if (!foundSemanticDecor) - fieldSemanticDecoration = nullptr; - } - if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout) - varsWithSemanticInfo.add(oldField); - - // step 2a - if (auto structFieldType = as(oldField->getFieldType())) - { - _flattenNestedStructs( - builder, - dst, - structFieldType, - fieldSemanticDecoration, - fieldLayout, - fieldMappingNode, - varsWithSemanticInfo); - continue; - } - - // step 3a - auto newKey = builder.createStructKey(); - copyNameHintAndDebugDecorations(newKey, oldKey); - - auto newField = builder.createStructField(dst, newKey, oldField->getFieldType()); - copyNameHintAndDebugDecorations(newField, oldField); - - if (fieldSemanticDecoration) - builder.addSemanticDecoration( - newKey, - fieldSemanticDecoration->getSemanticName(), - fieldSemanticDecoration->getSemanticIndex()); - - if (fieldLayout) - { - IRLayout* oldLayout = fieldLayout->getLayout(); - List instToCopy; - // Only copy certain decorations needed for resolving system semantics - for (UInt i = 0; i < oldLayout->getOperandCount(); i++) - { - auto operand = oldLayout->getOperand(i); - if (as(operand) || as(operand) || - as(operand) || as(operand)) - instToCopy.add(operand); - } - IRVarLayout* newLayout = builder.getVarLayout(instToCopy); - builder.addLayoutDecoration(newKey, newLayout); - } - // step 3b - fieldMappingNode.setMapping(newField); - } - - return dst; - } - - // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there - // was no struct flattening. - // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct - // `IRStructFields`s - IRStructType* maybeFlattenNestedStructs( - IRBuilder& builder, - IRStructType* src, - MapStructToFlatStruct& mapFieldToField, - HashSet& varsWithSemanticInfo) - { - // Find all values inside struct that need flattening and legalization. - bool hasStructTypeMembers = false; - for (auto field : src->getFields()) - { - if (as(field->getFieldType())) - { - hasStructTypeMembers = true; - break; - } - } - if (!hasStructTypeMembers) - return src; - - // We need to: - // 1. Make new struct 1:1 with old struct but without nestested structs (flatten) - // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be - // handled later). - // 3. Store the mapping from old to new struct fields to allow copying a old-struct to - // new-struct. - builder.setInsertAfter(src); - auto newStruct = builder.createStructType(); - copyNameHintAndDebugDecorations(newStruct, src); - mapFieldToField.setNode(src); - return _flattenNestedStructs( - builder, - newStruct, - src, - nullptr, - nullptr, - mapFieldToField, - varsWithSemanticInfo); - } - - // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'. - // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function. - template - void _replaceAllReturnInst( - IRBuilder& builder, - IRFunc* targetFunc, - IRStructType* newType, - CopyLogicFunc copyLogicFunc) - { - for (auto block : targetFunc->getBlocks()) - { - if (auto returnInst = as(block->getTerminator())) - { - builder.setInsertBefore(returnInst); - auto returnVal = returnInst->getVal(); - returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal)); - } - } - } - - UInt _returnNonOverlappingAttributeIndex(std::set& usedSemanticIndex) - { - // Find first unused semantic index of equal semantic type - // to fill any gaps in user set semantic bindings - UInt prev = 0; - for (auto i : usedSemanticIndex) - { - if (i > prev + 1) - { - break; - } - prev = i; - } - usedSemanticIndex.insert(prev + 1); - return prev + 1; - } - - template - struct AttributeParentPair - { - IRLayoutDecoration* layoutDecor; - T* attr; - }; - - IRLayoutDecoration* _replaceAttributeOfLayout( - IRBuilder& builder, - IRLayoutDecoration* parentLayoutDecor, - IRInst* instToReplace, - IRInst* instToReplaceWith) - { - // Replace `instToReplace` with a `instToReplaceWith` - - auto layout = parentLayoutDecor->getLayout(); - // Find the exact same decoration `instToReplace` in-case multiple of the same type exist - List opList; - opList.add(instToReplaceWith); - for (UInt i = 0; i < layout->getOperandCount(); i++) - { - if (layout->getOperand(i) != instToReplace) - opList.add(layout->getOperand(i)); - } - auto newLayoutDecor = builder.addLayoutDecoration( - parentLayoutDecor->getParent(), - builder.getVarLayout(opList)); - parentLayoutDecor->removeAndDeallocate(); - return newLayoutDecor; - } - - IRLayoutDecoration* _simplifyUserSemanticNames( - IRBuilder& builder, - IRLayoutDecoration* layoutDecor) - { - // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into - // ("SV_TARGET", 0) using 'IRUserSemanticAttr' This is done to ensure we can check semantic - // groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()' - SLANG_ASSERT(layoutDecor); - auto layout = layoutDecor->getLayout(); - List layoutOps; - layoutOps.reserve(3); - bool changed = false; - for (auto attr : layout->getAllAttrs()) - { - if (auto userSemantic = as(attr)) - { - UnownedStringSlice outName; - UnownedStringSlice outIndex; - bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex); - if (hasStringIndex) - { - changed = true; - auto loweredName = String(outName).toLower(); - auto loweredNameSlice = loweredName.getUnownedSlice(); - auto newDecoration = - builder.getUserSemanticAttr(loweredNameSlice, stringToInt(outIndex)); - userSemantic->replaceUsesWith(newDecoration); - userSemantic->removeAndDeallocate(); - userSemantic = newDecoration; - } - layoutOps.add(userSemantic); - continue; - } - layoutOps.add(attr); - } - if (changed) - { - auto parent = layoutDecor->parent; - layoutDecor->removeAndDeallocate(); - builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps)); - } - return layoutDecor; - } - // Find overlapping field semantics and legalize them - void fixFieldSemanticsOfFlatStruct(IRStructType* structType) - { - // Goal is to ensure we do not have overlapping semantics: - /* - // Assume the following code - struct Fragment - { - float4 p1 : SV_TARGET; - float3 p2 : SV_TARGET; - float2 p3 : SV_TARGET; - float2 p4 : SV_TARGET; - }; - - // Translates into - struct Fragment - { - float4 p1 : SV_TARGET0; - float3 p2 : SV_TARGET1; - float2 p3 : SV_TARGET2; - float2 p4 : SV_TARGET3; - }; - */ - - IRBuilder builder(this->m_module); - - List overlappingSemanticsDecor; - Dictionary>> - usedSemanticIndexSemanticDecor; - - List> overlappingVarOffset; - Dictionary>> usedSemanticIndexVarOffset; - - List> overlappingUserSemantic; - Dictionary>> - usedSemanticIndexUserSemantic; - - // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when - // legalizing we may destroy and remake a `IRLayoutDecoration*` - Dictionary oldLayoutDecorToNew; - - // Collect all "semantic info carrying decorations". Any collected decoration will - // fill up their respective 'Dictionary>' - // to keep track of in-use offsets for a semantic type. - // Example: IRSemanticDecoration with name of "SV_TARGET1". - // * This will have SEMANTIC_TYPE of "sv_target". - // * This will use up index '1' - // - // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to - // a list of 'overlapping semantic info decorations' so we can legalize this - // 'semantic info decoration' later. - // - // NOTE: this is a flat struct, all members are children of the initial - // IRStructType. - for (auto field : structType->getFields()) - { - auto key = field->getKey(); - if (auto semanticDecoration = key->findDecoration()) - { - // Ensure names are in a uniform lowercase format so we can bunch together simmilar - // semantics - UnownedStringSlice outName; - UnownedStringSlice outIndex; - bool hasStringIndex = - splitNameAndIndex(semanticDecoration->getSemanticName(), outName, outIndex); - if (hasStringIndex) - { - auto loweredName = String(outName).toLower(); - auto loweredNameSlice = loweredName.getUnownedSlice(); - auto newDecoration = - builder.addSemanticDecoration(key, loweredNameSlice, stringToInt(outIndex)); - semanticDecoration->replaceUsesWith(newDecoration); - semanticDecoration->removeAndDeallocate(); - semanticDecoration = newDecoration; - } - auto& semanticUse = - usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()]; - if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end()) - overlappingSemanticsDecor.add(semanticDecoration); - else - semanticUse.insert(semanticDecoration->getSemanticIndex()); - } - if (auto layoutDecor = key->findDecoration()) - { - // Ensure names are in a uniform lowercase format so we can bunch together simmilar - // semantics - layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor); - oldLayoutDecorToNew[layoutDecor] = layoutDecor; - auto layout = layoutDecor->getLayout(); - for (auto attr : layout->getAllAttrs()) - { - if (auto offset = as(attr)) - { - auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()]; - if (semanticUse.find(offset->getOffset()) != semanticUse.end()) - overlappingVarOffset.add({layoutDecor, offset}); - else - semanticUse.insert(offset->getOffset()); - } - else if (auto userSemantic = as(attr)) - { - auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()]; - if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end()) - overlappingUserSemantic.add({layoutDecor, userSemantic}); - else - semanticUse.insert(userSemantic->getIndex()); - } - } - } - } - - // Legalize all overlapping 'semantic info decorations' - for (auto decor : overlappingSemanticsDecor) - { - auto newOffset = _returnNonOverlappingAttributeIndex( - usedSemanticIndexSemanticDecor[decor->getSemanticName()]); - builder.addSemanticDecoration( - decor->getParent(), - decor->getSemanticName(), - (int)newOffset); - decor->removeAndDeallocate(); - } - for (auto& varOffset : overlappingVarOffset) - { - auto newOffset = _returnNonOverlappingAttributeIndex( - usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]); - auto newVarOffset = builder.getVarOffsetAttr( - varOffset.attr->getResourceKind(), - newOffset, - varOffset.attr->getSpace()); - oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout( - builder, - oldLayoutDecorToNew[varOffset.layoutDecor], - varOffset.attr, - newVarOffset); - } - for (auto& userSemantic : overlappingUserSemantic) - { - auto newOffset = _returnNonOverlappingAttributeIndex( - usedSemanticIndexUserSemantic[userSemantic.attr->getName()]); - auto newUserSemantic = - builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset); - oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout( - builder, - oldLayoutDecorToNew[userSemantic.layoutDecor], - userSemantic.attr, - newUserSemantic); - } - } - - void wrapReturnValueInStruct(EntryPointInfo entryPoint) - { - // Wrap return value into a struct if it is not already a struct. - // For example, given this entry point: - // ``` - // float4 main() : SV_Target { return float3(1,2,3); } - // ``` - // We are going to transform it into: - // ``` - // struct Output { - // float4 value : SV_Target; - // }; - // Output main() { return {float3(1,2,3)}; } - - auto func = entryPoint.entryPointFunc; - - auto returnType = func->getResultType(); - if (as(returnType)) - return; - auto entryPointLayoutDecor = func->findDecoration(); - if (!entryPointLayoutDecor) - return; - auto entryPointLayout = as(entryPointLayoutDecor->getLayout()); - if (!entryPointLayout) - return; - auto resultLayout = entryPointLayout->getResultLayout(); - - // If return type is already a struct, just make sure every field has a semantic. - if (auto returnStructType = as(returnType)) - { - IRBuilder builder(func); - MapStructToFlatStruct mapOldFieldToNewField; - // Flatten result struct type to ensure we do not have nested semantics - auto flattenedStruct = maybeFlattenNestedStructs( - builder, - returnStructType, - mapOldFieldToNewField, - semanticInfoToRemove); - if (returnStructType != flattenedStruct) - { - // Replace all return-values with the flattenedStruct we made. - _replaceAllReturnInst( - builder, - func, - flattenedStruct, - [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* - { - auto srcStructType = as(srcVal->getDataType()); - SLANG_ASSERT(srcStructType); - auto dstVal = copyBuilder.emitVar(dstType); - mapOldFieldToNewField.emitCopy<( - int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>( - copyBuilder, - dstVal, - srcVal); - return builder.emitLoad(dstVal); - }); - fixUpFuncType(func, flattenedStruct); - } - // Ensure non-overlapping semantics - fixFieldSemanticsOfFlatStruct(flattenedStruct); - ensureResultStructHasUserSemantic(flattenedStruct, resultLayout); - return; - } - - IRBuilder builder(func); - builder.setInsertBefore(func); - IRStructType* structType = builder.createStructType(); - auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); - builder.addNameHintDecoration( - structType, - (String(stageText) + toSlice("Output")).getUnownedSlice()); - auto key = builder.createStructKey(); - builder.addNameHintDecoration(key, toSlice("output")); - builder.addLayoutDecoration(key, resultLayout); - builder.createStructField(structType, key, returnType); - IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder); - structTypeLayoutBuilder.addField(key, resultLayout); - auto typeLayout = structTypeLayoutBuilder.build(); - IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); - auto varLayout = varLayoutBuilder.build(); - ensureResultStructHasUserSemantic(structType, varLayout); - - _replaceAllReturnInst( - builder, - func, - structType, - [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* - { return copyBuilder.emitMakeStruct(dstType, 1, &srcVal); }); - - // Assign an appropriate system value semantic for stage output - auto stage = entryPoint.entryPointDecor->getProfile().getStage(); - switch (stage) - { - case Stage::Compute: - case Stage::Fragment: - { - builder.addTargetSystemValueDecoration(key, toSlice("color(0)")); - break; - } - case Stage::Vertex: - { - builder.addTargetSystemValueDecoration(key, toSlice("position")); - break; - } - default: - SLANG_ASSERT(false); - return; - } - - fixUpFuncType(func, structType); - } - - void legalizeMeshEntryPoint(EntryPointInfo entryPoint) - { - auto func = entryPoint.entryPointFunc; - - IRBuilder builder{func->getModule()}; - for (auto param : func->getParams()) - { - if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) - { - IRVarLayout::Builder varLayoutBuilder( - &builder, - IRTypeLayout::Builder{&builder}.build()); - - varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); - auto paramVarLayout = varLayoutBuilder.build(); - builder.addLayoutDecoration(param, paramVarLayout); - - IRPtrTypeBase* type = as(param->getDataType()); - - const auto annotatedPayloadType = builder.getPtrType( - kIROp_ConstRefType, - type->getValueType(), - AddressSpace::MetalObjectData); - - param->setFullType(annotatedPayloadType); - } - } - IROutputTopologyDecoration* outputDeco = - entryPoint.entryPointFunc->findDecoration(); - if (outputDeco == nullptr) - { - SLANG_UNEXPECTED("Mesh shader output decoration missing"); - return; - } - const auto topology = outputDeco->getTopology(); - const auto topStr = topology->getStringSlice(); - UInt topologyEnum = 0; - if (topStr.caseInsensitiveEquals(toSlice("point"))) - { - topologyEnum = 1; - } - else if (topStr.caseInsensitiveEquals(toSlice("line"))) - { - topologyEnum = 2; - } - else if (topStr.caseInsensitiveEquals(toSlice("triangle"))) - { - topologyEnum = 3; - } - else - { - SLANG_UNEXPECTED("unknown topology"); - return; - } - - IRInst* topologyConst = builder.getIntValue(builder.getIntType(), topologyEnum); - - IRType* vertexType = nullptr; - IRType* indicesType = nullptr; - IRType* primitiveType = nullptr; - - IRInst* maxVertices = nullptr; - IRInst* maxPrimitives = nullptr; - - IRInst* verticesParam = nullptr; - IRInst* indicesParam = nullptr; - IRInst* primitivesParam = nullptr; - for (auto param : func->getParams()) - { - if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) - { - IRVarLayout::Builder varLayoutBuilder( - &builder, - IRTypeLayout::Builder{&builder}.build()); - - varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); - auto paramVarLayout = varLayoutBuilder.build(); - builder.addLayoutDecoration(param, paramVarLayout); - } - if (param->findDecorationImpl(kIROp_VerticesDecoration)) - { - auto vertexRefType = as(param->getDataType()); - auto vertexOutputType = as(vertexRefType->getValueType()); - vertexType = vertexOutputType->getElementType(); - maxVertices = vertexOutputType->getMaxElementCount(); - SLANG_ASSERT(vertexType); - - verticesParam = param; - auto vertStruct = as(vertexType); - for (auto field : vertStruct->getFields()) - { - auto key = field->getKey(); - if (auto deco = key->findDecoration()) - { - if (deco->getSemanticName().caseInsensitiveEquals(toSlice("sv_position"))) - { - builder.addTargetSystemValueDecoration(key, toSlice("position")); - } - } - } - } - if (param->findDecorationImpl(kIROp_IndicesDecoration)) - { - auto indicesRefType = (IRConstRefType*)param->getDataType(); - auto indicesOutputType = (IRIndicesType*)indicesRefType->getValueType(); - indicesType = indicesOutputType->getElementType(); - maxPrimitives = indicesOutputType->getMaxElementCount(); - SLANG_ASSERT(indicesType); - - indicesParam = param; - } - if (param->findDecorationImpl(kIROp_PrimitivesDecoration)) - { - auto primitivesRefType = (IRConstRefType*)param->getDataType(); - auto primitivesOutputType = (IRPrimitivesType*)primitivesRefType->getValueType(); - primitiveType = primitivesOutputType->getElementType(); - SLANG_ASSERT(primitiveType); - - primitivesParam = param; - auto primStruct = as(primitiveType); - for (auto field : primStruct->getFields()) - { - auto key = field->getKey(); - if (auto deco = key->findDecoration()) - { - if (deco->getSemanticName().caseInsensitiveEquals( - toSlice("sv_primitiveid"))) - { - builder.addTargetSystemValueDecoration(key, toSlice("primitive_id")); - } - } - } - } - } - if (primitiveType == nullptr) - { - primitiveType = builder.getVoidType(); - } - builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); - - auto meshParam = builder.emitParam(builder.getMetalMeshType( - vertexType, - primitiveType, - maxVertices, - maxPrimitives, - topologyConst)); - builder.addExternCppDecoration(meshParam, toSlice("_slang_mesh")); - - - verticesParam->removeFromParent(); - verticesParam->removeAndDeallocate(); - - indicesParam->removeFromParent(); - indicesParam->removeAndDeallocate(); - - if (primitivesParam != nullptr) - { - primitivesParam->removeFromParent(); - primitivesParam->removeAndDeallocate(); - } - } - - void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint) - { - // Find out DispatchMesh function - IRGlobalValueWithCode* dispatchMeshFunc = nullptr; - for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts()) - { - if (const auto func = as(globalInst)) - { - if (const auto dec = func->findDecoration()) - { - if (dec->getName() == "DispatchMesh") - { - SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found"); - dispatchMeshFunc = func; - } - } - } - } - - if (!dispatchMeshFunc) - return; - - IRBuilder builder{entryPoint.entryPointFunc->getModule()}; - - // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid - traverseUses( - dispatchMeshFunc, - [&](const IRUse* use) - { - if (const auto call = as(use->getUser())) - { - SLANG_ASSERT(call->getArgCount() == 4); - const auto payload = call->getArg(3); - - const auto payloadPtrType = - composeGetters(payload, &IRInst::getDataType); - SLANG_ASSERT(payloadPtrType); - const auto payloadType = payloadPtrType->getValueType(); - SLANG_ASSERT(payloadType); - - builder.setInsertBefore( - entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); - const auto annotatedPayloadType = builder.getPtrType( - kIROp_RefType, - payloadPtrType->getValueType(), - AddressSpace::MetalObjectData); - auto packedParam = builder.emitParam(annotatedPayloadType); - builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload")); - IRVarLayout::Builder varLayoutBuilder( - &builder, - IRTypeLayout::Builder{&builder}.build()); - - // Add the MetalPayload resource info, so we can emit [[payload]] - varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); - auto paramVarLayout = varLayoutBuilder.build(); - builder.addLayoutDecoration(packedParam, paramVarLayout); - - // Now we replace the call to DispatchMesh with a call to the mesh grid - // properties But first we need to create the parameter - const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType(); - auto mgp = builder.emitParam(meshGridPropertiesType); - builder.addExternCppDecoration(mgp, toSlice("_slang_mgp")); - } - }); - } - - IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType) - { - auto fromType = val->getFullType(); - if (auto fromVector = as(fromType)) - { - if (auto toVector = as(toType)) - { - if (fromVector->getElementCount() != toVector->getElementCount()) - { - fromType = builder.getVectorType( - fromVector->getElementType(), - toVector->getElementCount()); - val = builder.emitVectorReshape(fromType, val); - } - } - else if (as(toType)) - { - UInt index = 0; - val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); - if (toType->getOp() == kIROp_VoidType) - return nullptr; - } - } - else if (auto fromBasicType = as(fromType)) - { - if (fromBasicType->getOp() == kIROp_VoidType) - return nullptr; - if (!as(toType)) - return nullptr; - if (toType->getOp() == kIROp_VoidType) - return nullptr; - } - else - { - return nullptr; - } - return builder.emitCast(toType, val); - } - - struct SystemValLegalizationWorkItem - { - IRInst* var; - String attrName; - UInt attrIndex; - }; - - std::optional tryToMakeSystemValWorkItem(IRInst* var) - { - if (auto semanticDecoration = var->findDecoration()) - { - if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) - { - return { - {var, - String(semanticDecoration->getSemanticName()).toLower(), - (UInt)semanticDecoration->getSemanticIndex()}}; - } - } - - auto layoutDecor = var->findDecoration(); - if (!layoutDecor) - return {}; - auto sysValAttr = layoutDecor->findAttr(); - if (!sysValAttr) - return {}; - auto semanticName = String(sysValAttr->getName()); - auto sysAttrIndex = sysValAttr->getIndex(); - - return {{var, semanticName, sysAttrIndex}}; - } - - - List collectSystemValFromEntryPoint(EntryPointInfo entryPoint) - { - List systemValWorkItems; - for (auto param : entryPoint.entryPointFunc->getParams()) - { - auto maybeWorkItem = tryToMakeSystemValWorkItem(param); - if (maybeWorkItem.has_value()) - systemValWorkItems.add(std::move(maybeWorkItem.value())); - } - return systemValWorkItems; - } - - void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem) - { - IRBuilder builder(entryPoint.entryPointFunc); - - auto var = workItem.var; - auto semanticName = workItem.attrName; - - auto indexAsString = String(workItem.attrIndex); - auto info = getSystemValueInfo(semanticName, &indexAsString, var); - - if (info.isSpecial) - { - if (info.metalSystemValueNameEnum == SystemValueSemanticName::InnerCoverage) - { - // Metal does not support conservative rasterization, so this is always false. - auto val = builder.getBoolValue(false); - var->replaceUsesWith(val); - var->removeAndDeallocate(); - } - else if (info.metalSystemValueNameEnum == SystemValueSemanticName::GroupIndex) - { - // Ensure we have a cached "sv_groupthreadid" in our entry point - if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc)) - { - auto systemValWorkItems = collectSystemValFromEntryPoint(entryPoint); - for (auto i : systemValWorkItems) - { - auto indexAsStringGroupThreadId = String(i.attrIndex); - if (getSystemValueInfo(i.attrName, &indexAsStringGroupThreadId, i.var) - .metalSystemValueNameEnum == SystemValueSemanticName::GroupThreadID) - { - entryPointToGroupThreadId[entryPoint.entryPointFunc] = i.var; - } - } - if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc)) - { - // Add the missing groupthreadid needed to compute sv_groupindex - IRBuilder groupThreadIdBuilder(builder); - groupThreadIdBuilder.setInsertInto( - entryPoint.entryPointFunc->getFirstBlock()); - auto groupThreadId = groupThreadIdBuilder.emitParamAtHead( - getGroupThreadIdType(groupThreadIdBuilder)); - entryPointToGroupThreadId[entryPoint.entryPointFunc] = groupThreadId; - groupThreadIdBuilder.addNameHintDecoration( - groupThreadId, - groupThreadIDString); - - // Since "sv_groupindex" will be translated out to a global var and no - // longer be considered a system value we can reuse its layout and semantic - // info - Index foundRequiredDecorations = 0; - IRLayoutDecoration* layoutDecoration = nullptr; - UInt semanticIndex = 0; - for (auto decoration : var->getDecorations()) - { - if (auto layoutDecorationTmp = as(decoration)) - { - layoutDecoration = layoutDecorationTmp; - foundRequiredDecorations++; - } - else if (auto semanticDecoration = as(decoration)) - { - semanticIndex = semanticDecoration->getSemanticIndex(); - groupThreadIdBuilder.addSemanticDecoration( - groupThreadId, - groupThreadIDString, - (int)semanticIndex); - foundRequiredDecorations++; - } - if (foundRequiredDecorations >= 2) - break; - } - SLANG_ASSERT(layoutDecoration); - layoutDecoration->removeFromParent(); - layoutDecoration->insertAtStart(groupThreadId); - SystemValLegalizationWorkItem newWorkItem = { - groupThreadId, - groupThreadIDString, - semanticIndex}; - legalizeSystemValue(entryPoint, newWorkItem); - } - } - - IRBuilder svBuilder(builder.getModule()); - svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); - auto computeExtent = emitCalcGroupExtents( - svBuilder, - entryPoint.entryPointFunc, - builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 3))); - auto groupIndexCalc = emitCalcGroupIndex( - svBuilder, - entryPointToGroupThreadId[entryPoint.entryPointFunc], - computeExtent); - svBuilder.addNameHintDecoration( - groupIndexCalc, - UnownedStringSlice("sv_groupindex")); - - var->replaceUsesWith(groupIndexCalc); - var->removeAndDeallocate(); - } - } - if (info.isUnsupported) - { - reportUnsupportedSystemAttribute(var, semanticName); - return; - } - if (!info.permittedTypes.getCount()) - return; - - builder.addTargetSystemValueDecoration(var, info.metalSystemValueName.getUnownedSlice()); - - bool varTypeIsPermitted = false; - auto varType = var->getFullType(); - for (auto& permittedType : info.permittedTypes) - { - varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; - } - - if (!varTypeIsPermitted) - { - // Note: we do not currently prefer any conversion - // example: - // * allowed types for semantic: `float4`, `uint4`, `int4` - // * user used, `float2` - // * Slang will equally prefer `float4` to `uint4` to `int4`. - // This means the type may lose data if slang selects `uint4` or `int4`. - bool foundAConversion = false; - for (auto permittedType : info.permittedTypes) - { - var->setFullType(permittedType); - builder.setInsertBefore( - entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); - - // get uses before we `tryConvertValue` since this creates a new use - List uses; - for (auto use = var->firstUse; use; use = use->nextUse) - uses.add(use); - - auto convertedValue = tryConvertValue(builder, var, varType); - if (convertedValue == nullptr) - continue; - - foundAConversion = true; - copyNameHintAndDebugDecorations(convertedValue, var); - - for (auto use : uses) - builder.replaceOperand(use, convertedValue); - } - if (!foundAConversion) - { - // If we can't convert the value, report an error. - for (auto permittedType : info.permittedTypes) - { - StringBuilder typeNameSB; - getTypeNameHint(typeNameSB, permittedType); - m_sink->diagnose( - var->sourceLoc, - Diagnostics::systemValueTypeIncompatible, - semanticName, - typeNameSB.produceString()); - } - } - } - } - - void legalizeSystemValueParameters(EntryPointInfo entryPoint) - { - List systemValWorkItems = - collectSystemValFromEntryPoint(entryPoint); - - for (auto index = 0; index < systemValWorkItems.getCount(); index++) - { - legalizeSystemValue(entryPoint, systemValWorkItems[index]); - } - fixUpFuncType(entryPoint.entryPointFunc); - } - - void legalizeEntryPointForMetal(EntryPointInfo entryPoint) - { - // Input Parameter Legalize - depointerizeInputParams(entryPoint.entryPointFunc); - hoistEntryPointParameterFromStruct(entryPoint); - packStageInParameters(entryPoint); - flattenInputParameters(entryPoint); - - // System Value Legalize - legalizeSystemValueParameters(entryPoint); - - // Output Value Legalize - wrapReturnValueInStruct(entryPoint); - - // Other Legalize - switch (entryPoint.entryPointDecor->getProfile().getStage()) - { - case Stage::Amplification: - legalizeDispatchMeshPayloadForMetal(entryPoint); - break; - case Stage::Mesh: - legalizeMeshEntryPoint(entryPoint); - break; - default: - break; - } - } -}; - // metal textures only support writing 4-component values, even if the texture is only 1, 2, or // 3-component in this case the other channels get ignored, but the signature still doesnt match so // now we have to replace the value being written with a 4-component vector where the new components @@ -2173,10 +233,7 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) } } - LegalizeMetalEntryPointContext context(sink, module); - for (auto entryPoint : entryPoints) - context.legalizeEntryPointForMetal(entryPoint); - context.removeSemanticLayoutsFromLegalizedStructs(); + legalizeEntryPointVaryingParamsForMetal(module, sink, entryPoints); MetalAddressSpaceAssigner metalAddressSpaceAssigner; specializeAddressSpace(module, &metalAddressSpaceAssigner); diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index effc06f3ef..efa028703c 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -4,1537 +4,169 @@ #include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-varying-params.h" -#include "slang-ir-util.h" #include "slang-ir.h" -#include "slang-parameter-binding.h" - -#include namespace Slang { -struct EntryPointInfo -{ - IRFunc* entryPointFunc; - IREntryPointDecoration* entryPointDecor; -}; - -struct LegalizeWGSLEntryPointContext +static void legalizeCall(IRCall* call) { - HashSet semanticInfoToRemove; - UnownedStringSlice userSemanticName = toSlice("user_semantic"); - - DiagnosticSink* m_sink; - IRModule* m_module; - - LegalizeWGSLEntryPointContext(DiagnosticSink* sink, IRModule* module) - : m_sink(sink), m_module(module) - { - } - - void removeSemanticLayoutsFromLegalizedStructs() - { - // WGSL does not allow duplicate attributes to appear in the same shader. - // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]` - // must be removed. - for (auto field : semanticInfoToRemove) - { - auto key = field->getKey(); - // Some decorations appear twice, destroy all found - for (;;) - { - if (auto semanticDecor = key->findDecoration()) - { - semanticDecor->removeAndDeallocate(); - continue; - } - else if (auto layoutDecor = key->findDecoration()) - { - layoutDecor->removeAndDeallocate(); - continue; - } - break; - } - } - } - - // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct - void flattenInputParameters(EntryPointInfo entryPoint) - { - // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members). - /* - // Assume the following code - struct NestedFragment - { - float2 p3; - }; - struct Fragment - { - float4 p1; - float3 p2; - NestedFragment p3_nested; - }; - - // Fragment flattens into - struct Fragment - { - float4 p1; - float3 p2; - float2 p3; - }; - */ - - // This is important since WGSL does not allow semantic's on a struct - /* - // Assume the following code - struct NestedFragment1 - { - float2 p3; - }; - struct Fragment1 - { - float4 p1 : SV_TARGET0; - float3 p2 : SV_TARGET1; - NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct - }; - - */ - - // Unlike Metal, WGSL does NOT allow semantics on members of a nested struct. - /* - // Assume the following code - struct NestedFragment - { - float2 p3; - }; - struct Fragment - { - float4 p1 : SV_TARGET0; - NestedFragment p2 : SV_TARGET1; - NestedFragment p3 : SV_TARGET2; - }; - - // Legalized with flattening - struct Fragment - { - float4 p1 : SV_TARGET0; - float2 p2 : SV_TARGET1; - float2 p3 : SV_TARGET2; - }; - */ - - auto func = entryPoint.entryPointFunc; - bool modified = false; - for (auto param : func->getParams()) - { - auto layout = findVarLayout(param); - if (!layout) - continue; - if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) - continue; - if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) - continue; - // If we find a IRParam with a IRStructType member, we need to flatten the entire - // IRParam - if (auto structType = as(param->getDataType())) - { - IRBuilder builder(func); - MapStructToFlatStruct mapOldFieldToNewField; - - // Flatten struct if we have nested IRStructType - auto flattenedStruct = maybeFlattenNestedStructs( - builder, - structType, - mapOldFieldToNewField, - semanticInfoToRemove); - // Validate/rearange all semantics which overlap in our flat struct. - fixFieldSemanticsOfFlatStruct(flattenedStruct); - ensureStructHasUserSemantic( - flattenedStruct, - layout); - if (flattenedStruct != structType) - { - // Replace the 'old IRParam type' with a 'new IRParam type' - param->setFullType(flattenedStruct); - - // Emit a new variable at EntryPoint of 'old IRParam type' - builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto dstVal = builder.emitVar(structType); - auto dstLoad = builder.emitLoad(dstVal); - param->replaceUsesWith(dstLoad); - builder.setInsertBefore(dstLoad); - // Copy the 'new IRParam type' to our 'old IRParam type' - mapOldFieldToNewField - .emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>( - builder, - dstVal, - param); - - modified = true; - } - } - } - if (modified) - fixUpFuncType(func); - } - - struct WGSLSystemValueInfo - { - String wgslSystemValueName; - SystemValueSemanticName wgslSystemValueNameEnum; - ShortList permittedTypes; - bool isUnsupported = false; - WGSLSystemValueInfo() - { - // most commonly need 2 - permittedTypes.reserveOverflowBuffer(2); - } + // WGSL does not allow forming a pointer to a sub part of a composite value. + // For example, if we have + // ``` + // struct S { float x; float y; }; + // void foo(inout float v) { v = 1.0f; } + // void main() { S s; foo(s.x); } + // ``` + // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`. + // And trying to form `&s.x` in WGSL is illegal. + // To work around this, we will create a local variable to hold the sub part of + // the composite value. + // And then pass the local variable to the function. + // After the call, we will write back the local variable to the sub part of the + // composite value. + // + IRBuilder builder(call); + builder.setInsertBefore(call); + struct WritebackPair + { + IRInst* dest; + IRInst* value; }; + ShortList pendingWritebacks; - WGSLSystemValueInfo getSystemValueInfo( - String inSemanticName, - String* optionalSemanticIndex, - IRInst* parentVar) + for (UInt i = 0; i < call->getArgCount(); i++) { - IRBuilder builder(m_module); - WGSLSystemValueInfo result = {}; - UnownedStringSlice semanticName; - UnownedStringSlice semanticIndex; - - auto hasExplicitIndex = - splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex); - if (!hasExplicitIndex && optionalSemanticIndex) - semanticIndex = optionalSemanticIndex->getUnownedSlice(); - - result.wgslSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName); - - switch (result.wgslSystemValueNameEnum) + auto arg = call->getArg(i); + auto ptrType = as(arg->getDataType()); + if (!ptrType) + continue; + switch (arg->getOp()) { - - case SystemValueSemanticName::CullDistance: - { - result.isUnsupported = true; - } - break; - - case SystemValueSemanticName::ClipDistance: - { - // TODO: Implement this based on the 'clip-distances' feature in WGSL - // https: // www.w3.org/TR/webgpu/#dom-gpufeaturename-clip-distances - result.isUnsupported = true; - } - break; - - case SystemValueSemanticName::Coverage: - { - result.wgslSystemValueName = toSlice("sample_mask"); - result.permittedTypes.add(builder.getUIntType()); - } - break; - - case SystemValueSemanticName::Depth: - { - result.wgslSystemValueName = toSlice("frag_depth"); - result.permittedTypes.add(builder.getBasicType(BaseType::Float)); - } - break; - - case SystemValueSemanticName::DepthGreaterEqual: - case SystemValueSemanticName::DepthLessEqual: - { - result.isUnsupported = true; - } - break; - - case SystemValueSemanticName::DispatchThreadID: - { - result.wgslSystemValueName = toSlice("global_invocation_id"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3))); - } - break; - - case SystemValueSemanticName::DomainLocation: - { - result.isUnsupported = true; - } - break; - - case SystemValueSemanticName::GroupID: - { - result.wgslSystemValueName = toSlice("workgroup_id"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3))); - } - break; - - case SystemValueSemanticName::GroupIndex: - { - result.wgslSystemValueName = toSlice("local_invocation_index"); - result.permittedTypes.add(builder.getUIntType()); - } - break; - - case SystemValueSemanticName::GroupThreadID: - { - result.wgslSystemValueName = toSlice("local_invocation_id"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3))); - } - break; - - case SystemValueSemanticName::GSInstanceID: - { - // No Geometry shaders in WGSL - result.isUnsupported = true; - } - break; - - case SystemValueSemanticName::InnerCoverage: - { - result.isUnsupported = true; - } - break; - - case SystemValueSemanticName::InstanceID: - { - result.wgslSystemValueName = toSlice("instance_index"); - result.permittedTypes.add(builder.getUIntType()); - } - break; - - case SystemValueSemanticName::IsFrontFace: - { - result.wgslSystemValueName = toSlice("front_facing"); - result.permittedTypes.add(builder.getBoolType()); - } - break; - - case SystemValueSemanticName::OutputControlPointID: - case SystemValueSemanticName::PointSize: - { - result.isUnsupported = true; - } - break; - - case SystemValueSemanticName::Position: - { - result.wgslSystemValueName = toSlice("position"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::Float), - builder.getIntValue(builder.getIntType(), 4))); - break; - } - - case SystemValueSemanticName::PrimitiveID: - case SystemValueSemanticName::RenderTargetArrayIndex: - { - result.isUnsupported = true; - break; - } - - case SystemValueSemanticName::SampleIndex: - { - result.wgslSystemValueName = toSlice("sample_index"); - result.permittedTypes.add(builder.getUIntType()); - break; - } - - case SystemValueSemanticName::StencilRef: - case SystemValueSemanticName::Target: - case SystemValueSemanticName::TessFactor: - { - result.isUnsupported = true; - break; - } - - case SystemValueSemanticName::VertexID: - { - result.wgslSystemValueName = toSlice("vertex_index"); - result.permittedTypes.add(builder.getUIntType()); - break; - } - - case SystemValueSemanticName::ViewID: - case SystemValueSemanticName::ViewportArrayIndex: - case SystemValueSemanticName::StartVertexLocation: - case SystemValueSemanticName::StartInstanceLocation: - { - result.isUnsupported = true; - break; - } - + case kIROp_Var: + case kIROp_Param: + case kIROp_GlobalParam: + case kIROp_GlobalVar: + continue; default: - { - m_sink->diagnose( - parentVar, - Diagnostics::unimplementedSystemValueSemantic, - semanticName); - return result; - } + break; } - return result; - } + // Create a local variable to hold the input argument. + auto var = builder.emitVar(ptrType->getValueType(), AddressSpace::Function); - void reportUnsupportedSystemAttribute(IRInst* param, String semanticName) - { - m_sink->diagnose( - param->sourceLoc, - Diagnostics::systemValueAttributeNotSupported, - semanticName); + // Store the input argument into the local variable. + builder.emitStore(var, builder.emitLoad(arg)); + builder.replaceOperand(call->getArgs() + i, var); + pendingWritebacks.add({arg, var}); } - template - void ensureStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout) + // Perform writebacks after the call. + builder.setInsertAfter(call); + for (auto& pair : pendingWritebacks) { - // Ensure each field in an output struct type has either a system semantic or a user - // semantic, so that signature matching can happen correctly. - auto typeLayout = as(varLayout->getTypeLayout()); - Index index = 0; - IRBuilder builder(structType); - for (auto field : structType->getFields()) - { - auto key = field->getKey(); - if (auto semanticDecor = key->findDecoration()) - { - if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) - { - auto indexAsString = String(UInt(semanticDecor->getSemanticIndex())); - auto sysValInfo = - getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field); - if (sysValInfo.isUnsupported) - { - reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName()); - } - else - { - builder.addTargetSystemValueDecoration( - key, - sysValInfo.wgslSystemValueName.getUnownedSlice()); - semanticDecor->removeAndDeallocate(); - } - } - index++; - continue; - } - typeLayout->getFieldLayout(index); - auto fieldLayout = typeLayout->getFieldLayout(index); - if (auto offsetAttr = fieldLayout->findOffsetAttr(K)) - { - UInt varOffset = 0; - if (auto varOffsetAttr = varLayout->findOffsetAttr(K)) - varOffset = varOffsetAttr->getOffset(); - varOffset += offsetAttr->getOffset(); - builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset); - } - index++; - } - } - - // Stores a hicharchy of members and children which map 'oldStruct->member' to - // 'flatStruct->member' Note: this map assumes we map to FlatStruct since it is easier/faster to - // process - struct MapStructToFlatStruct - { - /* - We need a hicharchy map to resolve dependencies for mapping - oldStruct to newStruct efficently. Example: - - MyStruct - | - / | \ - / | \ - / | \ - M0 M1 M2 - | | | - A_0 A_0 B_0 - - Without storing hicharchy information, there will be no way to tell apart - `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField - only has 1 instance of `A::A0` - */ - - enum CopyOptions : int - { - // Copy a flattened-struct into a struct - FlatStructIntoStruct = 0, - - // Copy a struct into a flattened-struct - StructIntoFlatStruct = 1, - }; - - private: - // Children of member if applicable. - Dictionary members; - - // Field correlating to MapStructToFlatStruct Node. - IRInst* node; - IRStructKey* getKey() - { - SLANG_ASSERT(as(node)); - return as(node)->getKey(); - } - IRInst* getNode() { return node; } - IRType* getFieldType() - { - SLANG_ASSERT(as(node)); - return as(node)->getFieldType(); - } - - // Whom node maps to inside target flatStruct - IRStructField* targetMapping; - - auto begin() { return members.begin(); } - auto end() { return members.end(); } - - // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to - // members in val2 using `MapStructToFlatStruct` - template - static void _emitCopy( - IRBuilder& builder, - IRInst* val1, - IRStructType* type1, - IRInst* val2, - IRStructType* type2, - MapStructToFlatStruct& node) - { - for (auto& field1Pair : node) - { - auto& field1 = field1Pair.second; - - // Get member of val1 - IRInst* fieldAddr1 = nullptr; - if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) - { - fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey()); - } - else - { - if (as(val1)) - val1 = builder.emitLoad(val1); - fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey()); - } - - // If val1 is a struct, recurse - if (auto fieldAsStruct1 = as(field1.getFieldType())) - { - _emitCopy( - builder, - fieldAddr1, - fieldAsStruct1, - val2, - type2, - field1); - continue; - } - - // Get member of val2 which maps to val1.member - auto field2 = field1.getMapping(); - SLANG_ASSERT(field2); - IRInst* fieldAddr2 = nullptr; - if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) - { - if (as(val2)) - val2 = builder.emitLoad(val1); - fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey()); - } - else - { - fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey()); - } - - // Copy val2/val1 member into val1/val2 member - if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) - { - builder.emitStore(fieldAddr1, fieldAddr2); - } - else - { - builder.emitStore(fieldAddr2, fieldAddr1); - } - } - } - - public: - void setNode(IRInst* newNode) { node = newNode; } - // Get 'MapStructToFlatStruct' that is a child of 'parent'. - // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'. - MapStructToFlatStruct& getMember(IRStructField* member) { return members[member]; } - MapStructToFlatStruct& operator[](IRStructField* member) { return getMember(member); } - - void setMapping(IRStructField* newTargetMapping) { targetMapping = newTargetMapping; } - // Get 'MapStructToFlatStruct' that is a child of 'parent'. - // Return nullptr if no member is mapped to 'parent' - IRStructField* getMapping() { return targetMapping; } - - // Copies srcVal into dstVal using hicharchy map. - template - void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal) - { - auto dstType = dstVal->getDataType(); - if (auto dstPtrType = as(dstType)) - dstType = dstPtrType->getValueType(); - auto dstStructType = as(dstType); - SLANG_ASSERT(dstStructType); - - auto srcType = srcVal->getDataType(); - if (auto srcPtrType = as(srcType)) - srcType = srcPtrType->getValueType(); - auto srcStructType = as(srcType); - SLANG_ASSERT(srcStructType); - - if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) - { - // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a - // struct - SLANG_ASSERT(node == dstStructType); - _emitCopy( - builder, - dstVal, - dstStructType, - srcVal, - srcStructType, - *this); - } - else - { - // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct - SLANG_ASSERT(node == srcStructType); - _emitCopy( - builder, - srcVal, - srcStructType, - dstVal, - dstStructType, - *this); - } - } - }; - - IRStructType* _flattenNestedStructs( - IRBuilder& builder, - IRStructType* dst, - IRStructType* src, - IRSemanticDecoration* parentSemanticDecoration, - IRLayoutDecoration* parentLayout, - MapStructToFlatStruct& mapFieldToField, - HashSet& varsWithSemanticInfo) - { - // For all fields ('oldField') of a struct do the following: - // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration, - // IRLayoutDecoration), store these if found. - // * Do not propagate semantic info if the current node has *any* form of semantic - // information. - // Update varsWithSemanticInfo. - // 2. If IRStructType: - // 2a. Recurse this function with 'decorations that carry semantic info' from parent. - // 3. If not IRStructType: - // 3a. Emit 'newField' with 'newKey' equal to 'oldField' and 'oldKey', respectively, - // where 'oldKey' is the key corresponding to 'oldField'. - // Add 'decorations which carry semantic info' to 'newField', and move all decorations - // of 'oldKey' to 'newKey'. - // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is - // needed to copy between types. - for (auto oldField : src->getFields()) - { - auto& fieldMappingNode = mapFieldToField[oldField]; - fieldMappingNode.setNode(oldField); - - // step 1 - bool foundSemanticDecor = false; - auto oldKey = oldField->getKey(); - IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration; - if (auto oldSemanticDecoration = oldKey->findDecoration()) - { - foundSemanticDecor = true; - fieldSemanticDecoration = oldSemanticDecoration; - parentLayout = nullptr; - } - - IRLayoutDecoration* fieldLayout = parentLayout; - if (auto oldLayout = oldKey->findDecoration()) - { - fieldLayout = oldLayout; - if (!foundSemanticDecor) - fieldSemanticDecoration = nullptr; - } - if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout) - varsWithSemanticInfo.add(oldField); - - // step 2a - if (auto structFieldType = as(oldField->getFieldType())) - { - _flattenNestedStructs( - builder, - dst, - structFieldType, - fieldSemanticDecoration, - fieldLayout, - fieldMappingNode, - varsWithSemanticInfo); - continue; - } - - // step 3a - auto newKey = builder.createStructKey(); - oldKey->transferDecorationsTo(newKey); - - auto newField = builder.createStructField(dst, newKey, oldField->getFieldType()); - copyNameHintAndDebugDecorations(newField, oldField); - - if (fieldSemanticDecoration) - builder.addSemanticDecoration( - newKey, - fieldSemanticDecoration->getSemanticName(), - fieldSemanticDecoration->getSemanticIndex()); - - if (fieldLayout) - { - IRLayout* oldLayout = fieldLayout->getLayout(); - List instToCopy; - // Only copy certain decorations needed for resolving system semantics - for (UInt i = 0; i < oldLayout->getOperandCount(); i++) - { - auto operand = oldLayout->getOperand(i); - if (as(operand) || as(operand) || - as(operand) || as(operand)) - instToCopy.add(operand); - } - IRVarLayout* newLayout = builder.getVarLayout(instToCopy); - builder.addLayoutDecoration(newKey, newLayout); - } - // step 3b - fieldMappingNode.setMapping(newField); - } - - return dst; - } - - // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there - // was no struct flattening. - // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct - // `IRStructFields`s - IRStructType* maybeFlattenNestedStructs( - IRBuilder& builder, - IRStructType* src, - MapStructToFlatStruct& mapFieldToField, - HashSet& varsWithSemanticInfo) - { - // Find all values inside struct that need flattening and legalization. - bool hasStructTypeMembers = false; - for (auto field : src->getFields()) - { - if (as(field->getFieldType())) - { - hasStructTypeMembers = true; - break; - } - } - if (!hasStructTypeMembers) - return src; - - // We need to: - // 1. Make new struct 1:1 with old struct but without nestested structs (flatten) - // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be - // handled later). - // 3. Store the mapping from old to new struct fields to allow copying a old-struct to - // new-struct. - builder.setInsertAfter(src); - auto newStruct = builder.createStructType(); - copyNameHintAndDebugDecorations(newStruct, src); - mapFieldToField.setNode(src); - return _flattenNestedStructs( - builder, - newStruct, - src, - nullptr, - nullptr, - mapFieldToField, - varsWithSemanticInfo); + builder.emitStore(pair.dest, builder.emitLoad(pair.value)); } +} - // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'. - // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function. - template - void _replaceAllReturnInst( - IRBuilder& builder, - IRFunc* targetFunc, - IRStructType* newType, - CopyLogicFunc copyLogicFunc) +static void legalizeFunc(IRFunc* func) +{ + // Insert casts to convert integer return types + auto funcReturnType = func->getResultType(); + if (isIntegralType(funcReturnType)) { - for (auto block : targetFunc->getBlocks()) + for (auto block : func->getBlocks()) { if (auto returnInst = as(block->getTerminator())) { - builder.setInsertBefore(returnInst); - auto returnVal = returnInst->getVal(); - returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal)); - } - } - } - - UInt _returnNonOverlappingAttributeIndex(std::set& usedSemanticIndex) - { - // Find first unused semantic index of equal semantic type - // to fill any gaps in user set semantic bindings - UInt prev = 0; - for (auto i : usedSemanticIndex) - { - if (i > prev + 1) - { - break; - } - prev = i; - } - usedSemanticIndex.insert(prev + 1); - return prev + 1; - } - - template - struct AttributeParentPair - { - IRLayoutDecoration* layoutDecor; - T* attr; - }; - - IRLayoutDecoration* _replaceAttributeOfLayout( - IRBuilder& builder, - IRLayoutDecoration* parentLayoutDecor, - IRInst* instToReplace, - IRInst* instToReplaceWith) - { - // Replace `instToReplace` with a `instToReplaceWith` - - auto layout = parentLayoutDecor->getLayout(); - // Find the exact same decoration `instToReplace` in-case multiple of the same type exist - List opList; - opList.add(instToReplaceWith); - for (UInt i = 0; i < layout->getOperandCount(); i++) - { - if (layout->getOperand(i) != instToReplace) - opList.add(layout->getOperand(i)); - } - auto newLayoutDecor = builder.addLayoutDecoration( - parentLayoutDecor->getParent(), - builder.getVarLayout(opList)); - parentLayoutDecor->removeAndDeallocate(); - return newLayoutDecor; - } - - IRLayoutDecoration* _simplifyUserSemanticNames( - IRBuilder& builder, - IRLayoutDecoration* layoutDecor) - { - // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into - // ("SV_TARGET", 0) using 'IRUserSemanticAttr' This is done to ensure we can check semantic - // groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()' - SLANG_ASSERT(layoutDecor); - auto layout = layoutDecor->getLayout(); - List layoutOps; - layoutOps.reserve(3); - bool changed = false; - for (auto attr : layout->getAllAttrs()) - { - if (auto userSemantic = as(attr)) - { - UnownedStringSlice outName; - UnownedStringSlice outIndex; - bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex); - - changed = true; - auto newDecoration = builder.getUserSemanticAttr( - userSemanticName, - hasStringIndex ? stringToInt(outIndex) : 0); - userSemantic->replaceUsesWith(newDecoration); - userSemantic->removeAndDeallocate(); - userSemantic = newDecoration; - - layoutOps.add(userSemantic); - continue; - } - layoutOps.add(attr); - } - if (changed) - { - auto parent = layoutDecor->parent; - layoutDecor->removeAndDeallocate(); - builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps)); - } - return layoutDecor; - } - - // Find overlapping field semantics and legalize them - void fixFieldSemanticsOfFlatStruct(IRStructType* structType) - { - // Goal is to ensure we do not have overlapping semantics for the user defined semantics: - // Note that in WGSL, the semantics can be either `builtin` without index or `location` with - // index. - /* - // Assume the following code - struct Fragment - { - float4 p0 : SV_POSITION; - float2 p1 : TEXCOORD0; - float2 p2 : TEXCOORD1; - float3 p3 : COLOR0; - float3 p4 : COLOR1; - }; - - // Translates into - struct Fragment - { - float4 p0 : BUILTIN_POSITION; - float2 p1 : LOCATION_0; - float2 p2 : LOCATION_1; - float3 p3 : LOCATION_2; - float3 p4 : LOCATION_3; - }; - */ - - // For Multi-Render-Target, the semantic index must be translated to `location` with - // the same index. Assume the following code - /* - struct Fragment - { - float4 p0 : SV_TARGET1; - float4 p1 : SV_TARGET0; - }; - - // Translates into - struct Fragment - { - float4 p0 : LOCATION_1; - float4 p1 : LOCATION_0; - }; - */ - - IRBuilder builder(this->m_module); - - List overlappingSemanticsDecor; - Dictionary>> - usedSemanticIndexSemanticDecor; - - List> overlappingVarOffset; - Dictionary>> usedSemanticIndexVarOffset; - - List> overlappingUserSemantic; - Dictionary>> - usedSemanticIndexUserSemantic; - - // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when - // legalizing we may destroy and remake a `IRLayoutDecoration*` - Dictionary oldLayoutDecorToNew; - - // Collect all "semantic info carrying decorations". Any collected decoration will - // fill up their respective 'Dictionary>' - // to keep track of in-use offsets for a semantic type. - // Example: IRSemanticDecoration with name of "SV_TARGET1". - // * This will have SEMANTIC_TYPE of "sv_target". - // * This will use up index '1' - // - // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to - // a list of 'overlapping semantic info decorations' so we can legalize this - // 'semantic info decoration' later. - // - // NOTE: this is a flat struct, all members are children of the initial - // IRStructType. - for (auto field : structType->getFields()) - { - auto key = field->getKey(); - if (auto semanticDecoration = key->findDecoration()) - { - auto semanticName = semanticDecoration->getSemanticName(); - - // sv_target is treated as a user-semantic because it should be emitted with - // @location like how the user semantics are emitted. - // For fragment shader, only sv_target will user @location, and for non-fragment - // shaders, sv_target is not valid. - bool isUserSemantic = - (semanticName.startsWithCaseInsensitive(toSlice("sv_target")) || - !semanticName.startsWithCaseInsensitive(toSlice("sv_"))); - - // Ensure names are in a uniform lowercase format so we can bunch together simmilar - // semantics. - UnownedStringSlice outName; - UnownedStringSlice outIndex; - bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex); - - // user semantics gets all same semantic-name. - auto loweredName = String(outName).toLower(); - auto loweredNameSlice = - isUserSemantic ? userSemanticName : loweredName.getUnownedSlice(); - auto newDecoration = builder.addSemanticDecoration( - key, - loweredNameSlice, - hasStringIndex ? stringToInt(outIndex) : 0); - semanticDecoration->replaceUsesWith(newDecoration); - semanticDecoration->removeAndDeallocate(); - semanticDecoration = newDecoration; - - auto& semanticUse = - usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()]; - if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end()) - overlappingSemanticsDecor.add(semanticDecoration); - else - semanticUse.insert(semanticDecoration->getSemanticIndex()); - } - if (auto layoutDecor = key->findDecoration()) - { - // Ensure names are in a uniform lowercase format so we can bunch together simmilar - // semantics - layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor); - oldLayoutDecorToNew[layoutDecor] = layoutDecor; - auto layout = layoutDecor->getLayout(); - for (auto attr : layout->getAllAttrs()) - { - if (auto offset = as(attr)) - { - auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()]; - if (semanticUse.find(offset->getOffset()) != semanticUse.end()) - overlappingVarOffset.add({layoutDecor, offset}); - else - semanticUse.insert(offset->getOffset()); - } - else if (auto userSemantic = as(attr)) - { - auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()]; - if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end()) - overlappingUserSemantic.add({layoutDecor, userSemantic}); - else - semanticUse.insert(userSemantic->getIndex()); - } - } - } - } - - // Legalize all overlapping 'semantic info decorations' - for (auto decor : overlappingSemanticsDecor) - { - auto newOffset = _returnNonOverlappingAttributeIndex( - usedSemanticIndexSemanticDecor[decor->getSemanticName()]); - builder.addSemanticDecoration( - decor->getParent(), - decor->getSemanticName(), - (int)newOffset); - decor->removeAndDeallocate(); - } - for (auto& varOffset : overlappingVarOffset) - { - auto newOffset = _returnNonOverlappingAttributeIndex( - usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]); - auto newVarOffset = builder.getVarOffsetAttr( - varOffset.attr->getResourceKind(), - newOffset, - varOffset.attr->getSpace()); - oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout( - builder, - oldLayoutDecorToNew[varOffset.layoutDecor], - varOffset.attr, - newVarOffset); - } - for (auto& userSemantic : overlappingUserSemantic) - { - auto newOffset = _returnNonOverlappingAttributeIndex( - usedSemanticIndexUserSemantic[userSemantic.attr->getName()]); - auto newUserSemantic = - builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset); - oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout( - builder, - oldLayoutDecorToNew[userSemantic.layoutDecor], - userSemantic.attr, - newUserSemantic); - } - } - - void wrapReturnValueInStruct(EntryPointInfo entryPoint) - { - // Wrap return value into a struct if it is not already a struct. - // For example, given this entry point: - // ``` - // float4 main() : SV_Target { return float3(1,2,3); } - // ``` - // We are going to transform it into: - // ``` - // struct Output { - // float4 value : SV_Target; - // }; - // Output main() { return {float3(1,2,3)}; } - - auto func = entryPoint.entryPointFunc; - - auto returnType = func->getResultType(); - if (as(returnType)) - return; - auto entryPointLayoutDecor = func->findDecoration(); - if (!entryPointLayoutDecor) - return; - auto entryPointLayout = as(entryPointLayoutDecor->getLayout()); - if (!entryPointLayout) - return; - auto resultLayout = entryPointLayout->getResultLayout(); - - // If return type is already a struct, just make sure every field has a semantic. - if (auto returnStructType = as(returnType)) - { - IRBuilder builder(func); - MapStructToFlatStruct mapOldFieldToNewField; - // Flatten result struct type to ensure we do not have nested semantics - auto flattenedStruct = maybeFlattenNestedStructs( - builder, - returnStructType, - mapOldFieldToNewField, - semanticInfoToRemove); - if (returnStructType != flattenedStruct) - { - // Replace all return-values with the flattenedStruct we made. - _replaceAllReturnInst( - builder, - func, - flattenedStruct, - [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* - { - auto srcStructType = as(srcVal->getDataType()); - SLANG_ASSERT(srcStructType); - auto dstVal = copyBuilder.emitVar(dstType); - mapOldFieldToNewField.emitCopy<( - int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>( - copyBuilder, - dstVal, - srcVal); - return builder.emitLoad(dstVal); - }); - fixUpFuncType(func, flattenedStruct); - } - // Ensure non-overlapping semantics - fixFieldSemanticsOfFlatStruct(flattenedStruct); - ensureStructHasUserSemantic( - flattenedStruct, - resultLayout); - return; - } - - IRBuilder builder(func); - builder.setInsertBefore(func); - IRStructType* structType = builder.createStructType(); - auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); - builder.addNameHintDecoration( - structType, - (String(stageText) + toSlice("Output")).getUnownedSlice()); - auto key = builder.createStructKey(); - builder.addNameHintDecoration(key, toSlice("output")); - builder.addLayoutDecoration(key, resultLayout); - builder.createStructField(structType, key, returnType); - IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder); - structTypeLayoutBuilder.addField(key, resultLayout); - auto typeLayout = structTypeLayoutBuilder.build(); - IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); - auto varLayout = varLayoutBuilder.build(); - ensureStructHasUserSemantic(structType, varLayout); - - _replaceAllReturnInst( - builder, - func, - structType, - [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* - { return copyBuilder.emitMakeStruct(dstType, 1, &srcVal); }); - - // Assign an appropriate system value semantic for stage output - auto stage = entryPoint.entryPointDecor->getProfile().getStage(); - switch (stage) - { - case Stage::Compute: - case Stage::Fragment: - { - IRInst* operands[] = { - builder.getStringValue(userSemanticName), - builder.getIntValue(builder.getIntType(), 0)}; - builder.addDecoration( - key, - kIROp_SemanticDecoration, - operands, - SLANG_COUNT_OF(operands)); - break; - } - case Stage::Vertex: - { - builder.addTargetSystemValueDecoration(key, toSlice("position")); - break; - } - default: - SLANG_ASSERT(false); - return; - } - - fixUpFuncType(func, structType); - } - - IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType) - { - auto fromType = val->getFullType(); - if (auto fromVector = as(fromType)) - { - if (auto toVector = as(toType)) - { - if (fromVector->getElementCount() != toVector->getElementCount()) - { - fromType = builder.getVectorType( - fromVector->getElementType(), - toVector->getElementCount()); - val = builder.emitVectorReshape(fromType, val); - } - } - else if (as(toType)) - { - UInt index = 0; - val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); - if (toType->getOp() == kIROp_VoidType) - return nullptr; - } - } - else if (auto fromBasicType = as(fromType)) - { - if (fromBasicType->getOp() == kIROp_VoidType) - return nullptr; - if (!as(toType)) - return nullptr; - if (toType->getOp() == kIROp_VoidType) - return nullptr; - } - else - { - return nullptr; - } - return builder.emitCast(toType, val); - } - - struct SystemValLegalizationWorkItem - { - IRInst* var; - IRType* varType; - String attrName; - UInt attrIndex; - }; - - std::optional tryToMakeSystemValWorkItem( - IRInst* var, - IRType* varType) - { - if (auto semanticDecoration = var->findDecoration()) - { - if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) - { - return { - {var, - varType, - String(semanticDecoration->getSemanticName()).toLower(), - (UInt)semanticDecoration->getSemanticIndex()}}; - } - } - - auto layoutDecor = var->findDecoration(); - if (!layoutDecor) - return {}; - auto sysValAttr = layoutDecor->findAttr(); - if (!sysValAttr) - return {}; - auto semanticName = String(sysValAttr->getName()); - auto sysAttrIndex = sysValAttr->getIndex(); - - return {{var, varType, semanticName, sysAttrIndex}}; - } - - List collectSystemValFromEntryPoint(EntryPointInfo entryPoint) - { - List systemValWorkItems; - for (auto param : entryPoint.entryPointFunc->getParams()) - { - if (auto structType = as(param->getDataType())) - { - for (auto field : structType->getFields()) - { - // Nested struct-s are flattened already by flattenInputParameters(). - SLANG_ASSERT(!as(field->getFieldType())); - - auto key = field->getKey(); - auto fieldType = field->getFieldType(); - auto maybeWorkItem = tryToMakeSystemValWorkItem(key, fieldType); - if (maybeWorkItem.has_value()) - systemValWorkItems.add(std::move(maybeWorkItem.value())); - } - continue; - } - - auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType()); - if (maybeWorkItem.has_value()) - systemValWorkItems.add(std::move(maybeWorkItem.value())); - } - return systemValWorkItems; - } - - void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem) - { - IRBuilder builder(entryPoint.entryPointFunc); - - auto var = workItem.var; - auto varType = workItem.varType; - auto semanticName = workItem.attrName; - - auto indexAsString = String(workItem.attrIndex); - auto info = getSystemValueInfo(semanticName, &indexAsString, var); - - if (info.isUnsupported) - { - reportUnsupportedSystemAttribute(var, semanticName); - return; - } - if (!info.permittedTypes.getCount()) - return; - - builder.addTargetSystemValueDecoration(var, info.wgslSystemValueName.getUnownedSlice()); - - bool varTypeIsPermitted = false; - for (auto& permittedType : info.permittedTypes) - { - varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; - } - - if (!varTypeIsPermitted) - { - // Note: we do not currently prefer any conversion - // example: - // * allowed types for semantic: `float4`, `uint4`, `int4` - // * user used, `float2` - // * Slang will equally prefer `float4` to `uint4` to `int4`. - // This means the type may lose data if slang selects `uint4` or `int4`. - bool foundAConversion = false; - for (auto permittedType : info.permittedTypes) - { - var->setFullType(permittedType); - builder.setInsertBefore( - entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); - - // get uses before we `tryConvertValue` since this creates a new use - List uses; - for (auto use = var->firstUse; use; use = use->nextUse) - uses.add(use); - - auto convertedValue = tryConvertValue(builder, var, varType); - if (convertedValue == nullptr) - continue; - - foundAConversion = true; - copyNameHintAndDebugDecorations(convertedValue, var); - - for (auto use : uses) - builder.replaceOperand(use, convertedValue); - } - if (!foundAConversion) - { - // If we can't convert the value, report an error. - for (auto permittedType : info.permittedTypes) + auto returnedValue = returnInst->getOperand(0); + auto returnedValueType = returnedValue->getDataType(); + if (isIntegralType(returnedValueType)) { - StringBuilder typeNameSB; - getTypeNameHint(typeNameSB, permittedType); - m_sink->diagnose( - var->sourceLoc, - Diagnostics::systemValueTypeIncompatible, - semanticName, - typeNameSB.produceString()); + IRBuilder builder(returnInst); + builder.setInsertBefore(returnInst); + auto newOp = builder.emitCast(funcReturnType, returnedValue); + builder.replaceOperand(returnInst->getOperands(), newOp); } } } } +} - void legalizeSystemValueParameters(EntryPointInfo entryPoint) - { - List systemValWorkItems = - collectSystemValFromEntryPoint(entryPoint); - - for (auto index = 0; index < systemValWorkItems.getCount(); index++) - { - legalizeSystemValue(entryPoint, systemValWorkItems[index]); - } - fixUpFuncType(entryPoint.entryPointFunc); - } - - void legalizeEntryPointForWGSL(EntryPointInfo entryPoint) - { - // If the entrypoint is receiving varying inputs as a pointer, turn it into a value. - depointerizeInputParams(entryPoint.entryPointFunc); - - // Input Parameter Legalize - flattenInputParameters(entryPoint); - - // System Value Legalize - legalizeSystemValueParameters(entryPoint); - - // Output Value Legalize - wrapReturnValueInStruct(entryPoint); - } - - void legalizeCall(IRCall* call) - { - // WGSL does not allow forming a pointer to a sub part of a composite value. - // For example, if we have - // ``` - // struct S { float x; float y; }; - // void foo(inout float v) { v = 1.0f; } - // void main() { S s; foo(s.x); } - // ``` - // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`. - // And trying to form `&s.x` in WGSL is illegal. - // To work around this, we will create a local variable to hold the sub part of - // the composite value. - // And then pass the local variable to the function. - // After the call, we will write back the local variable to the sub part of the - // composite value. - // - IRBuilder builder(call); - builder.setInsertBefore(call); - struct WritebackPair - { - IRInst* dest; - IRInst* value; - }; - ShortList pendingWritebacks; - - for (UInt i = 0; i < call->getArgCount(); i++) - { - auto arg = call->getArg(i); - auto ptrType = as(arg->getDataType()); - if (!ptrType) - continue; - switch (arg->getOp()) - { - case kIROp_Var: - case kIROp_Param: - case kIROp_GlobalParam: - case kIROp_GlobalVar: - continue; - default: - break; - } - - // Create a local variable to hold the input argument. - auto var = builder.emitVar(ptrType->getValueType(), AddressSpace::Function); - - // Store the input argument into the local variable. - builder.emitStore(var, builder.emitLoad(arg)); - builder.replaceOperand(call->getArgs() + i, var); - pendingWritebacks.add({arg, var}); - } - - // Perform writebacks after the call. - builder.setInsertAfter(call); - for (auto& pair : pendingWritebacks) - { - builder.emitStore(pair.dest, builder.emitLoad(pair.value)); - } - } - - void legalizeFunc(IRFunc* func) - { - // Insert casts to convert integer return types - auto funcReturnType = func->getResultType(); - if (isIntegralType(funcReturnType)) - { - for (auto block : func->getBlocks()) - { - if (auto returnInst = as(block->getTerminator())) - { - auto returnedValue = returnInst->getOperand(0); - auto returnedValueType = returnedValue->getDataType(); - if (isIntegralType(returnedValueType)) - { - IRBuilder builder(returnInst); - builder.setInsertBefore(returnInst); - auto newOp = builder.emitCast(funcReturnType, returnedValue); - builder.replaceOperand(returnInst->getOperands(), newOp); - } - } - } - } - } - - void legalizeSwitch(IRSwitch* switchInst) - { - // WGSL Requires all switch statements to contain a default case. - // If the switch statement does not contain a default case, we will add one. - if (switchInst->getDefaultLabel() != switchInst->getBreakLabel()) - return; - IRBuilder builder(switchInst); - auto defaultBlock = builder.createBlock(); - builder.setInsertInto(defaultBlock); - builder.emitBranch(switchInst->getBreakLabel()); - defaultBlock->insertBefore(switchInst->getBreakLabel()); - List cases; - for (UInt i = 0; i < switchInst->getCaseCount(); i++) - { - cases.add(switchInst->getCaseValue(i)); - cases.add(switchInst->getCaseLabel(i)); - } - builder.setInsertBefore(switchInst); - auto newSwitch = builder.emitSwitch( - switchInst->getCondition(), - switchInst->getBreakLabel(), - defaultBlock, - (UInt)cases.getCount(), - cases.getBuffer()); - switchInst->transferDecorationsTo(newSwitch); - switchInst->removeAndDeallocate(); - } - - void processInst(IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_Call: - legalizeCall(static_cast(inst)); - break; - - case kIROp_Switch: - legalizeSwitch(as(inst)); - break; - - // For all binary operators, make sure both side of the operator have the same type - // (vector-ness and matrix-ness). - case kIROp_Add: - case kIROp_Sub: - case kIROp_Mul: - case kIROp_Div: - case kIROp_FRem: - case kIROp_IRem: - case kIROp_And: - case kIROp_Or: - case kIROp_BitAnd: - case kIROp_BitOr: - case kIROp_BitXor: - case kIROp_Lsh: - case kIROp_Rsh: - case kIROp_Eql: - case kIROp_Neq: - case kIROp_Greater: - case kIROp_Less: - case kIROp_Geq: - case kIROp_Leq: - legalizeBinaryOp(inst); - break; +static void legalizeSwitch(IRSwitch* switchInst) +{ + // WGSL Requires all switch statements to contain a default case. + // If the switch statement does not contain a default case, we will add one. + if (switchInst->getDefaultLabel() != switchInst->getBreakLabel()) + return; + IRBuilder builder(switchInst); + auto defaultBlock = builder.createBlock(); + builder.setInsertInto(defaultBlock); + builder.emitBranch(switchInst->getBreakLabel()); + defaultBlock->insertBefore(switchInst->getBreakLabel()); + List cases; + for (UInt i = 0; i < switchInst->getCaseCount(); i++) + { + cases.add(switchInst->getCaseValue(i)); + cases.add(switchInst->getCaseLabel(i)); + } + builder.setInsertBefore(switchInst); + auto newSwitch = builder.emitSwitch( + switchInst->getCondition(), + switchInst->getBreakLabel(), + defaultBlock, + (UInt)cases.getCount(), + cases.getBuffer()); + switchInst->transferDecorationsTo(newSwitch); + switchInst->removeAndDeallocate(); +} - case kIROp_Func: - legalizeFunc(static_cast(inst)); - [[fallthrough]]; - default: - for (auto child : inst->getModifiableChildren()) - { - processInst(child); - } +static void processInst(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_Call: + legalizeCall(static_cast(inst)); + break; + + case kIROp_Switch: + legalizeSwitch(as(inst)); + break; + + // For all binary operators, make sure both side of the operator have the same type + // (vector-ness and matrix-ness). + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + legalizeBinaryOp(inst); + break; + + case kIROp_Func: + legalizeFunc(static_cast(inst)); + [[fallthrough]]; + default: + for (auto child : inst->getModifiableChildren()) + { + processInst(child); } } -}; +} struct GlobalInstInliningContext : public GlobalInstInliningContextGeneric { @@ -1583,13 +215,10 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) entryPoints.add(info); } - LegalizeWGSLEntryPointContext context(sink, module); - for (auto entryPoint : entryPoints) - context.legalizeEntryPointForWGSL(entryPoint); - context.removeSemanticLayoutsFromLegalizedStructs(); + legalizeEntryPointVaryingParamsForWGSL(module, sink, entryPoints); // Go through every instruction in the module and legalize them as needed. - context.processInst(module->getModuleInst()); + processInst(module->getModuleInst()); // Some global insts are illegal, e.g. function calls. // We need to inline and remove those. From 413355337a8ad682a84ca8da1866927b1068b263 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Fri, 10 Jan 2025 23:24:12 -0500 Subject: [PATCH 02/18] refactor system val work item --- .../slang-ir-legalize-varying-params.cpp | 28 ++++++------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 65d0ec28cc..98b6d2b8a6 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -1548,6 +1548,7 @@ void depointerizeInputParams(IRFunc* entryPointFunc) } +// Commony entry point legalization context for Metal and WGSL. struct LegalizeShaderEntryPointContext { enum class LegalizeTarget @@ -2817,6 +2818,7 @@ struct LegalizeShaderEntryPointContext } else { + SLANG_ASSERT(isTargetWGSL()); IRInst* operands[] = { builder.getStringValue(wgslContext.userSemanticName), builder.getIntValue(builder.getIntType(), 0)}; @@ -2884,7 +2886,7 @@ struct LegalizeShaderEntryPointContext { IRInst* var; - // Only valid for WGSL. + // Only used for WGSL. IRType* varType; String attrName; @@ -2925,13 +2927,7 @@ struct LegalizeShaderEntryPointContext List systemValWorkItems; for (auto param : entryPoint.entryPointFunc->getParams()) { - std::optional maybeWorkItem; - - if (isTargetMetal()) - { - maybeWorkItem = tryToMakeSystemValWorkItem(param, nullptr); - } - else + if (isTargetWGSL()) { if (auto structType = as(param->getDataType())) { @@ -2948,9 +2944,8 @@ struct LegalizeShaderEntryPointContext } continue; } - maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType()); } - + auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType()); if (maybeWorkItem.has_value()) systemValWorkItems.add(std::move(maybeWorkItem.value())); } @@ -2962,14 +2957,7 @@ struct LegalizeShaderEntryPointContext IRBuilder builder(entryPoint.entryPointFunc); auto var = workItem.var; - auto varType = workItem.varType; - // XXX: can remove this by also passing this to Metal SV info? - if (isTargetMetal()) - { - varType = var->getFullType(); - } - auto semanticName = workItem.attrName; auto indexAsString = String(workItem.attrIndex); @@ -3045,7 +3033,7 @@ struct LegalizeShaderEntryPointContext layoutDecoration->insertAtStart(groupThreadId); SystemValLegalizationWorkItem newWorkItem = { groupThreadId, - nullptr, + groupThreadId->getFullType(), metalContext.groupThreadIDString, semanticIndex}; legalizeSystemValue(entryPoint, newWorkItem); @@ -3153,7 +3141,9 @@ struct LegalizeShaderEntryPointContext // If the entrypoint is receiving varying inputs as a pointer, turn it into a value. depointerizeInputParams(entryPoint.entryPointFunc); - // XXX: Enable these for WGSL + // TODO FIXME: Enable these for WGSL + // WGSL entry point legalization currently only applies attributes to struct parameters, + // apply the same hoisting from Metal to WGSL. if (isTargetMetal()) { hoistEntryPointParameterFromStruct(entryPoint); From 66492bbdedbaf14d2092973bd4645e12dc4906e7 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Fri, 10 Jan 2025 23:45:49 -0500 Subject: [PATCH 03/18] refactor simplify user names --- .../slang-ir-legalize-varying-params.cpp | 47 ++++++------------- 1 file changed, 14 insertions(+), 33 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 98b6d2b8a6..48531617ee 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -2606,37 +2606,18 @@ struct LegalizeShaderEntryPointContext UnownedStringSlice outIndex; bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex); - if (isTargetMetal()) - { - if (hasStringIndex) - { - auto loweredName = String(outName).toLower(); - auto loweredNameSlice = loweredName.getUnownedSlice(); - auto newDecoration = builder.addSemanticDecoration( - key, - loweredNameSlice, - stringToInt(outIndex)); - semanticDecoration->replaceUsesWith(newDecoration); - semanticDecoration->removeAndDeallocate(); - semanticDecoration = newDecoration; - } - } - else - { - // user semantics gets all same semantic-name. - auto loweredName = String(outName).toLower(); - auto loweredNameSlice = isUserSemantic ? wgslContext.userSemanticName - : loweredName.getUnownedSlice(); - auto newDecoration = builder.addSemanticDecoration( - key, - loweredNameSlice, - // hasStringIndex ? stringToInt(outIndex) : 0); - hasStringIndex ? stringToInt(outIndex) - : semanticDecoration->getSemanticIndex()); - semanticDecoration->replaceUsesWith(newDecoration); - semanticDecoration->removeAndDeallocate(); - semanticDecoration = newDecoration; - } + auto loweredName = String(outName).toLower(); + auto loweredNameSlice = isTargetMetal() || !isUserSemantic + ? loweredName.getUnownedSlice() + : wgslContext.userSemanticName; + auto semanticIndex = + hasStringIndex ? stringToInt(outIndex) : semanticDecoration->getSemanticIndex(); + auto newDecoration = + builder.addSemanticDecoration(key, loweredNameSlice, semanticIndex); + + semanticDecoration->replaceUsesWith(newDecoration); + semanticDecoration->removeAndDeallocate(); + semanticDecoration = newDecoration; auto& semanticUse = usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()]; @@ -3143,8 +3124,8 @@ struct LegalizeShaderEntryPointContext // TODO FIXME: Enable these for WGSL // WGSL entry point legalization currently only applies attributes to struct parameters, - // apply the same hoisting from Metal to WGSL. - if (isTargetMetal()) + // apply the same hoisting from Metal to WGSL to fix it. + // if (isTargetMetal()) { hoistEntryPointParameterFromStruct(entryPoint); packStageInParameters(entryPoint); From f8865ed3bda591ad0e7c1ce9ac25ee6f9f856837 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Fri, 10 Jan 2025 23:59:00 -0500 Subject: [PATCH 04/18] clean up fix semantic field of struct --- .../slang-ir-legalize-varying-params.cpp | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 48531617ee..e9796ed95e 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -1831,23 +1831,13 @@ struct LegalizeShaderEntryPointContext mapOldFieldToNewField, semanticInfoToRemove); - // XXX TODO: Clean this up maybe? - if (isTargetWGSL()) - { - // Validate/rearange all semantics which overlap in our flat struct. - fixFieldSemanticsOfFlatStruct(flattenedStruct); - ensureStructHasUserSemantic( - flattenedStruct, - layout); - } + // Validate/rearange all semantics which overlap in our flat struct. + fixFieldSemanticsOfFlatStruct(flattenedStruct); + ensureStructHasUserSemantic( + flattenedStruct, + layout); if (flattenedStruct != structType) { - if (isTargetMetal()) - { - // Validate/rearange all semantics which overlap in our flat struct - fixFieldSemanticsOfFlatStruct(flattenedStruct); - } - // Replace the 'old IRParam type' with a 'new IRParam type' param->setFullType(flattenedStruct); @@ -3125,7 +3115,7 @@ struct LegalizeShaderEntryPointContext // TODO FIXME: Enable these for WGSL // WGSL entry point legalization currently only applies attributes to struct parameters, // apply the same hoisting from Metal to WGSL to fix it. - // if (isTargetMetal()) + if (isTargetMetal()) { hoistEntryPointParameterFromStruct(entryPoint); packStageInParameters(entryPoint); From f5696ea124388b745c2afb3e309e70ebc4bb167e Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sat, 11 Jan 2025 00:06:04 -0500 Subject: [PATCH 05/18] improve code layout --- .../slang-ir-legalize-varying-params.cpp | 44 ++++++++++--------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index e9796ed95e..dcc18989f4 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -1549,14 +1549,36 @@ void depointerizeInputParams(IRFunc* entryPointFunc) // Commony entry point legalization context for Metal and WGSL. -struct LegalizeShaderEntryPointContext +class LegalizeShaderEntryPointContext { +public: enum class LegalizeTarget { Metal, WGSL, }; + LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink, LegalizeTarget target) + : m_module(module), m_sink(sink), m_target(target) + { + } + + void legalizeEntryPoints(List& entryPoints) + { + for (auto entryPoint : entryPoints) + legalizeEntryPoint(entryPoint); + removeSemanticLayoutsFromLegalizedStructs(); + } + +private: + IRModule* m_module; + DiagnosticSink* m_sink; + LegalizeTarget m_target; + HashSet semanticInfoToRemove; + + bool isTargetMetal() const { return m_target == LegalizeTarget::Metal; } + bool isTargetWGSL() const { return m_target == LegalizeTarget::WGSL; } + struct SystemValueInfo { String systemValueName; @@ -1591,19 +1613,6 @@ struct LegalizeShaderEntryPointContext } } - IRModule* m_module; - DiagnosticSink* m_sink; - LegalizeTarget m_target; - HashSet semanticInfoToRemove; - - bool isTargetMetal() const { return m_target == LegalizeTarget::Metal; } - bool isTargetWGSL() const { return m_target == LegalizeTarget::WGSL; } - - LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink, LegalizeTarget target) - : m_module(module), m_sink(sink), m_target(target) - { - } - void removeSemanticLayoutsFromLegalizedStructs() { // Metal and WGSL does not allow duplicate attributes to appear in the same shader. @@ -3147,13 +3156,6 @@ struct LegalizeShaderEntryPointContext } } - void legalizeEntryPoints(List& entryPoints) - { - for (auto entryPoint : entryPoints) - legalizeEntryPoint(entryPoint); - removeSemanticLayoutsFromLegalizedStructs(); - } - // ****************************************************************** // Metal specific Legalization Logic // ****************************************************************** From a20deadf259a188085bd28b2df3d1a4a49be26d9 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sat, 11 Jan 2025 17:28:04 -0500 Subject: [PATCH 06/18] split wgsl/metal to seperate classes and cleanup --- .../slang-ir-legalize-varying-params.cpp | 875 ++++++++++-------- 1 file changed, 477 insertions(+), 398 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index dcc18989f4..9ddfd8bc74 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -1,11 +1,12 @@ // slang-ir-legalize-varying-params.cpp #include "slang-ir-legalize-varying-params.h" -#include "core/slang-common.h" #include "slang-ir-clone.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" +#include "slang-ir.h" #include "slang-parameter-binding.h" +#include "slang.h" #include @@ -1548,18 +1549,11 @@ void depointerizeInputParams(IRFunc* entryPointFunc) } -// Commony entry point legalization context for Metal and WGSL. class LegalizeShaderEntryPointContext { public: - enum class LegalizeTarget - { - Metal, - WGSL, - }; - - LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink, LegalizeTarget target) - : m_module(module), m_sink(sink), m_target(target) + LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink, bool hoistParameters) + : m_module(module), m_sink(sink), hoistParameters(hoistParameters) { } @@ -1570,49 +1564,189 @@ class LegalizeShaderEntryPointContext removeSemanticLayoutsFromLegalizedStructs(); } -private: +protected: IRModule* m_module; DiagnosticSink* m_sink; - LegalizeTarget m_target; - HashSet semanticInfoToRemove; - - bool isTargetMetal() const { return m_target == LegalizeTarget::Metal; } - bool isTargetWGSL() const { return m_target == LegalizeTarget::WGSL; } struct SystemValueInfo { String systemValueName; SystemValueSemanticName systemValueNameEnum; - ShortList permittedTypes; - bool isUnsupported = false; - // Only used by Metal. + bool isUnsupported = false; bool isSpecial = false; + }; - SystemValueInfo() - { - // most commonly need 2 - permittedTypes.reserveOverflowBuffer(2); - } + struct SystemValLegalizationWorkItem + { + IRInst* var; + IRType* varType; + + String attrName; + UInt attrIndex; }; - SystemValueInfo getSystemValueInfo( + virtual SystemValueInfo getSystemValueInfo( String inSemanticName, String* optionalSemanticIndex, - IRInst* parentVar) + IRInst* parentVar) const = 0; + + virtual List collectSystemValFromEntryPoint( + EntryPointInfo entryPoint) const = 0; + + virtual void flattenNestedStructsTransferKeyDecorations(IRInst* newKey, IRInst* oldKey) + const = 0; + + virtual UnownedStringSlice getUserSemanticName(String& loweredName, bool isUserSemantic) + const = 0; + + virtual void addFragmentShaderReturnValueDecoration( + IRBuilder& builder, + IRInst* returnValueStructKey) const = 0; + + + virtual IRVarLayout* handleGeometryStageParameterVarLayout( + IRBuilder& builder, + IRVarLayout* paramVarLayout) const { - if (isTargetMetal()) + SLANG_UNUSED(builder); + return paramVarLayout; + } + + virtual void handleSpecialSystemValue( + const EntryPointInfo& entryPoint, + SystemValLegalizationWorkItem& workItem, + const SystemValueInfo& info, + IRBuilder& builder) + { + SLANG_UNUSED(entryPoint); + SLANG_UNUSED(workItem); + SLANG_UNUSED(info); + SLANG_UNUSED(builder); + } + + virtual void legalizeAmplificationStageEntryPoint(const EntryPointInfo& entryPoint) const + { + SLANG_UNUSED(entryPoint); + } + + virtual void legalizeMeshStageEntryPoint(const EntryPointInfo& entryPoint) const + { + SLANG_UNUSED(entryPoint); + } + + + std::optional tryToMakeSystemValWorkItem( + IRInst* var, + IRType* varType) const + { + if (auto semanticDecoration = var->findDecoration()) { - return getMetalSystemValueInfo(inSemanticName, optionalSemanticIndex, parentVar); + if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + { + return { + {var, + varType, + String(semanticDecoration->getSemanticName()).toLower(), + (UInt)semanticDecoration->getSemanticIndex()}}; + } } - else + + auto layoutDecor = var->findDecoration(); + if (!layoutDecor) + return {}; + auto sysValAttr = layoutDecor->findAttr(); + if (!sysValAttr) + return {}; + auto semanticName = String(sysValAttr->getName()); + auto sysAttrIndex = sysValAttr->getIndex(); + + return {{var, varType, semanticName, sysAttrIndex}}; + } + + void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem) + { + IRBuilder builder(entryPoint.entryPointFunc); + + auto var = workItem.var; + auto varType = workItem.varType; + auto semanticName = workItem.attrName; + + auto indexAsString = String(workItem.attrIndex); + SystemValueInfo info = getSystemValueInfo(semanticName, &indexAsString, var); + if (info.isSpecial) + { + handleSpecialSystemValue(entryPoint, workItem, info, builder); + } + + if (info.isUnsupported) + { + reportUnsupportedSystemAttribute(var, semanticName); + return; + } + if (!info.permittedTypes.getCount()) + return; + + builder.addTargetSystemValueDecoration(var, info.systemValueName.getUnownedSlice()); + + bool varTypeIsPermitted = false; + for (auto& permittedType : info.permittedTypes) { - SLANG_ASSERT(isTargetWGSL()); - return getWGSLSystemValueInfo(inSemanticName, optionalSemanticIndex, parentVar); + varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; + } + + if (!varTypeIsPermitted) + { + // Note: we do not currently prefer any conversion + // example: + // * allowed types for semantic: `float4`, `uint4`, `int4` + // * user used, `float2` + // * Slang will equally prefer `float4` to `uint4` to `int4`. + // This means the type may lose data if slang selects `uint4` or `int4`. + bool foundAConversion = false; + for (auto permittedType : info.permittedTypes) + { + var->setFullType(permittedType); + builder.setInsertBefore( + entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + + // get uses before we `tryConvertValue` since this creates a new use + List uses; + for (auto use = var->firstUse; use; use = use->nextUse) + uses.add(use); + + auto convertedValue = tryConvertValue(builder, var, varType); + if (convertedValue == nullptr) + continue; + + foundAConversion = true; + copyNameHintAndDebugDecorations(convertedValue, var); + + for (auto use : uses) + builder.replaceOperand(use, convertedValue); + } + if (!foundAConversion) + { + // If we can't convert the value, report an error. + for (auto permittedType : info.permittedTypes) + { + StringBuilder typeNameSB; + getTypeNameHint(typeNameSB, permittedType); + m_sink->diagnose( + var->sourceLoc, + Diagnostics::systemValueTypeIncompatible, + semanticName, + typeNameSB.produceString()); + } + } } } +private: + const bool hoistParameters; + HashSet semanticInfoToRemove; + void removeSemanticLayoutsFromLegalizedStructs() { // Metal and WGSL does not allow duplicate attributes to appear in the same shader. @@ -1958,26 +2092,9 @@ class LegalizeShaderEntryPointContext } } - // For Metal geometric stages, we need to translate VaryingInput offsets to - // MetalAttribute offsets. - if (isGeometryStage && isTargetMetal()) + if (isGeometryStage) { - IRVarLayout::Builder elementVarLayoutBuilder( - &builder, - paramVarLayout->getTypeLayout()); - elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout); - for (auto offsetAttr : paramVarLayout->getOffsetAttrs()) - { - auto resourceKind = offsetAttr->getResourceKind(); - if (resourceKind == LayoutResourceKind::VaryingInput) - { - resourceKind = LayoutResourceKind::MetalAttribute; - } - auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind); - resInfo->offset = offsetAttr->getOffset(); - resInfo->space = offsetAttr->getSpace(); - } - paramVarLayout = elementVarLayoutBuilder.build(); + paramVarLayout = handleGeometryStageParameterVarLayout(builder, paramVarLayout); } layoutBuilder.addField(key, paramVarLayout); @@ -2313,15 +2430,16 @@ class LegalizeShaderEntryPointContext // step 3a auto newKey = builder.createStructKey(); - if (isTargetMetal()) - { - copyNameHintAndDebugDecorations(newKey, oldKey); - } - else - { - SLANG_ASSERT(isTargetWGSL()); - oldKey->transferDecorationsTo(newKey); - } + flattenNestedStructsTransferKeyDecorations(newKey, oldKey); + // if (isTargetMetal()) + // { + // copyNameHintAndDebugDecorations(newKey, oldKey); + // } + // else + // { + // SLANG_ASSERT(isTargetWGSL()); + // oldKey->transferDecorationsTo(newKey); + // } auto newField = builder.createStructField(dst, newKey, oldField->getFieldType()); copyNameHintAndDebugDecorations(newField, oldField); @@ -2606,9 +2724,12 @@ class LegalizeShaderEntryPointContext bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex); auto loweredName = String(outName).toLower(); - auto loweredNameSlice = isTargetMetal() || !isUserSemantic - ? loweredName.getUnownedSlice() - : wgslContext.userSemanticName; + // auto loweredNameSlice = isTargetMetal() || !isUserSemantic + // ? loweredName.getUnownedSlice() + // : wgslContext.userSemanticName; + + auto loweredNameSlice = getUserSemanticName(loweredName, isUserSemantic); + auto semanticIndex = hasStringIndex ? stringToInt(outIndex) : semanticDecoration->getSemanticIndex(); auto newDecoration = @@ -2792,22 +2913,23 @@ class LegalizeShaderEntryPointContext case Stage::Compute: case Stage::Fragment: { - if (isTargetMetal()) - { - builder.addTargetSystemValueDecoration(key, toSlice("color(0)")); - } - else - { - SLANG_ASSERT(isTargetWGSL()); - IRInst* operands[] = { - builder.getStringValue(wgslContext.userSemanticName), - builder.getIntValue(builder.getIntType(), 0)}; - builder.addDecoration( - key, - kIROp_SemanticDecoration, - operands, - SLANG_COUNT_OF(operands)); - } + addFragmentShaderReturnValueDecoration(builder, key); + // if (isTargetMetal()) + // { + // builder.addTargetSystemValueDecoration(key, toSlice("color(0)")); + // } + // else + // { + // SLANG_ASSERT(isTargetWGSL()); + // IRInst* operands[] = { + // builder.getStringValue(wgslContext.userSemanticName), + // builder.getIntValue(builder.getIntType(), 0)}; + // builder.addDecoration( + // key, + // kIROp_SemanticDecoration, + // operands, + // SLANG_COUNT_OF(operands)); + // } break; } case Stage::Vertex: @@ -2862,248 +2984,6 @@ class LegalizeShaderEntryPointContext return builder.emitCast(toType, val); } - struct SystemValLegalizationWorkItem - { - IRInst* var; - - // Only used for WGSL. - IRType* varType; - - String attrName; - UInt attrIndex; - }; - - // varType is only valid for WGSL. - std::optional tryToMakeSystemValWorkItem( - IRInst* var, - IRType* varType) - { - if (auto semanticDecoration = var->findDecoration()) - { - if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) - { - return { - {var, - varType, - String(semanticDecoration->getSemanticName()).toLower(), - (UInt)semanticDecoration->getSemanticIndex()}}; - } - } - - auto layoutDecor = var->findDecoration(); - if (!layoutDecor) - return {}; - auto sysValAttr = layoutDecor->findAttr(); - if (!sysValAttr) - return {}; - auto semanticName = String(sysValAttr->getName()); - auto sysAttrIndex = sysValAttr->getIndex(); - - return {{var, varType, semanticName, sysAttrIndex}}; - } - - List collectSystemValFromEntryPoint(EntryPointInfo entryPoint) - { - List systemValWorkItems; - for (auto param : entryPoint.entryPointFunc->getParams()) - { - if (isTargetWGSL()) - { - if (auto structType = as(param->getDataType())) - { - for (auto field : structType->getFields()) - { - // Nested struct-s are flattened already by flattenInputParameters(). - SLANG_ASSERT(!as(field->getFieldType())); - - auto key = field->getKey(); - auto fieldType = field->getFieldType(); - auto maybeWorkItem = tryToMakeSystemValWorkItem(key, fieldType); - if (maybeWorkItem.has_value()) - systemValWorkItems.add(std::move(maybeWorkItem.value())); - } - continue; - } - } - auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType()); - if (maybeWorkItem.has_value()) - systemValWorkItems.add(std::move(maybeWorkItem.value())); - } - return systemValWorkItems; - } - - void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem) - { - IRBuilder builder(entryPoint.entryPointFunc); - - auto var = workItem.var; - auto varType = workItem.varType; - auto semanticName = workItem.attrName; - - auto indexAsString = String(workItem.attrIndex); - SystemValueInfo info = getSystemValueInfo(semanticName, &indexAsString, var); - if (info.isSpecial) - { - SLANG_ASSERT(isTargetMetal()); - if (info.systemValueNameEnum == SystemValueSemanticName::InnerCoverage) - { - // Metal does not support conservative rasterization, so this is always false. - auto val = builder.getBoolValue(false); - var->replaceUsesWith(val); - var->removeAndDeallocate(); - } - else if (info.systemValueNameEnum == SystemValueSemanticName::GroupIndex) - { - // Ensure we have a cached "sv_groupthreadid" in our entry point - if (!metalContext.entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc)) - { - auto systemValWorkItems = collectSystemValFromEntryPoint(entryPoint); - for (auto i : systemValWorkItems) - { - auto indexAsStringGroupThreadId = String(i.attrIndex); - if (getSystemValueInfo(i.attrName, &indexAsStringGroupThreadId, i.var) - .systemValueNameEnum == SystemValueSemanticName::GroupThreadID) - { - metalContext.entryPointToGroupThreadId[entryPoint.entryPointFunc] = - i.var; - } - } - if (!metalContext.entryPointToGroupThreadId.containsKey( - entryPoint.entryPointFunc)) - { - // Add the missing groupthreadid needed to compute sv_groupindex - IRBuilder groupThreadIdBuilder(builder); - groupThreadIdBuilder.setInsertInto( - entryPoint.entryPointFunc->getFirstBlock()); - auto groupThreadId = groupThreadIdBuilder.emitParamAtHead( - getMetalGroupThreadIdType(groupThreadIdBuilder)); - metalContext.entryPointToGroupThreadId[entryPoint.entryPointFunc] = - groupThreadId; - groupThreadIdBuilder.addNameHintDecoration( - groupThreadId, - metalContext.groupThreadIDString); - - // Since "sv_groupindex" will be translated out to a global var and no - // longer be considered a system value we can reuse its layout and semantic - // info - Index foundRequiredDecorations = 0; - IRLayoutDecoration* layoutDecoration = nullptr; - UInt semanticIndex = 0; - for (auto decoration : var->getDecorations()) - { - if (auto layoutDecorationTmp = as(decoration)) - { - layoutDecoration = layoutDecorationTmp; - foundRequiredDecorations++; - } - else if (auto semanticDecoration = as(decoration)) - { - semanticIndex = semanticDecoration->getSemanticIndex(); - groupThreadIdBuilder.addSemanticDecoration( - groupThreadId, - metalContext.groupThreadIDString, - (int)semanticIndex); - foundRequiredDecorations++; - } - if (foundRequiredDecorations >= 2) - break; - } - SLANG_ASSERT(layoutDecoration); - layoutDecoration->removeFromParent(); - layoutDecoration->insertAtStart(groupThreadId); - SystemValLegalizationWorkItem newWorkItem = { - groupThreadId, - groupThreadId->getFullType(), - metalContext.groupThreadIDString, - semanticIndex}; - legalizeSystemValue(entryPoint, newWorkItem); - } - } - - IRBuilder svBuilder(builder.getModule()); - svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); - auto computeExtent = emitCalcGroupExtents( - svBuilder, - entryPoint.entryPointFunc, - builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 3))); - auto groupIndexCalc = emitCalcGroupIndex( - svBuilder, - metalContext.entryPointToGroupThreadId[entryPoint.entryPointFunc], - computeExtent); - svBuilder.addNameHintDecoration( - groupIndexCalc, - UnownedStringSlice("sv_groupindex")); - - var->replaceUsesWith(groupIndexCalc); - var->removeAndDeallocate(); - } - } - - if (info.isUnsupported) - { - reportUnsupportedSystemAttribute(var, semanticName); - return; - } - if (!info.permittedTypes.getCount()) - return; - - builder.addTargetSystemValueDecoration(var, info.systemValueName.getUnownedSlice()); - - bool varTypeIsPermitted = false; - for (auto& permittedType : info.permittedTypes) - { - varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; - } - - if (!varTypeIsPermitted) - { - // Note: we do not currently prefer any conversion - // example: - // * allowed types for semantic: `float4`, `uint4`, `int4` - // * user used, `float2` - // * Slang will equally prefer `float4` to `uint4` to `int4`. - // This means the type may lose data if slang selects `uint4` or `int4`. - bool foundAConversion = false; - for (auto permittedType : info.permittedTypes) - { - var->setFullType(permittedType); - builder.setInsertBefore( - entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); - - // get uses before we `tryConvertValue` since this creates a new use - List uses; - for (auto use = var->firstUse; use; use = use->nextUse) - uses.add(use); - - auto convertedValue = tryConvertValue(builder, var, varType); - if (convertedValue == nullptr) - continue; - - foundAConversion = true; - copyNameHintAndDebugDecorations(convertedValue, var); - - for (auto use : uses) - builder.replaceOperand(use, convertedValue); - } - if (!foundAConversion) - { - // If we can't convert the value, report an error. - for (auto permittedType : info.permittedTypes) - { - StringBuilder typeNameSB; - getTypeNameHint(typeNameSB, permittedType); - m_sink->diagnose( - var->sourceLoc, - Diagnostics::systemValueTypeIncompatible, - semanticName, - typeNameSB.produceString()); - } - } - } - } - void legalizeSystemValueParameters(EntryPointInfo entryPoint) { List systemValWorkItems = @@ -3121,10 +3001,10 @@ class LegalizeShaderEntryPointContext // If the entrypoint is receiving varying inputs as a pointer, turn it into a value. depointerizeInputParams(entryPoint.entryPointFunc); - // TODO FIXME: Enable these for WGSL + // TODO FIXME: Enable these for WGSL and remove the `hoistParemeters` member field. // WGSL entry point legalization currently only applies attributes to struct parameters, // apply the same hoisting from Metal to WGSL to fix it. - if (isTargetMetal()) + if (hoistParameters) { hoistEntryPointParameterFromStruct(entryPoint); packStageInParameters(entryPoint); @@ -3144,71 +3024,32 @@ class LegalizeShaderEntryPointContext switch (entryPoint.entryPointDecor->getProfile().getStage()) { case Stage::Amplification: - SLANG_ASSERT(isTargetMetal()); - legalizeMetalDispatchMeshPayload(entryPoint); + legalizeAmplificationStageEntryPoint(entryPoint); break; case Stage::Mesh: - SLANG_ASSERT(isTargetMetal()); - legalizeMetalMeshEntryPoint(entryPoint); + legalizeMeshStageEntryPoint(entryPoint); break; default: break; } } +}; - // ****************************************************************** - // Metal specific Legalization Logic - // ****************************************************************** - - struct MetalContext - { - ShortList permittedTypes_sv_target; - Dictionary entryPointToGroupThreadId; - const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid"); - } metalContext; - - IRType* getMetalGroupThreadIdType(IRBuilder& builder) - { - SLANG_ASSERT(isTargetMetal()); - - return builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3)); - } - - // Get all permitted types of "sv_target" for Metal - ShortList& getMetalPermittedTypes_sv_target(IRBuilder& builder) +class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext +{ +public: + LegalizeMetalEntryPointContext(IRModule* module, DiagnosticSink* sink) + : LegalizeShaderEntryPointContext(module, sink, true) { - SLANG_ASSERT(isTargetMetal()); - - metalContext.permittedTypes_sv_target.reserveOverflowBuffer(5 * 4); - if (metalContext.permittedTypes_sv_target.getCount() == 0) - { - for (auto baseType : - {BaseType::Float, - BaseType::Half, - BaseType::Int, - BaseType::UInt, - BaseType::Int16, - BaseType::UInt16}) - { - for (IRIntegerValue i = 1; i <= 4; i++) - { - metalContext.permittedTypes_sv_target.add( - builder.getVectorType(builder.getBasicType(baseType), i)); - } - } - } - return metalContext.permittedTypes_sv_target; + generatePermittedTypes_sv_target(); } - SystemValueInfo getMetalSystemValueInfo( +protected: + SystemValueInfo getSystemValueInfo( String inSemanticName, String* optionalSemanticIndex, - IRInst* parentVar) + IRInst* parentVar) const SLANG_OVERRIDE { - SLANG_ASSERT(isTargetMetal()); - IRBuilder builder(m_module); SystemValueInfo result = {}; UnownedStringSlice semanticName; @@ -3305,7 +3146,7 @@ class LegalizeShaderEntryPointContext case SystemValueSemanticName::GroupThreadID: { result.systemValueName = toSlice("thread_position_in_threadgroup"); - result.permittedTypes.add(getMetalGroupThreadIdType(builder)); + result.permittedTypes.add(getGroupThreadIdType(builder)); break; } case SystemValueSemanticName::GSInstanceID: @@ -3393,7 +3234,7 @@ class LegalizeShaderEntryPointContext << "color(" << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0")) << ")") .produceString(); - result.permittedTypes = getMetalPermittedTypes_sv_target(builder); + result.permittedTypes = permittedTypes_sv_target; break; } @@ -3419,10 +3260,163 @@ class LegalizeShaderEntryPointContext return result; } - void legalizeMetalDispatchMeshPayload(EntryPointInfo entryPoint) + + List collectSystemValFromEntryPoint( + EntryPointInfo entryPoint) const SLANG_OVERRIDE { - SLANG_ASSERT(isTargetMetal()); + List systemValWorkItems; + for (auto param : entryPoint.entryPointFunc->getParams()) + { + auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType()); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + return systemValWorkItems; + } + + void flattenNestedStructsTransferKeyDecorations(IRInst* newKey, IRInst* oldKey) const + SLANG_OVERRIDE + { + + copyNameHintAndDebugDecorations(newKey, oldKey); + } + + UnownedStringSlice getUserSemanticName(String& loweredName, bool isUserSemantic) const + SLANG_OVERRIDE + { + SLANG_UNUSED(isUserSemantic); + return loweredName.getUnownedSlice(); + }; + + void addFragmentShaderReturnValueDecoration(IRBuilder& builder, IRInst* returnValueStructKey) + const SLANG_OVERRIDE + { + + builder.addTargetSystemValueDecoration(returnValueStructKey, toSlice("color(0)")); + } + + IRVarLayout* handleGeometryStageParameterVarLayout( + IRBuilder& builder, + IRVarLayout* paramVarLayout) const SLANG_OVERRIDE + { + // For Metal geometric stages, we need to translate VaryingInput offsets to + // MetalAttribute offsets. + IRVarLayout::Builder elementVarLayoutBuilder(&builder, paramVarLayout->getTypeLayout()); + elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout); + for (auto offsetAttr : paramVarLayout->getOffsetAttrs()) + { + auto resourceKind = offsetAttr->getResourceKind(); + if (resourceKind == LayoutResourceKind::VaryingInput) + { + resourceKind = LayoutResourceKind::MetalAttribute; + } + auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind); + resInfo->offset = offsetAttr->getOffset(); + resInfo->space = offsetAttr->getSpace(); + } + + return elementVarLayoutBuilder.build(); + } + + void handleSpecialSystemValue( + const EntryPointInfo& entryPoint, + SystemValLegalizationWorkItem& workItem, + const SystemValueInfo& info, + IRBuilder& builder) SLANG_OVERRIDE + { + auto var = workItem.var; + + if (info.systemValueNameEnum == SystemValueSemanticName::InnerCoverage) + { + // Metal does not support conservative rasterization, so this is always false. + auto val = builder.getBoolValue(false); + var->replaceUsesWith(val); + var->removeAndDeallocate(); + } + else if (info.systemValueNameEnum == SystemValueSemanticName::GroupIndex) + { + // Ensure we have a cached "sv_groupthreadid" in our entry point + if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc)) + { + auto systemValWorkItems = collectSystemValFromEntryPoint(entryPoint); + for (auto i : systemValWorkItems) + { + auto indexAsStringGroupThreadId = String(i.attrIndex); + if (getSystemValueInfo(i.attrName, &indexAsStringGroupThreadId, i.var) + .systemValueNameEnum == SystemValueSemanticName::GroupThreadID) + { + entryPointToGroupThreadId[entryPoint.entryPointFunc] = i.var; + } + } + if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc)) + { + // Add the missing groupthreadid needed to compute sv_groupindex + IRBuilder groupThreadIdBuilder(builder); + groupThreadIdBuilder.setInsertInto(entryPoint.entryPointFunc->getFirstBlock()); + auto groupThreadId = groupThreadIdBuilder.emitParamAtHead( + getGroupThreadIdType(groupThreadIdBuilder)); + entryPointToGroupThreadId[entryPoint.entryPointFunc] = groupThreadId; + groupThreadIdBuilder.addNameHintDecoration(groupThreadId, groupThreadIDString); + + // Since "sv_groupindex" will be translated out to a global var and no + // longer be considered a system value we can reuse its layout and + // semantic info + Index foundRequiredDecorations = 0; + IRLayoutDecoration* layoutDecoration = nullptr; + UInt semanticIndex = 0; + for (auto decoration : var->getDecorations()) + { + if (auto layoutDecorationTmp = as(decoration)) + { + layoutDecoration = layoutDecorationTmp; + foundRequiredDecorations++; + } + else if (auto semanticDecoration = as(decoration)) + { + semanticIndex = semanticDecoration->getSemanticIndex(); + groupThreadIdBuilder.addSemanticDecoration( + groupThreadId, + groupThreadIDString, + (int)semanticIndex); + foundRequiredDecorations++; + } + if (foundRequiredDecorations >= 2) + break; + } + SLANG_ASSERT(layoutDecoration); + layoutDecoration->removeFromParent(); + layoutDecoration->insertAtStart(groupThreadId); + SystemValLegalizationWorkItem newWorkItem = { + groupThreadId, + groupThreadId->getFullType(), + groupThreadIDString, + semanticIndex}; + legalizeSystemValue(entryPoint, newWorkItem); + } + } + + IRBuilder svBuilder(builder.getModule()); + svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); + auto computeExtent = emitCalcGroupExtents( + svBuilder, + entryPoint.entryPointFunc, + builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 3))); + auto groupIndexCalc = emitCalcGroupIndex( + svBuilder, + entryPointToGroupThreadId[entryPoint.entryPointFunc], + computeExtent); + svBuilder.addNameHintDecoration(groupIndexCalc, UnownedStringSlice("sv_groupindex")); + + var->replaceUsesWith(groupIndexCalc); + var->removeAndDeallocate(); + } + } + + void legalizeAmplificationStageEntryPoint(const EntryPointInfo& entryPoint) const SLANG_OVERRIDE + { // Find out DispatchMesh function IRGlobalValueWithCode* dispatchMeshFunc = nullptr; for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts()) @@ -3487,10 +3481,8 @@ class LegalizeShaderEntryPointContext }); } - void legalizeMetalMeshEntryPoint(EntryPointInfo entryPoint) + void legalizeMeshStageEntryPoint(const EntryPointInfo& entryPoint) const SLANG_OVERRIDE { - SLANG_ASSERT(isTargetMetal()); - auto func = entryPoint.entryPointFunc; IRBuilder builder{func->getModule()}; @@ -3651,22 +3643,57 @@ class LegalizeShaderEntryPointContext } } - // ****************************************************************** - // WGSL specific Legalization Logic - // ****************************************************************** +private: + ShortList permittedTypes_sv_target; + Dictionary entryPointToGroupThreadId; + const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid"); + + static IRType* getGroupThreadIdType(IRBuilder& builder) + { + return builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3)); + } + + void generatePermittedTypes_sv_target() + { + IRBuilder builder(m_module); + permittedTypes_sv_target.reserveOverflowBuffer(5 * 4); + if (permittedTypes_sv_target.getCount() == 0) + { + for (auto baseType : + {BaseType::Float, + BaseType::Half, + BaseType::Int, + BaseType::UInt, + BaseType::Int16, + BaseType::UInt16}) + { + for (IRIntegerValue i = 1; i <= 4; i++) + { + permittedTypes_sv_target.add( + builder.getVectorType(builder.getBasicType(baseType), i)); + } + } + } + } +}; + - struct WGSLContext +class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext +{ +public: + LegalizeWGSLEntryPointContext(IRModule* module, DiagnosticSink* sink) + : LegalizeShaderEntryPointContext(module, sink, false) { - UnownedStringSlice userSemanticName = toSlice("user_semantic"); - } wgslContext; + } - SystemValueInfo getWGSLSystemValueInfo( +protected: + SystemValueInfo getSystemValueInfo( String inSemanticName, String* optionalSemanticIndex, - IRInst* parentVar) + IRInst* parentVar) const SLANG_OVERRIDE { - SLANG_ASSERT(isTargetWGSL()); - IRBuilder builder(m_module); SystemValueInfo result = {}; UnownedStringSlice semanticName; @@ -3850,6 +3877,64 @@ class LegalizeShaderEntryPointContext return result; } + void flattenNestedStructsTransferKeyDecorations(IRInst* newKey, IRInst* oldKey) const + SLANG_OVERRIDE + { + oldKey->transferDecorationsTo(newKey); + } + + UnownedStringSlice getUserSemanticName(String& loweredName, bool isUserSemantic) const + SLANG_OVERRIDE + { + + return isUserSemantic ? userSemanticName : loweredName.getUnownedSlice(); + } + + void addFragmentShaderReturnValueDecoration(IRBuilder& builder, IRInst* returnValueStructKey) + const SLANG_OVERRIDE + { + IRInst* operands[] = { + builder.getStringValue(userSemanticName), + builder.getIntValue(builder.getIntType(), 0)}; + builder.addDecoration( + returnValueStructKey, + kIROp_SemanticDecoration, + operands, + SLANG_COUNT_OF(operands)); + }; + + List collectSystemValFromEntryPoint( + EntryPointInfo entryPoint) const SLANG_OVERRIDE + { + List systemValWorkItems; + for (auto param : entryPoint.entryPointFunc->getParams()) + { + if (auto structType = as(param->getDataType())) + { + for (auto field : structType->getFields()) + { + // Nested struct-s are flattened already by flattenInputParameters(). + SLANG_ASSERT(!as(field->getFieldType())); + + auto key = field->getKey(); + auto fieldType = field->getFieldType(); + auto maybeWorkItem = tryToMakeSystemValWorkItem(key, fieldType); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + continue; + } + + auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType()); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + + return systemValWorkItems; + } + +private: + const UnownedStringSlice userSemanticName = toSlice("user_semantic"); }; void legalizeEntryPointVaryingParamsForMetal( @@ -3857,10 +3942,7 @@ void legalizeEntryPointVaryingParamsForMetal( DiagnosticSink* sink, List& entryPoints) { - LegalizeShaderEntryPointContext context( - module, - sink, - LegalizeShaderEntryPointContext::LegalizeTarget::Metal); + LegalizeMetalEntryPointContext context(module, sink); context.legalizeEntryPoints(entryPoints); } @@ -3869,10 +3951,7 @@ void legalizeEntryPointVaryingParamsForWGSL( DiagnosticSink* sink, List& entryPoints) { - LegalizeShaderEntryPointContext context( - module, - sink, - LegalizeShaderEntryPointContext::LegalizeTarget::WGSL); + LegalizeWGSLEntryPointContext context(module, sink); context.legalizeEntryPoints(entryPoints); } From 41fd3b9739ae5667e5b5e96308f4209e417fdcd6 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sat, 11 Jan 2025 17:28:57 -0500 Subject: [PATCH 07/18] remove extra includes --- source/slang/slang-ir-legalize-varying-params.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 9ddfd8bc74..ca2eca0c0f 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -4,9 +4,7 @@ #include "slang-ir-clone.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" -#include "slang-ir.h" #include "slang-parameter-binding.h" -#include "slang.h" #include From 2a86fa4ce381754d708f1786fdf827e6769e7cb4 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sat, 11 Jan 2025 17:31:41 -0500 Subject: [PATCH 08/18] remove dead code comments --- .../slang-ir-legalize-varying-params.cpp | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index ca2eca0c0f..db1c71115c 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -2429,15 +2429,6 @@ class LegalizeShaderEntryPointContext // step 3a auto newKey = builder.createStructKey(); flattenNestedStructsTransferKeyDecorations(newKey, oldKey); - // if (isTargetMetal()) - // { - // copyNameHintAndDebugDecorations(newKey, oldKey); - // } - // else - // { - // SLANG_ASSERT(isTargetWGSL()); - // oldKey->transferDecorationsTo(newKey); - // } auto newField = builder.createStructField(dst, newKey, oldField->getFieldType()); copyNameHintAndDebugDecorations(newField, oldField); @@ -2722,10 +2713,6 @@ class LegalizeShaderEntryPointContext bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex); auto loweredName = String(outName).toLower(); - // auto loweredNameSlice = isTargetMetal() || !isUserSemantic - // ? loweredName.getUnownedSlice() - // : wgslContext.userSemanticName; - auto loweredNameSlice = getUserSemanticName(loweredName, isUserSemantic); auto semanticIndex = @@ -2912,22 +2899,6 @@ class LegalizeShaderEntryPointContext case Stage::Fragment: { addFragmentShaderReturnValueDecoration(builder, key); - // if (isTargetMetal()) - // { - // builder.addTargetSystemValueDecoration(key, toSlice("color(0)")); - // } - // else - // { - // SLANG_ASSERT(isTargetWGSL()); - // IRInst* operands[] = { - // builder.getStringValue(wgslContext.userSemanticName), - // builder.getIntValue(builder.getIntType(), 0)}; - // builder.addDecoration( - // key, - // kIROp_SemanticDecoration, - // operands, - // SLANG_COUNT_OF(operands)); - // } break; } case Stage::Vertex: From 30824736ea9960195232bd4df0b2dec0815674fe Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sat, 11 Jan 2025 17:38:12 -0500 Subject: [PATCH 09/18] minor cleanup --- .../slang-ir-legalize-varying-params.cpp | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index db1c71115c..e267e8343c 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -1550,11 +1550,6 @@ void depointerizeInputParams(IRFunc* entryPointFunc) class LegalizeShaderEntryPointContext { public: - LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink, bool hoistParameters) - : m_module(module), m_sink(sink), hoistParameters(hoistParameters) - { - } - void legalizeEntryPoints(List& entryPoints) { for (auto entryPoint : entryPoints) @@ -1563,6 +1558,11 @@ class LegalizeShaderEntryPointContext } protected: + LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink, bool hoistParameters) + : m_module(module), m_sink(sink), hoistParameters(hoistParameters) + { + } + IRModule* m_module; DiagnosticSink* m_sink; @@ -1596,7 +1596,7 @@ class LegalizeShaderEntryPointContext virtual void flattenNestedStructsTransferKeyDecorations(IRInst* newKey, IRInst* oldKey) const = 0; - virtual UnownedStringSlice getUserSemanticName(String& loweredName, bool isUserSemantic) + virtual UnownedStringSlice getUserSemanticNameSlice(String& loweredName, bool isUserSemantic) const = 0; virtual void addFragmentShaderReturnValueDecoration( @@ -2713,8 +2713,7 @@ class LegalizeShaderEntryPointContext bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex); auto loweredName = String(outName).toLower(); - auto loweredNameSlice = getUserSemanticName(loweredName, isUserSemantic); - + auto loweredNameSlice = getUserSemanticNameSlice(loweredName, isUserSemantic); auto semanticIndex = hasStringIndex ? stringToInt(outIndex) : semanticDecoration->getSemanticIndex(); auto newDecoration = @@ -3247,11 +3246,10 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext void flattenNestedStructsTransferKeyDecorations(IRInst* newKey, IRInst* oldKey) const SLANG_OVERRIDE { - copyNameHintAndDebugDecorations(newKey, oldKey); } - UnownedStringSlice getUserSemanticName(String& loweredName, bool isUserSemantic) const + UnownedStringSlice getUserSemanticNameSlice(String& loweredName, bool isUserSemantic) const SLANG_OVERRIDE { SLANG_UNUSED(isUserSemantic); @@ -3261,7 +3259,6 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext void addFragmentShaderReturnValueDecoration(IRBuilder& builder, IRInst* returnValueStructKey) const SLANG_OVERRIDE { - builder.addTargetSystemValueDecoration(returnValueStructKey, toSlice("color(0)")); } @@ -3294,7 +3291,7 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext const SystemValueInfo& info, IRBuilder& builder) SLANG_OVERRIDE { - auto var = workItem.var; + const auto var = workItem.var; if (info.systemValueNameEnum == SystemValueSemanticName::InnerCoverage) { @@ -3852,10 +3849,9 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext oldKey->transferDecorationsTo(newKey); } - UnownedStringSlice getUserSemanticName(String& loweredName, bool isUserSemantic) const + UnownedStringSlice getUserSemanticNameSlice(String& loweredName, bool isUserSemantic) const SLANG_OVERRIDE { - return isUserSemantic ? userSemanticName : loweredName.getUnownedSlice(); } From 06db88ef7001bdfe93fb23af35af0d026b255dee Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 14 Jan 2025 19:53:51 -0500 Subject: [PATCH 10/18] squash merge from master and resolve conflict --- cmake/SlangTarget.cmake | 12 +- docs/cuda-target.md | 11 + docs/user-guide/03-convenience-features.md | 8 +- external/slang-rhi | 2 +- source/compiler-core/slang-nvrtc-compiler.cpp | 257 +++++++++++++----- source/slang/hlsl.meta.slang | 16 +- source/slang/slang-ast-modifier.h | 20 +- source/slang/slang-check-impl.h | 2 + source/slang/slang-check-modifier.cpp | 108 ++++++-- source/slang/slang-diagnostic-defs.h | 8 +- source/slang/slang-emit-c-like.cpp | 40 ++- source/slang/slang-emit-c-like.h | 13 +- source/slang/slang-emit-glsl.cpp | 16 +- source/slang/slang-emit-spirv.cpp | 55 ++-- source/slang/slang-ir-autodiff-rev.cpp | 10 +- .../slang-ir-collect-global-uniforms.cpp | 10 + source/slang/slang-ir-insts.h | 11 +- .../slang-ir-legalize-varying-params.cpp | 16 +- source/slang/slang-ir-simplify-cfg.cpp | 12 +- source/slang/slang-ir-specialize.cpp | 45 +++ .../slang-ir-translate-glsl-global-var.cpp | 17 +- source/slang/slang-ir-util.cpp | 13 + source/slang/slang-ir-util.h | 2 + source/slang/slang-lower-to-ir.cpp | 46 +++- source/slang/slang-parser.cpp | 11 +- source/slang/slang-reflection-api.cpp | 20 +- source/slang/slang.cpp | 3 +- tests/autodiff/out-parameters-2.slang | 49 ++++ tests/bugs/simplify-if-else.slang | 26 ++ .../diagnostics/missing-return.slang.expected | 4 +- tests/glsl/compute-shader-layout-id.slang | 19 ++ tests/spirv/spec-constant-numthreads.slang | 35 +++ 32 files changed, 755 insertions(+), 162 deletions(-) create mode 100644 tests/autodiff/out-parameters-2.slang create mode 100644 tests/bugs/simplify-if-else.slang create mode 100644 tests/glsl/compute-shader-layout-id.slang create mode 100644 tests/spirv/spec-constant-numthreads.slang diff --git a/cmake/SlangTarget.cmake b/cmake/SlangTarget.cmake index 45e7cf1e1d..eae5cf35e4 100644 --- a/cmake/SlangTarget.cmake +++ b/cmake/SlangTarget.cmake @@ -505,10 +505,14 @@ function(slang_add_target dir type) endif() install( TARGETS ${target} ${export_args} - ARCHIVE DESTINATION ${archive_subdir} ${ARGN} - LIBRARY DESTINATION ${library_subdir} ${ARGN} - RUNTIME DESTINATION ${runtime_subdir} ${ARGN} - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ${ARGN} + ARCHIVE DESTINATION ${archive_subdir} + ${ARGN} + LIBRARY DESTINATION ${library_subdir} + ${ARGN} + RUNTIME DESTINATION ${runtime_subdir} + ${ARGN} + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + ${ARGN} ) endmacro() diff --git a/docs/cuda-target.md b/docs/cuda-target.md index a80dc59f9c..241f253fbe 100644 --- a/docs/cuda-target.md +++ b/docs/cuda-target.md @@ -301,6 +301,17 @@ There is potential to calculate the lane id using the [numthreads] markup in Sla * Intrinsics which only work in pixel shaders + QuadXXXX intrinsics +OptiX Support +============= + +Slang supports OptiX for raytracing. To compile raytracing programs, NVRTC must have access to the `optix.h` and dependent files that are typically distributed as part of the OptiX SDK. When Slang detects the use of raytracing in source, it will define `SLANG_CUDA_ENABLE_OPTIX` when `slang-cuda-prelude.h` is included. This will in turn try to include `optix.h`. + +Slang tries several mechanisms to locate `optix.h` when NVRTC is initiated. The first mechanism is to look in the include paths that are passed to Slang. If `optix.h` can be found in one of these paths, no more searching will be performed. + +If this fails, the default OptiX SDK install locations are searched. On Windows this is `%{PROGRAMDATA}\NVIDIA Corporation\OptiX SDK X.X.X\include`. On Linux this is `${HOME}/NVIDIA-OptiX-SDK-X.X.X-suffix`. + +If OptiX headers cannot be found, compilation will fail. + Limitations =========== diff --git a/docs/user-guide/03-convenience-features.md b/docs/user-guide/03-convenience-features.md index e6b337eed1..29e8fd2aaa 100644 --- a/docs/user-guide/03-convenience-features.md +++ b/docs/user-guide/03-convenience-features.md @@ -149,7 +149,7 @@ int rs = foo.staticMethod(a,b); ### Mutability of member function -For GPU performance considerations, the `this` argument in a member function is immutable by default. If you modify the content in `this` argument, the modification will be discarded after the call and does not affect the input object. If you intend to define a member function that mutates the object, use `[mutating]` attribute on the member function as shown in the following example. +For GPU performance considerations, the `this` argument in a member function is immutable by default. Attempting to modify `this` will result in a compile error. If you intend to define a member function that mutates the object, use `[mutating]` attribute on the member function as shown in the following example. ```hlsl struct Foo @@ -159,14 +159,14 @@ struct Foo [mutating] void setCount(int x) { count = x; } - void setCount2(int x) { count = x; } + // This would fail to compile. + // void setCount2(int x) { count = x; } } void test() { Foo f; - f.setCount(1); // f.count is 1 after the call. - f.setCount2(2); // f.count is still 1 after the call. + f.setCount(1); // Compiles } ``` diff --git a/external/slang-rhi b/external/slang-rhi index 19bc575bc1..d1f2718165 160000 --- a/external/slang-rhi +++ b/external/slang-rhi @@ -1 +1 @@ -Subproject commit 19bc575bc193e92210649d6d84ac202b199b29af +Subproject commit d1f2718165d0d540c8fc1eacf20b9edd2d6faac0 diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp index c5ccc8e23a..0042ad7085 100644 --- a/source/compiler-core/slang-nvrtc-compiler.cpp +++ b/source/compiler-core/slang-nvrtc-compiler.cpp @@ -127,11 +127,14 @@ class NVRTCDownstreamCompiler : public DownstreamCompilerBase nvrtcProgram m_program; }; - SlangResult _findIncludePath(String& outIncludePath); + SlangResult _findCUDAIncludePath(String& outIncludePath); + SlangResult _getCUDAIncludePath(String& outIncludePath); - SlangResult _getIncludePath(String& outIncludePath); + SlangResult _findOptixIncludePath(String& outIncludePath); + SlangResult _getOptixIncludePath(String& outIncludePath); SlangResult _maybeAddHalfSupport(const CompileOptions& options, CommandLine& ioCmdLine); + SlangResult _maybeAddOptixSupport(const CompileOptions& options, CommandLine& ioCmdLine); #define SLANG_NVTRC_MEMBER_FUNCS(ret, name, params) ret(*m_##name) params; @@ -140,9 +143,16 @@ class NVRTCDownstreamCompiler : public DownstreamCompilerBase // Holds list of paths passed in where cuda_fp16.h is found. Does *NOT* include cuda_fp16.h. List m_cudaFp16FoundPaths; - bool m_includeSearched = false; + bool m_cudaIncludeSearched = false; // Holds location of where include (for cuda_fp16.h) is found. - String m_includePath; + String m_cudaIncludePath; + + // Holds list of paths passed in where optix.h is found. Does *NOT* include optix.h. + List m_optixFoundPaths; + + bool m_optixIncludeSearched = false; + // Holds location of where include (for optix.h) is found. + String m_optixIncludePath; ComPtr m_sharedLibrary; }; @@ -602,21 +612,8 @@ static SlangResult _findNVRTC(NVRTCPathVisitor& visitor) } static const UnownedStringSlice g_fp16HeaderName = UnownedStringSlice::fromLiteral("cuda_fp16.h"); +static const UnownedStringSlice g_optixHeaderName = UnownedStringSlice::fromLiteral("optix.h"); -SlangResult NVRTCDownstreamCompiler::_getIncludePath(String& outPath) -{ - if (!m_includeSearched) - { - m_includeSearched = true; - - SLANG_ASSERT(m_includePath.getLength() == 0); - - _findIncludePath(m_includePath); - } - - outPath = m_includePath; - return m_includePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; -} SlangResult _findFileInIncludePath( const String& path, @@ -650,7 +647,7 @@ SlangResult _findFileInIncludePath( return SLANG_E_NOT_FOUND; } -SlangResult NVRTCDownstreamCompiler::_findIncludePath(String& outPath) +SlangResult NVRTCDownstreamCompiler::_findCUDAIncludePath(String& outPath) { outPath = String(); @@ -711,6 +708,130 @@ SlangResult NVRTCDownstreamCompiler::_findIncludePath(String& outPath) return SLANG_E_NOT_FOUND; } +SlangResult NVRTCDownstreamCompiler::_getCUDAIncludePath(String& outPath) +{ + if (!m_cudaIncludeSearched) + { + m_cudaIncludeSearched = true; + + SLANG_ASSERT(m_cudaIncludePath.getLength() == 0); + + _findCUDAIncludePath(m_cudaIncludePath); + } + + outPath = m_cudaIncludePath; + return m_cudaIncludePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; +} + +SlangResult NVRTCDownstreamCompiler::_findOptixIncludePath(String& outPath) +{ + outPath = String(); + + List rootPaths; + +#if SLANG_WINDOWS_FAMILY + const char* searchPattern = "OptiX SDK *"; + StringBuilder builder; + if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable( + UnownedStringSlice::fromLiteral("PROGRAMDATA"), + builder))) + { + rootPaths.add(Path::combine(builder, "NVIDIA Corporation")); + } +#else + const char* searchPattern = "NVIDIA-OptiX-SDK-*"; + StringBuilder builder; + if (SLANG_SUCCEEDED( + PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("HOME"), builder))) + { + rootPaths.add(builder); + } +#endif + + struct OptixHeaders + { + String path; + SemanticVersion version; + }; + + // Visitor to find Optix headers. + struct Visitor : public Path::Visitor + { + const String& rootPath; + List& optixPaths; + Visitor(const String& rootPath, List& optixPaths) + : rootPath(rootPath), optixPaths(optixPaths) + { + } + void accept(Path::Type type, const UnownedStringSlice& path) SLANG_OVERRIDE + { + if (type != Path::Type::Directory) + return; + + OptixHeaders optixPath; +#if SLANG_WINDOWS_FAMILY + // Paths are expected to look like ".\OptiX SDK X.X.X" + auto versionString = path.subString(path.lastIndexOf(' ') + 1, path.getLength()); +#else + // Paths are expected to look like "./NVIDIA-OptiX-SDK-X.X.X-suffix" + auto versionString = path.subString(0, path.lastIndexOf('-')); + versionString = + versionString.subString(path.lastIndexOf('-') + 1, versionString.getLength()); +#endif + if (SLANG_SUCCEEDED(SemanticVersion::parse(versionString, '.', optixPath.version))) + { + optixPath.path = Path::combine(Path::combine(rootPath, path), "include"); + String optixHeader = Path::combine(optixPath.path, g_optixHeaderName); + if (File::exists(optixHeader)) + { + optixPaths.add(optixPath); + } + } + } + }; + + List optixPaths; + + for (const String& rootPath : rootPaths) + { + Visitor visitor(rootPath, optixPaths); + Path::find(rootPath, searchPattern, &visitor); + } + + // Find newest version + const OptixHeaders* newest = nullptr; + for (Index i = 0; i < optixPaths.getCount(); ++i) + { + if (!newest || optixPaths[i].version > newest->version) + { + newest = &optixPaths[i]; + } + } + + if (newest) + { + outPath = newest->path; + return SLANG_OK; + } + + return SLANG_E_NOT_FOUND; +} + +SlangResult NVRTCDownstreamCompiler::_getOptixIncludePath(String& outPath) +{ + if (!m_optixIncludeSearched) + { + m_optixIncludeSearched = true; + + SLANG_ASSERT(m_optixIncludePath.getLength() == 0); + + _findOptixIncludePath(m_optixIncludePath); + } + + outPath = m_optixIncludePath; + return m_optixIncludePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; +} + SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( const DownstreamCompileOptions& options, CommandLine& ioCmdLine) @@ -747,7 +868,7 @@ SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( } String includePath; - SLANG_RETURN_ON_FAIL(_getIncludePath(includePath)); + SLANG_RETURN_ON_FAIL(_getCUDAIncludePath(includePath)); // Add the found include path ioCmdLine.addArg("-I"); @@ -758,6 +879,48 @@ SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( return SLANG_OK; } +SlangResult NVRTCDownstreamCompiler::_maybeAddOptixSupport( + const DownstreamCompileOptions& options, + CommandLine& ioCmdLine) +{ + // First check if we know if one of the include paths contains optix.h + for (const auto& includePath : options.includePaths) + { + if (m_optixFoundPaths.indexOf(includePath) >= 0) + { + // Okay we have an include path that we know works. + // Just need to enable OptiX in prelude + ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); + return SLANG_OK; + } + } + + // Let's see if one of the paths finds optix.h + for (const auto& curIncludePath : options.includePaths) + { + const String includePath = asString(curIncludePath); + const String checkPath = Path::combine(includePath, g_optixHeaderName); + if (File::exists(checkPath)) + { + m_optixFoundPaths.add(includePath); + // Just need to enable OptiX in prelude + ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); + return SLANG_OK; + } + } + + String includePath; + SLANG_RETURN_ON_FAIL(_getOptixIncludePath(includePath)); + + // Add the found include path + ioCmdLine.addArg("-I"); + ioCmdLine.addArg(includePath); + + ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); + + return SLANG_OK; +} + SlangResult NVRTCDownstreamCompiler::compile( const DownstreamCompileOptions& inOptions, IArtifact** outArtifact) @@ -780,6 +943,9 @@ SlangResult NVRTCDownstreamCompiler::compile( CommandLine cmdLine; + // --dopt option is only available in CUDA 11.7 and later + bool hasDoptOption = m_desc.version >= SemanticVersion(11, 7); + switch (options.debugInfoType) { case DebugInfoType::None: @@ -789,12 +955,20 @@ SlangResult NVRTCDownstreamCompiler::compile( default: { cmdLine.addArg("--device-debug"); + if (hasDoptOption) + { + cmdLine.addArg("--dopt=on"); + } break; } case DebugInfoType::Maximal: { cmdLine.addArg("--device-debug"); cmdLine.addArg("--generate-line-info"); + if (hasDoptOption) + { + cmdLine.addArg("--dopt=on"); + } break; } } @@ -910,48 +1084,7 @@ SlangResult NVRTCDownstreamCompiler::compile( // if (options.pipelineType == PipelineType::RayTracing) { - // The device-side OptiX API is accessed through a constellation - // of headers provided by the OptiX SDK, so we need to set an - // include path for the compile that makes those visible. - // - // TODO: The OptiX SDK installer doesn't set any kind of environment - // variable to indicate where the SDK was installed, so we seemingly - // need to probe paths instead. The form of the path will differ - // betwene Windows and Unix-y platforms, and we will need some kind - // of approach to probe multiple versiosn and use the latest. - // - // HACK: For now I'm using the fixed path for the most recent SDK - // release on Windows. This means that OptiX cross-compilation will - // only "work" on a subset of platforms, but that doesn't matter - // for now since it doesn't really "work" at all. - // - cmdLine.addArg("-I"); - cmdLine.addArg("C:/ProgramData/NVIDIA Corporation/OptiX SDK 7.0.0/include/"); - - // The OptiX headers in turn `#include ` and expect that - // to work. We could try to also add in an include path from the CUDA - // SDK (which seems to provide a `stddef.h` in the most recent version), - // but using that version doesn't seem to work (and also bakes in a - // requirement that the user have the CUDA SDK installed in addition - // to the OptiX SDK). - // - // Instead, we will rely on the NVRTC feature that lets us set up - // memory buffers to be used as include files by the we compile. - // We will define a dummy `stddef.h` that includes the bare minimum - // lines required to get the OptiX headers to compile without complaint. - // - // TODO: Confirm that the `LP64` definition here is actually needed. - // - headerIncludeNames.add("stddef.h"); - headers.add("#pragma once\n" - "#define LP64\n"); - - // Finally, we want the CUDA prelude to be able to react to whether - // or not OptiX is required (most notably by `#include`ing the appropriate - // header(s)), so we will insert a preprocessor define to indicate - // the requirement. - // - cmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); + SLANG_RETURN_ON_FAIL(_maybeAddOptixSupport(options, cmdLine)); } // Add any compiler specific options diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 7964e26d8d..11c4ab6f45 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -20932,6 +20932,8 @@ struct ConstBufferPointer // new aliased bindings for each distinct cast type. // +//@public: + /// Represent the kind of a descriptor type. enum DescriptorKind { @@ -21048,8 +21050,18 @@ ${{{{ } }}}} -/// Represents a bindless resource handle. A bindless resource handle is always a concrete type and can be +/// Represents a bindless handle to a descriptor. A descriptor handle is always an ordinary data type and can be /// declared in any memory location. +/// @remarks Opaque descriptor types such as textures(`Texture2D` etc.), `SamplerState` and buffers (e.g. `StructuredBuffer`) +/// can have undefined size and data representation on many targets. On platforms such as Vulkan and D3D12, descriptors are +/// communicated to the shader code by calling the host side API to write the descriptor into a descriptor set or table, instead +/// of directly writing bytes into an ordinary GPU accessible buffer. As a result, oapque handle types cannot be used in places +/// that refer to a ordinary buffer location, such as as element types of a `StructuredBuffer`. +/// However, a `DescriptorHandle` stores a handle (or address) to the actual descriptor, and is always an ordinary data type +/// that can be manipulated directly in the shader code. This gives the developer the flexibility to embed and pass around descriptor +/// parameters throughout the code, to enable cleaner modular designs. +/// See [User Guide](https://shader-slang.com/slang/user-guide/convenience-features.html#descriptorhandle-for-bindless-descriptor-access) +/// for more information on how to use `DescriptorHandle` in your code. __magic_type(DescriptorHandleType) __intrinsic_type($(kIROp_DescriptorHandleType)) struct DescriptorHandle : IComparable @@ -21140,6 +21152,8 @@ extern T getDescriptorFromHandle(DescriptorHandle handle __intrinsic_op($(kIROp_NonUniformResourceIndex)) DescriptorHandle nonuniform(DescriptorHandle ptr); +//@hidden: + __glsl_version(450) __glsl_extension(GL_ARB_shader_clock) [require(glsl_spirv, GL_ARB_shader_clock)] diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index f5dd86df15..ee29750a6a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -973,9 +973,14 @@ class GLSLLayoutLocalSizeAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* x; - IntVal* y; - IntVal* z; + IntVal* extents[3]; + + bool axisIsSpecConstId[3]; + + // References to specialization constants, for defining the number of + // threads with them. If set, the corresponding axis is set to nullptr + // above. + DeclRef specConstExtents[3]; }; class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute @@ -1038,9 +1043,12 @@ class NumThreadsAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* x; - IntVal* y; - IntVal* z; + IntVal* extents[3]; + + // References to specialization constants, for defining the number of + // threads with them. If set, the corresponding axis is set to nullptr + // above. + DeclRef specConstExtents[3]; }; class WaveSizeAttribute : public Attribute diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index b3e30dbc23..3ef1e8f3be 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1656,6 +1656,8 @@ struct SemanticsVisitor : public SemanticsContext void visitModifier(Modifier*); + DeclRef tryGetIntSpecializationConstant(Expr* expr); + AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope); bool hasIntArgs(Attribute* attr, int numArgs); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 3723c98f86..6e451b5cf9 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -114,6 +114,36 @@ void SemanticsVisitor::visitModifier(Modifier*) // Do nothing with modifiers for now } +DeclRef SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr) +{ + // First type-check the expression as normal + expr = CheckExpr(expr); + + if (IsErrorExpr(expr)) + return DeclRef(); + + if (!isScalarIntegerType(expr->type)) + return DeclRef(); + + auto specConstVar = as(expr); + if (!specConstVar || !specConstVar->declRef) + return DeclRef(); + + auto decl = specConstVar->declRef.getDecl(); + if (!decl) + return DeclRef(); + + for (auto modifier : decl->modifiers) + { + if (as(modifier) || as(modifier)) + { + return specConstVar->declRef.as(); + } + } + + return DeclRef(); +} + static bool _isDeclAllowedAsAttribute(DeclRef declRef) { if (as(declRef.getDecl())) @@ -350,8 +380,6 @@ Modifier* SemanticsVisitor::validateAttribute( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; - for (int i = 0; i < 3; ++i) { IntVal* value = nullptr; @@ -359,6 +387,14 @@ Modifier* SemanticsVisitor::validateAttribute( auto arg = attr->args[i]; if (arg) { + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) + { + numThreadsAttr->extents[i] = nullptr; + numThreadsAttr->specConstExtents[i] = specConstDecl; + continue; + } + auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { @@ -390,12 +426,8 @@ Modifier* SemanticsVisitor::validateAttribute( { value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; + numThreadsAttr->extents[i] = value; } - - numThreadsAttr->x = values[0]; - numThreadsAttr->y = values[1]; - numThreadsAttr->z = values[2]; } else if (auto waveSizeAttr = as(attr)) { @@ -1831,15 +1863,24 @@ Modifier* SemanticsVisitor::checkModifier( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; + // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl. + auto decl = as(syntaxNode); + SLANG_ASSERT(decl); for (int i = 0; i < 3; ++i) { - IntVal* value = nullptr; + attr->extents[i] = nullptr; auto arg = attr->args[i]; if (arg) { + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) + { + attr->specConstExtents[i] = specConstDecl; + continue; + } + auto intValue = checkConstantIntVal(arg); if (!intValue) { @@ -1847,7 +1888,45 @@ Modifier* SemanticsVisitor::checkModifier( } if (auto cintVal = as(intValue)) { - if (cintVal->getValue() < 1) + if (attr->axisIsSpecConstId[i]) + { + // This integer should actually be a reference to a + // specialization constant with this ID. + Int specConstId = cintVal->getValue(); + + for (auto member : decl->parentDecl->members) + { + auto constantId = member->findModifier(); + if (constantId) + { + SLANG_ASSERT(constantId->args.getCount() == 1); + auto id = checkConstantIntVal(constantId->args[0]); + if (id->getValue() == specConstId) + { + attr->specConstExtents[i] = + DeclRef(member->getDefaultDeclRef()); + break; + } + } + } + + // If not found, we need to create a new specialization + // constant with this ID. + if (!attr->specConstExtents[i]) + { + auto specConstVarDecl = getASTBuilder()->create(); + auto constantIdModifier = + getASTBuilder()->create(); + constantIdModifier->location = (int32_t)specConstId; + specConstVarDecl->type.type = getASTBuilder()->getIntType(); + addModifier(specConstVarDecl, constantIdModifier); + decl->parentDecl->addMember(specConstVarDecl); + attr->specConstExtents[i] = + DeclRef(specConstVarDecl->getDefaultDeclRef()); + } + continue; + } + else if (cintVal->getValue() < 1) { getSink()->diagnose( attr, @@ -1856,18 +1935,13 @@ Modifier* SemanticsVisitor::checkModifier( return nullptr; } } - value = intValue; + attr->extents[i] = intValue; } else { - value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; } - - attr->x = values[0]; - attr->y = values[1]; - attr->z = values[2]; } // Default behavior is to leave things as they are, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 821a895bc7..d86cd8be2a 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2060,7 +2060,7 @@ DIAGNOSTIC( DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") DIAGNOSTIC(41001, Error, recursiveType, "type '$0' contains cyclic reference to itself.") -DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'void' function") +DIAGNOSTIC(41010, Warning, missingReturn, "non-void function does not return in all cases") DIAGNOSTIC( 41011, Error, @@ -2459,6 +2459,12 @@ DIAGNOSTIC( Error, unsupportedTargetIntrinsic, "intrinsic operation '$0' is not supported for the current target.") +DIAGNOSTIC( + 55205, + Error, + unsupportedSpecializationConstantForNumThreads, + "Specialization constants are not supported in the 'numthreads' attribute for the current " + "target.") DIAGNOSTIC( 56001, Error, diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 7b51495e2b..d3a9359ff2 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -295,14 +295,48 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) } -/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( +IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]) +{ + Int specializationConstantIds[kThreadGroupAxisCount]; + IRNumThreadsDecoration* decor = + getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds); + + for (auto id : specializationConstantIds) + { + if (id >= 0) + { + getSink()->diagnose(decor, Diagnostics::unsupportedSpecializationConstantForNumThreads); + break; + } + } + return decor; +} + +/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( + IRFunc* func, + Int outNumThreads[kThreadGroupAxisCount], + Int outSpecializationConstantIds[kThreadGroupAxisCount]) { IRNumThreadsDecoration* decor = func->findDecoration(); - for (int i = 0; i < 3; ++i) + for (int i = 0; i < kThreadGroupAxisCount; ++i) { - outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1; + if (!decor) + { + outNumThreads[i] = 1; + outSpecializationConstantIds[i] = -1; + } + else if (auto specConst = as(decor->getOperand(i))) + { + outNumThreads[i] = 1; + outSpecializationConstantIds[i] = getSpecializationConstantId(specConst); + } + else + { + outNumThreads[i] = Int(getIntVal(decor->getOperand(i))); + outSpecializationConstantIds[i] = -1; + } } return decor; } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index e5080f731b..1354b7cbd8 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -500,11 +500,20 @@ class CLikeSourceEmitter : public SourceEmitterBase /// different. Returns an empty slice if not a built in type static UnownedStringSlice getDefaultBuiltinTypeName(IROp op); - /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1 - static IRNumThreadsDecoration* getComputeThreadGroupSize( + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all + /// dimensions to 1 + IRNumThreadsDecoration* getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]); + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all + /// dimensions to 1. If specialization constants are used for an axis, their + /// IDs is reported in non-negative entries of outSpecializationConstantIds. + static IRNumThreadsDecoration* getComputeThreadGroupSize( + IRFunc* func, + Int outNumThreads[kThreadGroupAxisCount], + Int outSpecializationConstantIds[kThreadGroupAxisCount]); + /// Finds the IRWaveSizeDecoration and gets the size from that. static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int* outWaveSize); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 23fff37acb..0dab07cfce 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -1335,7 +1335,8 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( auto emitLocalSizeLayout = [&]() { Int sizeAlongAxis[kThreadGroupAxisCount]; - getComputeThreadGroupSize(irFunc, sizeAlongAxis); + Int specializationConstantIds[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, sizeAlongAxis, specializationConstantIds); m_writer->emit("layout("); char const* axes[] = {"x", "y", "z"}; @@ -1345,8 +1346,17 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( m_writer->emit(", "); m_writer->emit("local_size_"); m_writer->emit(axes[ii]); - m_writer->emit(" = "); - m_writer->emit(sizeAlongAxis[ii]); + + if (specializationConstantIds[ii] >= 0) + { + m_writer->emit("_id = "); + m_writer->emit(specializationConstantIds[ii]); + } + else + { + m_writer->emit(" = "); + m_writer->emit(sizeAlongAxis[ii]); + } } m_writer->emit(") in;\n"); }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 068e1563ca..2cf84a8540 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4353,23 +4353,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // [3.6. Execution Mode]: LocalSize case kIROp_NumThreadsDecoration: { - // TODO: The `LocalSize` execution mode option requires - // literal values for the X,Y,Z thread-group sizes. - // There is a `LocalSizeId` variant that takes ``s - // for those sizes, and we should consider using that - // and requiring the appropriate capabilities - // if any of the operands to the decoration are not - // literals (in a future where we support non-literals - // in those positions in the Slang IR). - // auto numThreads = cast(decoration); - requireSPIRVExecutionMode( - decoration, - dstID, - SpvExecutionModeLocalSize, - SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); + if (numThreads->getXSpecConst() || numThreads->getYSpecConst() || + numThreads->getZSpecConst()) + { + // If any of the dimensions needs an ID, we need to emit + // all dimensions as an ID due to how LocalSizeId works. + int32_t ids[3]; + for (int i = 0; i < 3; ++i) + ids[i] = ensureInst(numThreads->getOperand(i))->id; + + // LocalSizeId is supported from SPIR-V 1.2 onwards without + // any extra capabilities. + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSizeId, + SpvLiteralInteger::from32(int32_t(ids[0])), + SpvLiteralInteger::from32(int32_t(ids[1])), + SpvLiteralInteger::from32(int32_t(ids[2]))); + } + else + { + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSize, + SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); + } } break; case kIROp_MaxVertexCountDecoration: @@ -7977,10 +7990,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { if (m_executionModes[entryPoint].add(executionMode)) { + SpvOp execModeOp = SpvOpExecutionMode; + if (executionMode == SpvExecutionModeLocalSizeId || + executionMode == SpvExecutionModeLocalSizeHintId || + executionMode == SpvExecutionModeSubgroupsPerWorkgroupId) + { + execModeOp = SpvOpExecutionModeId; + } + emitInst( getSection(SpvLogicalSectionID::ExecutionModes), parentInst, - SpvOpExecutionMode, + execModeOp, entryPoint, executionMode, ops...); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 65ce69877f..3237ba3b26 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -528,10 +528,12 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF // If primal parameter is mutable, we need to pass in a temp var. auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); - // We also need to setup the initial value of the temp var, otherwise - // the temp var will be uninitialized which could cause undefined behavior - // in the primal function. - builder.emitStore(tempVar, primalArg); + // If the parameter is not a pure 'out' param, we also need to setup the initial + // value of the temp var, otherwise the temp var will be uninitialized which could + // cause undefined behavior in the primal function. + // + if (!as(primalParamType)) + builder.emitStore(tempVar, primalArg); primalArgs.add(tempVar); } diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index 1c833a2948..372ef298e7 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -279,6 +279,16 @@ struct CollectGlobalUniformParametersContext continue; } + // NumThreadsDecoration may sometimes be the user for a global + // parameter. This occurs when the parameter was supposed to be + // a specialization constant, but isn't due to that not being + // supported for the target. These can be skipped here and + // diagnosed later. + if (as(user)) + { + continue; + } + // For each use site for the global parameter, we will // insert new code right before the instruction that uses // the parameter. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a58c2e900c..f46586aa2b 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -570,6 +570,7 @@ struct IRInstanceDecoration : IRDecoration IRIntLit* getCount() { return cast(getOperand(0)); } }; +struct IRGlobalParam; struct IRNumThreadsDecoration : IRDecoration { enum @@ -578,11 +579,13 @@ struct IRNumThreadsDecoration : IRDecoration }; IR_LEAF_ISA(NumThreadsDecoration) - IRIntLit* getX() { return cast(getOperand(0)); } - IRIntLit* getY() { return cast(getOperand(1)); } - IRIntLit* getZ() { return cast(getOperand(2)); } + IRIntLit* getX() { return as(getOperand(0)); } + IRIntLit* getY() { return as(getOperand(1)); } + IRIntLit* getZ() { return as(getOperand(2)); } - IRIntLit* getExtentAlongAxis(int axis) { return cast(getOperand(axis)); } + IRGlobalParam* getXSpecConst() { return as(getOperand(0)); } + IRGlobalParam* getYSpecConst() { return as(getOperand(1)); } + IRGlobalParam* getZSpecConst() { return as(getOperand(2)); } }; struct IRWaveSizeDecoration : IRDecoration diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index e267e8343c..6840197721 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -190,7 +190,7 @@ IRInst* emitCalcGroupExtents(IRBuilder& builder, IRFunc* entryPoint, IRVectorTyp for (int axis = 0; axis < kAxisCount; axis++) { - auto litValue = as(numThreadsDecor->getExtentAlongAxis(axis)); + auto litValue = as(numThreadsDecor->getOperand(axis)); if (!litValue) return nullptr; @@ -1434,6 +1434,20 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize // groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type); + if (!groupExtents) + { + m_sink->diagnose( + m_entryPointFunc, + Diagnostics::unsupportedSpecializationConstantForNumThreads); + + // Fill in placeholder values. + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + for (int axis = 0; axis < kAxisCount; axis++) + groupExtentAlongAxis[axis] = builder.getIntValue(uint3Type->getElementType(), 1); + groupExtents = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); + } + dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents); diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 90d30dcc77..68d79617a8 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -490,11 +490,19 @@ static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst) bool isFalseBranchTrivial = false; if (isTrivialIfElse(ifElseInst, isTrueBranchTrivial, isFalseBranchTrivial)) { - // If both branches of `if-else` are trivial jumps into after block, + // If either branch of `if-else` is a trivial jump into after block, // we can get rid of the entire conditional branch and replace it // with a jump into the after block. - if (auto termInst = as(ifElseInst->getTrueBlock()->getTerminator())) + IRUnconditionalBranch* termInst = + as(ifElseInst->getTrueBlock()->getTerminator()); + if (!termInst || (termInst->getTargetBlock() != ifElseInst->getAfterBlock())) { + termInst = as(ifElseInst->getFalseBlock()->getTerminator()); + } + + if (termInst) + { + SLANG_ASSERT(termInst->getTargetBlock() == ifElseInst->getAfterBlock()); List args; for (UInt i = 0; i < termInst->getArgCount(); i++) args.add(termInst->getArg(i)); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 40cd40758a..a9b0d44121 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -71,6 +71,42 @@ struct SpecializationContext module->getContainerPool().free(&cleanInsts); } + bool isUnsimplifiedArithmeticInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Neg: + case kIROp_Not: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Leq: + case kIROp_Geq: + case kIROp_Less: + case kIROp_IRem: + case kIROp_FRem: + case kIROp_Greater: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_BitNot: + case kIROp_BitCast: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_Select: + return true; + default: + return false; + } + } + // An instruction is then fully specialized if and only // if it is in our set. // @@ -133,6 +169,14 @@ struct SpecializationContext return areAllOperandsFullySpecialized(inst); } + if (isUnsimplifiedArithmeticInst(inst)) + { + // For arithmetic insts, we want to wait for simplification before specialization, + // since different insts can simplify to the same value. + // + return false; + } + // The default case is that a global value is always specialized. if (inst->getParent() == module->getModuleInst()) { @@ -1092,6 +1136,7 @@ struct SpecializationContext { this->changed = true; eliminateDeadCode(module->getModuleInst()); + applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); } // Once the work list has gone dry, we should have the invariant diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index a44e16a7ce..077cdb98d0 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -282,10 +282,11 @@ struct GlobalVarTranslationContext if (!numthreadsDecor) return; builder.setInsertBefore(use->getUser()); - IRInst* values[] = { - numthreadsDecor->getExtentAlongAxis(0), - numthreadsDecor->getExtentAlongAxis(1), - numthreadsDecor->getExtentAlongAxis(2)}; + IRInst* values[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; + auto workgroupSize = builder.emitMakeVector( builder.getVectorType(builder.getIntType(), 3), 3, @@ -328,10 +329,10 @@ struct GlobalVarTranslationContext if (!firstBlock) continue; builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - IRInst* args[] = { - numthreadsDecor->getExtentAlongAxis(0), - numthreadsDecor->getExtentAlongAxis(1), - numthreadsDecor->getExtentAlongAxis(2)}; + IRInst* args[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; auto workgroupSize = builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args); builder.emitStore(globalVar, workgroupSize); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index c753600a7c..d05e1db7d4 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1973,4 +1973,17 @@ IRType* getIRVectorBaseType(IRType* type) return as(type)->getElementType(); } +Int getSpecializationConstantId(IRGlobalParam* param) +{ + auto layout = findVarLayout(param); + if (!layout) + return 0; + + auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant); + if (!offset) + return 0; + + return offset->getOffset(); +} + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index e23aeb6180..666ac71c03 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -373,6 +373,8 @@ inline bool isSPIRV(CodeGenTarget codeGenTarget) int getIRVectorElementSize(IRType* type); IRType* getIRVectorBaseType(IRType* type); +Int getSpecializationConstantId(IRGlobalParam* param); + } // namespace Slang #endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e82fc03fde..0863457198 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7625,12 +7625,29 @@ struct DeclLoweringVisitor : DeclVisitor { verifyComputeDerivativeGroupModifier = true; getAllEntryPointsNoOverride(entryPoints); + + LoweredValInfo extents[3]; + + for (int i = 0; i < 3; ++i) + { + extents[i] = layoutLocalSizeAttr->specConstExtents[i] + ? emitDeclRef( + context, + layoutLocalSizeAttr->specConstExtents[i], + lowerType( + context, + getType( + context->astBuilder, + layoutLocalSizeAttr->specConstExtents[i]))) + : lowerVal(context, layoutLocalSizeAttr->extents[i]); + } + for (auto d : entryPoints) as(getBuilder()->addNumThreadsDecoration( d, - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->x)), - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->y)), - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->z)))); + getSimpleVal(context, extents[0]), + getSimpleVal(context, extents[1]), + getSimpleVal(context, extents[2]))); } else if (as(modifier)) { @@ -10336,11 +10353,28 @@ struct DeclLoweringVisitor : DeclVisitor } else if (auto numThreadsAttr = as(modifier)) { + LoweredValInfo extents[3]; + + for (int i = 0; i < 3; ++i) + { + extents[i] = numThreadsAttr->specConstExtents[i] + ? emitDeclRef( + context, + numThreadsAttr->specConstExtents[i], + lowerType( + context, + getType( + context->astBuilder, + numThreadsAttr->specConstExtents[i]))) + : lowerVal(context, numThreadsAttr->extents[i]); + } + numThreadsDecor = as(getBuilder()->addNumThreadsDecoration( irFunc, - getSimpleVal(context, lowerVal(context, numThreadsAttr->x)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->y)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->z)))); + getSimpleVal(context, extents[0]), + getSimpleVal(context, extents[1]), + getSimpleVal(context, extents[2]))); + numThreadsDecor->sourceLoc = numThreadsAttr->loc; } else if (auto waveSizeAttr = as(modifier)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index c275a868b5..6ae41a2eb9 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8437,7 +8437,9 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) int localSizeIndex = -1; if (nameText.startsWith(localSizePrefix) && - nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1) + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1 || + (nameText.endsWith("_id") && + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4)))) { char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1]; localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1; @@ -8451,6 +8453,8 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) numThreadsAttrib->args.setCount(3); for (auto& i : numThreadsAttrib->args) i = nullptr; + for (auto& b : numThreadsAttrib->axisIsSpecConstId) + b = false; // Just mark the loc and name from the first in the list numThreadsAttrib->keywordName = getName(parser, "numthreads"); @@ -8467,6 +8471,11 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) } numThreadsAttrib->args[localSizeIndex] = expr; + + // We can't resolve the specialization constant declaration + // here, because it may not even exist. IDs pointing to unnamed + // specialization constants are allowed in GLSL. + numThreadsAttrib->axisIsSpecConstId[localSizeIndex] = nameText.endsWith("_id"); } } else if (nameText == "derivative_group_quadsNV") diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index d235c82703..d1adfedc0b 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -4033,18 +4033,14 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier(); if (numThreadsAttribute) { - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x)) - sizeAlongAxis[0] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->x) - sizeAlongAxis[0] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y)) - sizeAlongAxis[1] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->y) - sizeAlongAxis[1] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z)) - sizeAlongAxis[2] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->z) - sizeAlongAxis[2] = 0; + for (int i = 0; i < 3; ++i) + { + if (auto cint = + entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->extents[i])) + sizeAlongAxis[i] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->extents[i]) + sizeAlongAxis[i] = 0; + } } // diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 5ec1996581..efc1c6fd11 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1493,7 +1493,8 @@ DeclRef Linkage::specializeWithArgTypes( DiagnosticSink* sink) { SemanticsVisitor visitor(getSemanticsForReflection()); - visitor = visitor.withSink(sink); + SemanticsVisitor::ExprLocalScope scope; + visitor = visitor.withSink(sink).withExprLocalScope(&scope); SLANG_AST_BUILDER_RAII(getASTBuilder()); diff --git a/tests/autodiff/out-parameters-2.slang b/tests/autodiff/out-parameters-2.slang new file mode 100644 index 0000000000..b4c4b07c61 --- /dev/null +++ b/tests/autodiff/out-parameters-2.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typedef DifferentialPair dpfloat; + +struct Foo : IDifferentiable +{ + float a; + int b; +} + +[PreferCheckpoint] +float k() +{ + return outputBuffer[3] + 1; +} + +[Differentiable] +void h(float x, float y, out Foo result) +{ + float p = no_diff k(); + float m = x + y + p; + float n = x - y; + float r = m * n + 2 * x * y; + + result = {r, 2}; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float x = 2.0; + float y = 3.5; + float dx = 1.0; + float dy = 0.5; + + dpfloat dresult; + dpfloat dpx = diffPair(x); + dpfloat dpy = diffPair(y); + Foo.Differential dFoo; + dFoo.a = 1.0; + bwd_diff(h)(dpx, dpy, dFoo); + + outputBuffer[0] = dpx.d; // CHECK: 12.0 + outputBuffer[1] = dpy.d; // CHECK: -4.0 +} \ No newline at end of file diff --git a/tests/bugs/simplify-if-else.slang b/tests/bugs/simplify-if-else.slang new file mode 100644 index 0000000000..8719a15995 --- /dev/null +++ b/tests/bugs/simplify-if-else.slang @@ -0,0 +1,26 @@ +//TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target hlsl +//CHECK: computeMain + +//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + vector vvv = vector(0); + float32_t ret = 0.0f; + if (vvv.y < 1.0f) + { + ret = 1.0f; + } + else + { + if (vvv.y > 1.0f && outputBuffer[3] == 3) + { + ret = 0.0f; + } else { + if (true) {} + } + } + outputBuffer[int(dispatchThreadID.x)] = int(ret); +} diff --git a/tests/diagnostics/missing-return.slang.expected b/tests/diagnostics/missing-return.slang.expected index e41e756ff4..7626665241 100644 --- a/tests/diagnostics/missing-return.slang.expected +++ b/tests/diagnostics/missing-return.slang.expected @@ -1,9 +1,9 @@ result code = 0 standard error = { -tests/diagnostics/missing-return.slang(7): warning 41010: control flow may reach end of non-'void' function +tests/diagnostics/missing-return.slang(7): warning 41010: non-void function does not return in all cases int bad(int a, int b) ^~~ -tests/diagnostics/missing-return.slang(14): warning 41010: control flow may reach end of non-'void' function +tests/diagnostics/missing-return.slang(14): warning 41010: non-void function does not return in all cases int alsoBad(int a, int b) ^~~~~~~ } diff --git a/tests/glsl/compute-shader-layout-id.slang b/tests/glsl/compute-shader-layout-id.slang new file mode 100644 index 0000000000..bee8137d82 --- /dev/null +++ b/tests/glsl/compute-shader-layout-id.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry main -allow-glsl +#version 450 + +[vk::constant_id(1)] +const int constValue1 = 0; + +[vk::constant_id(2)] +const int constValue3 = 5; + +// CHECK-DAG: OpExecutionModeId %main LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] +// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 +// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 +// CHECK-DAG: OpDecorate %[[C2]] SpecId 2 + +layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = constValue3) in; +void main() +{ +} + diff --git a/tests/spirv/spec-constant-numthreads.slang b/tests/spirv/spec-constant-numthreads.slang new file mode 100644 index 0000000000..5c133219cf --- /dev/null +++ b/tests/spirv/spec-constant-numthreads.slang @@ -0,0 +1,35 @@ +//TEST:SIMPLE(filecheck=GLSL): -target glsl -allow-glsl +//TEST:SIMPLE(filecheck=GLSL): -target glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv -allow-glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv + +// CHECK-DAG: OpExecutionModeId %computeMain1 LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] +// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 +// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 +// CHECK-DAG: %[[C2]] = OpConstant %int 4 +// CHECK-DAG: OpStore %{{.*}} %[[C0]] +// CHECK-DAG: OpStore %{{.*}} %[[C1]] +// CHECK-DAG: OpStore %{{.*}} %[[C2]] + +// GLSL-DAG: layout(constant_id = 1) +// GLSL-DAG: int constValue0_0 = 0; +// GLSL-DAG: layout(constant_id = 0) +// GLSL-DAG: int constValue1_0 = 0; +// GLSL-DAG: layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = 4) in; + +[vk::specialization_constant] +const int constValue0 = 0; + +[vk::constant_id(0)] +const int constValue1 = 0; + +RWStructuredBuffer outputBuffer; + +[numthreads(constValue0, constValue1, 4)] +void computeMain1() +{ + int3 size = WorkgroupSize(); + outputBuffer[0] = size.x; + outputBuffer[1] = size.y; + outputBuffer[2] = size.z; +} From c42d707fd25ee0328598650d3235cd2322810ccc Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 14 Jan 2025 20:04:54 -0500 Subject: [PATCH 11/18] apply metal spec const thread count changes --- .../slang-ir-legalize-varying-params.cpp | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 6840197721..4858d6e31c 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3378,12 +3378,25 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext IRBuilder svBuilder(builder.getModule()); svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); - auto computeExtent = emitCalcGroupExtents( - svBuilder, - entryPoint.entryPointFunc, - builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 3))); + auto uint3Type = builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 3)); + auto computeExtent = + emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type); + if (!computeExtent) + { + m_sink->diagnose( + entryPoint.entryPointFunc, + Diagnostics::unsupportedSpecializationConstantForNumThreads); + + // Fill in placeholder values. + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + for (int axis = 0; axis < kAxisCount; axis++) + groupExtentAlongAxis[axis] = + builder.getIntValue(uint3Type->getElementType(), 1); + computeExtent = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); + } auto groupIndexCalc = emitCalcGroupIndex( svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc], From 5a4d84dc19c925fad66e8a9cc5497bc13a36dcb6 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 14 Jan 2025 20:08:11 -0500 Subject: [PATCH 12/18] Revert "apply metal spec const thread count changes" This reverts commit c42d707fd25ee0328598650d3235cd2322810ccc. --- .../slang-ir-legalize-varying-params.cpp | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 4858d6e31c..6840197721 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3378,25 +3378,12 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext IRBuilder svBuilder(builder.getModule()); svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); - auto uint3Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 3)); - auto computeExtent = - emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type); - if (!computeExtent) - { - m_sink->diagnose( - entryPoint.entryPointFunc, - Diagnostics::unsupportedSpecializationConstantForNumThreads); - - // Fill in placeholder values. - static const int kAxisCount = 3; - IRInst* groupExtentAlongAxis[kAxisCount] = {}; - for (int axis = 0; axis < kAxisCount; axis++) - groupExtentAlongAxis[axis] = - builder.getIntValue(uint3Type->getElementType(), 1); - computeExtent = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); - } + auto computeExtent = emitCalcGroupExtents( + svBuilder, + entryPoint.entryPointFunc, + builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 3))); auto groupIndexCalc = emitCalcGroupIndex( svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc], From 6c512e11c30b4f07ea47bfda101879b5d822b27c Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 14 Jan 2025 20:08:23 -0500 Subject: [PATCH 13/18] Revert "squash merge from master and resolve conflict" This reverts commit 06db88ef7001bdfe93fb23af35af0d026b255dee. --- cmake/SlangTarget.cmake | 12 +- docs/cuda-target.md | 11 - docs/user-guide/03-convenience-features.md | 8 +- external/slang-rhi | 2 +- source/compiler-core/slang-nvrtc-compiler.cpp | 257 +++++------------- source/slang/hlsl.meta.slang | 16 +- source/slang/slang-ast-modifier.h | 20 +- source/slang/slang-check-impl.h | 2 - source/slang/slang-check-modifier.cpp | 108 ++------ source/slang/slang-diagnostic-defs.h | 8 +- source/slang/slang-emit-c-like.cpp | 40 +-- source/slang/slang-emit-c-like.h | 13 +- source/slang/slang-emit-glsl.cpp | 16 +- source/slang/slang-emit-spirv.cpp | 55 ++-- source/slang/slang-ir-autodiff-rev.cpp | 10 +- .../slang-ir-collect-global-uniforms.cpp | 10 - source/slang/slang-ir-insts.h | 11 +- .../slang-ir-legalize-varying-params.cpp | 16 +- source/slang/slang-ir-simplify-cfg.cpp | 12 +- source/slang/slang-ir-specialize.cpp | 45 --- .../slang-ir-translate-glsl-global-var.cpp | 17 +- source/slang/slang-ir-util.cpp | 13 - source/slang/slang-ir-util.h | 2 - source/slang/slang-lower-to-ir.cpp | 46 +--- source/slang/slang-parser.cpp | 11 +- source/slang/slang-reflection-api.cpp | 20 +- source/slang/slang.cpp | 3 +- tests/autodiff/out-parameters-2.slang | 49 ---- tests/bugs/simplify-if-else.slang | 26 -- .../diagnostics/missing-return.slang.expected | 4 +- tests/glsl/compute-shader-layout-id.slang | 19 -- tests/spirv/spec-constant-numthreads.slang | 35 --- 32 files changed, 162 insertions(+), 755 deletions(-) delete mode 100644 tests/autodiff/out-parameters-2.slang delete mode 100644 tests/bugs/simplify-if-else.slang delete mode 100644 tests/glsl/compute-shader-layout-id.slang delete mode 100644 tests/spirv/spec-constant-numthreads.slang diff --git a/cmake/SlangTarget.cmake b/cmake/SlangTarget.cmake index eae5cf35e4..45e7cf1e1d 100644 --- a/cmake/SlangTarget.cmake +++ b/cmake/SlangTarget.cmake @@ -505,14 +505,10 @@ function(slang_add_target dir type) endif() install( TARGETS ${target} ${export_args} - ARCHIVE DESTINATION ${archive_subdir} - ${ARGN} - LIBRARY DESTINATION ${library_subdir} - ${ARGN} - RUNTIME DESTINATION ${runtime_subdir} - ${ARGN} - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} - ${ARGN} + ARCHIVE DESTINATION ${archive_subdir} ${ARGN} + LIBRARY DESTINATION ${library_subdir} ${ARGN} + RUNTIME DESTINATION ${runtime_subdir} ${ARGN} + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ${ARGN} ) endmacro() diff --git a/docs/cuda-target.md b/docs/cuda-target.md index 241f253fbe..a80dc59f9c 100644 --- a/docs/cuda-target.md +++ b/docs/cuda-target.md @@ -301,17 +301,6 @@ There is potential to calculate the lane id using the [numthreads] markup in Sla * Intrinsics which only work in pixel shaders + QuadXXXX intrinsics -OptiX Support -============= - -Slang supports OptiX for raytracing. To compile raytracing programs, NVRTC must have access to the `optix.h` and dependent files that are typically distributed as part of the OptiX SDK. When Slang detects the use of raytracing in source, it will define `SLANG_CUDA_ENABLE_OPTIX` when `slang-cuda-prelude.h` is included. This will in turn try to include `optix.h`. - -Slang tries several mechanisms to locate `optix.h` when NVRTC is initiated. The first mechanism is to look in the include paths that are passed to Slang. If `optix.h` can be found in one of these paths, no more searching will be performed. - -If this fails, the default OptiX SDK install locations are searched. On Windows this is `%{PROGRAMDATA}\NVIDIA Corporation\OptiX SDK X.X.X\include`. On Linux this is `${HOME}/NVIDIA-OptiX-SDK-X.X.X-suffix`. - -If OptiX headers cannot be found, compilation will fail. - Limitations =========== diff --git a/docs/user-guide/03-convenience-features.md b/docs/user-guide/03-convenience-features.md index 29e8fd2aaa..e6b337eed1 100644 --- a/docs/user-guide/03-convenience-features.md +++ b/docs/user-guide/03-convenience-features.md @@ -149,7 +149,7 @@ int rs = foo.staticMethod(a,b); ### Mutability of member function -For GPU performance considerations, the `this` argument in a member function is immutable by default. Attempting to modify `this` will result in a compile error. If you intend to define a member function that mutates the object, use `[mutating]` attribute on the member function as shown in the following example. +For GPU performance considerations, the `this` argument in a member function is immutable by default. If you modify the content in `this` argument, the modification will be discarded after the call and does not affect the input object. If you intend to define a member function that mutates the object, use `[mutating]` attribute on the member function as shown in the following example. ```hlsl struct Foo @@ -159,14 +159,14 @@ struct Foo [mutating] void setCount(int x) { count = x; } - // This would fail to compile. - // void setCount2(int x) { count = x; } + void setCount2(int x) { count = x; } } void test() { Foo f; - f.setCount(1); // Compiles + f.setCount(1); // f.count is 1 after the call. + f.setCount2(2); // f.count is still 1 after the call. } ``` diff --git a/external/slang-rhi b/external/slang-rhi index d1f2718165..19bc575bc1 160000 --- a/external/slang-rhi +++ b/external/slang-rhi @@ -1 +1 @@ -Subproject commit d1f2718165d0d540c8fc1eacf20b9edd2d6faac0 +Subproject commit 19bc575bc193e92210649d6d84ac202b199b29af diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp index 0042ad7085..c5ccc8e23a 100644 --- a/source/compiler-core/slang-nvrtc-compiler.cpp +++ b/source/compiler-core/slang-nvrtc-compiler.cpp @@ -127,14 +127,11 @@ class NVRTCDownstreamCompiler : public DownstreamCompilerBase nvrtcProgram m_program; }; - SlangResult _findCUDAIncludePath(String& outIncludePath); - SlangResult _getCUDAIncludePath(String& outIncludePath); + SlangResult _findIncludePath(String& outIncludePath); - SlangResult _findOptixIncludePath(String& outIncludePath); - SlangResult _getOptixIncludePath(String& outIncludePath); + SlangResult _getIncludePath(String& outIncludePath); SlangResult _maybeAddHalfSupport(const CompileOptions& options, CommandLine& ioCmdLine); - SlangResult _maybeAddOptixSupport(const CompileOptions& options, CommandLine& ioCmdLine); #define SLANG_NVTRC_MEMBER_FUNCS(ret, name, params) ret(*m_##name) params; @@ -143,16 +140,9 @@ class NVRTCDownstreamCompiler : public DownstreamCompilerBase // Holds list of paths passed in where cuda_fp16.h is found. Does *NOT* include cuda_fp16.h. List m_cudaFp16FoundPaths; - bool m_cudaIncludeSearched = false; + bool m_includeSearched = false; // Holds location of where include (for cuda_fp16.h) is found. - String m_cudaIncludePath; - - // Holds list of paths passed in where optix.h is found. Does *NOT* include optix.h. - List m_optixFoundPaths; - - bool m_optixIncludeSearched = false; - // Holds location of where include (for optix.h) is found. - String m_optixIncludePath; + String m_includePath; ComPtr m_sharedLibrary; }; @@ -612,8 +602,21 @@ static SlangResult _findNVRTC(NVRTCPathVisitor& visitor) } static const UnownedStringSlice g_fp16HeaderName = UnownedStringSlice::fromLiteral("cuda_fp16.h"); -static const UnownedStringSlice g_optixHeaderName = UnownedStringSlice::fromLiteral("optix.h"); +SlangResult NVRTCDownstreamCompiler::_getIncludePath(String& outPath) +{ + if (!m_includeSearched) + { + m_includeSearched = true; + + SLANG_ASSERT(m_includePath.getLength() == 0); + + _findIncludePath(m_includePath); + } + + outPath = m_includePath; + return m_includePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; +} SlangResult _findFileInIncludePath( const String& path, @@ -647,7 +650,7 @@ SlangResult _findFileInIncludePath( return SLANG_E_NOT_FOUND; } -SlangResult NVRTCDownstreamCompiler::_findCUDAIncludePath(String& outPath) +SlangResult NVRTCDownstreamCompiler::_findIncludePath(String& outPath) { outPath = String(); @@ -708,130 +711,6 @@ SlangResult NVRTCDownstreamCompiler::_findCUDAIncludePath(String& outPath) return SLANG_E_NOT_FOUND; } -SlangResult NVRTCDownstreamCompiler::_getCUDAIncludePath(String& outPath) -{ - if (!m_cudaIncludeSearched) - { - m_cudaIncludeSearched = true; - - SLANG_ASSERT(m_cudaIncludePath.getLength() == 0); - - _findCUDAIncludePath(m_cudaIncludePath); - } - - outPath = m_cudaIncludePath; - return m_cudaIncludePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; -} - -SlangResult NVRTCDownstreamCompiler::_findOptixIncludePath(String& outPath) -{ - outPath = String(); - - List rootPaths; - -#if SLANG_WINDOWS_FAMILY - const char* searchPattern = "OptiX SDK *"; - StringBuilder builder; - if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable( - UnownedStringSlice::fromLiteral("PROGRAMDATA"), - builder))) - { - rootPaths.add(Path::combine(builder, "NVIDIA Corporation")); - } -#else - const char* searchPattern = "NVIDIA-OptiX-SDK-*"; - StringBuilder builder; - if (SLANG_SUCCEEDED( - PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("HOME"), builder))) - { - rootPaths.add(builder); - } -#endif - - struct OptixHeaders - { - String path; - SemanticVersion version; - }; - - // Visitor to find Optix headers. - struct Visitor : public Path::Visitor - { - const String& rootPath; - List& optixPaths; - Visitor(const String& rootPath, List& optixPaths) - : rootPath(rootPath), optixPaths(optixPaths) - { - } - void accept(Path::Type type, const UnownedStringSlice& path) SLANG_OVERRIDE - { - if (type != Path::Type::Directory) - return; - - OptixHeaders optixPath; -#if SLANG_WINDOWS_FAMILY - // Paths are expected to look like ".\OptiX SDK X.X.X" - auto versionString = path.subString(path.lastIndexOf(' ') + 1, path.getLength()); -#else - // Paths are expected to look like "./NVIDIA-OptiX-SDK-X.X.X-suffix" - auto versionString = path.subString(0, path.lastIndexOf('-')); - versionString = - versionString.subString(path.lastIndexOf('-') + 1, versionString.getLength()); -#endif - if (SLANG_SUCCEEDED(SemanticVersion::parse(versionString, '.', optixPath.version))) - { - optixPath.path = Path::combine(Path::combine(rootPath, path), "include"); - String optixHeader = Path::combine(optixPath.path, g_optixHeaderName); - if (File::exists(optixHeader)) - { - optixPaths.add(optixPath); - } - } - } - }; - - List optixPaths; - - for (const String& rootPath : rootPaths) - { - Visitor visitor(rootPath, optixPaths); - Path::find(rootPath, searchPattern, &visitor); - } - - // Find newest version - const OptixHeaders* newest = nullptr; - for (Index i = 0; i < optixPaths.getCount(); ++i) - { - if (!newest || optixPaths[i].version > newest->version) - { - newest = &optixPaths[i]; - } - } - - if (newest) - { - outPath = newest->path; - return SLANG_OK; - } - - return SLANG_E_NOT_FOUND; -} - -SlangResult NVRTCDownstreamCompiler::_getOptixIncludePath(String& outPath) -{ - if (!m_optixIncludeSearched) - { - m_optixIncludeSearched = true; - - SLANG_ASSERT(m_optixIncludePath.getLength() == 0); - - _findOptixIncludePath(m_optixIncludePath); - } - - outPath = m_optixIncludePath; - return m_optixIncludePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; -} - SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( const DownstreamCompileOptions& options, CommandLine& ioCmdLine) @@ -868,7 +747,7 @@ SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( } String includePath; - SLANG_RETURN_ON_FAIL(_getCUDAIncludePath(includePath)); + SLANG_RETURN_ON_FAIL(_getIncludePath(includePath)); // Add the found include path ioCmdLine.addArg("-I"); @@ -879,48 +758,6 @@ SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( return SLANG_OK; } -SlangResult NVRTCDownstreamCompiler::_maybeAddOptixSupport( - const DownstreamCompileOptions& options, - CommandLine& ioCmdLine) -{ - // First check if we know if one of the include paths contains optix.h - for (const auto& includePath : options.includePaths) - { - if (m_optixFoundPaths.indexOf(includePath) >= 0) - { - // Okay we have an include path that we know works. - // Just need to enable OptiX in prelude - ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); - return SLANG_OK; - } - } - - // Let's see if one of the paths finds optix.h - for (const auto& curIncludePath : options.includePaths) - { - const String includePath = asString(curIncludePath); - const String checkPath = Path::combine(includePath, g_optixHeaderName); - if (File::exists(checkPath)) - { - m_optixFoundPaths.add(includePath); - // Just need to enable OptiX in prelude - ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); - return SLANG_OK; - } - } - - String includePath; - SLANG_RETURN_ON_FAIL(_getOptixIncludePath(includePath)); - - // Add the found include path - ioCmdLine.addArg("-I"); - ioCmdLine.addArg(includePath); - - ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); - - return SLANG_OK; -} - SlangResult NVRTCDownstreamCompiler::compile( const DownstreamCompileOptions& inOptions, IArtifact** outArtifact) @@ -943,9 +780,6 @@ SlangResult NVRTCDownstreamCompiler::compile( CommandLine cmdLine; - // --dopt option is only available in CUDA 11.7 and later - bool hasDoptOption = m_desc.version >= SemanticVersion(11, 7); - switch (options.debugInfoType) { case DebugInfoType::None: @@ -955,20 +789,12 @@ SlangResult NVRTCDownstreamCompiler::compile( default: { cmdLine.addArg("--device-debug"); - if (hasDoptOption) - { - cmdLine.addArg("--dopt=on"); - } break; } case DebugInfoType::Maximal: { cmdLine.addArg("--device-debug"); cmdLine.addArg("--generate-line-info"); - if (hasDoptOption) - { - cmdLine.addArg("--dopt=on"); - } break; } } @@ -1084,7 +910,48 @@ SlangResult NVRTCDownstreamCompiler::compile( // if (options.pipelineType == PipelineType::RayTracing) { - SLANG_RETURN_ON_FAIL(_maybeAddOptixSupport(options, cmdLine)); + // The device-side OptiX API is accessed through a constellation + // of headers provided by the OptiX SDK, so we need to set an + // include path for the compile that makes those visible. + // + // TODO: The OptiX SDK installer doesn't set any kind of environment + // variable to indicate where the SDK was installed, so we seemingly + // need to probe paths instead. The form of the path will differ + // betwene Windows and Unix-y platforms, and we will need some kind + // of approach to probe multiple versiosn and use the latest. + // + // HACK: For now I'm using the fixed path for the most recent SDK + // release on Windows. This means that OptiX cross-compilation will + // only "work" on a subset of platforms, but that doesn't matter + // for now since it doesn't really "work" at all. + // + cmdLine.addArg("-I"); + cmdLine.addArg("C:/ProgramData/NVIDIA Corporation/OptiX SDK 7.0.0/include/"); + + // The OptiX headers in turn `#include ` and expect that + // to work. We could try to also add in an include path from the CUDA + // SDK (which seems to provide a `stddef.h` in the most recent version), + // but using that version doesn't seem to work (and also bakes in a + // requirement that the user have the CUDA SDK installed in addition + // to the OptiX SDK). + // + // Instead, we will rely on the NVRTC feature that lets us set up + // memory buffers to be used as include files by the we compile. + // We will define a dummy `stddef.h` that includes the bare minimum + // lines required to get the OptiX headers to compile without complaint. + // + // TODO: Confirm that the `LP64` definition here is actually needed. + // + headerIncludeNames.add("stddef.h"); + headers.add("#pragma once\n" + "#define LP64\n"); + + // Finally, we want the CUDA prelude to be able to react to whether + // or not OptiX is required (most notably by `#include`ing the appropriate + // header(s)), so we will insert a preprocessor define to indicate + // the requirement. + // + cmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); } // Add any compiler specific options diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 11c4ab6f45..7964e26d8d 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -20932,8 +20932,6 @@ struct ConstBufferPointer // new aliased bindings for each distinct cast type. // -//@public: - /// Represent the kind of a descriptor type. enum DescriptorKind { @@ -21050,18 +21048,8 @@ ${{{{ } }}}} -/// Represents a bindless handle to a descriptor. A descriptor handle is always an ordinary data type and can be +/// Represents a bindless resource handle. A bindless resource handle is always a concrete type and can be /// declared in any memory location. -/// @remarks Opaque descriptor types such as textures(`Texture2D` etc.), `SamplerState` and buffers (e.g. `StructuredBuffer`) -/// can have undefined size and data representation on many targets. On platforms such as Vulkan and D3D12, descriptors are -/// communicated to the shader code by calling the host side API to write the descriptor into a descriptor set or table, instead -/// of directly writing bytes into an ordinary GPU accessible buffer. As a result, oapque handle types cannot be used in places -/// that refer to a ordinary buffer location, such as as element types of a `StructuredBuffer`. -/// However, a `DescriptorHandle` stores a handle (or address) to the actual descriptor, and is always an ordinary data type -/// that can be manipulated directly in the shader code. This gives the developer the flexibility to embed and pass around descriptor -/// parameters throughout the code, to enable cleaner modular designs. -/// See [User Guide](https://shader-slang.com/slang/user-guide/convenience-features.html#descriptorhandle-for-bindless-descriptor-access) -/// for more information on how to use `DescriptorHandle` in your code. __magic_type(DescriptorHandleType) __intrinsic_type($(kIROp_DescriptorHandleType)) struct DescriptorHandle : IComparable @@ -21152,8 +21140,6 @@ extern T getDescriptorFromHandle(DescriptorHandle handle __intrinsic_op($(kIROp_NonUniformResourceIndex)) DescriptorHandle nonuniform(DescriptorHandle ptr); -//@hidden: - __glsl_version(450) __glsl_extension(GL_ARB_shader_clock) [require(glsl_spirv, GL_ARB_shader_clock)] diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index ee29750a6a..f5dd86df15 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -973,14 +973,9 @@ class GLSLLayoutLocalSizeAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* extents[3]; - - bool axisIsSpecConstId[3]; - - // References to specialization constants, for defining the number of - // threads with them. If set, the corresponding axis is set to nullptr - // above. - DeclRef specConstExtents[3]; + IntVal* x; + IntVal* y; + IntVal* z; }; class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute @@ -1043,12 +1038,9 @@ class NumThreadsAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* extents[3]; - - // References to specialization constants, for defining the number of - // threads with them. If set, the corresponding axis is set to nullptr - // above. - DeclRef specConstExtents[3]; + IntVal* x; + IntVal* y; + IntVal* z; }; class WaveSizeAttribute : public Attribute diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 3ef1e8f3be..b3e30dbc23 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1656,8 +1656,6 @@ struct SemanticsVisitor : public SemanticsContext void visitModifier(Modifier*); - DeclRef tryGetIntSpecializationConstant(Expr* expr); - AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope); bool hasIntArgs(Attribute* attr, int numArgs); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 6e451b5cf9..3723c98f86 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -114,36 +114,6 @@ void SemanticsVisitor::visitModifier(Modifier*) // Do nothing with modifiers for now } -DeclRef SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr) -{ - // First type-check the expression as normal - expr = CheckExpr(expr); - - if (IsErrorExpr(expr)) - return DeclRef(); - - if (!isScalarIntegerType(expr->type)) - return DeclRef(); - - auto specConstVar = as(expr); - if (!specConstVar || !specConstVar->declRef) - return DeclRef(); - - auto decl = specConstVar->declRef.getDecl(); - if (!decl) - return DeclRef(); - - for (auto modifier : decl->modifiers) - { - if (as(modifier) || as(modifier)) - { - return specConstVar->declRef.as(); - } - } - - return DeclRef(); -} - static bool _isDeclAllowedAsAttribute(DeclRef declRef) { if (as(declRef.getDecl())) @@ -380,6 +350,8 @@ Modifier* SemanticsVisitor::validateAttribute( { SLANG_ASSERT(attr->args.getCount() == 3); + IntVal* values[3]; + for (int i = 0; i < 3; ++i) { IntVal* value = nullptr; @@ -387,14 +359,6 @@ Modifier* SemanticsVisitor::validateAttribute( auto arg = attr->args[i]; if (arg) { - auto specConstDecl = tryGetIntSpecializationConstant(arg); - if (specConstDecl) - { - numThreadsAttr->extents[i] = nullptr; - numThreadsAttr->specConstExtents[i] = specConstDecl; - continue; - } - auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { @@ -426,8 +390,12 @@ Modifier* SemanticsVisitor::validateAttribute( { value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - numThreadsAttr->extents[i] = value; + values[i] = value; } + + numThreadsAttr->x = values[0]; + numThreadsAttr->y = values[1]; + numThreadsAttr->z = values[2]; } else if (auto waveSizeAttr = as(attr)) { @@ -1863,24 +1831,15 @@ Modifier* SemanticsVisitor::checkModifier( { SLANG_ASSERT(attr->args.getCount() == 3); - // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl. - auto decl = as(syntaxNode); - SLANG_ASSERT(decl); + IntVal* values[3]; for (int i = 0; i < 3; ++i) { - attr->extents[i] = nullptr; + IntVal* value = nullptr; auto arg = attr->args[i]; if (arg) { - auto specConstDecl = tryGetIntSpecializationConstant(arg); - if (specConstDecl) - { - attr->specConstExtents[i] = specConstDecl; - continue; - } - auto intValue = checkConstantIntVal(arg); if (!intValue) { @@ -1888,45 +1847,7 @@ Modifier* SemanticsVisitor::checkModifier( } if (auto cintVal = as(intValue)) { - if (attr->axisIsSpecConstId[i]) - { - // This integer should actually be a reference to a - // specialization constant with this ID. - Int specConstId = cintVal->getValue(); - - for (auto member : decl->parentDecl->members) - { - auto constantId = member->findModifier(); - if (constantId) - { - SLANG_ASSERT(constantId->args.getCount() == 1); - auto id = checkConstantIntVal(constantId->args[0]); - if (id->getValue() == specConstId) - { - attr->specConstExtents[i] = - DeclRef(member->getDefaultDeclRef()); - break; - } - } - } - - // If not found, we need to create a new specialization - // constant with this ID. - if (!attr->specConstExtents[i]) - { - auto specConstVarDecl = getASTBuilder()->create(); - auto constantIdModifier = - getASTBuilder()->create(); - constantIdModifier->location = (int32_t)specConstId; - specConstVarDecl->type.type = getASTBuilder()->getIntType(); - addModifier(specConstVarDecl, constantIdModifier); - decl->parentDecl->addMember(specConstVarDecl); - attr->specConstExtents[i] = - DeclRef(specConstVarDecl->getDefaultDeclRef()); - } - continue; - } - else if (cintVal->getValue() < 1) + if (cintVal->getValue() < 1) { getSink()->diagnose( attr, @@ -1935,13 +1856,18 @@ Modifier* SemanticsVisitor::checkModifier( return nullptr; } } - attr->extents[i] = intValue; + value = intValue; } else { - attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } + values[i] = value; } + + attr->x = values[0]; + attr->y = values[1]; + attr->z = values[2]; } // Default behavior is to leave things as they are, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d86cd8be2a..821a895bc7 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2060,7 +2060,7 @@ DIAGNOSTIC( DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") DIAGNOSTIC(41001, Error, recursiveType, "type '$0' contains cyclic reference to itself.") -DIAGNOSTIC(41010, Warning, missingReturn, "non-void function does not return in all cases") +DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'void' function") DIAGNOSTIC( 41011, Error, @@ -2459,12 +2459,6 @@ DIAGNOSTIC( Error, unsupportedTargetIntrinsic, "intrinsic operation '$0' is not supported for the current target.") -DIAGNOSTIC( - 55205, - Error, - unsupportedSpecializationConstantForNumThreads, - "Specialization constants are not supported in the 'numthreads' attribute for the current " - "target.") DIAGNOSTIC( 56001, Error, diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index d3a9359ff2..7b51495e2b 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -295,48 +295,14 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) } -IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( - IRFunc* func, - Int outNumThreads[kThreadGroupAxisCount]) -{ - Int specializationConstantIds[kThreadGroupAxisCount]; - IRNumThreadsDecoration* decor = - getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds); - - for (auto id : specializationConstantIds) - { - if (id >= 0) - { - getSink()->diagnose(decor, Diagnostics::unsupportedSpecializationConstantForNumThreads); - break; - } - } - return decor; -} - /* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( IRFunc* func, - Int outNumThreads[kThreadGroupAxisCount], - Int outSpecializationConstantIds[kThreadGroupAxisCount]) + Int outNumThreads[kThreadGroupAxisCount]) { IRNumThreadsDecoration* decor = func->findDecoration(); - for (int i = 0; i < kThreadGroupAxisCount; ++i) + for (int i = 0; i < 3; ++i) { - if (!decor) - { - outNumThreads[i] = 1; - outSpecializationConstantIds[i] = -1; - } - else if (auto specConst = as(decor->getOperand(i))) - { - outNumThreads[i] = 1; - outSpecializationConstantIds[i] = getSpecializationConstantId(specConst); - } - else - { - outNumThreads[i] = Int(getIntVal(decor->getOperand(i))); - outSpecializationConstantIds[i] = -1; - } + outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1; } return decor; } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 1354b7cbd8..e5080f731b 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -500,19 +500,10 @@ class CLikeSourceEmitter : public SourceEmitterBase /// different. Returns an empty slice if not a built in type static UnownedStringSlice getDefaultBuiltinTypeName(IROp op); - /// Finds the IRNumThreadsDecoration and gets the size from that or sets all - /// dimensions to 1 - IRNumThreadsDecoration* getComputeThreadGroupSize( - IRFunc* func, - Int outNumThreads[kThreadGroupAxisCount]); - - /// Finds the IRNumThreadsDecoration and gets the size from that or sets all - /// dimensions to 1. If specialization constants are used for an axis, their - /// IDs is reported in non-negative entries of outSpecializationConstantIds. + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1 static IRNumThreadsDecoration* getComputeThreadGroupSize( IRFunc* func, - Int outNumThreads[kThreadGroupAxisCount], - Int outSpecializationConstantIds[kThreadGroupAxisCount]); + Int outNumThreads[kThreadGroupAxisCount]); /// Finds the IRWaveSizeDecoration and gets the size from that. static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int* outWaveSize); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 0dab07cfce..23fff37acb 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -1335,8 +1335,7 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( auto emitLocalSizeLayout = [&]() { Int sizeAlongAxis[kThreadGroupAxisCount]; - Int specializationConstantIds[kThreadGroupAxisCount]; - getComputeThreadGroupSize(irFunc, sizeAlongAxis, specializationConstantIds); + getComputeThreadGroupSize(irFunc, sizeAlongAxis); m_writer->emit("layout("); char const* axes[] = {"x", "y", "z"}; @@ -1346,17 +1345,8 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( m_writer->emit(", "); m_writer->emit("local_size_"); m_writer->emit(axes[ii]); - - if (specializationConstantIds[ii] >= 0) - { - m_writer->emit("_id = "); - m_writer->emit(specializationConstantIds[ii]); - } - else - { - m_writer->emit(" = "); - m_writer->emit(sizeAlongAxis[ii]); - } + m_writer->emit(" = "); + m_writer->emit(sizeAlongAxis[ii]); } m_writer->emit(") in;\n"); }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 2cf84a8540..068e1563ca 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4353,36 +4353,23 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // [3.6. Execution Mode]: LocalSize case kIROp_NumThreadsDecoration: { + // TODO: The `LocalSize` execution mode option requires + // literal values for the X,Y,Z thread-group sizes. + // There is a `LocalSizeId` variant that takes ``s + // for those sizes, and we should consider using that + // and requiring the appropriate capabilities + // if any of the operands to the decoration are not + // literals (in a future where we support non-literals + // in those positions in the Slang IR). + // auto numThreads = cast(decoration); - if (numThreads->getXSpecConst() || numThreads->getYSpecConst() || - numThreads->getZSpecConst()) - { - // If any of the dimensions needs an ID, we need to emit - // all dimensions as an ID due to how LocalSizeId works. - int32_t ids[3]; - for (int i = 0; i < 3; ++i) - ids[i] = ensureInst(numThreads->getOperand(i))->id; - - // LocalSizeId is supported from SPIR-V 1.2 onwards without - // any extra capabilities. - requireSPIRVExecutionMode( - decoration, - dstID, - SpvExecutionModeLocalSizeId, - SpvLiteralInteger::from32(int32_t(ids[0])), - SpvLiteralInteger::from32(int32_t(ids[1])), - SpvLiteralInteger::from32(int32_t(ids[2]))); - } - else - { - requireSPIRVExecutionMode( - decoration, - dstID, - SpvExecutionModeLocalSize, - SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); - } + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSize, + SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); } break; case kIROp_MaxVertexCountDecoration: @@ -7990,18 +7977,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { if (m_executionModes[entryPoint].add(executionMode)) { - SpvOp execModeOp = SpvOpExecutionMode; - if (executionMode == SpvExecutionModeLocalSizeId || - executionMode == SpvExecutionModeLocalSizeHintId || - executionMode == SpvExecutionModeSubgroupsPerWorkgroupId) - { - execModeOp = SpvOpExecutionModeId; - } - emitInst( getSection(SpvLogicalSectionID::ExecutionModes), parentInst, - execModeOp, + SpvOpExecutionMode, entryPoint, executionMode, ops...); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 3237ba3b26..65ce69877f 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -528,12 +528,10 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF // If primal parameter is mutable, we need to pass in a temp var. auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); - // If the parameter is not a pure 'out' param, we also need to setup the initial - // value of the temp var, otherwise the temp var will be uninitialized which could - // cause undefined behavior in the primal function. - // - if (!as(primalParamType)) - builder.emitStore(tempVar, primalArg); + // We also need to setup the initial value of the temp var, otherwise + // the temp var will be uninitialized which could cause undefined behavior + // in the primal function. + builder.emitStore(tempVar, primalArg); primalArgs.add(tempVar); } diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index 372ef298e7..1c833a2948 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -279,16 +279,6 @@ struct CollectGlobalUniformParametersContext continue; } - // NumThreadsDecoration may sometimes be the user for a global - // parameter. This occurs when the parameter was supposed to be - // a specialization constant, but isn't due to that not being - // supported for the target. These can be skipped here and - // diagnosed later. - if (as(user)) - { - continue; - } - // For each use site for the global parameter, we will // insert new code right before the instruction that uses // the parameter. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f46586aa2b..a58c2e900c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -570,7 +570,6 @@ struct IRInstanceDecoration : IRDecoration IRIntLit* getCount() { return cast(getOperand(0)); } }; -struct IRGlobalParam; struct IRNumThreadsDecoration : IRDecoration { enum @@ -579,13 +578,11 @@ struct IRNumThreadsDecoration : IRDecoration }; IR_LEAF_ISA(NumThreadsDecoration) - IRIntLit* getX() { return as(getOperand(0)); } - IRIntLit* getY() { return as(getOperand(1)); } - IRIntLit* getZ() { return as(getOperand(2)); } + IRIntLit* getX() { return cast(getOperand(0)); } + IRIntLit* getY() { return cast(getOperand(1)); } + IRIntLit* getZ() { return cast(getOperand(2)); } - IRGlobalParam* getXSpecConst() { return as(getOperand(0)); } - IRGlobalParam* getYSpecConst() { return as(getOperand(1)); } - IRGlobalParam* getZSpecConst() { return as(getOperand(2)); } + IRIntLit* getExtentAlongAxis(int axis) { return cast(getOperand(axis)); } }; struct IRWaveSizeDecoration : IRDecoration diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 6840197721..e267e8343c 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -190,7 +190,7 @@ IRInst* emitCalcGroupExtents(IRBuilder& builder, IRFunc* entryPoint, IRVectorTyp for (int axis = 0; axis < kAxisCount; axis++) { - auto litValue = as(numThreadsDecor->getOperand(axis)); + auto litValue = as(numThreadsDecor->getExtentAlongAxis(axis)); if (!litValue) return nullptr; @@ -1434,20 +1434,6 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize // groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type); - if (!groupExtents) - { - m_sink->diagnose( - m_entryPointFunc, - Diagnostics::unsupportedSpecializationConstantForNumThreads); - - // Fill in placeholder values. - static const int kAxisCount = 3; - IRInst* groupExtentAlongAxis[kAxisCount] = {}; - for (int axis = 0; axis < kAxisCount; axis++) - groupExtentAlongAxis[axis] = builder.getIntValue(uint3Type->getElementType(), 1); - groupExtents = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); - } - dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents); diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 68d79617a8..90d30dcc77 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -490,19 +490,11 @@ static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst) bool isFalseBranchTrivial = false; if (isTrivialIfElse(ifElseInst, isTrueBranchTrivial, isFalseBranchTrivial)) { - // If either branch of `if-else` is a trivial jump into after block, + // If both branches of `if-else` are trivial jumps into after block, // we can get rid of the entire conditional branch and replace it // with a jump into the after block. - IRUnconditionalBranch* termInst = - as(ifElseInst->getTrueBlock()->getTerminator()); - if (!termInst || (termInst->getTargetBlock() != ifElseInst->getAfterBlock())) + if (auto termInst = as(ifElseInst->getTrueBlock()->getTerminator())) { - termInst = as(ifElseInst->getFalseBlock()->getTerminator()); - } - - if (termInst) - { - SLANG_ASSERT(termInst->getTargetBlock() == ifElseInst->getAfterBlock()); List args; for (UInt i = 0; i < termInst->getArgCount(); i++) args.add(termInst->getArg(i)); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index a9b0d44121..40cd40758a 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -71,42 +71,6 @@ struct SpecializationContext module->getContainerPool().free(&cleanInsts); } - bool isUnsimplifiedArithmeticInst(IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_Add: - case kIROp_Sub: - case kIROp_Mul: - case kIROp_Div: - case kIROp_Neg: - case kIROp_Not: - case kIROp_Eql: - case kIROp_Neq: - case kIROp_Leq: - case kIROp_Geq: - case kIROp_Less: - case kIROp_IRem: - case kIROp_FRem: - case kIROp_Greater: - case kIROp_Lsh: - case kIROp_Rsh: - case kIROp_BitAnd: - case kIROp_BitOr: - case kIROp_BitXor: - case kIROp_BitNot: - case kIROp_BitCast: - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: - case kIROp_IntCast: - case kIROp_FloatCast: - case kIROp_Select: - return true; - default: - return false; - } - } - // An instruction is then fully specialized if and only // if it is in our set. // @@ -169,14 +133,6 @@ struct SpecializationContext return areAllOperandsFullySpecialized(inst); } - if (isUnsimplifiedArithmeticInst(inst)) - { - // For arithmetic insts, we want to wait for simplification before specialization, - // since different insts can simplify to the same value. - // - return false; - } - // The default case is that a global value is always specialized. if (inst->getParent() == module->getModuleInst()) { @@ -1136,7 +1092,6 @@ struct SpecializationContext { this->changed = true; eliminateDeadCode(module->getModuleInst()); - applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); } // Once the work list has gone dry, we should have the invariant diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 077cdb98d0..a44e16a7ce 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -282,11 +282,10 @@ struct GlobalVarTranslationContext if (!numthreadsDecor) return; builder.setInsertBefore(use->getUser()); - IRInst* values[3] = { - numthreadsDecor->getOperand(0), - numthreadsDecor->getOperand(1), - numthreadsDecor->getOperand(2)}; - + IRInst* values[] = { + numthreadsDecor->getExtentAlongAxis(0), + numthreadsDecor->getExtentAlongAxis(1), + numthreadsDecor->getExtentAlongAxis(2)}; auto workgroupSize = builder.emitMakeVector( builder.getVectorType(builder.getIntType(), 3), 3, @@ -329,10 +328,10 @@ struct GlobalVarTranslationContext if (!firstBlock) continue; builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - IRInst* args[3] = { - numthreadsDecor->getOperand(0), - numthreadsDecor->getOperand(1), - numthreadsDecor->getOperand(2)}; + IRInst* args[] = { + numthreadsDecor->getExtentAlongAxis(0), + numthreadsDecor->getExtentAlongAxis(1), + numthreadsDecor->getExtentAlongAxis(2)}; auto workgroupSize = builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args); builder.emitStore(globalVar, workgroupSize); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index d05e1db7d4..c753600a7c 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1973,17 +1973,4 @@ IRType* getIRVectorBaseType(IRType* type) return as(type)->getElementType(); } -Int getSpecializationConstantId(IRGlobalParam* param) -{ - auto layout = findVarLayout(param); - if (!layout) - return 0; - - auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant); - if (!offset) - return 0; - - return offset->getOffset(); -} - } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 666ac71c03..e23aeb6180 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -373,8 +373,6 @@ inline bool isSPIRV(CodeGenTarget codeGenTarget) int getIRVectorElementSize(IRType* type); IRType* getIRVectorBaseType(IRType* type); -Int getSpecializationConstantId(IRGlobalParam* param); - } // namespace Slang #endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 0863457198..e82fc03fde 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7625,29 +7625,12 @@ struct DeclLoweringVisitor : DeclVisitor { verifyComputeDerivativeGroupModifier = true; getAllEntryPointsNoOverride(entryPoints); - - LoweredValInfo extents[3]; - - for (int i = 0; i < 3; ++i) - { - extents[i] = layoutLocalSizeAttr->specConstExtents[i] - ? emitDeclRef( - context, - layoutLocalSizeAttr->specConstExtents[i], - lowerType( - context, - getType( - context->astBuilder, - layoutLocalSizeAttr->specConstExtents[i]))) - : lowerVal(context, layoutLocalSizeAttr->extents[i]); - } - for (auto d : entryPoints) as(getBuilder()->addNumThreadsDecoration( d, - getSimpleVal(context, extents[0]), - getSimpleVal(context, extents[1]), - getSimpleVal(context, extents[2]))); + getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->x)), + getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->y)), + getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->z)))); } else if (as(modifier)) { @@ -10353,28 +10336,11 @@ struct DeclLoweringVisitor : DeclVisitor } else if (auto numThreadsAttr = as(modifier)) { - LoweredValInfo extents[3]; - - for (int i = 0; i < 3; ++i) - { - extents[i] = numThreadsAttr->specConstExtents[i] - ? emitDeclRef( - context, - numThreadsAttr->specConstExtents[i], - lowerType( - context, - getType( - context->astBuilder, - numThreadsAttr->specConstExtents[i]))) - : lowerVal(context, numThreadsAttr->extents[i]); - } - numThreadsDecor = as(getBuilder()->addNumThreadsDecoration( irFunc, - getSimpleVal(context, extents[0]), - getSimpleVal(context, extents[1]), - getSimpleVal(context, extents[2]))); - numThreadsDecor->sourceLoc = numThreadsAttr->loc; + getSimpleVal(context, lowerVal(context, numThreadsAttr->x)), + getSimpleVal(context, lowerVal(context, numThreadsAttr->y)), + getSimpleVal(context, lowerVal(context, numThreadsAttr->z)))); } else if (auto waveSizeAttr = as(modifier)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 6ae41a2eb9..c275a868b5 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8437,9 +8437,7 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) int localSizeIndex = -1; if (nameText.startsWith(localSizePrefix) && - (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1 || - (nameText.endsWith("_id") && - (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4)))) + nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1) { char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1]; localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1; @@ -8453,8 +8451,6 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) numThreadsAttrib->args.setCount(3); for (auto& i : numThreadsAttrib->args) i = nullptr; - for (auto& b : numThreadsAttrib->axisIsSpecConstId) - b = false; // Just mark the loc and name from the first in the list numThreadsAttrib->keywordName = getName(parser, "numthreads"); @@ -8471,11 +8467,6 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) } numThreadsAttrib->args[localSizeIndex] = expr; - - // We can't resolve the specialization constant declaration - // here, because it may not even exist. IDs pointing to unnamed - // specialization constants are allowed in GLSL. - numThreadsAttrib->axisIsSpecConstId[localSizeIndex] = nameText.endsWith("_id"); } } else if (nameText == "derivative_group_quadsNV") diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index d1adfedc0b..d235c82703 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -4033,14 +4033,18 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier(); if (numThreadsAttribute) { - for (int i = 0; i < 3; ++i) - { - if (auto cint = - entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->extents[i])) - sizeAlongAxis[i] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->extents[i]) - sizeAlongAxis[i] = 0; - } + if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x)) + sizeAlongAxis[0] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->x) + sizeAlongAxis[0] = 0; + if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y)) + sizeAlongAxis[1] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->y) + sizeAlongAxis[1] = 0; + if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z)) + sizeAlongAxis[2] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->z) + sizeAlongAxis[2] = 0; } // diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index efc1c6fd11..5ec1996581 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1493,8 +1493,7 @@ DeclRef Linkage::specializeWithArgTypes( DiagnosticSink* sink) { SemanticsVisitor visitor(getSemanticsForReflection()); - SemanticsVisitor::ExprLocalScope scope; - visitor = visitor.withSink(sink).withExprLocalScope(&scope); + visitor = visitor.withSink(sink); SLANG_AST_BUILDER_RAII(getASTBuilder()); diff --git a/tests/autodiff/out-parameters-2.slang b/tests/autodiff/out-parameters-2.slang deleted file mode 100644 index b4c4b07c61..0000000000 --- a/tests/autodiff/out-parameters-2.slang +++ /dev/null @@ -1,49 +0,0 @@ -//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type -//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type - -//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer -RWStructuredBuffer outputBuffer; - -typedef DifferentialPair dpfloat; - -struct Foo : IDifferentiable -{ - float a; - int b; -} - -[PreferCheckpoint] -float k() -{ - return outputBuffer[3] + 1; -} - -[Differentiable] -void h(float x, float y, out Foo result) -{ - float p = no_diff k(); - float m = x + y + p; - float n = x - y; - float r = m * n + 2 * x * y; - - result = {r, 2}; -} - -[numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) -{ - float x = 2.0; - float y = 3.5; - float dx = 1.0; - float dy = 0.5; - - dpfloat dresult; - dpfloat dpx = diffPair(x); - dpfloat dpy = diffPair(y); - Foo.Differential dFoo; - dFoo.a = 1.0; - bwd_diff(h)(dpx, dpy, dFoo); - - outputBuffer[0] = dpx.d; // CHECK: 12.0 - outputBuffer[1] = dpy.d; // CHECK: -4.0 -} \ No newline at end of file diff --git a/tests/bugs/simplify-if-else.slang b/tests/bugs/simplify-if-else.slang deleted file mode 100644 index 8719a15995..0000000000 --- a/tests/bugs/simplify-if-else.slang +++ /dev/null @@ -1,26 +0,0 @@ -//TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target hlsl -//CHECK: computeMain - -//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer -RWStructuredBuffer outputBuffer; - -[numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) -{ - vector vvv = vector(0); - float32_t ret = 0.0f; - if (vvv.y < 1.0f) - { - ret = 1.0f; - } - else - { - if (vvv.y > 1.0f && outputBuffer[3] == 3) - { - ret = 0.0f; - } else { - if (true) {} - } - } - outputBuffer[int(dispatchThreadID.x)] = int(ret); -} diff --git a/tests/diagnostics/missing-return.slang.expected b/tests/diagnostics/missing-return.slang.expected index 7626665241..e41e756ff4 100644 --- a/tests/diagnostics/missing-return.slang.expected +++ b/tests/diagnostics/missing-return.slang.expected @@ -1,9 +1,9 @@ result code = 0 standard error = { -tests/diagnostics/missing-return.slang(7): warning 41010: non-void function does not return in all cases +tests/diagnostics/missing-return.slang(7): warning 41010: control flow may reach end of non-'void' function int bad(int a, int b) ^~~ -tests/diagnostics/missing-return.slang(14): warning 41010: non-void function does not return in all cases +tests/diagnostics/missing-return.slang(14): warning 41010: control flow may reach end of non-'void' function int alsoBad(int a, int b) ^~~~~~~ } diff --git a/tests/glsl/compute-shader-layout-id.slang b/tests/glsl/compute-shader-layout-id.slang deleted file mode 100644 index bee8137d82..0000000000 --- a/tests/glsl/compute-shader-layout-id.slang +++ /dev/null @@ -1,19 +0,0 @@ -//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry main -allow-glsl -#version 450 - -[vk::constant_id(1)] -const int constValue1 = 0; - -[vk::constant_id(2)] -const int constValue3 = 5; - -// CHECK-DAG: OpExecutionModeId %main LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] -// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 -// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 -// CHECK-DAG: OpDecorate %[[C2]] SpecId 2 - -layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = constValue3) in; -void main() -{ -} - diff --git a/tests/spirv/spec-constant-numthreads.slang b/tests/spirv/spec-constant-numthreads.slang deleted file mode 100644 index 5c133219cf..0000000000 --- a/tests/spirv/spec-constant-numthreads.slang +++ /dev/null @@ -1,35 +0,0 @@ -//TEST:SIMPLE(filecheck=GLSL): -target glsl -allow-glsl -//TEST:SIMPLE(filecheck=GLSL): -target glsl -//TEST:SIMPLE(filecheck=CHECK): -target spirv -allow-glsl -//TEST:SIMPLE(filecheck=CHECK): -target spirv - -// CHECK-DAG: OpExecutionModeId %computeMain1 LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] -// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 -// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 -// CHECK-DAG: %[[C2]] = OpConstant %int 4 -// CHECK-DAG: OpStore %{{.*}} %[[C0]] -// CHECK-DAG: OpStore %{{.*}} %[[C1]] -// CHECK-DAG: OpStore %{{.*}} %[[C2]] - -// GLSL-DAG: layout(constant_id = 1) -// GLSL-DAG: int constValue0_0 = 0; -// GLSL-DAG: layout(constant_id = 0) -// GLSL-DAG: int constValue1_0 = 0; -// GLSL-DAG: layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = 4) in; - -[vk::specialization_constant] -const int constValue0 = 0; - -[vk::constant_id(0)] -const int constValue1 = 0; - -RWStructuredBuffer outputBuffer; - -[numthreads(constValue0, constValue1, 4)] -void computeMain1() -{ - int3 size = WorkgroupSize(); - outputBuffer[0] = size.x; - outputBuffer[1] = size.y; - outputBuffer[2] = size.z; -} From 99869d573a46dadeb24445405f5a1e37a8e03d0d Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 14 Jan 2025 20:20:09 -0500 Subject: [PATCH 14/18] Merge remote-tracking branch 'origin/master' --- cmake/SlangTarget.cmake | 12 +- docs/cuda-target.md | 11 + docs/user-guide/03-convenience-features.md | 8 +- external/slang-rhi | 2 +- source/compiler-core/slang-nvrtc-compiler.cpp | 257 +++++++++++++----- source/slang/hlsl.meta.slang | 16 +- source/slang/slang-ast-modifier.h | 20 +- source/slang/slang-check-impl.h | 2 + source/slang/slang-check-modifier.cpp | 108 ++++++-- source/slang/slang-diagnostic-defs.h | 8 +- source/slang/slang-emit-c-like.cpp | 40 ++- source/slang/slang-emit-c-like.h | 13 +- source/slang/slang-emit-glsl.cpp | 16 +- source/slang/slang-emit-spirv.cpp | 55 ++-- source/slang/slang-ir-autodiff-rev.cpp | 10 +- .../slang-ir-collect-global-uniforms.cpp | 10 + source/slang/slang-ir-insts.h | 11 +- .../slang-ir-legalize-varying-params.cpp | 16 +- source/slang/slang-ir-simplify-cfg.cpp | 12 +- source/slang/slang-ir-specialize.cpp | 45 +++ .../slang-ir-translate-glsl-global-var.cpp | 17 +- source/slang/slang-ir-util.cpp | 13 + source/slang/slang-ir-util.h | 2 + source/slang/slang-lower-to-ir.cpp | 46 +++- source/slang/slang-parser.cpp | 11 +- source/slang/slang-reflection-api.cpp | 20 +- source/slang/slang.cpp | 3 +- tests/autodiff/out-parameters-2.slang | 49 ++++ tests/bugs/simplify-if-else.slang | 26 ++ .../diagnostics/missing-return.slang.expected | 4 +- tests/glsl/compute-shader-layout-id.slang | 19 ++ tests/spirv/spec-constant-numthreads.slang | 35 +++ 32 files changed, 755 insertions(+), 162 deletions(-) create mode 100644 tests/autodiff/out-parameters-2.slang create mode 100644 tests/bugs/simplify-if-else.slang create mode 100644 tests/glsl/compute-shader-layout-id.slang create mode 100644 tests/spirv/spec-constant-numthreads.slang diff --git a/cmake/SlangTarget.cmake b/cmake/SlangTarget.cmake index 45e7cf1e1d..eae5cf35e4 100644 --- a/cmake/SlangTarget.cmake +++ b/cmake/SlangTarget.cmake @@ -505,10 +505,14 @@ function(slang_add_target dir type) endif() install( TARGETS ${target} ${export_args} - ARCHIVE DESTINATION ${archive_subdir} ${ARGN} - LIBRARY DESTINATION ${library_subdir} ${ARGN} - RUNTIME DESTINATION ${runtime_subdir} ${ARGN} - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ${ARGN} + ARCHIVE DESTINATION ${archive_subdir} + ${ARGN} + LIBRARY DESTINATION ${library_subdir} + ${ARGN} + RUNTIME DESTINATION ${runtime_subdir} + ${ARGN} + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + ${ARGN} ) endmacro() diff --git a/docs/cuda-target.md b/docs/cuda-target.md index a80dc59f9c..241f253fbe 100644 --- a/docs/cuda-target.md +++ b/docs/cuda-target.md @@ -301,6 +301,17 @@ There is potential to calculate the lane id using the [numthreads] markup in Sla * Intrinsics which only work in pixel shaders + QuadXXXX intrinsics +OptiX Support +============= + +Slang supports OptiX for raytracing. To compile raytracing programs, NVRTC must have access to the `optix.h` and dependent files that are typically distributed as part of the OptiX SDK. When Slang detects the use of raytracing in source, it will define `SLANG_CUDA_ENABLE_OPTIX` when `slang-cuda-prelude.h` is included. This will in turn try to include `optix.h`. + +Slang tries several mechanisms to locate `optix.h` when NVRTC is initiated. The first mechanism is to look in the include paths that are passed to Slang. If `optix.h` can be found in one of these paths, no more searching will be performed. + +If this fails, the default OptiX SDK install locations are searched. On Windows this is `%{PROGRAMDATA}\NVIDIA Corporation\OptiX SDK X.X.X\include`. On Linux this is `${HOME}/NVIDIA-OptiX-SDK-X.X.X-suffix`. + +If OptiX headers cannot be found, compilation will fail. + Limitations =========== diff --git a/docs/user-guide/03-convenience-features.md b/docs/user-guide/03-convenience-features.md index e6b337eed1..29e8fd2aaa 100644 --- a/docs/user-guide/03-convenience-features.md +++ b/docs/user-guide/03-convenience-features.md @@ -149,7 +149,7 @@ int rs = foo.staticMethod(a,b); ### Mutability of member function -For GPU performance considerations, the `this` argument in a member function is immutable by default. If you modify the content in `this` argument, the modification will be discarded after the call and does not affect the input object. If you intend to define a member function that mutates the object, use `[mutating]` attribute on the member function as shown in the following example. +For GPU performance considerations, the `this` argument in a member function is immutable by default. Attempting to modify `this` will result in a compile error. If you intend to define a member function that mutates the object, use `[mutating]` attribute on the member function as shown in the following example. ```hlsl struct Foo @@ -159,14 +159,14 @@ struct Foo [mutating] void setCount(int x) { count = x; } - void setCount2(int x) { count = x; } + // This would fail to compile. + // void setCount2(int x) { count = x; } } void test() { Foo f; - f.setCount(1); // f.count is 1 after the call. - f.setCount2(2); // f.count is still 1 after the call. + f.setCount(1); // Compiles } ``` diff --git a/external/slang-rhi b/external/slang-rhi index 19bc575bc1..d1f2718165 160000 --- a/external/slang-rhi +++ b/external/slang-rhi @@ -1 +1 @@ -Subproject commit 19bc575bc193e92210649d6d84ac202b199b29af +Subproject commit d1f2718165d0d540c8fc1eacf20b9edd2d6faac0 diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp index c5ccc8e23a..0042ad7085 100644 --- a/source/compiler-core/slang-nvrtc-compiler.cpp +++ b/source/compiler-core/slang-nvrtc-compiler.cpp @@ -127,11 +127,14 @@ class NVRTCDownstreamCompiler : public DownstreamCompilerBase nvrtcProgram m_program; }; - SlangResult _findIncludePath(String& outIncludePath); + SlangResult _findCUDAIncludePath(String& outIncludePath); + SlangResult _getCUDAIncludePath(String& outIncludePath); - SlangResult _getIncludePath(String& outIncludePath); + SlangResult _findOptixIncludePath(String& outIncludePath); + SlangResult _getOptixIncludePath(String& outIncludePath); SlangResult _maybeAddHalfSupport(const CompileOptions& options, CommandLine& ioCmdLine); + SlangResult _maybeAddOptixSupport(const CompileOptions& options, CommandLine& ioCmdLine); #define SLANG_NVTRC_MEMBER_FUNCS(ret, name, params) ret(*m_##name) params; @@ -140,9 +143,16 @@ class NVRTCDownstreamCompiler : public DownstreamCompilerBase // Holds list of paths passed in where cuda_fp16.h is found. Does *NOT* include cuda_fp16.h. List m_cudaFp16FoundPaths; - bool m_includeSearched = false; + bool m_cudaIncludeSearched = false; // Holds location of where include (for cuda_fp16.h) is found. - String m_includePath; + String m_cudaIncludePath; + + // Holds list of paths passed in where optix.h is found. Does *NOT* include optix.h. + List m_optixFoundPaths; + + bool m_optixIncludeSearched = false; + // Holds location of where include (for optix.h) is found. + String m_optixIncludePath; ComPtr m_sharedLibrary; }; @@ -602,21 +612,8 @@ static SlangResult _findNVRTC(NVRTCPathVisitor& visitor) } static const UnownedStringSlice g_fp16HeaderName = UnownedStringSlice::fromLiteral("cuda_fp16.h"); +static const UnownedStringSlice g_optixHeaderName = UnownedStringSlice::fromLiteral("optix.h"); -SlangResult NVRTCDownstreamCompiler::_getIncludePath(String& outPath) -{ - if (!m_includeSearched) - { - m_includeSearched = true; - - SLANG_ASSERT(m_includePath.getLength() == 0); - - _findIncludePath(m_includePath); - } - - outPath = m_includePath; - return m_includePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; -} SlangResult _findFileInIncludePath( const String& path, @@ -650,7 +647,7 @@ SlangResult _findFileInIncludePath( return SLANG_E_NOT_FOUND; } -SlangResult NVRTCDownstreamCompiler::_findIncludePath(String& outPath) +SlangResult NVRTCDownstreamCompiler::_findCUDAIncludePath(String& outPath) { outPath = String(); @@ -711,6 +708,130 @@ SlangResult NVRTCDownstreamCompiler::_findIncludePath(String& outPath) return SLANG_E_NOT_FOUND; } +SlangResult NVRTCDownstreamCompiler::_getCUDAIncludePath(String& outPath) +{ + if (!m_cudaIncludeSearched) + { + m_cudaIncludeSearched = true; + + SLANG_ASSERT(m_cudaIncludePath.getLength() == 0); + + _findCUDAIncludePath(m_cudaIncludePath); + } + + outPath = m_cudaIncludePath; + return m_cudaIncludePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; +} + +SlangResult NVRTCDownstreamCompiler::_findOptixIncludePath(String& outPath) +{ + outPath = String(); + + List rootPaths; + +#if SLANG_WINDOWS_FAMILY + const char* searchPattern = "OptiX SDK *"; + StringBuilder builder; + if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable( + UnownedStringSlice::fromLiteral("PROGRAMDATA"), + builder))) + { + rootPaths.add(Path::combine(builder, "NVIDIA Corporation")); + } +#else + const char* searchPattern = "NVIDIA-OptiX-SDK-*"; + StringBuilder builder; + if (SLANG_SUCCEEDED( + PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("HOME"), builder))) + { + rootPaths.add(builder); + } +#endif + + struct OptixHeaders + { + String path; + SemanticVersion version; + }; + + // Visitor to find Optix headers. + struct Visitor : public Path::Visitor + { + const String& rootPath; + List& optixPaths; + Visitor(const String& rootPath, List& optixPaths) + : rootPath(rootPath), optixPaths(optixPaths) + { + } + void accept(Path::Type type, const UnownedStringSlice& path) SLANG_OVERRIDE + { + if (type != Path::Type::Directory) + return; + + OptixHeaders optixPath; +#if SLANG_WINDOWS_FAMILY + // Paths are expected to look like ".\OptiX SDK X.X.X" + auto versionString = path.subString(path.lastIndexOf(' ') + 1, path.getLength()); +#else + // Paths are expected to look like "./NVIDIA-OptiX-SDK-X.X.X-suffix" + auto versionString = path.subString(0, path.lastIndexOf('-')); + versionString = + versionString.subString(path.lastIndexOf('-') + 1, versionString.getLength()); +#endif + if (SLANG_SUCCEEDED(SemanticVersion::parse(versionString, '.', optixPath.version))) + { + optixPath.path = Path::combine(Path::combine(rootPath, path), "include"); + String optixHeader = Path::combine(optixPath.path, g_optixHeaderName); + if (File::exists(optixHeader)) + { + optixPaths.add(optixPath); + } + } + } + }; + + List optixPaths; + + for (const String& rootPath : rootPaths) + { + Visitor visitor(rootPath, optixPaths); + Path::find(rootPath, searchPattern, &visitor); + } + + // Find newest version + const OptixHeaders* newest = nullptr; + for (Index i = 0; i < optixPaths.getCount(); ++i) + { + if (!newest || optixPaths[i].version > newest->version) + { + newest = &optixPaths[i]; + } + } + + if (newest) + { + outPath = newest->path; + return SLANG_OK; + } + + return SLANG_E_NOT_FOUND; +} + +SlangResult NVRTCDownstreamCompiler::_getOptixIncludePath(String& outPath) +{ + if (!m_optixIncludeSearched) + { + m_optixIncludeSearched = true; + + SLANG_ASSERT(m_optixIncludePath.getLength() == 0); + + _findOptixIncludePath(m_optixIncludePath); + } + + outPath = m_optixIncludePath; + return m_optixIncludePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; +} + SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( const DownstreamCompileOptions& options, CommandLine& ioCmdLine) @@ -747,7 +868,7 @@ SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( } String includePath; - SLANG_RETURN_ON_FAIL(_getIncludePath(includePath)); + SLANG_RETURN_ON_FAIL(_getCUDAIncludePath(includePath)); // Add the found include path ioCmdLine.addArg("-I"); @@ -758,6 +879,48 @@ SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( return SLANG_OK; } +SlangResult NVRTCDownstreamCompiler::_maybeAddOptixSupport( + const DownstreamCompileOptions& options, + CommandLine& ioCmdLine) +{ + // First check if we know if one of the include paths contains optix.h + for (const auto& includePath : options.includePaths) + { + if (m_optixFoundPaths.indexOf(includePath) >= 0) + { + // Okay we have an include path that we know works. + // Just need to enable OptiX in prelude + ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); + return SLANG_OK; + } + } + + // Let's see if one of the paths finds optix.h + for (const auto& curIncludePath : options.includePaths) + { + const String includePath = asString(curIncludePath); + const String checkPath = Path::combine(includePath, g_optixHeaderName); + if (File::exists(checkPath)) + { + m_optixFoundPaths.add(includePath); + // Just need to enable OptiX in prelude + ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); + return SLANG_OK; + } + } + + String includePath; + SLANG_RETURN_ON_FAIL(_getOptixIncludePath(includePath)); + + // Add the found include path + ioCmdLine.addArg("-I"); + ioCmdLine.addArg(includePath); + + ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); + + return SLANG_OK; +} + SlangResult NVRTCDownstreamCompiler::compile( const DownstreamCompileOptions& inOptions, IArtifact** outArtifact) @@ -780,6 +943,9 @@ SlangResult NVRTCDownstreamCompiler::compile( CommandLine cmdLine; + // --dopt option is only available in CUDA 11.7 and later + bool hasDoptOption = m_desc.version >= SemanticVersion(11, 7); + switch (options.debugInfoType) { case DebugInfoType::None: @@ -789,12 +955,20 @@ SlangResult NVRTCDownstreamCompiler::compile( default: { cmdLine.addArg("--device-debug"); + if (hasDoptOption) + { + cmdLine.addArg("--dopt=on"); + } break; } case DebugInfoType::Maximal: { cmdLine.addArg("--device-debug"); cmdLine.addArg("--generate-line-info"); + if (hasDoptOption) + { + cmdLine.addArg("--dopt=on"); + } break; } } @@ -910,48 +1084,7 @@ SlangResult NVRTCDownstreamCompiler::compile( // if (options.pipelineType == PipelineType::RayTracing) { - // The device-side OptiX API is accessed through a constellation - // of headers provided by the OptiX SDK, so we need to set an - // include path for the compile that makes those visible. - // - // TODO: The OptiX SDK installer doesn't set any kind of environment - // variable to indicate where the SDK was installed, so we seemingly - // need to probe paths instead. The form of the path will differ - // betwene Windows and Unix-y platforms, and we will need some kind - // of approach to probe multiple versiosn and use the latest. - // - // HACK: For now I'm using the fixed path for the most recent SDK - // release on Windows. This means that OptiX cross-compilation will - // only "work" on a subset of platforms, but that doesn't matter - // for now since it doesn't really "work" at all. - // - cmdLine.addArg("-I"); - cmdLine.addArg("C:/ProgramData/NVIDIA Corporation/OptiX SDK 7.0.0/include/"); - - // The OptiX headers in turn `#include ` and expect that - // to work. We could try to also add in an include path from the CUDA - // SDK (which seems to provide a `stddef.h` in the most recent version), - // but using that version doesn't seem to work (and also bakes in a - // requirement that the user have the CUDA SDK installed in addition - // to the OptiX SDK). - // - // Instead, we will rely on the NVRTC feature that lets us set up - // memory buffers to be used as include files by the we compile. - // We will define a dummy `stddef.h` that includes the bare minimum - // lines required to get the OptiX headers to compile without complaint. - // - // TODO: Confirm that the `LP64` definition here is actually needed. - // - headerIncludeNames.add("stddef.h"); - headers.add("#pragma once\n" - "#define LP64\n"); - - // Finally, we want the CUDA prelude to be able to react to whether - // or not OptiX is required (most notably by `#include`ing the appropriate - // header(s)), so we will insert a preprocessor define to indicate - // the requirement. - // - cmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); + SLANG_RETURN_ON_FAIL(_maybeAddOptixSupport(options, cmdLine)); } // Add any compiler specific options diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 7964e26d8d..11c4ab6f45 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -20932,6 +20932,8 @@ struct ConstBufferPointer // new aliased bindings for each distinct cast type. // +//@public: + /// Represent the kind of a descriptor type. enum DescriptorKind { @@ -21048,8 +21050,18 @@ ${{{{ } }}}} -/// Represents a bindless resource handle. A bindless resource handle is always a concrete type and can be +/// Represents a bindless handle to a descriptor. A descriptor handle is always an ordinary data type and can be /// declared in any memory location. +/// @remarks Opaque descriptor types such as textures(`Texture2D` etc.), `SamplerState` and buffers (e.g. `StructuredBuffer`) +/// can have undefined size and data representation on many targets. On platforms such as Vulkan and D3D12, descriptors are +/// communicated to the shader code by calling the host side API to write the descriptor into a descriptor set or table, instead +/// of directly writing bytes into an ordinary GPU accessible buffer. As a result, oapque handle types cannot be used in places +/// that refer to a ordinary buffer location, such as as element types of a `StructuredBuffer`. +/// However, a `DescriptorHandle` stores a handle (or address) to the actual descriptor, and is always an ordinary data type +/// that can be manipulated directly in the shader code. This gives the developer the flexibility to embed and pass around descriptor +/// parameters throughout the code, to enable cleaner modular designs. +/// See [User Guide](https://shader-slang.com/slang/user-guide/convenience-features.html#descriptorhandle-for-bindless-descriptor-access) +/// for more information on how to use `DescriptorHandle` in your code. __magic_type(DescriptorHandleType) __intrinsic_type($(kIROp_DescriptorHandleType)) struct DescriptorHandle : IComparable @@ -21140,6 +21152,8 @@ extern T getDescriptorFromHandle(DescriptorHandle handle __intrinsic_op($(kIROp_NonUniformResourceIndex)) DescriptorHandle nonuniform(DescriptorHandle ptr); +//@hidden: + __glsl_version(450) __glsl_extension(GL_ARB_shader_clock) [require(glsl_spirv, GL_ARB_shader_clock)] diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index f5dd86df15..ee29750a6a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -973,9 +973,14 @@ class GLSLLayoutLocalSizeAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* x; - IntVal* y; - IntVal* z; + IntVal* extents[3]; + + bool axisIsSpecConstId[3]; + + // References to specialization constants, for defining the number of + // threads with them. If set, the corresponding axis is set to nullptr + // above. + DeclRef specConstExtents[3]; }; class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute @@ -1038,9 +1043,12 @@ class NumThreadsAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* x; - IntVal* y; - IntVal* z; + IntVal* extents[3]; + + // References to specialization constants, for defining the number of + // threads with them. If set, the corresponding axis is set to nullptr + // above. + DeclRef specConstExtents[3]; }; class WaveSizeAttribute : public Attribute diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index b3e30dbc23..3ef1e8f3be 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1656,6 +1656,8 @@ struct SemanticsVisitor : public SemanticsContext void visitModifier(Modifier*); + DeclRef tryGetIntSpecializationConstant(Expr* expr); + AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope); bool hasIntArgs(Attribute* attr, int numArgs); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 3723c98f86..6e451b5cf9 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -114,6 +114,36 @@ void SemanticsVisitor::visitModifier(Modifier*) // Do nothing with modifiers for now } +DeclRef SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr) +{ + // First type-check the expression as normal + expr = CheckExpr(expr); + + if (IsErrorExpr(expr)) + return DeclRef(); + + if (!isScalarIntegerType(expr->type)) + return DeclRef(); + + auto specConstVar = as(expr); + if (!specConstVar || !specConstVar->declRef) + return DeclRef(); + + auto decl = specConstVar->declRef.getDecl(); + if (!decl) + return DeclRef(); + + for (auto modifier : decl->modifiers) + { + if (as(modifier) || as(modifier)) + { + return specConstVar->declRef.as(); + } + } + + return DeclRef(); +} + static bool _isDeclAllowedAsAttribute(DeclRef declRef) { if (as(declRef.getDecl())) @@ -350,8 +380,6 @@ Modifier* SemanticsVisitor::validateAttribute( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; - for (int i = 0; i < 3; ++i) { IntVal* value = nullptr; @@ -359,6 +387,14 @@ Modifier* SemanticsVisitor::validateAttribute( auto arg = attr->args[i]; if (arg) { + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) + { + numThreadsAttr->extents[i] = nullptr; + numThreadsAttr->specConstExtents[i] = specConstDecl; + continue; + } + auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { @@ -390,12 +426,8 @@ Modifier* SemanticsVisitor::validateAttribute( { value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; + numThreadsAttr->extents[i] = value; } - - numThreadsAttr->x = values[0]; - numThreadsAttr->y = values[1]; - numThreadsAttr->z = values[2]; } else if (auto waveSizeAttr = as(attr)) { @@ -1831,15 +1863,24 @@ Modifier* SemanticsVisitor::checkModifier( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; + // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl. + auto decl = as(syntaxNode); + SLANG_ASSERT(decl); for (int i = 0; i < 3; ++i) { - IntVal* value = nullptr; + attr->extents[i] = nullptr; auto arg = attr->args[i]; if (arg) { + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) + { + attr->specConstExtents[i] = specConstDecl; + continue; + } + auto intValue = checkConstantIntVal(arg); if (!intValue) { @@ -1847,7 +1888,45 @@ Modifier* SemanticsVisitor::checkModifier( } if (auto cintVal = as(intValue)) { - if (cintVal->getValue() < 1) + if (attr->axisIsSpecConstId[i]) + { + // This integer should actually be a reference to a + // specialization constant with this ID. + Int specConstId = cintVal->getValue(); + + for (auto member : decl->parentDecl->members) + { + auto constantId = member->findModifier(); + if (constantId) + { + SLANG_ASSERT(constantId->args.getCount() == 1); + auto id = checkConstantIntVal(constantId->args[0]); + if (id->getValue() == specConstId) + { + attr->specConstExtents[i] = + DeclRef(member->getDefaultDeclRef()); + break; + } + } + } + + // If not found, we need to create a new specialization + // constant with this ID. + if (!attr->specConstExtents[i]) + { + auto specConstVarDecl = getASTBuilder()->create(); + auto constantIdModifier = + getASTBuilder()->create(); + constantIdModifier->location = (int32_t)specConstId; + specConstVarDecl->type.type = getASTBuilder()->getIntType(); + addModifier(specConstVarDecl, constantIdModifier); + decl->parentDecl->addMember(specConstVarDecl); + attr->specConstExtents[i] = + DeclRef(specConstVarDecl->getDefaultDeclRef()); + } + continue; + } + else if (cintVal->getValue() < 1) { getSink()->diagnose( attr, @@ -1856,18 +1935,13 @@ Modifier* SemanticsVisitor::checkModifier( return nullptr; } } - value = intValue; + attr->extents[i] = intValue; } else { - value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; } - - attr->x = values[0]; - attr->y = values[1]; - attr->z = values[2]; } // Default behavior is to leave things as they are, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 821a895bc7..d86cd8be2a 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2060,7 +2060,7 @@ DIAGNOSTIC( DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") DIAGNOSTIC(41001, Error, recursiveType, "type '$0' contains cyclic reference to itself.") -DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'void' function") +DIAGNOSTIC(41010, Warning, missingReturn, "non-void function does not return in all cases") DIAGNOSTIC( 41011, Error, @@ -2459,6 +2459,12 @@ DIAGNOSTIC( Error, unsupportedTargetIntrinsic, "intrinsic operation '$0' is not supported for the current target.") +DIAGNOSTIC( + 55205, + Error, + unsupportedSpecializationConstantForNumThreads, + "Specialization constants are not supported in the 'numthreads' attribute for the current " + "target.") DIAGNOSTIC( 56001, Error, diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 7b51495e2b..d3a9359ff2 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -295,14 +295,48 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) } -/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( +IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]) +{ + Int specializationConstantIds[kThreadGroupAxisCount]; + IRNumThreadsDecoration* decor = + getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds); + + for (auto id : specializationConstantIds) + { + if (id >= 0) + { + getSink()->diagnose(decor, Diagnostics::unsupportedSpecializationConstantForNumThreads); + break; + } + } + return decor; +} + +/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( + IRFunc* func, + Int outNumThreads[kThreadGroupAxisCount], + Int outSpecializationConstantIds[kThreadGroupAxisCount]) { IRNumThreadsDecoration* decor = func->findDecoration(); - for (int i = 0; i < 3; ++i) + for (int i = 0; i < kThreadGroupAxisCount; ++i) { - outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1; + if (!decor) + { + outNumThreads[i] = 1; + outSpecializationConstantIds[i] = -1; + } + else if (auto specConst = as(decor->getOperand(i))) + { + outNumThreads[i] = 1; + outSpecializationConstantIds[i] = getSpecializationConstantId(specConst); + } + else + { + outNumThreads[i] = Int(getIntVal(decor->getOperand(i))); + outSpecializationConstantIds[i] = -1; + } } return decor; } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index e5080f731b..1354b7cbd8 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -500,11 +500,20 @@ class CLikeSourceEmitter : public SourceEmitterBase /// different. Returns an empty slice if not a built in type static UnownedStringSlice getDefaultBuiltinTypeName(IROp op); - /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1 - static IRNumThreadsDecoration* getComputeThreadGroupSize( + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all + /// dimensions to 1 + IRNumThreadsDecoration* getComputeThreadGroupSize( IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]); + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all + /// dimensions to 1. If specialization constants are used for an axis, their + /// IDs is reported in non-negative entries of outSpecializationConstantIds. + static IRNumThreadsDecoration* getComputeThreadGroupSize( + IRFunc* func, + Int outNumThreads[kThreadGroupAxisCount], + Int outSpecializationConstantIds[kThreadGroupAxisCount]); + /// Finds the IRWaveSizeDecoration and gets the size from that. static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int* outWaveSize); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 23fff37acb..0dab07cfce 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -1335,7 +1335,8 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( auto emitLocalSizeLayout = [&]() { Int sizeAlongAxis[kThreadGroupAxisCount]; - getComputeThreadGroupSize(irFunc, sizeAlongAxis); + Int specializationConstantIds[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, sizeAlongAxis, specializationConstantIds); m_writer->emit("layout("); char const* axes[] = {"x", "y", "z"}; @@ -1345,8 +1346,17 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( m_writer->emit(", "); m_writer->emit("local_size_"); m_writer->emit(axes[ii]); - m_writer->emit(" = "); - m_writer->emit(sizeAlongAxis[ii]); + + if (specializationConstantIds[ii] >= 0) + { + m_writer->emit("_id = "); + m_writer->emit(specializationConstantIds[ii]); + } + else + { + m_writer->emit(" = "); + m_writer->emit(sizeAlongAxis[ii]); + } } m_writer->emit(") in;\n"); }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 068e1563ca..2cf84a8540 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4353,23 +4353,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // [3.6. Execution Mode]: LocalSize case kIROp_NumThreadsDecoration: { - // TODO: The `LocalSize` execution mode option requires - // literal values for the X,Y,Z thread-group sizes. - // There is a `LocalSizeId` variant that takes ``s - // for those sizes, and we should consider using that - // and requiring the appropriate capabilities - // if any of the operands to the decoration are not - // literals (in a future where we support non-literals - // in those positions in the Slang IR). - // auto numThreads = cast(decoration); - requireSPIRVExecutionMode( - decoration, - dstID, - SpvExecutionModeLocalSize, - SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); + if (numThreads->getXSpecConst() || numThreads->getYSpecConst() || + numThreads->getZSpecConst()) + { + // If any of the dimensions needs an ID, we need to emit + // all dimensions as an ID due to how LocalSizeId works. + int32_t ids[3]; + for (int i = 0; i < 3; ++i) + ids[i] = ensureInst(numThreads->getOperand(i))->id; + + // LocalSizeId is supported from SPIR-V 1.2 onwards without + // any extra capabilities. + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSizeId, + SpvLiteralInteger::from32(int32_t(ids[0])), + SpvLiteralInteger::from32(int32_t(ids[1])), + SpvLiteralInteger::from32(int32_t(ids[2]))); + } + else + { + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSize, + SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); + } } break; case kIROp_MaxVertexCountDecoration: @@ -7977,10 +7990,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { if (m_executionModes[entryPoint].add(executionMode)) { + SpvOp execModeOp = SpvOpExecutionMode; + if (executionMode == SpvExecutionModeLocalSizeId || + executionMode == SpvExecutionModeLocalSizeHintId || + executionMode == SpvExecutionModeSubgroupsPerWorkgroupId) + { + execModeOp = SpvOpExecutionModeId; + } + emitInst( getSection(SpvLogicalSectionID::ExecutionModes), parentInst, - SpvOpExecutionMode, + execModeOp, entryPoint, executionMode, ops...); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 65ce69877f..3237ba3b26 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -528,10 +528,12 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF // If primal parameter is mutable, we need to pass in a temp var. auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); - // We also need to setup the initial value of the temp var, otherwise - // the temp var will be uninitialized which could cause undefined behavior - // in the primal function. - builder.emitStore(tempVar, primalArg); + // If the parameter is not a pure 'out' param, we also need to setup the initial + // value of the temp var, otherwise the temp var will be uninitialized which could + // cause undefined behavior in the primal function. + // + if (!as(primalParamType)) + builder.emitStore(tempVar, primalArg); primalArgs.add(tempVar); } diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index 1c833a2948..372ef298e7 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -279,6 +279,16 @@ struct CollectGlobalUniformParametersContext continue; } + // NumThreadsDecoration may sometimes be the user for a global + // parameter. This occurs when the parameter was supposed to be + // a specialization constant, but isn't due to that not being + // supported for the target. These can be skipped here and + // diagnosed later. + if (as(user)) + { + continue; + } + // For each use site for the global parameter, we will // insert new code right before the instruction that uses // the parameter. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a58c2e900c..f46586aa2b 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -570,6 +570,7 @@ struct IRInstanceDecoration : IRDecoration IRIntLit* getCount() { return cast(getOperand(0)); } }; +struct IRGlobalParam; struct IRNumThreadsDecoration : IRDecoration { enum @@ -578,11 +579,13 @@ struct IRNumThreadsDecoration : IRDecoration }; IR_LEAF_ISA(NumThreadsDecoration) - IRIntLit* getX() { return cast(getOperand(0)); } - IRIntLit* getY() { return cast(getOperand(1)); } - IRIntLit* getZ() { return cast(getOperand(2)); } + IRIntLit* getX() { return as(getOperand(0)); } + IRIntLit* getY() { return as(getOperand(1)); } + IRIntLit* getZ() { return as(getOperand(2)); } - IRIntLit* getExtentAlongAxis(int axis) { return cast(getOperand(axis)); } + IRGlobalParam* getXSpecConst() { return as(getOperand(0)); } + IRGlobalParam* getYSpecConst() { return as(getOperand(1)); } + IRGlobalParam* getZSpecConst() { return as(getOperand(2)); } }; struct IRWaveSizeDecoration : IRDecoration diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index e267e8343c..6840197721 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -190,7 +190,7 @@ IRInst* emitCalcGroupExtents(IRBuilder& builder, IRFunc* entryPoint, IRVectorTyp for (int axis = 0; axis < kAxisCount; axis++) { - auto litValue = as(numThreadsDecor->getExtentAlongAxis(axis)); + auto litValue = as(numThreadsDecor->getOperand(axis)); if (!litValue) return nullptr; @@ -1434,6 +1434,20 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize // groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type); + if (!groupExtents) + { + m_sink->diagnose( + m_entryPointFunc, + Diagnostics::unsupportedSpecializationConstantForNumThreads); + + // Fill in placeholder values. + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + for (int axis = 0; axis < kAxisCount; axis++) + groupExtentAlongAxis[axis] = builder.getIntValue(uint3Type->getElementType(), 1); + groupExtents = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); + } + dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents); diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 90d30dcc77..68d79617a8 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -490,11 +490,19 @@ static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst) bool isFalseBranchTrivial = false; if (isTrivialIfElse(ifElseInst, isTrueBranchTrivial, isFalseBranchTrivial)) { - // If both branches of `if-else` are trivial jumps into after block, + // If either branch of `if-else` is a trivial jump into after block, // we can get rid of the entire conditional branch and replace it // with a jump into the after block. - if (auto termInst = as(ifElseInst->getTrueBlock()->getTerminator())) + IRUnconditionalBranch* termInst = + as(ifElseInst->getTrueBlock()->getTerminator()); + if (!termInst || (termInst->getTargetBlock() != ifElseInst->getAfterBlock())) { + termInst = as(ifElseInst->getFalseBlock()->getTerminator()); + } + + if (termInst) + { + SLANG_ASSERT(termInst->getTargetBlock() == ifElseInst->getAfterBlock()); List args; for (UInt i = 0; i < termInst->getArgCount(); i++) args.add(termInst->getArg(i)); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 40cd40758a..a9b0d44121 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -71,6 +71,42 @@ struct SpecializationContext module->getContainerPool().free(&cleanInsts); } + bool isUnsimplifiedArithmeticInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Neg: + case kIROp_Not: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Leq: + case kIROp_Geq: + case kIROp_Less: + case kIROp_IRem: + case kIROp_FRem: + case kIROp_Greater: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_BitNot: + case kIROp_BitCast: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_Select: + return true; + default: + return false; + } + } + // An instruction is then fully specialized if and only // if it is in our set. // @@ -133,6 +169,14 @@ struct SpecializationContext return areAllOperandsFullySpecialized(inst); } + if (isUnsimplifiedArithmeticInst(inst)) + { + // For arithmetic insts, we want to wait for simplification before specialization, + // since different insts can simplify to the same value. + // + return false; + } + // The default case is that a global value is always specialized. if (inst->getParent() == module->getModuleInst()) { @@ -1092,6 +1136,7 @@ struct SpecializationContext { this->changed = true; eliminateDeadCode(module->getModuleInst()); + applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); } // Once the work list has gone dry, we should have the invariant diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index a44e16a7ce..077cdb98d0 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -282,10 +282,11 @@ struct GlobalVarTranslationContext if (!numthreadsDecor) return; builder.setInsertBefore(use->getUser()); - IRInst* values[] = { - numthreadsDecor->getExtentAlongAxis(0), - numthreadsDecor->getExtentAlongAxis(1), - numthreadsDecor->getExtentAlongAxis(2)}; + IRInst* values[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; + auto workgroupSize = builder.emitMakeVector( builder.getVectorType(builder.getIntType(), 3), 3, @@ -328,10 +329,10 @@ struct GlobalVarTranslationContext if (!firstBlock) continue; builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - IRInst* args[] = { - numthreadsDecor->getExtentAlongAxis(0), - numthreadsDecor->getExtentAlongAxis(1), - numthreadsDecor->getExtentAlongAxis(2)}; + IRInst* args[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; auto workgroupSize = builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args); builder.emitStore(globalVar, workgroupSize); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index c753600a7c..d05e1db7d4 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1973,4 +1973,17 @@ IRType* getIRVectorBaseType(IRType* type) return as(type)->getElementType(); } +Int getSpecializationConstantId(IRGlobalParam* param) +{ + auto layout = findVarLayout(param); + if (!layout) + return 0; + + auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant); + if (!offset) + return 0; + + return offset->getOffset(); +} + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index e23aeb6180..666ac71c03 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -373,6 +373,8 @@ inline bool isSPIRV(CodeGenTarget codeGenTarget) int getIRVectorElementSize(IRType* type); IRType* getIRVectorBaseType(IRType* type); +Int getSpecializationConstantId(IRGlobalParam* param); + } // namespace Slang #endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e82fc03fde..0863457198 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7625,12 +7625,29 @@ struct DeclLoweringVisitor : DeclVisitor { verifyComputeDerivativeGroupModifier = true; getAllEntryPointsNoOverride(entryPoints); + + LoweredValInfo extents[3]; + + for (int i = 0; i < 3; ++i) + { + extents[i] = layoutLocalSizeAttr->specConstExtents[i] + ? emitDeclRef( + context, + layoutLocalSizeAttr->specConstExtents[i], + lowerType( + context, + getType( + context->astBuilder, + layoutLocalSizeAttr->specConstExtents[i]))) + : lowerVal(context, layoutLocalSizeAttr->extents[i]); + } + for (auto d : entryPoints) as(getBuilder()->addNumThreadsDecoration( d, - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->x)), - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->y)), - getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->z)))); + getSimpleVal(context, extents[0]), + getSimpleVal(context, extents[1]), + getSimpleVal(context, extents[2]))); } else if (as(modifier)) { @@ -10336,11 +10353,28 @@ struct DeclLoweringVisitor : DeclVisitor } else if (auto numThreadsAttr = as(modifier)) { + LoweredValInfo extents[3]; + + for (int i = 0; i < 3; ++i) + { + extents[i] = numThreadsAttr->specConstExtents[i] + ? emitDeclRef( + context, + numThreadsAttr->specConstExtents[i], + lowerType( + context, + getType( + context->astBuilder, + numThreadsAttr->specConstExtents[i]))) + : lowerVal(context, numThreadsAttr->extents[i]); + } + numThreadsDecor = as(getBuilder()->addNumThreadsDecoration( irFunc, - getSimpleVal(context, lowerVal(context, numThreadsAttr->x)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->y)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->z)))); + getSimpleVal(context, extents[0]), + getSimpleVal(context, extents[1]), + getSimpleVal(context, extents[2]))); + numThreadsDecor->sourceLoc = numThreadsAttr->loc; } else if (auto waveSizeAttr = as(modifier)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index c275a868b5..6ae41a2eb9 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8437,7 +8437,9 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) int localSizeIndex = -1; if (nameText.startsWith(localSizePrefix) && - nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1) + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1 || + (nameText.endsWith("_id") && + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4)))) { char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1]; localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1; @@ -8451,6 +8453,8 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) numThreadsAttrib->args.setCount(3); for (auto& i : numThreadsAttrib->args) i = nullptr; + for (auto& b : numThreadsAttrib->axisIsSpecConstId) + b = false; // Just mark the loc and name from the first in the list numThreadsAttrib->keywordName = getName(parser, "numthreads"); @@ -8467,6 +8471,11 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) } numThreadsAttrib->args[localSizeIndex] = expr; + + // We can't resolve the specialization constant declaration + // here, because it may not even exist. IDs pointing to unnamed + // specialization constants are allowed in GLSL. + numThreadsAttrib->axisIsSpecConstId[localSizeIndex] = nameText.endsWith("_id"); } } else if (nameText == "derivative_group_quadsNV") diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index d235c82703..d1adfedc0b 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -4033,18 +4033,14 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier(); if (numThreadsAttribute) { - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x)) - sizeAlongAxis[0] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->x) - sizeAlongAxis[0] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y)) - sizeAlongAxis[1] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->y) - sizeAlongAxis[1] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z)) - sizeAlongAxis[2] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->z) - sizeAlongAxis[2] = 0; + for (int i = 0; i < 3; ++i) + { + if (auto cint = + entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->extents[i])) + sizeAlongAxis[i] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->extents[i]) + sizeAlongAxis[i] = 0; + } } // diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 5ec1996581..efc1c6fd11 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1493,7 +1493,8 @@ DeclRef Linkage::specializeWithArgTypes( DiagnosticSink* sink) { SemanticsVisitor visitor(getSemanticsForReflection()); - visitor = visitor.withSink(sink); + SemanticsVisitor::ExprLocalScope scope; + visitor = visitor.withSink(sink).withExprLocalScope(&scope); SLANG_AST_BUILDER_RAII(getASTBuilder()); diff --git a/tests/autodiff/out-parameters-2.slang b/tests/autodiff/out-parameters-2.slang new file mode 100644 index 0000000000..b4c4b07c61 --- /dev/null +++ b/tests/autodiff/out-parameters-2.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typedef DifferentialPair dpfloat; + +struct Foo : IDifferentiable +{ + float a; + int b; +} + +[PreferCheckpoint] +float k() +{ + return outputBuffer[3] + 1; +} + +[Differentiable] +void h(float x, float y, out Foo result) +{ + float p = no_diff k(); + float m = x + y + p; + float n = x - y; + float r = m * n + 2 * x * y; + + result = {r, 2}; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float x = 2.0; + float y = 3.5; + float dx = 1.0; + float dy = 0.5; + + dpfloat dresult; + dpfloat dpx = diffPair(x); + dpfloat dpy = diffPair(y); + Foo.Differential dFoo; + dFoo.a = 1.0; + bwd_diff(h)(dpx, dpy, dFoo); + + outputBuffer[0] = dpx.d; // CHECK: 12.0 + outputBuffer[1] = dpy.d; // CHECK: -4.0 +} \ No newline at end of file diff --git a/tests/bugs/simplify-if-else.slang b/tests/bugs/simplify-if-else.slang new file mode 100644 index 0000000000..8719a15995 --- /dev/null +++ b/tests/bugs/simplify-if-else.slang @@ -0,0 +1,26 @@ +//TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target hlsl +//CHECK: computeMain + +//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + vector vvv = vector(0); + float32_t ret = 0.0f; + if (vvv.y < 1.0f) + { + ret = 1.0f; + } + else + { + if (vvv.y > 1.0f && outputBuffer[3] == 3) + { + ret = 0.0f; + } else { + if (true) {} + } + } + outputBuffer[int(dispatchThreadID.x)] = int(ret); +} diff --git a/tests/diagnostics/missing-return.slang.expected b/tests/diagnostics/missing-return.slang.expected index e41e756ff4..7626665241 100644 --- a/tests/diagnostics/missing-return.slang.expected +++ b/tests/diagnostics/missing-return.slang.expected @@ -1,9 +1,9 @@ result code = 0 standard error = { -tests/diagnostics/missing-return.slang(7): warning 41010: control flow may reach end of non-'void' function +tests/diagnostics/missing-return.slang(7): warning 41010: non-void function does not return in all cases int bad(int a, int b) ^~~ -tests/diagnostics/missing-return.slang(14): warning 41010: control flow may reach end of non-'void' function +tests/diagnostics/missing-return.slang(14): warning 41010: non-void function does not return in all cases int alsoBad(int a, int b) ^~~~~~~ } diff --git a/tests/glsl/compute-shader-layout-id.slang b/tests/glsl/compute-shader-layout-id.slang new file mode 100644 index 0000000000..bee8137d82 --- /dev/null +++ b/tests/glsl/compute-shader-layout-id.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry main -allow-glsl +#version 450 + +[vk::constant_id(1)] +const int constValue1 = 0; + +[vk::constant_id(2)] +const int constValue3 = 5; + +// CHECK-DAG: OpExecutionModeId %main LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] +// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 +// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 +// CHECK-DAG: OpDecorate %[[C2]] SpecId 2 + +layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = constValue3) in; +void main() +{ +} + diff --git a/tests/spirv/spec-constant-numthreads.slang b/tests/spirv/spec-constant-numthreads.slang new file mode 100644 index 0000000000..5c133219cf --- /dev/null +++ b/tests/spirv/spec-constant-numthreads.slang @@ -0,0 +1,35 @@ +//TEST:SIMPLE(filecheck=GLSL): -target glsl -allow-glsl +//TEST:SIMPLE(filecheck=GLSL): -target glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv -allow-glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv + +// CHECK-DAG: OpExecutionModeId %computeMain1 LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] +// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 +// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 +// CHECK-DAG: %[[C2]] = OpConstant %int 4 +// CHECK-DAG: OpStore %{{.*}} %[[C0]] +// CHECK-DAG: OpStore %{{.*}} %[[C1]] +// CHECK-DAG: OpStore %{{.*}} %[[C2]] + +// GLSL-DAG: layout(constant_id = 1) +// GLSL-DAG: int constValue0_0 = 0; +// GLSL-DAG: layout(constant_id = 0) +// GLSL-DAG: int constValue1_0 = 0; +// GLSL-DAG: layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = 4) in; + +[vk::specialization_constant] +const int constValue0 = 0; + +[vk::constant_id(0)] +const int constValue1 = 0; + +RWStructuredBuffer outputBuffer; + +[numthreads(constValue0, constValue1, 4)] +void computeMain1() +{ + int3 size = WorkgroupSize(); + outputBuffer[0] = size.x; + outputBuffer[1] = size.y; + outputBuffer[2] = size.z; +} From 3b9e6f53cee2e6076ac2b7a0d015a1ed2cbbd627 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 14 Jan 2025 20:22:36 -0500 Subject: [PATCH 15/18] apply metal spec const thread count changes --- .../slang-ir-legalize-varying-params.cpp | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 6840197721..4858d6e31c 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3378,12 +3378,25 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext IRBuilder svBuilder(builder.getModule()); svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); - auto computeExtent = emitCalcGroupExtents( - svBuilder, - entryPoint.entryPointFunc, - builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 3))); + auto uint3Type = builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 3)); + auto computeExtent = + emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type); + if (!computeExtent) + { + m_sink->diagnose( + entryPoint.entryPointFunc, + Diagnostics::unsupportedSpecializationConstantForNumThreads); + + // Fill in placeholder values. + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + for (int axis = 0; axis < kAxisCount; axis++) + groupExtentAlongAxis[axis] = + builder.getIntValue(uint3Type->getElementType(), 1); + computeExtent = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); + } auto groupIndexCalc = emitCalcGroupIndex( svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc], From 6bf9d4e501a73edb765cb1cc785d586016ccb3d4 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 14 Jan 2025 20:29:04 -0500 Subject: [PATCH 16/18] Revert "apply metal spec const thread count changes" This reverts commit 3b9e6f53cee2e6076ac2b7a0d015a1ed2cbbd627. --- .../slang-ir-legalize-varying-params.cpp | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 4858d6e31c..6840197721 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3378,25 +3378,12 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext IRBuilder svBuilder(builder.getModule()); svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); - auto uint3Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 3)); - auto computeExtent = - emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type); - if (!computeExtent) - { - m_sink->diagnose( - entryPoint.entryPointFunc, - Diagnostics::unsupportedSpecializationConstantForNumThreads); - - // Fill in placeholder values. - static const int kAxisCount = 3; - IRInst* groupExtentAlongAxis[kAxisCount] = {}; - for (int axis = 0; axis < kAxisCount; axis++) - groupExtentAlongAxis[axis] = - builder.getIntValue(uint3Type->getElementType(), 1); - computeExtent = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); - } + auto computeExtent = emitCalcGroupExtents( + svBuilder, + entryPoint.entryPointFunc, + builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 3))); auto groupIndexCalc = emitCalcGroupIndex( svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc], From 3817bb98c5571a98f4a8179bb6135418240e8811 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 14 Jan 2025 20:29:13 -0500 Subject: [PATCH 17/18] Revert "Merge remote-tracking branch 'origin/master'" This reverts commit 99869d573a46dadeb24445405f5a1e37a8e03d0d. --- cmake/SlangTarget.cmake | 12 +- docs/cuda-target.md | 11 - docs/user-guide/03-convenience-features.md | 8 +- external/slang-rhi | 2 +- source/compiler-core/slang-nvrtc-compiler.cpp | 257 +++++------------- source/slang/hlsl.meta.slang | 16 +- source/slang/slang-ast-modifier.h | 20 +- source/slang/slang-check-impl.h | 2 - source/slang/slang-check-modifier.cpp | 108 ++------ source/slang/slang-diagnostic-defs.h | 8 +- source/slang/slang-emit-c-like.cpp | 40 +-- source/slang/slang-emit-c-like.h | 13 +- source/slang/slang-emit-glsl.cpp | 16 +- source/slang/slang-emit-spirv.cpp | 55 ++-- source/slang/slang-ir-autodiff-rev.cpp | 10 +- .../slang-ir-collect-global-uniforms.cpp | 10 - source/slang/slang-ir-insts.h | 11 +- .../slang-ir-legalize-varying-params.cpp | 16 +- source/slang/slang-ir-simplify-cfg.cpp | 12 +- source/slang/slang-ir-specialize.cpp | 45 --- .../slang-ir-translate-glsl-global-var.cpp | 17 +- source/slang/slang-ir-util.cpp | 13 - source/slang/slang-ir-util.h | 2 - source/slang/slang-lower-to-ir.cpp | 46 +--- source/slang/slang-parser.cpp | 11 +- source/slang/slang-reflection-api.cpp | 20 +- source/slang/slang.cpp | 3 +- tests/autodiff/out-parameters-2.slang | 49 ---- tests/bugs/simplify-if-else.slang | 26 -- .../diagnostics/missing-return.slang.expected | 4 +- tests/glsl/compute-shader-layout-id.slang | 19 -- tests/spirv/spec-constant-numthreads.slang | 35 --- 32 files changed, 162 insertions(+), 755 deletions(-) delete mode 100644 tests/autodiff/out-parameters-2.slang delete mode 100644 tests/bugs/simplify-if-else.slang delete mode 100644 tests/glsl/compute-shader-layout-id.slang delete mode 100644 tests/spirv/spec-constant-numthreads.slang diff --git a/cmake/SlangTarget.cmake b/cmake/SlangTarget.cmake index eae5cf35e4..45e7cf1e1d 100644 --- a/cmake/SlangTarget.cmake +++ b/cmake/SlangTarget.cmake @@ -505,14 +505,10 @@ function(slang_add_target dir type) endif() install( TARGETS ${target} ${export_args} - ARCHIVE DESTINATION ${archive_subdir} - ${ARGN} - LIBRARY DESTINATION ${library_subdir} - ${ARGN} - RUNTIME DESTINATION ${runtime_subdir} - ${ARGN} - PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} - ${ARGN} + ARCHIVE DESTINATION ${archive_subdir} ${ARGN} + LIBRARY DESTINATION ${library_subdir} ${ARGN} + RUNTIME DESTINATION ${runtime_subdir} ${ARGN} + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ${ARGN} ) endmacro() diff --git a/docs/cuda-target.md b/docs/cuda-target.md index 241f253fbe..a80dc59f9c 100644 --- a/docs/cuda-target.md +++ b/docs/cuda-target.md @@ -301,17 +301,6 @@ There is potential to calculate the lane id using the [numthreads] markup in Sla * Intrinsics which only work in pixel shaders + QuadXXXX intrinsics -OptiX Support -============= - -Slang supports OptiX for raytracing. To compile raytracing programs, NVRTC must have access to the `optix.h` and dependent files that are typically distributed as part of the OptiX SDK. When Slang detects the use of raytracing in source, it will define `SLANG_CUDA_ENABLE_OPTIX` when `slang-cuda-prelude.h` is included. This will in turn try to include `optix.h`. - -Slang tries several mechanisms to locate `optix.h` when NVRTC is initiated. The first mechanism is to look in the include paths that are passed to Slang. If `optix.h` can be found in one of these paths, no more searching will be performed. - -If this fails, the default OptiX SDK install locations are searched. On Windows this is `%{PROGRAMDATA}\NVIDIA Corporation\OptiX SDK X.X.X\include`. On Linux this is `${HOME}/NVIDIA-OptiX-SDK-X.X.X-suffix`. - -If OptiX headers cannot be found, compilation will fail. - Limitations =========== diff --git a/docs/user-guide/03-convenience-features.md b/docs/user-guide/03-convenience-features.md index 29e8fd2aaa..e6b337eed1 100644 --- a/docs/user-guide/03-convenience-features.md +++ b/docs/user-guide/03-convenience-features.md @@ -149,7 +149,7 @@ int rs = foo.staticMethod(a,b); ### Mutability of member function -For GPU performance considerations, the `this` argument in a member function is immutable by default. Attempting to modify `this` will result in a compile error. If you intend to define a member function that mutates the object, use `[mutating]` attribute on the member function as shown in the following example. +For GPU performance considerations, the `this` argument in a member function is immutable by default. If you modify the content in `this` argument, the modification will be discarded after the call and does not affect the input object. If you intend to define a member function that mutates the object, use `[mutating]` attribute on the member function as shown in the following example. ```hlsl struct Foo @@ -159,14 +159,14 @@ struct Foo [mutating] void setCount(int x) { count = x; } - // This would fail to compile. - // void setCount2(int x) { count = x; } + void setCount2(int x) { count = x; } } void test() { Foo f; - f.setCount(1); // Compiles + f.setCount(1); // f.count is 1 after the call. + f.setCount2(2); // f.count is still 1 after the call. } ``` diff --git a/external/slang-rhi b/external/slang-rhi index d1f2718165..19bc575bc1 160000 --- a/external/slang-rhi +++ b/external/slang-rhi @@ -1 +1 @@ -Subproject commit d1f2718165d0d540c8fc1eacf20b9edd2d6faac0 +Subproject commit 19bc575bc193e92210649d6d84ac202b199b29af diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp index 0042ad7085..c5ccc8e23a 100644 --- a/source/compiler-core/slang-nvrtc-compiler.cpp +++ b/source/compiler-core/slang-nvrtc-compiler.cpp @@ -127,14 +127,11 @@ class NVRTCDownstreamCompiler : public DownstreamCompilerBase nvrtcProgram m_program; }; - SlangResult _findCUDAIncludePath(String& outIncludePath); - SlangResult _getCUDAIncludePath(String& outIncludePath); + SlangResult _findIncludePath(String& outIncludePath); - SlangResult _findOptixIncludePath(String& outIncludePath); - SlangResult _getOptixIncludePath(String& outIncludePath); + SlangResult _getIncludePath(String& outIncludePath); SlangResult _maybeAddHalfSupport(const CompileOptions& options, CommandLine& ioCmdLine); - SlangResult _maybeAddOptixSupport(const CompileOptions& options, CommandLine& ioCmdLine); #define SLANG_NVTRC_MEMBER_FUNCS(ret, name, params) ret(*m_##name) params; @@ -143,16 +140,9 @@ class NVRTCDownstreamCompiler : public DownstreamCompilerBase // Holds list of paths passed in where cuda_fp16.h is found. Does *NOT* include cuda_fp16.h. List m_cudaFp16FoundPaths; - bool m_cudaIncludeSearched = false; + bool m_includeSearched = false; // Holds location of where include (for cuda_fp16.h) is found. - String m_cudaIncludePath; - - // Holds list of paths passed in where optix.h is found. Does *NOT* include optix.h. - List m_optixFoundPaths; - - bool m_optixIncludeSearched = false; - // Holds location of where include (for optix.h) is found. - String m_optixIncludePath; + String m_includePath; ComPtr m_sharedLibrary; }; @@ -612,8 +602,21 @@ static SlangResult _findNVRTC(NVRTCPathVisitor& visitor) } static const UnownedStringSlice g_fp16HeaderName = UnownedStringSlice::fromLiteral("cuda_fp16.h"); -static const UnownedStringSlice g_optixHeaderName = UnownedStringSlice::fromLiteral("optix.h"); +SlangResult NVRTCDownstreamCompiler::_getIncludePath(String& outPath) +{ + if (!m_includeSearched) + { + m_includeSearched = true; + + SLANG_ASSERT(m_includePath.getLength() == 0); + + _findIncludePath(m_includePath); + } + + outPath = m_includePath; + return m_includePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; +} SlangResult _findFileInIncludePath( const String& path, @@ -647,7 +650,7 @@ SlangResult _findFileInIncludePath( return SLANG_E_NOT_FOUND; } -SlangResult NVRTCDownstreamCompiler::_findCUDAIncludePath(String& outPath) +SlangResult NVRTCDownstreamCompiler::_findIncludePath(String& outPath) { outPath = String(); @@ -708,130 +711,6 @@ SlangResult NVRTCDownstreamCompiler::_findCUDAIncludePath(String& outPath) return SLANG_E_NOT_FOUND; } -SlangResult NVRTCDownstreamCompiler::_getCUDAIncludePath(String& outPath) -{ - if (!m_cudaIncludeSearched) - { - m_cudaIncludeSearched = true; - - SLANG_ASSERT(m_cudaIncludePath.getLength() == 0); - - _findCUDAIncludePath(m_cudaIncludePath); - } - - outPath = m_cudaIncludePath; - return m_cudaIncludePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; -} - -SlangResult NVRTCDownstreamCompiler::_findOptixIncludePath(String& outPath) -{ - outPath = String(); - - List rootPaths; - -#if SLANG_WINDOWS_FAMILY - const char* searchPattern = "OptiX SDK *"; - StringBuilder builder; - if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable( - UnownedStringSlice::fromLiteral("PROGRAMDATA"), - builder))) - { - rootPaths.add(Path::combine(builder, "NVIDIA Corporation")); - } -#else - const char* searchPattern = "NVIDIA-OptiX-SDK-*"; - StringBuilder builder; - if (SLANG_SUCCEEDED( - PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("HOME"), builder))) - { - rootPaths.add(builder); - } -#endif - - struct OptixHeaders - { - String path; - SemanticVersion version; - }; - - // Visitor to find Optix headers. - struct Visitor : public Path::Visitor - { - const String& rootPath; - List& optixPaths; - Visitor(const String& rootPath, List& optixPaths) - : rootPath(rootPath), optixPaths(optixPaths) - { - } - void accept(Path::Type type, const UnownedStringSlice& path) SLANG_OVERRIDE - { - if (type != Path::Type::Directory) - return; - - OptixHeaders optixPath; -#if SLANG_WINDOWS_FAMILY - // Paths are expected to look like ".\OptiX SDK X.X.X" - auto versionString = path.subString(path.lastIndexOf(' ') + 1, path.getLength()); -#else - // Paths are expected to look like "./NVIDIA-OptiX-SDK-X.X.X-suffix" - auto versionString = path.subString(0, path.lastIndexOf('-')); - versionString = - versionString.subString(path.lastIndexOf('-') + 1, versionString.getLength()); -#endif - if (SLANG_SUCCEEDED(SemanticVersion::parse(versionString, '.', optixPath.version))) - { - optixPath.path = Path::combine(Path::combine(rootPath, path), "include"); - String optixHeader = Path::combine(optixPath.path, g_optixHeaderName); - if (File::exists(optixHeader)) - { - optixPaths.add(optixPath); - } - } - } - }; - - List optixPaths; - - for (const String& rootPath : rootPaths) - { - Visitor visitor(rootPath, optixPaths); - Path::find(rootPath, searchPattern, &visitor); - } - - // Find newest version - const OptixHeaders* newest = nullptr; - for (Index i = 0; i < optixPaths.getCount(); ++i) - { - if (!newest || optixPaths[i].version > newest->version) - { - newest = &optixPaths[i]; - } - } - - if (newest) - { - outPath = newest->path; - return SLANG_OK; - } - - return SLANG_E_NOT_FOUND; -} - -SlangResult NVRTCDownstreamCompiler::_getOptixIncludePath(String& outPath) -{ - if (!m_optixIncludeSearched) - { - m_optixIncludeSearched = true; - - SLANG_ASSERT(m_optixIncludePath.getLength() == 0); - - _findOptixIncludePath(m_optixIncludePath); - } - - outPath = m_optixIncludePath; - return m_optixIncludePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND; -} - SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( const DownstreamCompileOptions& options, CommandLine& ioCmdLine) @@ -868,7 +747,7 @@ SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( } String includePath; - SLANG_RETURN_ON_FAIL(_getCUDAIncludePath(includePath)); + SLANG_RETURN_ON_FAIL(_getIncludePath(includePath)); // Add the found include path ioCmdLine.addArg("-I"); @@ -879,48 +758,6 @@ SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport( return SLANG_OK; } -SlangResult NVRTCDownstreamCompiler::_maybeAddOptixSupport( - const DownstreamCompileOptions& options, - CommandLine& ioCmdLine) -{ - // First check if we know if one of the include paths contains optix.h - for (const auto& includePath : options.includePaths) - { - if (m_optixFoundPaths.indexOf(includePath) >= 0) - { - // Okay we have an include path that we know works. - // Just need to enable OptiX in prelude - ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); - return SLANG_OK; - } - } - - // Let's see if one of the paths finds optix.h - for (const auto& curIncludePath : options.includePaths) - { - const String includePath = asString(curIncludePath); - const String checkPath = Path::combine(includePath, g_optixHeaderName); - if (File::exists(checkPath)) - { - m_optixFoundPaths.add(includePath); - // Just need to enable OptiX in prelude - ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); - return SLANG_OK; - } - } - - String includePath; - SLANG_RETURN_ON_FAIL(_getOptixIncludePath(includePath)); - - // Add the found include path - ioCmdLine.addArg("-I"); - ioCmdLine.addArg(includePath); - - ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); - - return SLANG_OK; -} - SlangResult NVRTCDownstreamCompiler::compile( const DownstreamCompileOptions& inOptions, IArtifact** outArtifact) @@ -943,9 +780,6 @@ SlangResult NVRTCDownstreamCompiler::compile( CommandLine cmdLine; - // --dopt option is only available in CUDA 11.7 and later - bool hasDoptOption = m_desc.version >= SemanticVersion(11, 7); - switch (options.debugInfoType) { case DebugInfoType::None: @@ -955,20 +789,12 @@ SlangResult NVRTCDownstreamCompiler::compile( default: { cmdLine.addArg("--device-debug"); - if (hasDoptOption) - { - cmdLine.addArg("--dopt=on"); - } break; } case DebugInfoType::Maximal: { cmdLine.addArg("--device-debug"); cmdLine.addArg("--generate-line-info"); - if (hasDoptOption) - { - cmdLine.addArg("--dopt=on"); - } break; } } @@ -1084,7 +910,48 @@ SlangResult NVRTCDownstreamCompiler::compile( // if (options.pipelineType == PipelineType::RayTracing) { - SLANG_RETURN_ON_FAIL(_maybeAddOptixSupport(options, cmdLine)); + // The device-side OptiX API is accessed through a constellation + // of headers provided by the OptiX SDK, so we need to set an + // include path for the compile that makes those visible. + // + // TODO: The OptiX SDK installer doesn't set any kind of environment + // variable to indicate where the SDK was installed, so we seemingly + // need to probe paths instead. The form of the path will differ + // betwene Windows and Unix-y platforms, and we will need some kind + // of approach to probe multiple versiosn and use the latest. + // + // HACK: For now I'm using the fixed path for the most recent SDK + // release on Windows. This means that OptiX cross-compilation will + // only "work" on a subset of platforms, but that doesn't matter + // for now since it doesn't really "work" at all. + // + cmdLine.addArg("-I"); + cmdLine.addArg("C:/ProgramData/NVIDIA Corporation/OptiX SDK 7.0.0/include/"); + + // The OptiX headers in turn `#include ` and expect that + // to work. We could try to also add in an include path from the CUDA + // SDK (which seems to provide a `stddef.h` in the most recent version), + // but using that version doesn't seem to work (and also bakes in a + // requirement that the user have the CUDA SDK installed in addition + // to the OptiX SDK). + // + // Instead, we will rely on the NVRTC feature that lets us set up + // memory buffers to be used as include files by the we compile. + // We will define a dummy `stddef.h` that includes the bare minimum + // lines required to get the OptiX headers to compile without complaint. + // + // TODO: Confirm that the `LP64` definition here is actually needed. + // + headerIncludeNames.add("stddef.h"); + headers.add("#pragma once\n" + "#define LP64\n"); + + // Finally, we want the CUDA prelude to be able to react to whether + // or not OptiX is required (most notably by `#include`ing the appropriate + // header(s)), so we will insert a preprocessor define to indicate + // the requirement. + // + cmdLine.addArg("-DSLANG_CUDA_ENABLE_OPTIX"); } // Add any compiler specific options diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 11c4ab6f45..7964e26d8d 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -20932,8 +20932,6 @@ struct ConstBufferPointer // new aliased bindings for each distinct cast type. // -//@public: - /// Represent the kind of a descriptor type. enum DescriptorKind { @@ -21050,18 +21048,8 @@ ${{{{ } }}}} -/// Represents a bindless handle to a descriptor. A descriptor handle is always an ordinary data type and can be +/// Represents a bindless resource handle. A bindless resource handle is always a concrete type and can be /// declared in any memory location. -/// @remarks Opaque descriptor types such as textures(`Texture2D` etc.), `SamplerState` and buffers (e.g. `StructuredBuffer`) -/// can have undefined size and data representation on many targets. On platforms such as Vulkan and D3D12, descriptors are -/// communicated to the shader code by calling the host side API to write the descriptor into a descriptor set or table, instead -/// of directly writing bytes into an ordinary GPU accessible buffer. As a result, oapque handle types cannot be used in places -/// that refer to a ordinary buffer location, such as as element types of a `StructuredBuffer`. -/// However, a `DescriptorHandle` stores a handle (or address) to the actual descriptor, and is always an ordinary data type -/// that can be manipulated directly in the shader code. This gives the developer the flexibility to embed and pass around descriptor -/// parameters throughout the code, to enable cleaner modular designs. -/// See [User Guide](https://shader-slang.com/slang/user-guide/convenience-features.html#descriptorhandle-for-bindless-descriptor-access) -/// for more information on how to use `DescriptorHandle` in your code. __magic_type(DescriptorHandleType) __intrinsic_type($(kIROp_DescriptorHandleType)) struct DescriptorHandle : IComparable @@ -21152,8 +21140,6 @@ extern T getDescriptorFromHandle(DescriptorHandle handle __intrinsic_op($(kIROp_NonUniformResourceIndex)) DescriptorHandle nonuniform(DescriptorHandle ptr); -//@hidden: - __glsl_version(450) __glsl_extension(GL_ARB_shader_clock) [require(glsl_spirv, GL_ARB_shader_clock)] diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index ee29750a6a..f5dd86df15 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -973,14 +973,9 @@ class GLSLLayoutLocalSizeAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* extents[3]; - - bool axisIsSpecConstId[3]; - - // References to specialization constants, for defining the number of - // threads with them. If set, the corresponding axis is set to nullptr - // above. - DeclRef specConstExtents[3]; + IntVal* x; + IntVal* y; + IntVal* z; }; class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute @@ -1043,12 +1038,9 @@ class NumThreadsAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - IntVal* extents[3]; - - // References to specialization constants, for defining the number of - // threads with them. If set, the corresponding axis is set to nullptr - // above. - DeclRef specConstExtents[3]; + IntVal* x; + IntVal* y; + IntVal* z; }; class WaveSizeAttribute : public Attribute diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 3ef1e8f3be..b3e30dbc23 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1656,8 +1656,6 @@ struct SemanticsVisitor : public SemanticsContext void visitModifier(Modifier*); - DeclRef tryGetIntSpecializationConstant(Expr* expr); - AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope); bool hasIntArgs(Attribute* attr, int numArgs); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 6e451b5cf9..3723c98f86 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -114,36 +114,6 @@ void SemanticsVisitor::visitModifier(Modifier*) // Do nothing with modifiers for now } -DeclRef SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr) -{ - // First type-check the expression as normal - expr = CheckExpr(expr); - - if (IsErrorExpr(expr)) - return DeclRef(); - - if (!isScalarIntegerType(expr->type)) - return DeclRef(); - - auto specConstVar = as(expr); - if (!specConstVar || !specConstVar->declRef) - return DeclRef(); - - auto decl = specConstVar->declRef.getDecl(); - if (!decl) - return DeclRef(); - - for (auto modifier : decl->modifiers) - { - if (as(modifier) || as(modifier)) - { - return specConstVar->declRef.as(); - } - } - - return DeclRef(); -} - static bool _isDeclAllowedAsAttribute(DeclRef declRef) { if (as(declRef.getDecl())) @@ -380,6 +350,8 @@ Modifier* SemanticsVisitor::validateAttribute( { SLANG_ASSERT(attr->args.getCount() == 3); + IntVal* values[3]; + for (int i = 0; i < 3; ++i) { IntVal* value = nullptr; @@ -387,14 +359,6 @@ Modifier* SemanticsVisitor::validateAttribute( auto arg = attr->args[i]; if (arg) { - auto specConstDecl = tryGetIntSpecializationConstant(arg); - if (specConstDecl) - { - numThreadsAttr->extents[i] = nullptr; - numThreadsAttr->specConstExtents[i] = specConstDecl; - continue; - } - auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { @@ -426,8 +390,12 @@ Modifier* SemanticsVisitor::validateAttribute( { value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - numThreadsAttr->extents[i] = value; + values[i] = value; } + + numThreadsAttr->x = values[0]; + numThreadsAttr->y = values[1]; + numThreadsAttr->z = values[2]; } else if (auto waveSizeAttr = as(attr)) { @@ -1863,24 +1831,15 @@ Modifier* SemanticsVisitor::checkModifier( { SLANG_ASSERT(attr->args.getCount() == 3); - // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl. - auto decl = as(syntaxNode); - SLANG_ASSERT(decl); + IntVal* values[3]; for (int i = 0; i < 3; ++i) { - attr->extents[i] = nullptr; + IntVal* value = nullptr; auto arg = attr->args[i]; if (arg) { - auto specConstDecl = tryGetIntSpecializationConstant(arg); - if (specConstDecl) - { - attr->specConstExtents[i] = specConstDecl; - continue; - } - auto intValue = checkConstantIntVal(arg); if (!intValue) { @@ -1888,45 +1847,7 @@ Modifier* SemanticsVisitor::checkModifier( } if (auto cintVal = as(intValue)) { - if (attr->axisIsSpecConstId[i]) - { - // This integer should actually be a reference to a - // specialization constant with this ID. - Int specConstId = cintVal->getValue(); - - for (auto member : decl->parentDecl->members) - { - auto constantId = member->findModifier(); - if (constantId) - { - SLANG_ASSERT(constantId->args.getCount() == 1); - auto id = checkConstantIntVal(constantId->args[0]); - if (id->getValue() == specConstId) - { - attr->specConstExtents[i] = - DeclRef(member->getDefaultDeclRef()); - break; - } - } - } - - // If not found, we need to create a new specialization - // constant with this ID. - if (!attr->specConstExtents[i]) - { - auto specConstVarDecl = getASTBuilder()->create(); - auto constantIdModifier = - getASTBuilder()->create(); - constantIdModifier->location = (int32_t)specConstId; - specConstVarDecl->type.type = getASTBuilder()->getIntType(); - addModifier(specConstVarDecl, constantIdModifier); - decl->parentDecl->addMember(specConstVarDecl); - attr->specConstExtents[i] = - DeclRef(specConstVarDecl->getDefaultDeclRef()); - } - continue; - } - else if (cintVal->getValue() < 1) + if (cintVal->getValue() < 1) { getSink()->diagnose( attr, @@ -1935,13 +1856,18 @@ Modifier* SemanticsVisitor::checkModifier( return nullptr; } } - attr->extents[i] = intValue; + value = intValue; } else { - attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } + values[i] = value; } + + attr->x = values[0]; + attr->y = values[1]; + attr->z = values[2]; } // Default behavior is to leave things as they are, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d86cd8be2a..821a895bc7 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2060,7 +2060,7 @@ DIAGNOSTIC( DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") DIAGNOSTIC(41001, Error, recursiveType, "type '$0' contains cyclic reference to itself.") -DIAGNOSTIC(41010, Warning, missingReturn, "non-void function does not return in all cases") +DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'void' function") DIAGNOSTIC( 41011, Error, @@ -2459,12 +2459,6 @@ DIAGNOSTIC( Error, unsupportedTargetIntrinsic, "intrinsic operation '$0' is not supported for the current target.") -DIAGNOSTIC( - 55205, - Error, - unsupportedSpecializationConstantForNumThreads, - "Specialization constants are not supported in the 'numthreads' attribute for the current " - "target.") DIAGNOSTIC( 56001, Error, diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index d3a9359ff2..7b51495e2b 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -295,48 +295,14 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) } -IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( - IRFunc* func, - Int outNumThreads[kThreadGroupAxisCount]) -{ - Int specializationConstantIds[kThreadGroupAxisCount]; - IRNumThreadsDecoration* decor = - getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds); - - for (auto id : specializationConstantIds) - { - if (id >= 0) - { - getSink()->diagnose(decor, Diagnostics::unsupportedSpecializationConstantForNumThreads); - break; - } - } - return decor; -} - /* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize( IRFunc* func, - Int outNumThreads[kThreadGroupAxisCount], - Int outSpecializationConstantIds[kThreadGroupAxisCount]) + Int outNumThreads[kThreadGroupAxisCount]) { IRNumThreadsDecoration* decor = func->findDecoration(); - for (int i = 0; i < kThreadGroupAxisCount; ++i) + for (int i = 0; i < 3; ++i) { - if (!decor) - { - outNumThreads[i] = 1; - outSpecializationConstantIds[i] = -1; - } - else if (auto specConst = as(decor->getOperand(i))) - { - outNumThreads[i] = 1; - outSpecializationConstantIds[i] = getSpecializationConstantId(specConst); - } - else - { - outNumThreads[i] = Int(getIntVal(decor->getOperand(i))); - outSpecializationConstantIds[i] = -1; - } + outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1; } return decor; } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 1354b7cbd8..e5080f731b 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -500,19 +500,10 @@ class CLikeSourceEmitter : public SourceEmitterBase /// different. Returns an empty slice if not a built in type static UnownedStringSlice getDefaultBuiltinTypeName(IROp op); - /// Finds the IRNumThreadsDecoration and gets the size from that or sets all - /// dimensions to 1 - IRNumThreadsDecoration* getComputeThreadGroupSize( - IRFunc* func, - Int outNumThreads[kThreadGroupAxisCount]); - - /// Finds the IRNumThreadsDecoration and gets the size from that or sets all - /// dimensions to 1. If specialization constants are used for an axis, their - /// IDs is reported in non-negative entries of outSpecializationConstantIds. + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1 static IRNumThreadsDecoration* getComputeThreadGroupSize( IRFunc* func, - Int outNumThreads[kThreadGroupAxisCount], - Int outSpecializationConstantIds[kThreadGroupAxisCount]); + Int outNumThreads[kThreadGroupAxisCount]); /// Finds the IRWaveSizeDecoration and gets the size from that. static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int* outWaveSize); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 0dab07cfce..23fff37acb 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -1335,8 +1335,7 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( auto emitLocalSizeLayout = [&]() { Int sizeAlongAxis[kThreadGroupAxisCount]; - Int specializationConstantIds[kThreadGroupAxisCount]; - getComputeThreadGroupSize(irFunc, sizeAlongAxis, specializationConstantIds); + getComputeThreadGroupSize(irFunc, sizeAlongAxis); m_writer->emit("layout("); char const* axes[] = {"x", "y", "z"}; @@ -1346,17 +1345,8 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl( m_writer->emit(", "); m_writer->emit("local_size_"); m_writer->emit(axes[ii]); - - if (specializationConstantIds[ii] >= 0) - { - m_writer->emit("_id = "); - m_writer->emit(specializationConstantIds[ii]); - } - else - { - m_writer->emit(" = "); - m_writer->emit(sizeAlongAxis[ii]); - } + m_writer->emit(" = "); + m_writer->emit(sizeAlongAxis[ii]); } m_writer->emit(") in;\n"); }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 2cf84a8540..068e1563ca 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4353,36 +4353,23 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // [3.6. Execution Mode]: LocalSize case kIROp_NumThreadsDecoration: { + // TODO: The `LocalSize` execution mode option requires + // literal values for the X,Y,Z thread-group sizes. + // There is a `LocalSizeId` variant that takes ``s + // for those sizes, and we should consider using that + // and requiring the appropriate capabilities + // if any of the operands to the decoration are not + // literals (in a future where we support non-literals + // in those positions in the Slang IR). + // auto numThreads = cast(decoration); - if (numThreads->getXSpecConst() || numThreads->getYSpecConst() || - numThreads->getZSpecConst()) - { - // If any of the dimensions needs an ID, we need to emit - // all dimensions as an ID due to how LocalSizeId works. - int32_t ids[3]; - for (int i = 0; i < 3; ++i) - ids[i] = ensureInst(numThreads->getOperand(i))->id; - - // LocalSizeId is supported from SPIR-V 1.2 onwards without - // any extra capabilities. - requireSPIRVExecutionMode( - decoration, - dstID, - SpvExecutionModeLocalSizeId, - SpvLiteralInteger::from32(int32_t(ids[0])), - SpvLiteralInteger::from32(int32_t(ids[1])), - SpvLiteralInteger::from32(int32_t(ids[2]))); - } - else - { - requireSPIRVExecutionMode( - decoration, - dstID, - SpvExecutionModeLocalSize, - SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), - SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); - } + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeLocalSize, + SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))); } break; case kIROp_MaxVertexCountDecoration: @@ -7990,18 +7977,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { if (m_executionModes[entryPoint].add(executionMode)) { - SpvOp execModeOp = SpvOpExecutionMode; - if (executionMode == SpvExecutionModeLocalSizeId || - executionMode == SpvExecutionModeLocalSizeHintId || - executionMode == SpvExecutionModeSubgroupsPerWorkgroupId) - { - execModeOp = SpvOpExecutionModeId; - } - emitInst( getSection(SpvLogicalSectionID::ExecutionModes), parentInst, - execModeOp, + SpvOpExecutionMode, entryPoint, executionMode, ops...); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 3237ba3b26..65ce69877f 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -528,12 +528,10 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF // If primal parameter is mutable, we need to pass in a temp var. auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); - // If the parameter is not a pure 'out' param, we also need to setup the initial - // value of the temp var, otherwise the temp var will be uninitialized which could - // cause undefined behavior in the primal function. - // - if (!as(primalParamType)) - builder.emitStore(tempVar, primalArg); + // We also need to setup the initial value of the temp var, otherwise + // the temp var will be uninitialized which could cause undefined behavior + // in the primal function. + builder.emitStore(tempVar, primalArg); primalArgs.add(tempVar); } diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index 372ef298e7..1c833a2948 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -279,16 +279,6 @@ struct CollectGlobalUniformParametersContext continue; } - // NumThreadsDecoration may sometimes be the user for a global - // parameter. This occurs when the parameter was supposed to be - // a specialization constant, but isn't due to that not being - // supported for the target. These can be skipped here and - // diagnosed later. - if (as(user)) - { - continue; - } - // For each use site for the global parameter, we will // insert new code right before the instruction that uses // the parameter. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f46586aa2b..a58c2e900c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -570,7 +570,6 @@ struct IRInstanceDecoration : IRDecoration IRIntLit* getCount() { return cast(getOperand(0)); } }; -struct IRGlobalParam; struct IRNumThreadsDecoration : IRDecoration { enum @@ -579,13 +578,11 @@ struct IRNumThreadsDecoration : IRDecoration }; IR_LEAF_ISA(NumThreadsDecoration) - IRIntLit* getX() { return as(getOperand(0)); } - IRIntLit* getY() { return as(getOperand(1)); } - IRIntLit* getZ() { return as(getOperand(2)); } + IRIntLit* getX() { return cast(getOperand(0)); } + IRIntLit* getY() { return cast(getOperand(1)); } + IRIntLit* getZ() { return cast(getOperand(2)); } - IRGlobalParam* getXSpecConst() { return as(getOperand(0)); } - IRGlobalParam* getYSpecConst() { return as(getOperand(1)); } - IRGlobalParam* getZSpecConst() { return as(getOperand(2)); } + IRIntLit* getExtentAlongAxis(int axis) { return cast(getOperand(axis)); } }; struct IRWaveSizeDecoration : IRDecoration diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 6840197721..e267e8343c 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -190,7 +190,7 @@ IRInst* emitCalcGroupExtents(IRBuilder& builder, IRFunc* entryPoint, IRVectorTyp for (int axis = 0; axis < kAxisCount; axis++) { - auto litValue = as(numThreadsDecor->getOperand(axis)); + auto litValue = as(numThreadsDecor->getExtentAlongAxis(axis)); if (!litValue) return nullptr; @@ -1434,20 +1434,6 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize // groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type); - if (!groupExtents) - { - m_sink->diagnose( - m_entryPointFunc, - Diagnostics::unsupportedSpecializationConstantForNumThreads); - - // Fill in placeholder values. - static const int kAxisCount = 3; - IRInst* groupExtentAlongAxis[kAxisCount] = {}; - for (int axis = 0; axis < kAxisCount; axis++) - groupExtentAlongAxis[axis] = builder.getIntValue(uint3Type->getElementType(), 1); - groupExtents = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); - } - dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents); diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 68d79617a8..90d30dcc77 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -490,19 +490,11 @@ static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst) bool isFalseBranchTrivial = false; if (isTrivialIfElse(ifElseInst, isTrueBranchTrivial, isFalseBranchTrivial)) { - // If either branch of `if-else` is a trivial jump into after block, + // If both branches of `if-else` are trivial jumps into after block, // we can get rid of the entire conditional branch and replace it // with a jump into the after block. - IRUnconditionalBranch* termInst = - as(ifElseInst->getTrueBlock()->getTerminator()); - if (!termInst || (termInst->getTargetBlock() != ifElseInst->getAfterBlock())) + if (auto termInst = as(ifElseInst->getTrueBlock()->getTerminator())) { - termInst = as(ifElseInst->getFalseBlock()->getTerminator()); - } - - if (termInst) - { - SLANG_ASSERT(termInst->getTargetBlock() == ifElseInst->getAfterBlock()); List args; for (UInt i = 0; i < termInst->getArgCount(); i++) args.add(termInst->getArg(i)); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index a9b0d44121..40cd40758a 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -71,42 +71,6 @@ struct SpecializationContext module->getContainerPool().free(&cleanInsts); } - bool isUnsimplifiedArithmeticInst(IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_Add: - case kIROp_Sub: - case kIROp_Mul: - case kIROp_Div: - case kIROp_Neg: - case kIROp_Not: - case kIROp_Eql: - case kIROp_Neq: - case kIROp_Leq: - case kIROp_Geq: - case kIROp_Less: - case kIROp_IRem: - case kIROp_FRem: - case kIROp_Greater: - case kIROp_Lsh: - case kIROp_Rsh: - case kIROp_BitAnd: - case kIROp_BitOr: - case kIROp_BitXor: - case kIROp_BitNot: - case kIROp_BitCast: - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: - case kIROp_IntCast: - case kIROp_FloatCast: - case kIROp_Select: - return true; - default: - return false; - } - } - // An instruction is then fully specialized if and only // if it is in our set. // @@ -169,14 +133,6 @@ struct SpecializationContext return areAllOperandsFullySpecialized(inst); } - if (isUnsimplifiedArithmeticInst(inst)) - { - // For arithmetic insts, we want to wait for simplification before specialization, - // since different insts can simplify to the same value. - // - return false; - } - // The default case is that a global value is always specialized. if (inst->getParent() == module->getModuleInst()) { @@ -1136,7 +1092,6 @@ struct SpecializationContext { this->changed = true; eliminateDeadCode(module->getModuleInst()); - applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); } // Once the work list has gone dry, we should have the invariant diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 077cdb98d0..a44e16a7ce 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -282,11 +282,10 @@ struct GlobalVarTranslationContext if (!numthreadsDecor) return; builder.setInsertBefore(use->getUser()); - IRInst* values[3] = { - numthreadsDecor->getOperand(0), - numthreadsDecor->getOperand(1), - numthreadsDecor->getOperand(2)}; - + IRInst* values[] = { + numthreadsDecor->getExtentAlongAxis(0), + numthreadsDecor->getExtentAlongAxis(1), + numthreadsDecor->getExtentAlongAxis(2)}; auto workgroupSize = builder.emitMakeVector( builder.getVectorType(builder.getIntType(), 3), 3, @@ -329,10 +328,10 @@ struct GlobalVarTranslationContext if (!firstBlock) continue; builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - IRInst* args[3] = { - numthreadsDecor->getOperand(0), - numthreadsDecor->getOperand(1), - numthreadsDecor->getOperand(2)}; + IRInst* args[] = { + numthreadsDecor->getExtentAlongAxis(0), + numthreadsDecor->getExtentAlongAxis(1), + numthreadsDecor->getExtentAlongAxis(2)}; auto workgroupSize = builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args); builder.emitStore(globalVar, workgroupSize); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index d05e1db7d4..c753600a7c 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1973,17 +1973,4 @@ IRType* getIRVectorBaseType(IRType* type) return as(type)->getElementType(); } -Int getSpecializationConstantId(IRGlobalParam* param) -{ - auto layout = findVarLayout(param); - if (!layout) - return 0; - - auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant); - if (!offset) - return 0; - - return offset->getOffset(); -} - } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 666ac71c03..e23aeb6180 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -373,8 +373,6 @@ inline bool isSPIRV(CodeGenTarget codeGenTarget) int getIRVectorElementSize(IRType* type); IRType* getIRVectorBaseType(IRType* type); -Int getSpecializationConstantId(IRGlobalParam* param); - } // namespace Slang #endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 0863457198..e82fc03fde 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7625,29 +7625,12 @@ struct DeclLoweringVisitor : DeclVisitor { verifyComputeDerivativeGroupModifier = true; getAllEntryPointsNoOverride(entryPoints); - - LoweredValInfo extents[3]; - - for (int i = 0; i < 3; ++i) - { - extents[i] = layoutLocalSizeAttr->specConstExtents[i] - ? emitDeclRef( - context, - layoutLocalSizeAttr->specConstExtents[i], - lowerType( - context, - getType( - context->astBuilder, - layoutLocalSizeAttr->specConstExtents[i]))) - : lowerVal(context, layoutLocalSizeAttr->extents[i]); - } - for (auto d : entryPoints) as(getBuilder()->addNumThreadsDecoration( d, - getSimpleVal(context, extents[0]), - getSimpleVal(context, extents[1]), - getSimpleVal(context, extents[2]))); + getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->x)), + getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->y)), + getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->z)))); } else if (as(modifier)) { @@ -10353,28 +10336,11 @@ struct DeclLoweringVisitor : DeclVisitor } else if (auto numThreadsAttr = as(modifier)) { - LoweredValInfo extents[3]; - - for (int i = 0; i < 3; ++i) - { - extents[i] = numThreadsAttr->specConstExtents[i] - ? emitDeclRef( - context, - numThreadsAttr->specConstExtents[i], - lowerType( - context, - getType( - context->astBuilder, - numThreadsAttr->specConstExtents[i]))) - : lowerVal(context, numThreadsAttr->extents[i]); - } - numThreadsDecor = as(getBuilder()->addNumThreadsDecoration( irFunc, - getSimpleVal(context, extents[0]), - getSimpleVal(context, extents[1]), - getSimpleVal(context, extents[2]))); - numThreadsDecor->sourceLoc = numThreadsAttr->loc; + getSimpleVal(context, lowerVal(context, numThreadsAttr->x)), + getSimpleVal(context, lowerVal(context, numThreadsAttr->y)), + getSimpleVal(context, lowerVal(context, numThreadsAttr->z)))); } else if (auto waveSizeAttr = as(modifier)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 6ae41a2eb9..c275a868b5 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8437,9 +8437,7 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) int localSizeIndex = -1; if (nameText.startsWith(localSizePrefix) && - (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1 || - (nameText.endsWith("_id") && - (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4)))) + nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1) { char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1]; localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1; @@ -8453,8 +8451,6 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) numThreadsAttrib->args.setCount(3); for (auto& i : numThreadsAttrib->args) i = nullptr; - for (auto& b : numThreadsAttrib->axisIsSpecConstId) - b = false; // Just mark the loc and name from the first in the list numThreadsAttrib->keywordName = getName(parser, "numthreads"); @@ -8471,11 +8467,6 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) } numThreadsAttrib->args[localSizeIndex] = expr; - - // We can't resolve the specialization constant declaration - // here, because it may not even exist. IDs pointing to unnamed - // specialization constants are allowed in GLSL. - numThreadsAttrib->axisIsSpecConstId[localSizeIndex] = nameText.endsWith("_id"); } } else if (nameText == "derivative_group_quadsNV") diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index d1adfedc0b..d235c82703 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -4033,14 +4033,18 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier(); if (numThreadsAttribute) { - for (int i = 0; i < 3; ++i) - { - if (auto cint = - entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->extents[i])) - sizeAlongAxis[i] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->extents[i]) - sizeAlongAxis[i] = 0; - } + if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x)) + sizeAlongAxis[0] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->x) + sizeAlongAxis[0] = 0; + if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y)) + sizeAlongAxis[1] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->y) + sizeAlongAxis[1] = 0; + if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z)) + sizeAlongAxis[2] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->z) + sizeAlongAxis[2] = 0; } // diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index efc1c6fd11..5ec1996581 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1493,8 +1493,7 @@ DeclRef Linkage::specializeWithArgTypes( DiagnosticSink* sink) { SemanticsVisitor visitor(getSemanticsForReflection()); - SemanticsVisitor::ExprLocalScope scope; - visitor = visitor.withSink(sink).withExprLocalScope(&scope); + visitor = visitor.withSink(sink); SLANG_AST_BUILDER_RAII(getASTBuilder()); diff --git a/tests/autodiff/out-parameters-2.slang b/tests/autodiff/out-parameters-2.slang deleted file mode 100644 index b4c4b07c61..0000000000 --- a/tests/autodiff/out-parameters-2.slang +++ /dev/null @@ -1,49 +0,0 @@ -//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type -//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type - -//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer -RWStructuredBuffer outputBuffer; - -typedef DifferentialPair dpfloat; - -struct Foo : IDifferentiable -{ - float a; - int b; -} - -[PreferCheckpoint] -float k() -{ - return outputBuffer[3] + 1; -} - -[Differentiable] -void h(float x, float y, out Foo result) -{ - float p = no_diff k(); - float m = x + y + p; - float n = x - y; - float r = m * n + 2 * x * y; - - result = {r, 2}; -} - -[numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) -{ - float x = 2.0; - float y = 3.5; - float dx = 1.0; - float dy = 0.5; - - dpfloat dresult; - dpfloat dpx = diffPair(x); - dpfloat dpy = diffPair(y); - Foo.Differential dFoo; - dFoo.a = 1.0; - bwd_diff(h)(dpx, dpy, dFoo); - - outputBuffer[0] = dpx.d; // CHECK: 12.0 - outputBuffer[1] = dpy.d; // CHECK: -4.0 -} \ No newline at end of file diff --git a/tests/bugs/simplify-if-else.slang b/tests/bugs/simplify-if-else.slang deleted file mode 100644 index 8719a15995..0000000000 --- a/tests/bugs/simplify-if-else.slang +++ /dev/null @@ -1,26 +0,0 @@ -//TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target hlsl -//CHECK: computeMain - -//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer -RWStructuredBuffer outputBuffer; - -[numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) -{ - vector vvv = vector(0); - float32_t ret = 0.0f; - if (vvv.y < 1.0f) - { - ret = 1.0f; - } - else - { - if (vvv.y > 1.0f && outputBuffer[3] == 3) - { - ret = 0.0f; - } else { - if (true) {} - } - } - outputBuffer[int(dispatchThreadID.x)] = int(ret); -} diff --git a/tests/diagnostics/missing-return.slang.expected b/tests/diagnostics/missing-return.slang.expected index 7626665241..e41e756ff4 100644 --- a/tests/diagnostics/missing-return.slang.expected +++ b/tests/diagnostics/missing-return.slang.expected @@ -1,9 +1,9 @@ result code = 0 standard error = { -tests/diagnostics/missing-return.slang(7): warning 41010: non-void function does not return in all cases +tests/diagnostics/missing-return.slang(7): warning 41010: control flow may reach end of non-'void' function int bad(int a, int b) ^~~ -tests/diagnostics/missing-return.slang(14): warning 41010: non-void function does not return in all cases +tests/diagnostics/missing-return.slang(14): warning 41010: control flow may reach end of non-'void' function int alsoBad(int a, int b) ^~~~~~~ } diff --git a/tests/glsl/compute-shader-layout-id.slang b/tests/glsl/compute-shader-layout-id.slang deleted file mode 100644 index bee8137d82..0000000000 --- a/tests/glsl/compute-shader-layout-id.slang +++ /dev/null @@ -1,19 +0,0 @@ -//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry main -allow-glsl -#version 450 - -[vk::constant_id(1)] -const int constValue1 = 0; - -[vk::constant_id(2)] -const int constValue3 = 5; - -// CHECK-DAG: OpExecutionModeId %main LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] -// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 -// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 -// CHECK-DAG: OpDecorate %[[C2]] SpecId 2 - -layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = constValue3) in; -void main() -{ -} - diff --git a/tests/spirv/spec-constant-numthreads.slang b/tests/spirv/spec-constant-numthreads.slang deleted file mode 100644 index 5c133219cf..0000000000 --- a/tests/spirv/spec-constant-numthreads.slang +++ /dev/null @@ -1,35 +0,0 @@ -//TEST:SIMPLE(filecheck=GLSL): -target glsl -allow-glsl -//TEST:SIMPLE(filecheck=GLSL): -target glsl -//TEST:SIMPLE(filecheck=CHECK): -target spirv -allow-glsl -//TEST:SIMPLE(filecheck=CHECK): -target spirv - -// CHECK-DAG: OpExecutionModeId %computeMain1 LocalSizeId %[[C0:[0-9A-Za-z_]+]] %[[C1:[0-9A-Za-z_]+]] %[[C2:[0-9A-Za-z_]+]] -// CHECK-DAG: OpDecorate %[[C0]] SpecId 1 -// CHECK-DAG: OpDecorate %[[C1]] SpecId 0 -// CHECK-DAG: %[[C2]] = OpConstant %int 4 -// CHECK-DAG: OpStore %{{.*}} %[[C0]] -// CHECK-DAG: OpStore %{{.*}} %[[C1]] -// CHECK-DAG: OpStore %{{.*}} %[[C2]] - -// GLSL-DAG: layout(constant_id = 1) -// GLSL-DAG: int constValue0_0 = 0; -// GLSL-DAG: layout(constant_id = 0) -// GLSL-DAG: int constValue1_0 = 0; -// GLSL-DAG: layout(local_size_x_id = 1, local_size_y_id = 0, local_size_z = 4) in; - -[vk::specialization_constant] -const int constValue0 = 0; - -[vk::constant_id(0)] -const int constValue1 = 0; - -RWStructuredBuffer outputBuffer; - -[numthreads(constValue0, constValue1, 4)] -void computeMain1() -{ - int3 size = WorkgroupSize(); - outputBuffer[0] = size.x; - outputBuffer[1] = size.y; - outputBuffer[2] = size.z; -} From a5e5dc126759dec7de6c61712def75cdaff0d2b0 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 14 Jan 2025 20:22:36 -0500 Subject: [PATCH 18/18] apply metal spec const thread count changes --- .../slang-ir-legalize-varying-params.cpp | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 6840197721..4858d6e31c 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3378,12 +3378,25 @@ class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext IRBuilder svBuilder(builder.getModule()); svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); - auto computeExtent = emitCalcGroupExtents( - svBuilder, - entryPoint.entryPointFunc, - builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 3))); + auto uint3Type = builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 3)); + auto computeExtent = + emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type); + if (!computeExtent) + { + m_sink->diagnose( + entryPoint.entryPointFunc, + Diagnostics::unsupportedSpecializationConstantForNumThreads); + + // Fill in placeholder values. + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + for (int axis = 0; axis < kAxisCount; axis++) + groupExtentAlongAxis[axis] = + builder.getIntValue(uint3Type->getElementType(), 1); + computeExtent = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis); + } auto groupIndexCalc = emitCalcGroupIndex( svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc],