From bec82c4a846b080c43fa97d572cd49f902a120b2 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 18 Sep 2024 16:32:10 -0400 Subject: [PATCH 1/3] Allow lookups of overloaded methods. --- include/slang.h | 18 ++ source/slang/slang-check-shader.cpp | 13 +- source/slang/slang-compiler.h | 8 +- source/slang/slang-reflection-api.cpp | 169 ++++++++++++++---- source/slang/slang.cpp | 61 +++---- .../unit-test-function-reflection.cpp | 32 ++++ 6 files changed, 227 insertions(+), 74 deletions(-) diff --git a/include/slang.h b/include/slang.h index 3024aa8844..06e054370e 100644 --- a/include/slang.h +++ b/include/slang.h @@ -2591,6 +2591,9 @@ extern "C" SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func); SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic); SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(SlangReflectionFunction* func, SlangInt argTypeCount, SlangReflectionType* const* argTypes); + SLANG_API bool spReflectionFunction_isOverloaded(SlangReflectionFunction* func); + SLANG_API unsigned int spReflectionFunction_getOverloadCount(SlangReflectionFunction* func); + SLANG_API SlangReflectionFunction* spReflectionFunction_getOverload(SlangReflectionFunction* func, unsigned int index); // Abstract Decl Reflection @@ -3594,6 +3597,21 @@ namespace slang { return (FunctionReflection*)spReflectionFunction_specializeWithArgTypes((SlangReflectionFunction*)this, argCount, (SlangReflectionType* const*)types); } + + bool isOverloaded() + { + return spReflectionFunction_isOverloaded((SlangReflectionFunction*)this); + } + + unsigned int getOverloadCount() + { + return spReflectionFunction_getOverloadCount((SlangReflectionFunction*)this); + } + + FunctionReflection* getOverload(unsigned int index) + { + return (FunctionReflection*)spReflectionFunction_getOverload((SlangReflectionFunction*)this, index); + } }; struct GenericReflection diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 1718c3afd3..08bac1f78c 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -235,8 +235,17 @@ namespace Slang Name* name, DiagnosticSink* sink) { - auto declRef = translationUnit->findDeclFromString(getText(name), sink); - FuncDecl* entryPointFuncDecl = declRef.as().getDecl(); + FuncDecl* entryPointFuncDecl = nullptr; + + auto expr = translationUnit->findDeclFromString(getText(name), sink); + if (auto declRefExpr = as(expr)) + { + auto declRef = declRefExpr->declRef; + entryPointFuncDecl = declRef.as().getDecl(); + + if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit) + entryPointFuncDecl = nullptr; + } if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit) entryPointFuncDecl = nullptr; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 0c788ae182..7227fcbb4c 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -420,11 +420,11 @@ namespace Slang String const& typeStr, DiagnosticSink* sink); - DeclRef findDeclFromString( + Expr* findDeclFromString( String const& name, DiagnosticSink* sink); - DeclRef findDeclFromStringInType( + Expr* findDeclFromStringInType( Type* type, String const& name, LookupMask mask, @@ -576,7 +576,7 @@ namespace Slang Dictionary m_types; // Any decls looked up dynamically using `findDeclFromString`. - Dictionary> m_decls; + Dictionary m_decls; Scope* m_lookupScope = nullptr; std::unique_ptr> m_mapMangledNameToIntVal; @@ -2174,7 +2174,7 @@ namespace Slang DiagnosticSink* sink); DeclRef specializeWithArgTypes( - DeclRef funcDeclRef, + Expr* funcExpr, List argTypes, DiagnosticSink* sink); diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 38129babf5..1342666439 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -66,10 +66,21 @@ static inline SlangReflectionVariable* convert(DeclRef var) return (SlangReflectionVariable*) var.declRefBase; } -static inline DeclRef convert(SlangReflectionFunction* func) +static inline DeclRef convertToFunc(SlangReflectionFunction* func) { - DeclRefBase* declBase = (DeclRefBase*)func; - return DeclRef(declBase); + NodeBase* nodeBase = (NodeBase*)func; + if (DeclRefBase* declRefBase = as(nodeBase)) + { + return DeclRef(declRefBase); + } + + return DeclRef(); +} + +static inline OverloadedExpr* convertToOverloadedFunc(SlangReflectionFunction* func) +{ + NodeBase* nodeBase = (NodeBase*)func; + return as(nodeBase); } static inline SlangReflectionFunction* convert(DeclRef func) @@ -77,6 +88,11 @@ static inline SlangReflectionFunction* convert(DeclRef func) return (SlangReflectionFunction*)func.declRefBase; } +static inline SlangReflectionFunction* convert(OverloadedExpr* overloadedFunc) +{ + return (SlangReflectionFunction*)overloadedFunc; +} + static inline DeclRef convertGenericToDeclRef(SlangReflectionGeneric* func) { DeclRefBase* declBase = (DeclRefBase*)func; @@ -785,6 +801,27 @@ SLANG_API SlangResult spReflectionType_GetFullName(SlangReflectionType* inType, return SLANG_OK; } +SlangReflectionFunction* tryConvertExprToFunctionReflection(ASTBuilder* astBuilder, Expr* expr) +{ + if (auto declRefExpr = as(expr)) + { + auto declRef = declRefExpr->declRef; + if (auto genericDeclRef = declRef.as()) + { + auto innerDeclRef = substituteDeclRef( + SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef); + } + + if (auto funcDeclRef = declRef.as()) + return convert(funcDeclRef); + } + else if (auto overloadedExpr = as(expr)) + return convert(overloadedExpr); + + return nullptr; +} + SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflection* reflection, char const* name) { auto programLayout = convert(reflection); @@ -800,17 +837,9 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti auto astBuilder = program->getLinkage()->getASTBuilder(); try { - auto result = program->findDeclFromString(name, &sink); - - if (auto genericDeclRef = result.as()) - { - auto innerDeclRef = substituteDeclRef( - SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); - result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef); - } - - if (auto funcDeclRef = result.as()) - return convert(funcDeclRef); + return tryConvertExprToFunctionReflection( + astBuilder, + program->findDeclFromString(name, &sink)); } catch (...) { @@ -828,12 +857,13 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByNameInType(SlangRe Slang::DiagnosticSink sink( programLayout->getTargetReq()->getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); - + + auto astBuilder = program->getLinkage()->getASTBuilder(); + try { auto result = program->findDeclFromStringInType(type, name, LookupMask::Function, &sink); - if (auto funcDeclRef = result.as()) - return convert(funcDeclRef); + return tryConvertExprToFunctionReflection(astBuilder, result); } catch (...) { @@ -855,8 +885,11 @@ SLANG_API SlangReflectionVariable* spReflection_FindVarByNameInType(SlangReflect try { auto result = program->findDeclFromStringInType(type, name, LookupMask::Value, &sink); - if (auto varDeclRef = result.as()) - return convert(varDeclRef.as()); + if (auto declRefExpr = as(result)) + { + if (auto varDeclRef = declRefExpr->declRef.as()) + return convert(varDeclRef.as()); + } } catch (...) { @@ -3009,21 +3042,23 @@ SLANG_API SlangStage spReflectionVariableLayout_getStage( SLANG_API SlangReflectionDecl* spReflectionFunction_asDecl(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; + return (SlangReflectionDecl*)func.getDecl(); } SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; + return getText(func.getDecl()->getName()).getBuffer(); } SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; auto rawType = func.getDecl()->returnType.type; @@ -3034,7 +3069,9 @@ SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectio SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* inFunc, SlangModifierID modifierID) { - auto funcDeclRef = convert(inFunc); + auto funcDeclRef = convertToFunc(inFunc); + if (!funcDeclRef) return nullptr; + auto varRefl = convert(funcDeclRef.as()); if (!varRefl) return nullptr; @@ -3043,35 +3080,38 @@ SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflec SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return 0; + return getUserAttributeCount(func.getDecl()); } SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* inFunc, unsigned int index) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; return getUserAttributeByIndex(func.getDecl(), index); } SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* inFunc, SlangSession* session, char const* name) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; + return findUserAttributeByName(asInternal(session), func.getDecl(), name); } SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return 0; + return (unsigned int)func.getDecl()->getParameters().getCount(); } SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* inFunc, unsigned int index) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; auto astBuilder = getModule(func.getDecl())->getLinkage()->getASTBuilder(); @@ -3081,13 +3121,16 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func) { - auto declRef = convert(func); + auto declRef = convertToFunc(func); + if (!declRef) + return nullptr; + return convertDeclToGeneric(getInnermostGenericParent(declRef)); } SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic) { - auto declRef = convert(func); + auto declRef = convertToFunc(func); auto genericDeclRef = convertGenericToDeclRef(generic); if (!declRef || !genericDeclRef) return nullptr; @@ -3103,12 +3146,27 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( SlangInt argTypeCount, SlangReflectionType* const* argTypes) { - auto declRef = convert(func); - if (!declRef) + Linkage* linkage = nullptr; + Expr* funcExpr = nullptr; + + if (auto funcDeclRef = convertToFunc(func)) + { + linkage = getModule(funcDeclRef.getDecl())->getLinkage(); + auto declRefExpr = linkage->getASTBuilder()->create(); + declRefExpr->declRef = funcDeclRef; + funcExpr = declRefExpr; + } + else if (auto overloadedExpr = convertToOverloadedFunc(func)) + { + linkage = getModule(overloadedExpr->lookupResult2.items[0].declRef.getDecl())->getLinkage(); + funcExpr = overloadedExpr; + } + else + { return nullptr; - - - auto linkage = getModule(declRef.getDecl())->getLinkage(); + } + + auto astBuilder = linkage->getASTBuilder(); List argTypeList; for (SlangInt ii = 0; ii < argTypeCount; ++ii) @@ -3120,7 +3178,7 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( try { DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as()); + return convert(linkage->specializeWithArgTypes(funcExpr, argTypeList, &sink).as()); } catch (...) { @@ -3128,6 +3186,45 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( } } +SLANG_API bool spReflectionFunction_isOverloaded( + SlangReflectionFunction* func) +{ + return (convertToOverloadedFunc(func) != nullptr); +} + +SLANG_API unsigned int spReflectionFunction_getOverloadCount( + SlangReflectionFunction* func) +{ + auto overloadedFunc = convertToOverloadedFunc(func); + if (!overloadedFunc) return 1; + + return (unsigned int) overloadedFunc->lookupResult2.items.getCount(); +} + +SLANG_API SlangReflectionFunction* spReflectionFunction_getOverload( + SlangReflectionFunction* func, + unsigned int index) +{ + auto overloadedFunc = convertToOverloadedFunc(func); + if (!overloadedFunc) return nullptr; + + auto declRef = overloadedFunc->lookupResult2.items[index].declRef; + if (auto funcDeclRef = declRef.as()) + { + return convert(declRef.as()); + } + else if (auto genericDeclRef = declRef.as()) + { + auto astBuilder = getModule(genericDeclRef.getDecl())->getLinkage()->getASTBuilder(); + auto innerDeclRef = substituteDeclRef( + SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); + return convert( + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef).as()); + } + + return nullptr; +} + // Abstract decl reflection SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl) diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 6c152cdddc..fa35a382f3 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1370,7 +1370,7 @@ DeclRef getGenericParentDeclRef( } DeclRef Linkage::specializeWithArgTypes( - DeclRef funcDeclRef, + Expr* funcExpr, List argTypes, DiagnosticSink* sink) { @@ -1378,6 +1378,16 @@ DeclRef Linkage::specializeWithArgTypes( visitor = visitor.withSink(sink); ASTBuilder* astBuilder = getASTBuilder(); + + if (auto declRefFuncExpr = as(funcExpr)) + { + auto genericDeclRefExpr = astBuilder->create(); + genericDeclRefExpr->declRef = getGenericParentDeclRef( + getASTBuilder(), + &visitor, + declRefFuncExpr->declRef); + funcExpr = genericDeclRefExpr; + } List argExprs; for (SlangInt aa = 0; aa < argTypes.getCount(); ++aa) @@ -1394,10 +1404,7 @@ DeclRef Linkage::specializeWithArgTypes( // Construct invoke expr. auto invokeExpr = astBuilder->create(); - auto declRefExpr = astBuilder->create(); - - declRefExpr->declRef = getGenericParentDeclRef(getASTBuilder(), &visitor, funcDeclRef); - invokeExpr->functionExpr = declRefExpr; + invokeExpr->functionExpr = funcExpr; invokeExpr->arguments = argExprs; auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr); @@ -2331,14 +2338,14 @@ Type* ComponentType::getTypeFromString( return type; } -DeclRef ComponentType::findDeclFromString( +Expr* ComponentType::findDeclFromString( String const& name, DiagnosticSink* sink) { // If we've looked up this type name before, // then we can re-use it. // - DeclRef result; + Expr* result; if (m_decls.tryGetValue(name, result)) return result; @@ -2369,34 +2376,26 @@ DeclRef ComponentType::findDeclFromString( SemanticsVisitor visitor(context); - auto checkedExpr = visitor.CheckExpr(expr); - if (auto declRefExpr = as(checkedExpr)) - { - result = declRefExpr->declRef; - } - else if (auto overloadedExpr = as(checkedExpr)) + auto checkedExpr = visitor.CheckTerm(expr); + + if (as(checkedExpr) || as(checkedExpr)) { - sink->diagnose(SourceLoc(), Diagnostics::ambiguousReference, name); - for (auto candidate : overloadedExpr->lookupResult2) - { - sink->diagnose(candidate.declRef.getDecl(), Diagnostics::overloadCandidate, candidate.declRef); - } + result = checkedExpr; } + m_decls[name] = result; return result; } -DeclRef ComponentType::findDeclFromStringInType( +Expr* ComponentType::findDeclFromStringInType( Type* type, String const& name, LookupMask mask, DiagnosticSink* sink) { - DeclRef result; - // Only look up in the type if it is a DeclRefType if (!as(type)) - return DeclRef(); + return nullptr; // TODO(JS): For now just used the linkages ASTBuilder to keep on scope // @@ -2433,7 +2432,7 @@ DeclRef ComponentType::findDeclFromStringInType( } if (!as(expr)) - return result; + return nullptr; auto rs = astBuilder->create(); auto typeExpr = astBuilder->create(); @@ -2453,20 +2452,18 @@ DeclRef ComponentType::findDeclFromStringInType( auto checkedTerm = visitor.CheckTerm(expr); auto resolvedTerm = visitor.maybeResolveOverloadedExpr(checkedTerm, mask, sink); + - if (auto declRefExpr = as(resolvedTerm)) + if (auto overloadedExpr = as(resolvedTerm)) { - result = declRefExpr->declRef; + return overloadedExpr; } - - if (auto genericDeclRef = result.as()) - { - result = createDefaultSubstitutionsIfNeeded( - astBuilder, &visitor, DeclRef(genericDeclRef.getDecl()->inner)); - result = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, result); + if (auto declRefExpr = as(resolvedTerm)) + { + return declRefExpr; } - return result; + return nullptr; } bool ComponentType::isSubType(Type* subType, Type* superType) diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp index ddcb81adc4..2b52a86917 100644 --- a/tools/slang-unit-test/unit-test-function-reflection.cpp +++ b/tools/slang-unit-test/unit-test-function-reflection.cpp @@ -36,6 +36,9 @@ SLANG_UNIT_TEST(functionReflection) { return pos; } + + float foo(float x) { return x; } + float foo(float x, uint i) { return x + i; } )"; auto moduleName = "moduleG" + String(Process::getId()); @@ -90,5 +93,34 @@ SLANG_UNIT_TEST(functionReflection) SLANG_CHECK(result == SLANG_OK); SLANG_CHECK(val == 1024); SLANG_CHECK(funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute); + + // Check overloaded method resolution + auto overloadReflection = module->getLayout()->findFunctionByName("foo"); + SLANG_CHECK(overloadReflection != nullptr); + SLANG_CHECK(overloadReflection->isOverloaded() == true); + SLANG_CHECK(overloadReflection->getOverloadCount() == 2); + + auto firstOverload = overloadReflection->getOverload(0); + SLANG_CHECK(firstOverload != nullptr); + SLANG_CHECK(UnownedStringSlice(firstOverload->getName()) == "foo"); + SLANG_CHECK(firstOverload->getParameterCount() == 2); + SLANG_CHECK(UnownedStringSlice(firstOverload->getParameterByIndex(0)->getName()) == "x"); + SLANG_CHECK(getTypeFullName(firstOverload->getParameterByIndex(0)->getType()) == "float"); + SLANG_CHECK(UnownedStringSlice(firstOverload->getParameterByIndex(1)->getName()) == "i"); + SLANG_CHECK(getTypeFullName(firstOverload->getParameterByIndex(1)->getType()) == "uint"); + + auto secondOverload = overloadReflection->getOverload(1); + SLANG_CHECK(secondOverload != nullptr); + SLANG_CHECK(UnownedStringSlice(secondOverload->getName()) == "foo"); + SLANG_CHECK(secondOverload->getParameterCount() == 1); + SLANG_CHECK(UnownedStringSlice(secondOverload->getParameterByIndex(0)->getName()) == "x"); + + // Check overload resolution via argument types. + slang::TypeReflection* argTypes[] = { + module->getLayout()->findTypeByName("float"), + module->getLayout()->findTypeByName("uint"), + }; + auto resolvedFunctionReflection = overloadReflection->specializeWithArgTypes(2, argTypes); + SLANG_CHECK(resolvedFunctionReflection == firstOverload); } From b69b9d195f63dea53d26e54814e367e3175f674e Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 18 Sep 2024 17:00:47 -0400 Subject: [PATCH 2/3] Update slang-reflection-api.cpp --- source/slang/slang-reflection-api.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 1342666439..b6fc059867 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -3165,8 +3165,6 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( { return nullptr; } - - auto astBuilder = linkage->getASTBuilder(); List argTypeList; for (SlangInt ii = 0; ii < argTypeCount; ++ii) From 5a1fcfe8ae50021c1d6832af161ea5f71a64aa1d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Wed, 18 Sep 2024 17:10:34 -0400 Subject: [PATCH 3/3] Update slang.cpp --- source/slang/slang.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index fa35a382f3..dc5f9a7550 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -2345,7 +2345,7 @@ Expr* ComponentType::findDeclFromString( // If we've looked up this type name before, // then we can re-use it. // - Expr* result; + Expr* result = nullptr; if (m_decls.tryGetValue(name, result)) return result;