diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index e5e8e8acc2..1dc81e861c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4226,7 +4226,7 @@ namespace Slang // DiagnosticSink tempSink(getSourceManager(), nullptr); ExprLocalScope localScope; - SemanticsVisitor subVisitor(withSink(&tempSink).withExprLocalScope(&localScope)); + SemanticsVisitor subVisitor(withSink(&tempSink).withParentFunc(synFuncDecl).withExprLocalScope(&localScope)); // With our temporary diagnostic sink soaking up any messages // from overload resolution, we can now try to resolve diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 0e01eeed27..7318fa390b 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2266,7 +2266,6 @@ namespace Slang // Look at the base expression for the call, and figure out how to invoke it. auto funcExpr = expr->functionExpr; - auto funcExprType = funcExpr->type; // If we are trying to apply an erroneous expression, then just bail out now. if(IsErrorExpr(funcExpr)) @@ -2295,25 +2294,6 @@ namespace Slang for (auto& arg : expr->arguments) { arg = maybeOpenRef(arg); - } - - auto funcType = as(funcExprType); - for (Index i = 0; i < expr->arguments.getCount(); i++) - { - auto& arg = expr->arguments[i]; - if (funcType && i < funcType->getParamCount()) - { - switch (funcType->getParamDirection(i)) - { - case kParameterDirection_Out: - case kParameterDirection_InOut: - case kParameterDirection_Ref: - case kParameterDirection_ConstRef: - continue; - default: - break; - } - } arg = maybeOpenExistential(arg); } @@ -2443,6 +2423,45 @@ namespace Slang // the user the most help we can. if (shouldAddToCache) typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate; + + // Now that we have resolved the overload candidate, we need to undo an `openExistential` + // operation that was applied to `out` arguments. + // + auto funcType = context.bestCandidate->funcType; + ShortList paramDirections; + if (funcType) + { + for (Index i = 0; i < funcType->getParamCount(); i++) + { + paramDirections.add(funcType->getParamDirection(i)); + } + } + else if (auto callableDeclRef = context.bestCandidate->item.declRef.as()) + { + for (auto param : callableDeclRef.getDecl()->getParameters()) + { + paramDirections.add(getParameterDirection(param)); + } + } + for (Index i = 0; i < expr->arguments.getCount(); i++) + { + auto& arg = expr->arguments[i]; + if (i < paramDirections.getCount()) + { + switch (paramDirections[i]) + { + case kParameterDirection_Out: + case kParameterDirection_InOut: + case kParameterDirection_Ref: + case kParameterDirection_ConstRef: + break; + default: + continue; + } + } + if (auto extractExistentialExpr = as(arg)) + arg = extractExistentialExpr->originalExpr; + } return CompleteOverloadCandidate(context, *context.bestCandidate); } @@ -2475,7 +2494,7 @@ namespace Slang // Nothing at all was found that we could even consider invoking. // In all other cases, this is an error. - getSink()->diagnose(expr->functionExpr, Diagnostics::expectedFunction, funcExprType); + getSink()->diagnose(expr->functionExpr, Diagnostics::expectedFunction, funcExpr->type); expr->type = QualType(m_astBuilder->getErrorType()); return expr; } diff --git a/tests/bugs/gh-4467.slang b/tests/bugs/gh-4467.slang new file mode 100644 index 0000000000..dc97b9890d --- /dev/null +++ b/tests/bugs/gh-4467.slang @@ -0,0 +1,43 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK): -d3d12 -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHK): -vk -compute -shaderobj -output-using-type + + +// Test that we can synthesize a [mutating] interface requirement from a nonmutating implementation, +// and the interface requirement signature contains an output interface-typed parameter. + +//TEST_INPUT: ubuffer(data=[0 0], stride=4):out,name outputBuffer +RWStructuredBuffer outputBuffer; + +interface IFoo +{ + int getVal(); +}; + +interface IBar +{ + [mutating] void method(out IFoo o); +}; + +struct FooImpl : IFoo +{ + int x; + int getVal() { return x; } +} + +struct BarImpl : IBar +{ + void method(out IFoo o) + { + o = FooImpl(1); + } +}; + +[numthreads(1,1,1)] +void computeMain() +{ + BarImpl bar; + IFoo foo; + bar.method(foo); + // CHK: 1 + outputBuffer[0] = foo.getVal(); +} \ No newline at end of file