Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API method to specialize function reference with argument types #4966

Merged
Merged
6 changes: 6 additions & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -2589,6 +2589,7 @@ extern "C"
SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func);
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func);
SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic);
SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(SlangReflectionFunction* func, SlangInt argTypeCount, SlangReflectionType* const* argTypes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does apply specializations do? Is it just get default specialized DeclRef? This imo is an internal detail in generic representation and should not be exposed in the api. What are the scenarios that this is needed from the user?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is that the generic parameters & their arguments (if any) are represented through a SlangReflectionGeneric. The fact that the generic is represented as a parent in the hierarchy is not directly exposed to the user. It's just that every SlangReflectionFunction/SlangReflectionType/SlangReflectionVariable declaration may have an associated generic which can be obtained using getGenericContainer. Further, each declaration can be nested inside another one, so you can traverse all generic parameters affecting a reflected function/variable/type by using getOuterGenericContainer repeatedly.

In our case, we're actually not referencing the parent generic decl directly, the decl-ref for SlangReflectionFunction and SlangReflectionGeneric are actually the same decl-ref, but they expose different methods (the generic reflection is exclusively for looking through generic parameters, and is common to functions, variables & types)

Each SlangReflectionGeneric effectively represents a set of specializations. applySpecializations creates a specialized decl-ref out of an unspecialized one by applying the bag of specializations to the referenced decl-ref.
In this case, I can lookup an unspecialized function, extract the generic-reflection, set some parameters on the generic-reflection, then apply it to the function-reflection to get the specialized version.

I couldn't come up with a sleeker way to do this, given that we can have many levels of generics, all of which may or may not already have specializations. This idea of putting specializations behind a special generic-reflection is to just avoid having to duplicate the generic-specific methods for all the different types of reflections.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. As long as we make sure the API doesn't assume generic decl is the parent of the decl itself, and just treat generic as a different aspect of the same decl, then we should be fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be careful on the ast - children enumeration logic to not list the generic decl as the children, but instead the inner decl.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is currently the case. If the decl-walking API hits a generic decl, the user must cast it to SlangReflectionGeneric and call getInnerDecl() to get the inner node.


// Abstract Decl Reflection

Expand Down Expand Up @@ -3587,6 +3588,11 @@ namespace slang
{
return (FunctionReflection*)spReflectionFunction_applySpecializations((SlangReflectionFunction*)this, (SlangReflectionGeneric*)generic);
}

FunctionReflection* specializeWithArgTypes(unsigned int argCount, TypeReflection* const* types)
{
return (FunctionReflection*)spReflectionFunction_specializeWithArgTypes((SlangReflectionFunction*)this, argCount, (SlangReflectionType* const*)types);
}
};

struct GenericReflection
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ namespace Slang
};

/// Shared state for a semantics-checking session.
struct SharedSemanticsContext
struct SharedSemanticsContext : public RefObject
{
Linkage* m_linkage = nullptr;

Expand Down
14 changes: 13 additions & 1 deletion source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
namespace Slang
{
struct PathInfo;
struct IncludeHandler;
struct IncludeHandler;
struct SharedSemanticsContext;

class ProgramLayout;
class PtrType;
class TargetProgram;
Expand Down Expand Up @@ -2170,6 +2172,11 @@ namespace Slang
DeclRef<Decl> declRef,
List<Expr*> argExprs,
DiagnosticSink* sink);

DeclRef<Decl> specializeWithArgTypes(
DeclRef<Decl> funcDeclRef,
List<Type*> argTypes,
DiagnosticSink* sink);

DiagnosticSink::Flags diagnosticSinkFlags = 0;

Expand All @@ -2183,6 +2190,9 @@ namespace Slang
m_retainedSession = nullptr;
}

// Get shared semantics information for reflection purposes.
SharedSemanticsContext* getSemanticsForReflection();

private:
/// The global Slang library session that this linkage is a child of
Session* m_session = nullptr;
Expand Down Expand Up @@ -2236,6 +2246,8 @@ namespace Slang

List<Type*> m_specializedTypes;

RefPtr<SharedSemanticsContext> m_semanticsForReflection;

};

/// Shared functionality between front- and back-end compile requests.
Expand Down
69 changes: 55 additions & 14 deletions source/slang/slang-reflection-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,18 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti
programLayout->getTargetReq()->getLinkage()->getSourceManager(),
Lexer::sourceLocationLexer);

auto astBuilder = program->getLinkage()->getASTBuilder();
try
{
auto result = program->findDeclFromString(name, &sink);

if (auto genericDeclRef = result.as<GenericDecl>())
{
auto innerDeclRef = substituteDeclRef(
SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
}

if (auto funcDeclRef = result.as<FunctionDeclBase>())
return convert(funcDeclRef);
}
Expand Down Expand Up @@ -924,35 +933,36 @@ SLANG_API bool spReflection_isSubType(
}
}

SlangReflectionGeneric* getInnermostGenericParent(DeclRef<Decl> declRef)
DeclRef<Decl> getInnermostGenericParent(DeclRef<Decl> declRef)
{
auto decl = declRef.getDecl();
auto astBuilder = getModule(decl)->getLinkage()->getASTBuilder();
auto parentDecl = decl;
while(parentDecl)
{
if(parentDecl->parentDecl && as<GenericDecl>(parentDecl->parentDecl))
return convertDeclToGeneric(
substituteDeclRef(
return substituteDeclRef(
SubstitutionSet(declRef),
astBuilder,
createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl))));
createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl)));
parentDecl = parentDecl->parentDecl;
}

return nullptr;
return DeclRef<Decl>();
}

SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type)
{
auto slangType = convert(type);
if (auto declRefType = as<DeclRefType>(slangType))
{
return getInnermostGenericParent(declRefType->getDeclRef());
return convertDeclToGeneric(
getInnermostGenericParent(declRefType->getDeclRef()));
}
else if (auto genericDeclRefType = as<GenericDeclRefType>(slangType))
{
return getInnermostGenericParent(genericDeclRefType->getDeclRef());
return convertDeclToGeneric(
getInnermostGenericParent(genericDeclRefType->getDeclRef()));
}

return nullptr;
Expand Down Expand Up @@ -2835,7 +2845,7 @@ SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inV
SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var)
{
auto declRef = convert(var);
return getInnermostGenericParent(declRef);
return convertDeclToGeneric(getInnermostGenericParent(declRef));
}

SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic)
Expand Down Expand Up @@ -3072,7 +3082,7 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func)
{
auto declRef = convert(func);
return getInnermostGenericParent(declRef);
return convertDeclToGeneric(getInnermostGenericParent(declRef));
}

SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic)
Expand All @@ -3088,6 +3098,36 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(Sla
return convert(substDeclRef.as<FunctionDeclBase>());
}

SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
SlangReflectionFunction* func,
SlangInt argTypeCount,
SlangReflectionType* const* argTypes)
{
auto declRef = convert(func);
if (!declRef)
return nullptr;


auto linkage = getModule(declRef.getDecl())->getLinkage();

List<Type*> argTypeList;
for (SlangInt ii = 0; ii < argTypeCount; ++ii)
{
auto argType = convert(argTypes[ii]);
argTypeList.add(argType);
}

try
{
DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as<FunctionDeclBase>());
}
catch (...)
{
return nullptr;
}
}

// Abstract decl reflection

SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl)
Expand Down Expand Up @@ -3329,11 +3369,12 @@ SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(S

auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder();

return getInnermostGenericParent(
substituteDeclRef(
SubstitutionSet(declRef),
astBuilder,
createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl))));
return convertDeclToGeneric(
getInnermostGenericParent(
substituteDeclRef(
SubstitutionSet(declRef),
astBuilder,
createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl)))));
}

SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam)
Expand Down
100 changes: 67 additions & 33 deletions source/slang/slang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "slang-type-layout.h"
#include "slang-lookup.h"

#
#include "slang-options.h"

#include "slang-repro.h"
Expand Down Expand Up @@ -1069,8 +1068,12 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka
for (const auto& nameToMod : builtinLinkage->mapNameToLoadedModules)
mapNameToLoadedModules.add(nameToMod);
}

m_semanticsForReflection = new SharedSemanticsContext(this, nullptr, nullptr);
}

SharedSemanticsContext* Linkage::getSemanticsForReflection() { return m_semanticsForReflection.get(); }

ISlangUnknown* Linkage::getInterface(const Guid& guid)
{
if(guid == ISlangUnknown::getTypeGuid() || guid == ISession::getTypeGuid())
Expand Down Expand Up @@ -1348,18 +1351,11 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType(
return asExternal(specializedType);
}


DeclRef<Decl> Linkage::specializeGeneric(
DeclRef<Decl> declRef,
List<Expr*> argExprs,
DiagnosticSink* sink)
DeclRef<GenericDecl> getGenericParentDeclRef(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can just use getInnermostGenericParent?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit different unfortunately.
This function returns the decl-ref of the generic-decl while getInnermostGenericParent returns the decl-ref of the inner decl of the closest generic-decl (this is so that substitutions are available)

ASTBuilder* astBuilder,
SemanticsVisitor* visitor,
DeclRef<Decl> declRef)
{
SLANG_AST_BUILDER_RAII(getASTBuilder());
SLANG_ASSERT(declRef);

SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink);
SemanticsVisitor visitor(&sharedSemanticsContext);

// Create substituted parent decl ref.
auto decl = declRef.getDecl();

Expand All @@ -1369,9 +1365,58 @@ DeclRef<Decl> Linkage::specializeGeneric(
}

auto genericDecl = as<GenericDecl>(decl);
auto genericDeclRef = createDefaultSubstitutionsIfNeeded(getASTBuilder(), &visitor, DeclRef(genericDecl)).as<GenericDecl>();
genericDeclRef = substituteDeclRef(SubstitutionSet(declRef), getASTBuilder(), genericDeclRef).as<GenericDecl>();
auto genericDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, DeclRef(genericDecl)).as<GenericDecl>();
return substituteDeclRef(SubstitutionSet(declRef), astBuilder, genericDeclRef).as<GenericDecl>();
}

DeclRef<Decl> Linkage::specializeWithArgTypes(
DeclRef<Decl> funcDeclRef,
List<Type*> argTypes,
DiagnosticSink* sink)
{
SemanticsVisitor visitor(getSemanticsForReflection());
visitor = visitor.withSink(sink);

ASTBuilder* astBuilder = getASTBuilder();

List<Expr*> argExprs;
for (SlangInt aa = 0; aa < argTypes.getCount(); ++aa)
{
auto argType = argTypes[aa];

// Create an 'empty' expr with the given type. Ideally, the expression itself should not matter
// only its checked type.
//
auto argExpr = astBuilder->create<VarExpr>();
argExpr->type = argType;
argExprs.add(argExpr);
}

// Construct invoke expr.
auto invokeExpr = astBuilder->create<InvokeExpr>();
auto declRefExpr = astBuilder->create<DeclRefExpr>();

declRefExpr->declRef = getGenericParentDeclRef(getASTBuilder(), &visitor, funcDeclRef);
invokeExpr->functionExpr = declRefExpr;
invokeExpr->arguments = argExprs;

auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr);
return as<DeclRefExpr>(as<InvokeExpr>(checkedInvokeExpr)->functionExpr)->declRef;
}


DeclRef<Decl> Linkage::specializeGeneric(
DeclRef<Decl> declRef,
List<Expr*> argExprs,
DiagnosticSink* sink)
{
SLANG_AST_BUILDER_RAII(getASTBuilder());
SLANG_ASSERT(declRef);

SemanticsVisitor visitor(getSemanticsForReflection());
visitor = visitor.withSink(sink);

auto genericDeclRef = getGenericParentDeclRef(getASTBuilder(), &visitor, declRef);

DeclRefExpr* declRefExpr = getASTBuilder()->create<DeclRefExpr>();
declRefExpr->declRef = genericDeclRef;
Expand Down Expand Up @@ -1561,8 +1606,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentTy

try
{
SharedSemanticsContext sharedSemanticsContext(this, nullptr, &sink);
SemanticsVisitor visitor(&sharedSemanticsContext);
SemanticsVisitor visitor(getSemanticsForReflection());
visitor = visitor.withSink(&sink);

auto witness =
visitor.isSubtype((Slang::Type*)type, (Slang::Type*)interfaceType, IsSubTypeOptions::None);
if (auto subtypeWitness = as<SubtypeWitness>(witness))
Expand Down Expand Up @@ -2318,12 +2364,8 @@ DeclRef<Decl> ComponentType::findDeclFromString(

Expr* expr = linkage->parseTermString(name, scope);

SharedSemanticsContext sharedSemanticsContext(
linkage,
nullptr,
sink);
SemanticsContext context(&sharedSemanticsContext);
context = context.allowStaticReferenceToNonStaticMember();
SemanticsContext context(linkage->getSemanticsForReflection());
context = context.allowStaticReferenceToNonStaticMember().withSink(sink);

SemanticsVisitor visitor(context);

Expand Down Expand Up @@ -2377,12 +2419,8 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType(

Expr* expr = linkage->parseTermString(name, scope);

SharedSemanticsContext sharedSemanticsContext(
linkage,
nullptr,
sink);
SemanticsContext context(&sharedSemanticsContext);
context = context.allowStaticReferenceToNonStaticMember();
SemanticsContext context(linkage->getSemanticsForReflection());
context = context.allowStaticReferenceToNonStaticMember().withSink(sink);

SemanticsVisitor visitor(context);

Expand Down Expand Up @@ -2433,11 +2471,7 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType(

bool ComponentType::isSubType(Type* subType, Type* superType)
{
SharedSemanticsContext sharedSemanticsContext(
getLinkage(),
nullptr,
nullptr);
SemanticsContext context(&sharedSemanticsContext);
SemanticsContext context(getLinkage()->getSemanticsForReflection());
SemanticsVisitor visitor(context);

return (visitor.isSubtype(subType, superType, IsSubTypeOptions::None) != nullptr);
Expand Down
Loading
Loading