Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow lookups of overloaded methods. #5110

Merged
merged 6 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -2592,6 +2592,9 @@ extern "C"
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func);
SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic);
SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(SlangReflectionFunction* func, SlangInt argTypeCount, SlangReflectionType* const* argTypes);
SLANG_API bool spReflectionFunction_isOverloaded(SlangReflectionFunction* func);
SLANG_API unsigned int spReflectionFunction_getOverloadCount(SlangReflectionFunction* func);
SLANG_API SlangReflectionFunction* spReflectionFunction_getOverload(SlangReflectionFunction* func, unsigned int index);

// Abstract Decl Reflection

Expand Down Expand Up @@ -3595,6 +3598,21 @@ namespace slang
{
return (FunctionReflection*)spReflectionFunction_specializeWithArgTypes((SlangReflectionFunction*)this, argCount, (SlangReflectionType* const*)types);
}

bool isOverloaded()
{
return spReflectionFunction_isOverloaded((SlangReflectionFunction*)this);
}

unsigned int getOverloadCount()
{
return spReflectionFunction_getOverloadCount((SlangReflectionFunction*)this);
}

FunctionReflection* getOverload(unsigned int index)
{
return (FunctionReflection*)spReflectionFunction_getOverload((SlangReflectionFunction*)this, index);
}
};

struct GenericReflection
Expand Down
13 changes: 11 additions & 2 deletions source/slang/slang-check-shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,17 @@ namespace Slang
Name* name,
DiagnosticSink* sink)
{
auto declRef = translationUnit->findDeclFromString(getText(name), sink);
FuncDecl* entryPointFuncDecl = declRef.as<FuncDecl>().getDecl();
FuncDecl* entryPointFuncDecl = nullptr;

auto expr = translationUnit->findDeclFromString(getText(name), sink);
if (auto declRefExpr = as<DeclRefExpr>(expr))
{
auto declRef = declRefExpr->declRef;
entryPointFuncDecl = declRef.as<FuncDecl>().getDecl();

if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit)
entryPointFuncDecl = nullptr;
}

if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit)
entryPointFuncDecl = nullptr;
Expand Down
8 changes: 4 additions & 4 deletions source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,11 @@ namespace Slang
String const& typeStr,
DiagnosticSink* sink);

DeclRef<Decl> findDeclFromString(
Expr* findDeclFromString(
String const& name,
DiagnosticSink* sink);

DeclRef<Decl> findDeclFromStringInType(
Expr* findDeclFromStringInType(
Type* type,
String const& name,
LookupMask mask,
Expand Down Expand Up @@ -576,7 +576,7 @@ namespace Slang
Dictionary<String, Type*> m_types;

// Any decls looked up dynamically using `findDeclFromString`.
Dictionary<String, DeclRef<Decl>> m_decls;
Dictionary<String, Expr*> m_decls;

Scope* m_lookupScope = nullptr;
std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal;
Expand Down Expand Up @@ -2174,7 +2174,7 @@ namespace Slang
DiagnosticSink* sink);

DeclRef<Decl> specializeWithArgTypes(
DeclRef<Decl> funcDeclRef,
Expr* funcExpr,
List<Type*> argTypes,
DiagnosticSink* sink);

Expand Down
167 changes: 131 additions & 36 deletions source/slang/slang-reflection-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,33 @@ static inline SlangReflectionVariable* convert(DeclRef<Decl> var)
return (SlangReflectionVariable*) var.declRefBase;
}

static inline DeclRef<FunctionDeclBase> convert(SlangReflectionFunction* func)
static inline DeclRef<FunctionDeclBase> convertToFunc(SlangReflectionFunction* func)
{
DeclRefBase* declBase = (DeclRefBase*)func;
return DeclRef<FunctionDeclBase>(declBase);
NodeBase* nodeBase = (NodeBase*)func;
if (DeclRefBase* declRefBase = as<DeclRefBase>(nodeBase))
{
return DeclRef<FunctionDeclBase>(declRefBase);
}

return DeclRef<FunctionDeclBase>();
}

static inline OverloadedExpr* convertToOverloadedFunc(SlangReflectionFunction* func)
{
NodeBase* nodeBase = (NodeBase*)func;
return as<OverloadedExpr>(nodeBase);
}

static inline SlangReflectionFunction* convert(DeclRef<FunctionDeclBase> func)
{
return (SlangReflectionFunction*)func.declRefBase;
}

static inline SlangReflectionFunction* convert(OverloadedExpr* overloadedFunc)
{
return (SlangReflectionFunction*)overloadedFunc;
}

static inline DeclRef<Decl> convertGenericToDeclRef(SlangReflectionGeneric* func)
{
DeclRefBase* declBase = (DeclRefBase*)func;
Expand Down Expand Up @@ -785,6 +801,27 @@ SLANG_API SlangResult spReflectionType_GetFullName(SlangReflectionType* inType,
return SLANG_OK;
}

SlangReflectionFunction* tryConvertExprToFunctionReflection(ASTBuilder* astBuilder, Expr* expr)
{
if (auto declRefExpr = as<DeclRefExpr>(expr))
{
auto declRef = declRefExpr->declRef;
if (auto genericDeclRef = declRef.as<GenericDecl>())
{
auto innerDeclRef = substituteDeclRef(
SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
}

if (auto funcDeclRef = declRef.as<FunctionDeclBase>())
return convert(funcDeclRef);
}
else if (auto overloadedExpr = as<OverloadedExpr>(expr))
return convert(overloadedExpr);

return nullptr;
}

SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflection* reflection, char const* name)
{
auto programLayout = convert(reflection);
Expand All @@ -800,17 +837,9 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti
auto astBuilder = program->getLinkage()->getASTBuilder();
try
{
auto result = program->findDeclFromString(name, &sink);

if (auto genericDeclRef = result.as<GenericDecl>())
{
auto innerDeclRef = substituteDeclRef(
SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
}

if (auto funcDeclRef = result.as<FunctionDeclBase>())
return convert(funcDeclRef);
return tryConvertExprToFunctionReflection(
astBuilder,
program->findDeclFromString(name, &sink));
}
catch (...)
{
Expand All @@ -828,12 +857,13 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByNameInType(SlangRe
Slang::DiagnosticSink sink(
programLayout->getTargetReq()->getLinkage()->getSourceManager(),
Lexer::sourceLocationLexer);


auto astBuilder = program->getLinkage()->getASTBuilder();

try
{
auto result = program->findDeclFromStringInType(type, name, LookupMask::Function, &sink);
if (auto funcDeclRef = result.as<FunctionDeclBase>())
return convert(funcDeclRef);
return tryConvertExprToFunctionReflection(astBuilder, result);
}
catch (...)
{
Expand All @@ -855,8 +885,11 @@ SLANG_API SlangReflectionVariable* spReflection_FindVarByNameInType(SlangReflect
try
{
auto result = program->findDeclFromStringInType(type, name, LookupMask::Value, &sink);
if (auto varDeclRef = result.as<VarDeclBase>())
return convert(varDeclRef.as<Decl>());
if (auto declRefExpr = as<DeclRefExpr>(result))
{
if (auto varDeclRef = declRefExpr->declRef.as<VarDeclBase>())
return convert(varDeclRef.as<Decl>());
}
}
catch (...)
{
Expand Down Expand Up @@ -3009,21 +3042,23 @@ SLANG_API SlangStage spReflectionVariableLayout_getStage(

SLANG_API SlangReflectionDecl* spReflectionFunction_asDecl(SlangReflectionFunction* inFunc)
{
auto func = convert(inFunc);
auto func = convertToFunc(inFunc);
if (!func) return nullptr;

return (SlangReflectionDecl*)func.getDecl();
}

SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* inFunc)
{
auto func = convert(inFunc);
auto func = convertToFunc(inFunc);
if (!func) return nullptr;

return getText(func.getDecl()->getName()).getBuffer();
}

SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* inFunc)
{
auto func = convert(inFunc);
auto func = convertToFunc(inFunc);
if (!func) return nullptr;

auto rawType = func.getDecl()->returnType.type;
Expand All @@ -3034,7 +3069,9 @@ SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectio

SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* inFunc, SlangModifierID modifierID)
{
auto funcDeclRef = convert(inFunc);
auto funcDeclRef = convertToFunc(inFunc);
if (!funcDeclRef) return nullptr;

auto varRefl = convert(funcDeclRef.as<Decl>());
if (!varRefl) return nullptr;

Expand All @@ -3043,35 +3080,38 @@ SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflec

SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* inFunc)
{
auto func = convert(inFunc);
auto func = convertToFunc(inFunc);
if (!func) return 0;

return getUserAttributeCount(func.getDecl());
}

SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* inFunc, unsigned int index)
{
auto func = convert(inFunc);
auto func = convertToFunc(inFunc);
if (!func) return nullptr;
return getUserAttributeByIndex(func.getDecl(), index);
}

SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* inFunc, SlangSession* session, char const* name)
{
auto func = convert(inFunc);
auto func = convertToFunc(inFunc);
if (!func) return nullptr;

return findUserAttributeByName(asInternal(session), func.getDecl(), name);
}

SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* inFunc)
{
auto func = convert(inFunc);
auto func = convertToFunc(inFunc);
if (!func) return 0;

return (unsigned int)func.getDecl()->getParameters().getCount();
}

SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* inFunc, unsigned int index)
{
auto func = convert(inFunc);
auto func = convertToFunc(inFunc);
if (!func) return nullptr;

auto astBuilder = getModule(func.getDecl())->getLinkage()->getASTBuilder();
Expand All @@ -3081,13 +3121,16 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec

SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func)
{
auto declRef = convert(func);
auto declRef = convertToFunc(func);
if (!declRef)
return nullptr;

return convertDeclToGeneric(getInnermostGenericParent(declRef));
}

SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic)
{
auto declRef = convert(func);
auto declRef = convertToFunc(func);
auto genericDeclRef = convertGenericToDeclRef(generic);
if (!declRef || !genericDeclRef)
return nullptr;
Expand All @@ -3103,12 +3146,25 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
SlangInt argTypeCount,
SlangReflectionType* const* argTypes)
{
auto declRef = convert(func);
if (!declRef)
Linkage* linkage = nullptr;
Expr* funcExpr = nullptr;

if (auto funcDeclRef = convertToFunc(func))
{
linkage = getModule(funcDeclRef.getDecl())->getLinkage();
auto declRefExpr = linkage->getASTBuilder()->create<DeclRefExpr>();
declRefExpr->declRef = funcDeclRef;
funcExpr = declRefExpr;
}
else if (auto overloadedExpr = convertToOverloadedFunc(func))
{
linkage = getModule(overloadedExpr->lookupResult2.items[0].declRef.getDecl())->getLinkage();
funcExpr = overloadedExpr;
}
else
{
return nullptr;


auto linkage = getModule(declRef.getDecl())->getLinkage();
}

List<Type*> argTypeList;
for (SlangInt ii = 0; ii < argTypeCount; ++ii)
Expand All @@ -3120,14 +3176,53 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
try
{
DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as<FunctionDeclBase>());
return convert(linkage->specializeWithArgTypes(funcExpr, argTypeList, &sink).as<FunctionDeclBase>());
}
catch (...)
{
return nullptr;
}
}

SLANG_API bool spReflectionFunction_isOverloaded(
SlangReflectionFunction* func)
{
return (convertToOverloadedFunc(func) != nullptr);
}

SLANG_API unsigned int spReflectionFunction_getOverloadCount(
SlangReflectionFunction* func)
{
auto overloadedFunc = convertToOverloadedFunc(func);
if (!overloadedFunc) return 1;

return (unsigned int) overloadedFunc->lookupResult2.items.getCount();
}

SLANG_API SlangReflectionFunction* spReflectionFunction_getOverload(
SlangReflectionFunction* func,
unsigned int index)
{
auto overloadedFunc = convertToOverloadedFunc(func);
if (!overloadedFunc) return nullptr;

auto declRef = overloadedFunc->lookupResult2.items[index].declRef;
if (auto funcDeclRef = declRef.as<FunctionDeclBase>())
{
return convert(declRef.as<FunctionDeclBase>());
}
else if (auto genericDeclRef = declRef.as<GenericDecl>())
{
auto astBuilder = getModule(genericDeclRef.getDecl())->getLinkage()->getASTBuilder();
auto innerDeclRef = substituteDeclRef(
SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
return convert(
createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef).as<FunctionDeclBase>());
}

return nullptr;
}

// Abstract decl reflection

SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl)
Expand Down
Loading
Loading