Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check whether array element is fully specialized #6000

Merged
merged 5 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions source/slang/slang-ir-specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ struct SpecializationContext
}
}

if (isWrapperType(inst))
{
// For all the wrapper type, we need to make sure the operands are fully specialized.
return areAllOperandsFullySpecialized(inst);
}

// The default case is that a global value is always specialized.
if (inst->getParent() == module->getModuleInst())
{
Expand Down
25 changes: 25 additions & 0 deletions source/slang/slang-ir-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,31 @@ bool isSimpleHLSLDataType(IRInst* inst)
return true;
}

bool isWrapperType(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_ArrayType:
case kIROp_TextureType:
case kIROp_VectorType:
case kIROp_MatrixType:
case kIROp_PtrType:
case kIROp_RefType:
case kIROp_ConstRefType:
case kIROp_HLSLStructuredBufferType:
case kIROp_HLSLRWStructuredBufferType:
case kIROp_HLSLRasterizerOrderedStructuredBufferType:
case kIROp_HLSLAppendStructuredBufferType:
case kIROp_HLSLConsumeStructuredBufferType:
case kIROp_TupleType:
case kIROp_OptionalType:
case kIROp_TypePack:
return true;
default:
return false;
}
}

SourceLoc findFirstUseLoc(IRInst* inst)
{
for (auto use = inst->firstUse; use; use = use->nextUse)
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-util.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ bool isSimpleDataType(IRType* type);

bool isSimpleHLSLDataType(IRInst* inst);

bool isWrapperType(IRInst* inst);

SourceLoc findFirstUseLoc(IRInst* inst);

inline bool isChildInstOf(IRInst* inst, IRInst* parent)
Expand Down
86 changes: 86 additions & 0 deletions tests/bugs/gh-5776.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -profile sm_6_0 -use-dxil -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cuda -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cpu -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -wgpu -output-using-type


interface IFoo
{
associatedtype FooType : IFoo;
}

extension float : IFoo
{
typedef float FooType;
}

__generic<T:IFoo, let N:int>
extension Array<T, N> : IFoo
{
typedef Array<T.FooType, N> FooType;
}

__generic<T:IFoo, let N:int>
extension vector<T, N> : IFoo
{
typedef vector<T.FooType, N> FooType;
}

__generic<T:IFoo, let N:int, let M:int>
extension matrix<T, N, M> : IFoo
{
typedef matrix<T.FooType, N, M> FooType;
}

struct WrappedBuffer<T : IFoo>
{
StructuredBuffer<T> buffer;
int shape;

T get(int idx) { return buffer[idx]; }
}


struct GradInBuffer<T : IFoo>
{
WrappedBuffer<T.FooType> wrapBuffer;
}

struct CallData
{
GradInBuffer<float[2]> grad_in1;
GradInBuffer<vector<float, 2>> grad_in2;
GradInBuffer<float2x2> grad_in3;
}


//TEST_INPUT: set call_data.grad_in1.wrapBuffer.buffer = ubuffer(data=[1.0 2.0 3.0 4.0], stride=4);
//TEST_INPUT: set call_data.grad_in2.wrapBuffer.buffer = ubuffer(data=[5.0 6.0 7.0 8.0], stride=4);
//TEST_INPUT: set call_data.grad_in3.wrapBuffer.buffer = ubuffer(data=[1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0], stride=4);
ParameterBlock<CallData> call_data;


//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer
RWStructuredBuffer<float> outputBuffer;


[shader("compute")]
[numthreads(1, 1, 1)]
void computeMain()
{
float[2] data1 = call_data.grad_in1.wrapBuffer.buffer[0];
float[2] data2 = call_data.grad_in1.wrapBuffer.get(1);
outputBuffer[0] = data1[0];
outputBuffer[1] = data2[0];

vector<float, 2> data3 = call_data.grad_in2.wrapBuffer.buffer[0];
vector<float, 2> data4 = call_data.grad_in2.wrapBuffer.get(1);
outputBuffer[2] = data3[0];
outputBuffer[3] = data4[0];

float2x2 data5 = call_data.grad_in3.wrapBuffer.buffer[0];
float2x2 data6 = call_data.grad_in3.wrapBuffer.get(1);
outputBuffer[4] = data5[0][0];
outputBuffer[5] = data6[0][0];
}
7 changes: 7 additions & 0 deletions tests/bugs/gh-5776.slang.expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
type: float
1.000000
3.000000
5.000000
7.000000
1.000000
5.000000
Loading