Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open existential on arguments after overload resolution. #4982

Merged
merged 4 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4226,7 +4226,7 @@ namespace Slang
//
DiagnosticSink tempSink(getSourceManager(), nullptr);
ExprLocalScope localScope;
SemanticsVisitor subVisitor(withSink(&tempSink).withExprLocalScope(&localScope));
SemanticsVisitor subVisitor(withSink(&tempSink).withParentFunc(synFuncDecl).withExprLocalScope(&localScope));

// With our temporary diagnostic sink soaking up any messages
// from overload resolution, we can now try to resolve
Expand Down
61 changes: 40 additions & 21 deletions source/slang/slang-check-overload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2238,7 +2238,6 @@ namespace Slang

// Look at the base expression for the call, and figure out how to invoke it.
auto funcExpr = expr->functionExpr;
auto funcExprType = funcExpr->type;

// If we are trying to apply an erroneous expression, then just bail out now.
if(IsErrorExpr(funcExpr))
Expand Down Expand Up @@ -2267,25 +2266,6 @@ namespace Slang
for (auto& arg : expr->arguments)
{
arg = maybeOpenRef(arg);
}

auto funcType = as<FuncType>(funcExprType);
for (Index i = 0; i < expr->arguments.getCount(); i++)
{
auto& arg = expr->arguments[i];
if (funcType && i < funcType->getParamCount())
{
switch (funcType->getParamDirection(i))
{
case kParameterDirection_Out:
case kParameterDirection_InOut:
case kParameterDirection_Ref:
case kParameterDirection_ConstRef:
continue;
default:
break;
}
}
arg = maybeOpenExistential(arg);
}

Expand Down Expand Up @@ -2415,6 +2395,45 @@ namespace Slang
// the user the most help we can.
if (shouldAddToCache)
typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate;

// Now that we have resolved the overload candidate, we need to undo an `openExistential`
// operation that was applied to `out` arguments.
//
auto funcType = context.bestCandidate->funcType;
ShortList<ParameterDirection> paramDirections;
if (funcType)
{
for (Index i = 0; i < funcType->getParamCount(); i++)
{
paramDirections.add(funcType->getParamDirection(i));
}
}
else if (auto callableDeclRef = context.bestCandidate->item.declRef.as<CallableDecl>())
{
for (auto param : callableDeclRef.getDecl()->getMembersOfType<ParamDecl>())
csyonghe marked this conversation as resolved.
Show resolved Hide resolved
{
paramDirections.add(getParameterDirection(param));
}
}
for (Index i = 0; i < expr->arguments.getCount(); i++)
{
auto& arg = expr->arguments[i];
if (i < paramDirections.getCount())
{
switch (paramDirections[i])
{
case kParameterDirection_Out:
case kParameterDirection_InOut:
case kParameterDirection_Ref:
case kParameterDirection_ConstRef:
break;
default:
continue;
}
}
if (auto extractExistentialExpr = as<ExtractExistentialValueExpr>(arg))
arg = extractExistentialExpr->originalExpr;
}
return CompleteOverloadCandidate(context, *context.bestCandidate);
}

Expand Down Expand Up @@ -2447,7 +2466,7 @@ namespace Slang

// Nothing at all was found that we could even consider invoking.
// In all other cases, this is an error.
getSink()->diagnose(expr->functionExpr, Diagnostics::expectedFunction, funcExprType);
getSink()->diagnose(expr->functionExpr, Diagnostics::expectedFunction, funcExpr->type);
expr->type = QualType(m_astBuilder->getErrorType());
return expr;
}
Expand Down
43 changes: 43 additions & 0 deletions tests/bugs/gh-4467.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK): -d3d12 -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHK): -vk -compute -shaderobj -output-using-type


// Test that we can synthesize a [mutating] interface requirement from a nonmutating implementation,
// and the interface requirement signature contains an output interface-typed parameter.

//TEST_INPUT: ubuffer(data=[0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;

interface IFoo
{
int getVal();
};

interface IBar
{
[mutating] void method(out IFoo o);
};

struct FooImpl : IFoo
{
int x;
int getVal() { return x; }
}

struct BarImpl : IBar
{
void method(out IFoo o)
{
o = FooImpl(1);
}
};

[numthreads(1,1,1)]
void computeMain()
{
BarImpl bar;
IFoo foo;
bar.method(foo);
// CHK: 1
outputBuffer[0] = foo.getVal();
}
Loading