Skip to content

Commit

Permalink
Implemented Combined-texture for WGSL (#5130)
Browse files Browse the repository at this point in the history
* Implemented Combined-texture for WGSL

* Remove unnecessary comment

* Limit to std430 layout

* Fix compiler warning for unused variable

---------

Co-authored-by: Yong He <yonghe@outlook.com>
  • Loading branch information
jkwak-work and csyonghe authored Sep 23, 2024
1 parent 14b1098 commit 3e950e1
Show file tree
Hide file tree
Showing 6 changed files with 537 additions and 10 deletions.
1 change: 1 addition & 0 deletions source/slang/slang-compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ namespace Slang
case CodeGenTarget::PyTorchCppBinding:
case CodeGenTarget::CSource:
case CodeGenTarget::Metal:
case CodeGenTarget::WGSL:
{
return PassThroughMode::None;
}
Expand Down
3 changes: 2 additions & 1 deletion source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -973,8 +973,9 @@ Result linkAndOptimizeIR(
case CodeGenTarget::Metal:
case CodeGenTarget::MetalLib:
case CodeGenTarget::MetalLibAssembly:
case CodeGenTarget::WGSL:
if (requiredLoweringPassSet.combinedTextureSamplers)
lowerCombinedTextureSamplers(irModule, sink);
lowerCombinedTextureSamplers(codeGenContext, irModule, sink);
break;
}

Expand Down
28 changes: 20 additions & 8 deletions source/slang/slang-ir-lower-combined-texture-sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace Slang
struct LowerCombinedSamplerContext
{
Dictionary<IRType*, LoweredCombinedSamplerStructInfo> mapTypeToLoweredInfo;
CodeGenTarget codeGenTarget;

LoweredCombinedSamplerStructInfo lowerCombinedTextureSamplerType(IRTextureTypeBase* textureType)
{
Expand Down Expand Up @@ -57,8 +58,16 @@ namespace Slang
builder.createStructField(structType, info.sampler, info.samplerType);

// Type layout.

auto textureResourceKind = isMutable ? LayoutResourceKind::UnorderedAccess : LayoutResourceKind::ShaderResource;

bool isWGSLTarget = codeGenTarget == CodeGenTarget::WGSL;
LayoutResourceKind textureResourceKind = isMutable ? LayoutResourceKind::UnorderedAccess : LayoutResourceKind::ShaderResource;
LayoutResourceKind samplerResourceKind = LayoutResourceKind::SamplerState;
if (isWGSLTarget)
{
textureResourceKind = LayoutResourceKind::DescriptorTableSlot;
samplerResourceKind = LayoutResourceKind::DescriptorTableSlot;
}

IRTypeLayout::Builder textureTypeLayoutBuilder(&builder);
textureTypeLayoutBuilder.addResourceUsage(
textureResourceKind,
Expand All @@ -67,7 +76,7 @@ namespace Slang

IRTypeLayout::Builder samplerTypeLayoutBuilder(&builder);
samplerTypeLayoutBuilder.addResourceUsage(
LayoutResourceKind::SamplerState,
samplerResourceKind,
LayoutSize(1));
auto samplerTypeLayout = samplerTypeLayoutBuilder.build();

Expand All @@ -76,7 +85,7 @@ namespace Slang
auto textureVarLayout = textureVarLayoutBuilder.build();

IRVarLayout::Builder samplerVarLayoutBuilder(&builder, samplerTypeLayout);
samplerVarLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::SamplerState)->offset = 0;
samplerVarLayoutBuilder.findOrAddResourceInfo(samplerResourceKind)->offset = isWGSLTarget ? 1 : 0;
auto samplerVarLayout = samplerVarLayoutBuilder.build();

IRStructTypeLayout::Builder layoutBuilder(&builder);
Expand All @@ -91,12 +100,14 @@ namespace Slang
};

void lowerCombinedTextureSamplers(
CodeGenContext* codeGenContext,
IRModule* module,
DiagnosticSink* sink)
{
SLANG_UNUSED(sink);

LowerCombinedSamplerContext context;
context.codeGenTarget = codeGenContext->getTargetFormat();

// Lower combined texture sampler type into a struct type.
for (auto globalInst : module->getGlobalInsts())
Expand Down Expand Up @@ -127,12 +138,13 @@ namespace Slang

for (auto offsetAttr : varLayout->getOffsetAttrs())
{
if (offsetAttr->getResourceKind() == LayoutResourceKind::UnorderedAccess ||
offsetAttr->getResourceKind() == LayoutResourceKind::ShaderResource)
LayoutResourceKind resKind = offsetAttr->getResourceKind();
if (resKind == LayoutResourceKind::UnorderedAccess ||
resKind == LayoutResourceKind::ShaderResource)
resOffsetAttr = offsetAttr;
else if (offsetAttr->getResourceKind() == LayoutResourceKind::DescriptorTableSlot)
else if (resKind == LayoutResourceKind::DescriptorTableSlot)
descriptorTableSlotOffsetAttr = offsetAttr;
auto info = newVarLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind());
auto info = newVarLayoutBuilder.findOrAddResourceInfo(resKind);
info->offset = offsetAttr->getOffset();
info->space = offsetAttr->getSpace();
info->kind = offsetAttr->getResourceKind();
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-lower-combined-texture-sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

namespace Slang
{
struct CodeGenContext;
struct IRModule;
class DiagnosticSink;

// Lower combined texture sampler types to structs.
void lowerCombinedTextureSamplers(
CodeGenContext* codeGenContext,
IRModule* module,
DiagnosticSink* sink
);
Expand Down
120 changes: 119 additions & 1 deletion source/slang/slang-type-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,25 @@ struct HLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl
};
HLSLObjectLayoutRulesImpl kHLSLObjectLayoutRulesImpl;

struct WGSLObjectLayoutRulesImpl : GLSLObjectLayoutRulesImpl
{
virtual ObjectLayoutInfo GetObjectLayout(ShaderParameterKind kind, const Options& options) override
{
ObjectLayoutInfo info = GLSLObjectLayoutRulesImpl::GetObjectLayout(kind, options);

switch (kind)
{
case ShaderParameterKind::TextureSampler:
case ShaderParameterKind::MutableTextureSampler:
info.layoutInfos.add(SimpleLayoutInfo(LayoutResourceKind::DescriptorTableSlot, 1));
break;
}

return info;
}
};
WGSLObjectLayoutRulesImpl kWGSLObjectLayoutRulesImpl;

// HACK: Treating ray-tracing input/output as if it was another
// case of varying input/output when it really needs to be
// based on byte storage/layout.
Expand Down Expand Up @@ -1053,11 +1072,32 @@ struct MetalLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
LayoutRulesImpl* getStructuredBufferRules(CompilerOptionSet& compilerOptions) override;
};

struct WGSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
{
virtual LayoutRulesImpl* getAnyValueRules() override;
virtual LayoutRulesImpl* getConstantBufferRules(CompilerOptionSet& compilerOptions) override;
virtual LayoutRulesImpl* getPushConstantBufferRules() override;
virtual LayoutRulesImpl* getTextureBufferRules(CompilerOptionSet& compilerOptions) override;
virtual LayoutRulesImpl* getVaryingInputRules() override;
virtual LayoutRulesImpl* getVaryingOutputRules() override;
virtual LayoutRulesImpl* getSpecializationConstantRules() override;
virtual LayoutRulesImpl* getShaderStorageBufferRules(CompilerOptionSet& compilerOptions) override;
virtual LayoutRulesImpl* getParameterBlockRules(CompilerOptionSet& compilerOptions) override;

LayoutRulesImpl* getRayPayloadParameterRules() override;
LayoutRulesImpl* getCallablePayloadParameterRules() override;
LayoutRulesImpl* getHitAttributesParameterRules() override;

LayoutRulesImpl* getShaderRecordConstantBufferRules() override;
LayoutRulesImpl* getStructuredBufferRules(CompilerOptionSet& compilerOptions) override;
};

GLSLLayoutRulesFamilyImpl kGLSLLayoutRulesFamilyImpl;
HLSLLayoutRulesFamilyImpl kHLSLLayoutRulesFamilyImpl;
CPULayoutRulesFamilyImpl kCPULayoutRulesFamilyImpl;
CUDALayoutRulesFamilyImpl kCUDALayoutRulesFamilyImpl;
MetalLayoutRulesFamilyImpl kMetalLayoutRulesFamilyImpl;
WGSLLayoutRulesFamilyImpl kWGSLLayoutRulesFamilyImpl;

// CPU case

Expand Down Expand Up @@ -1816,6 +1856,82 @@ LayoutRulesImpl* MetalLayoutRulesFamilyImpl::getHitAttributesParameterRules()
return nullptr;
}

// WGSL Family

LayoutRulesImpl kWGSLConstantBufferLayoutRulesImpl_ = {
&kWGSLLayoutRulesFamilyImpl, &kStd140LayoutRulesImpl, &kWGSLObjectLayoutRulesImpl,
};

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getAnyValueRules()
{
return &kGLSLAnyValueLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getConstantBufferRules(CompilerOptionSet&)
{
return &kWGSLConstantBufferLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getParameterBlockRules(CompilerOptionSet&)
{
return &kStd140LayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getPushConstantBufferRules()
{
return &kGLSLPushConstantLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getShaderRecordConstantBufferRules()
{
return &kGLSLShaderRecordLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getTextureBufferRules(CompilerOptionSet&)
{
return &kStd430LayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getVaryingInputRules()
{
return &kGLSLVaryingInputLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getVaryingOutputRules()
{
return &kGLSLVaryingOutputLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getSpecializationConstantRules()
{
return &kGLSLSpecializationConstantLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getShaderStorageBufferRules(CompilerOptionSet&)
{
return &kStd430LayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getRayPayloadParameterRules()
{
return &kGLSLRayPayloadParameterLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getCallablePayloadParameterRules()
{
return &kGLSLCallablePayloadParameterLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getHitAttributesParameterRules()
{
return &kGLSLHitAttributesParameterLayoutRulesImpl_;
}

LayoutRulesImpl* WGSLLayoutRulesFamilyImpl::getStructuredBufferRules(CompilerOptionSet&)
{
return &kGLSLStructuredBufferLayoutRulesImpl_;
}


LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targetReq)
{
Expand All @@ -1831,9 +1947,11 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe
case CodeGenTarget::GLSL:
case CodeGenTarget::SPIRV:
case CodeGenTarget::SPIRVAssembly:
case CodeGenTarget::WGSL:
return &kGLSLLayoutRulesFamilyImpl;

case CodeGenTarget::WGSL:
return &kWGSLLayoutRulesFamilyImpl;

case CodeGenTarget::HostHostCallable:
case CodeGenTarget::ShaderHostCallable:
case CodeGenTarget::HostExecutable:
Expand Down
Loading

0 comments on commit 3e950e1

Please sign in to comment.