diff --git a/clang/include/clang/Basic/LangOptions.h b/clang/include/clang/Basic/LangOptions.h index 569584bcc2297..a8943df5b39aa 100644 --- a/clang/include/clang/Basic/LangOptions.h +++ b/clang/include/clang/Basic/LangOptions.h @@ -552,6 +552,10 @@ class LangOptions : public LangOptionsBase { llvm::dxbc::RootSignatureVersion HLSLRootSigVer = llvm::dxbc::RootSignatureVersion::V1_1; + /// The HLSL root signature that will be used to overide the root signature + /// used for the shader entry point. + std::string HLSLRootSigOverride; + // Indicates if the wasm-opt binary must be ignored in the case of a // WebAssembly target. bool NoWasmOpt = false; diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td index 6a46fec1701f3..7313b360f521a 100644 --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -9436,6 +9436,18 @@ def dxc_rootsig_ver : Alias, Group, Visibility<[DXCOption]>; +def fdx_rootsignature_define : + Joined<["-"], "fdx-rootsignature-define=">, + Group, + Visibility<[ClangOption, CC1Option]>, + MarshallingInfoString, "\"\"">, + HelpText<"Override entry function root signature with root signature at " + "given macro name.">; +def dxc_rootsig_define : + Separate<["-"], "rootsig-define">, + Alias, + Group, + Visibility<[DXCOption]>; def hlsl_entrypoint : Option<["-"], "hlsl-entry", KIND_SEPARATE>, Group, Visibility<[ClangOption, CC1Option]>, diff --git a/clang/include/clang/Frontend/FrontendActions.h b/clang/include/clang/Frontend/FrontendActions.h index a5dfb770c58a2..73308c004bd23 100644 --- a/clang/include/clang/Frontend/FrontendActions.h +++ b/clang/include/clang/Frontend/FrontendActions.h @@ -329,6 +329,18 @@ class GetDependenciesByModuleNameAction : public PreprocessOnlyAction { : ModuleName(ModuleName) {} }; +//===----------------------------------------------------------------------===// +// HLSL Specific Actions +//===----------------------------------------------------------------------===// + +class HLSLFrontendAction : public WrapperFrontendAction { +protected: + void ExecuteAction() override; + +public: + HLSLFrontendAction(std::unique_ptr WrappedAction); +}; + } // end namespace clang #endif diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index a49bdfd51fbee..c87e6637c7fce 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -236,6 +236,10 @@ class RootSignatureParser { RootSignatureToken CurToken; }; +IdentifierInfo *ParseHLSLRootSignature(Sema &Actions, + llvm::dxbc::RootSignatureVersion Version, + StringLiteral *Signature); + } // namespace hlsl } // namespace clang diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 016456f241eed..5cbe1b658f5cd 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -153,6 +153,10 @@ class SemaHLSL : public SemaBase { ActOnFinishRootSignatureDecl(SourceLocation Loc, IdentifierInfo *DeclIdent, ArrayRef Elements); + void SetRootSignatureOverride(IdentifierInfo *DeclIdent) { + RootSigOverrideIdent = DeclIdent; + } + // Returns true if any RootSignatureElement is invalid and a diagnostic was // produced bool @@ -221,6 +225,8 @@ class SemaHLSL : public SemaBase { uint32_t ImplicitBindingNextOrderID = 0; + IdentifierInfo *RootSigOverrideIdent = nullptr; + private: void collectResourceBindingsOnVarDecl(VarDecl *D); void collectResourceBindingsOnUserRecordDecl(const VarDecl *VD, diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp index 677d8bc82cb0a..7140b817d0a32 100644 --- a/clang/lib/CodeGen/CodeGenModule.cpp +++ b/clang/lib/CodeGen/CodeGenModule.cpp @@ -7534,6 +7534,9 @@ void CodeGenModule::EmitTopLevelDecl(Decl *D) { getContext().getCanonicalTagType(cast(D))); break; + case Decl::HLSLRootSignature: + // Will be handled by attached function + break; case Decl::HLSLBuffer: getHLSLRuntime().addBuffer(cast(D)); break; diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp index e8181dca59c17..1b44090534e82 100644 --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -3801,6 +3801,7 @@ static void RenderHLSLOptions(const ArgList &Args, ArgStringList &CmdArgs, options::OPT_disable_llvm_passes, options::OPT_fnative_half_type, options::OPT_hlsl_entrypoint, + options::OPT_fdx_rootsignature_define, options::OPT_fdx_rootsignature_version}; if (!types::isHLSL(InputType)) return; diff --git a/clang/lib/Driver/ToolChains/HLSL.cpp b/clang/lib/Driver/ToolChains/HLSL.cpp index eaa7d736719b5..5a0ed779262e9 100644 --- a/clang/lib/Driver/ToolChains/HLSL.cpp +++ b/clang/lib/Driver/ToolChains/HLSL.cpp @@ -351,6 +351,13 @@ HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch, A->claim(); continue; } + if (A->getOption().getID() == options::OPT_dxc_rootsig_define) { + DAL->AddJoinedArg(nullptr, + Opts.getOption(options::OPT_fdx_rootsignature_define), + A->getValue()); + A->claim(); + continue; + } if (A->getOption().getID() == options::OPT__SLASH_O) { StringRef OStr = A->getValue(); if (OStr == "d") { diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp index da96352e1d82c..29f9cf3a7f0e3 100644 --- a/clang/lib/Frontend/CompilerInvocation.cpp +++ b/clang/lib/Frontend/CompilerInvocation.cpp @@ -640,6 +640,10 @@ static bool FixupInvocation(CompilerInvocation &Invocation, Diags.Report(diag::err_drv_argument_not_allowed_with) << "-fdx-rootsignature-version" << GetInputKindName(IK); + if (Args.hasArg(OPT_fdx_rootsignature_define) && !LangOpts.HLSL) + Diags.Report(diag::err_drv_argument_not_allowed_with) + << "-fdx-rootsignature-define" << GetInputKindName(IK); + if (Args.hasArg(OPT_fgpu_allow_device_init) && !LangOpts.HIP) Diags.Report(diag::warn_ignored_hip_only_option) << Args.getLastArg(OPT_fgpu_allow_device_init)->getAsString(Args); diff --git a/clang/lib/Frontend/FrontendActions.cpp b/clang/lib/Frontend/FrontendActions.cpp index 685a9bbf2cde9..ccda2c4ce4b6d 100644 --- a/clang/lib/Frontend/FrontendActions.cpp +++ b/clang/lib/Frontend/FrontendActions.cpp @@ -22,6 +22,7 @@ #include "clang/Lex/HeaderSearch.h" #include "clang/Lex/Preprocessor.h" #include "clang/Lex/PreprocessorOptions.h" +#include "clang/Parse/ParseHLSLRootSignature.h" #include "clang/Sema/TemplateInstCallback.h" #include "clang/Serialization/ASTReader.h" #include "clang/Serialization/ASTWriter.h" @@ -1241,3 +1242,85 @@ void GetDependenciesByModuleNameAction::ExecuteAction() { PPCallbacks *CB = PP.getPPCallbacks(); CB->moduleImport(SourceLocation(), Path, ModResult); } + +//===----------------------------------------------------------------------===// +// HLSL Specific Actions +//===----------------------------------------------------------------------===// + +class InjectRootSignatureCallback : public PPCallbacks { +private: + Sema &Actions; + StringRef RootSigName; + llvm::dxbc::RootSignatureVersion Version; + + std::optional processStringLiteral(ArrayRef Tokens) { + for (Token Tok : Tokens) + if (!tok::isStringLiteral(Tok.getKind())) + return std::nullopt; + + ExprResult StringResult = Actions.ActOnUnevaluatedStringLiteral(Tokens); + if (StringResult.isInvalid()) + return std::nullopt; + + if (auto Signature = dyn_cast(StringResult.get())) + return Signature; + + return std::nullopt; + } + +public: + void MacroDefined(const Token &MacroNameTok, + const MacroDirective *MD) override { + if (RootSigName != MacroNameTok.getIdentifierInfo()->getName()) + return; + + const MacroInfo *MI = MD->getMacroInfo(); + auto Signature = processStringLiteral(MI->tokens()); + if (!Signature.has_value()) { + Actions.getDiagnostics().Report(MI->getDefinitionLoc(), + diag::err_expected_string_literal) + << /*in attributes...*/ 4 << "RootSignature"; + return; + } + + IdentifierInfo *DeclIdent = + hlsl::ParseHLSLRootSignature(Actions, Version, *Signature); + Actions.HLSL().SetRootSignatureOverride(DeclIdent); + } + + InjectRootSignatureCallback(Sema &Actions, StringRef RootSigName, + llvm::dxbc::RootSignatureVersion Version) + : PPCallbacks(), Actions(Actions), RootSigName(RootSigName), + Version(Version) {} +}; + +void HLSLFrontendAction::ExecuteAction() { + // Pre-requisites to invoke + CompilerInstance &CI = getCompilerInstance(); + if (!CI.hasASTContext() || !CI.hasPreprocessor()) + return WrapperFrontendAction::ExecuteAction(); + + // InjectRootSignatureCallback requires access to invoke Sema to lookup/ + // register a root signature declaration. The wrapped action is required to + // account for this by only creating a Sema if one doesn't already exist + // (like we have done, and, ASTFrontendAction::ExecuteAction) + if (!CI.hasSema()) + CI.createSema(getTranslationUnitKind(), + /*CodeCompleteConsumer=*/nullptr); + Sema &S = CI.getSema(); + + // Register HLSL specific callbacks + auto LangOpts = CI.getLangOpts(); + auto MacroCallback = std::make_unique( + S, LangOpts.HLSLRootSigOverride, LangOpts.HLSLRootSigVer); + + Preprocessor &PP = CI.getPreprocessor(); + PP.addPPCallbacks(std::move(MacroCallback)); + + // Invoke as normal + WrapperFrontendAction::ExecuteAction(); +} + +HLSLFrontendAction::HLSLFrontendAction( + std::unique_ptr WrappedAction) + : WrapperFrontendAction(std::move(WrappedAction)) {} diff --git a/clang/lib/FrontendTool/ExecuteCompilerInvocation.cpp b/clang/lib/FrontendTool/ExecuteCompilerInvocation.cpp index 443eb4f1a29bf..9a6844d5f7d40 100644 --- a/clang/lib/FrontendTool/ExecuteCompilerInvocation.cpp +++ b/clang/lib/FrontendTool/ExecuteCompilerInvocation.cpp @@ -181,6 +181,9 @@ CreateFrontendAction(CompilerInstance &CI) { const FrontendOptions &FEOpts = CI.getFrontendOpts(); + if (CI.getLangOpts().HLSL) + Act = std::make_unique(std::move(Act)); + if (FEOpts.FixAndRecompile) { Act = std::make_unique(std::move(Act)); } diff --git a/clang/lib/Parse/ParseDeclCXX.cpp b/clang/lib/Parse/ParseDeclCXX.cpp index 005ad524605ff..8135f4f603907 100644 --- a/clang/lib/Parse/ParseDeclCXX.cpp +++ b/clang/lib/Parse/ParseDeclCXX.cpp @@ -4944,33 +4944,20 @@ void Parser::ParseHLSLRootSignatureAttributeArgs(ParsedAttributes &Attrs) { return std::nullopt; }; - auto StrLiteral = ProcessStringLiteral(); - if (!StrLiteral.has_value()) { + auto Signature = ProcessStringLiteral(); + if (!Signature.has_value()) { Diag(Tok, diag::err_expected_string_literal) - << /*in attributes...*/ 4 << RootSignatureIdent->getName(); - SkipUntil(tok::r_paren, StopAtSemi | StopBeforeMatch); - T.consumeClose(); + << /*in attributes...*/ 4 << "RootSignature"; return; } // Construct our identifier - StringLiteral *Signature = StrLiteral.value(); - auto [DeclIdent, Found] = - Actions.HLSL().ActOnStartRootSignatureDecl(Signature->getString()); - // If we haven't found an already defined DeclIdent then parse the root - // signature string and construct the in-memory elements - if (!Found) { - // Invoke the root signature parser to construct the in-memory constructs - hlsl::RootSignatureParser Parser(getLangOpts().HLSLRootSigVer, Signature, - PP); - if (Parser.parse()) { - T.consumeClose(); - return; - } - - // Construct the declaration. - Actions.HLSL().ActOnFinishRootSignatureDecl(RootSignatureLoc, DeclIdent, - Parser.getElements()); + IdentifierInfo *DeclIdent = hlsl::ParseHLSLRootSignature( + Actions, getLangOpts().HLSLRootSigVer, *Signature); + if (!DeclIdent) { + SkipUntil(tok::r_paren, StopAtSemi | StopBeforeMatch); + T.consumeClose(); + return; } // Create the arg for the ParsedAttr diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index 5490c61f52356..1af72f8b1c934 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -9,6 +9,7 @@ #include "clang/Parse/ParseHLSLRootSignature.h" #include "clang/Lex/LiteralSupport.h" +#include "clang/Sema/Sema.h" using namespace llvm::hlsl::rootsig; @@ -1448,5 +1449,28 @@ SourceLocation RootSignatureParser::getTokenLocation(RootSignatureToken Tok) { PP.getLangOpts(), PP.getTargetInfo()); } +IdentifierInfo *ParseHLSLRootSignature(Sema &Actions, + llvm::dxbc::RootSignatureVersion Version, + StringLiteral *Signature) { + // Construct our identifier + auto [DeclIdent, Found] = + Actions.HLSL().ActOnStartRootSignatureDecl(Signature->getString()); + // If we haven't found an already defined DeclIdent then parse the root + // signature string and construct the in-memory elements + if (!Found) { + // Invoke the root signature parser to construct the in-memory constructs + hlsl::RootSignatureParser Parser(Version, Signature, + Actions.getPreprocessor()); + if (Parser.parse()) + return nullptr; + + // Construct the declaration. + Actions.HLSL().ActOnFinishRootSignatureDecl( + Signature->getBeginLoc(), DeclIdent, Parser.getElements()); + } + + return DeclIdent; +} + } // namespace hlsl } // namespace clang diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index f87715950c74c..29e092156010d 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -729,6 +729,23 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) return; + // If we have specified a root signature to override the entry function then + // attach it now + if (RootSigOverrideIdent) { + LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(), + Sema::LookupOrdinaryName); + if (SemaRef.LookupQualifiedName(R, FD->getDeclContext())) + if (auto *SignatureDecl = + dyn_cast(R.getFoundDecl())) { + FD->dropAttr(); + // We could look up the SourceRange of the macro here as well + AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(), + SourceRange(), ParsedAttr::Form::Microsoft()); + FD->addAttr(::new (getASTContext()) RootSignatureAttr( + getASTContext(), AL, RootSigOverrideIdent, SignatureDecl)); + } + } + llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment(); if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) { if (const auto *Shader = FD->getAttr()) { diff --git a/clang/test/AST/HLSL/rootsignature-define-ast.hlsl b/clang/test/AST/HLSL/rootsignature-define-ast.hlsl new file mode 100644 index 0000000000000..9c17cbc9ad2eb --- /dev/null +++ b/clang/test/AST/HLSL/rootsignature-define-ast.hlsl @@ -0,0 +1,62 @@ +// Establish a baseline without define specified +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -ast-dump \ +// RUN: -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,NO-OVERRIDE + +// Check that we can set the entry function even if it doesn't have an attr +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -ast-dump \ +// RUN: -hlsl-entry none_main -fdx-rootsignature-define=SampleCBV \ +// RUN: -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,SET + +// Check that we can set the entry function overriding an attr +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -ast-dump \ +// RUN: -hlsl-entry uav_main -fdx-rootsignature-define=SampleCBV \ +// RUN: -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,OVERRIDE + +// Check that we can override with a command line root signature +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -ast-dump \ +// RUN: -hlsl-entry cbv_main -fdx-rootsignature-define=CmdRS -DCmdRS='"SRV(t0)"' \ +// RUN: -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CMD + +#define SampleCBV "CBV(b0)" +#define SampleUAV "UAV(u0)" + +// CMD: -HLSLRootSignatureDecl 0x{{.*}} {{.*}} implicit [[CMD_DECL:__hlsl_rootsig_decl_\d*]] +// CMD-SAME: version: 1.1, RootElements{ +// CMD-SAME: RootSRV(t0, +// CMD-SAME: space = 0, visibility = All, flags = DataStaticWhileSetAtExecute +// CMD-SAME: )} + +// CHECK: -HLSLRootSignatureDecl 0x{{.*}} {{.*}} implicit [[CBV_DECL:__hlsl_rootsig_decl_\d*]] +// CHECK-SAME: version: 1.1, RootElements{ +// CHECK-SAME: RootCBV(b0, +// CHECK-SAME: space = 0, visibility = All, flags = DataStaticWhileSetAtExecute +// CHECK-SAME: )} + +// CHECK-LABEL: -FunctionDecl 0x{{.*}} {{.*}} cbv_main +// NO-OVERRIDE: -RootSignatureAttr 0x{{.*}} {{.*}} [[CBV_DECL]] +// SET: -RootSignatureAttr 0x{{.*}} {{.*}} [[CBV_DECL]] +// CMD: -RootSignatureAttr 0x{{.*}} {{.*}} [[CMD_DECL]] + +[RootSignature(SampleCBV)] +void cbv_main() {} + +// CHECK: -HLSLRootSignatureDecl 0x{{.*}} {{.*}} implicit [[UAV_DECL:__hlsl_rootsig_decl_\d*]] +// CHECK-SAME: version: 1.1, RootElements{ +// CHECK-SAME: RootUAV(u0, +// CHECK-SAME: space = 0, visibility = All, flags = DataVolatile +// CHECK-SAME: )} + +// CHECK-LABEL: -FunctionDecl 0x{{.*}} {{.*}} uav_main +// NO-OVERRIDE: -RootSignatureAttr 0x{{.*}} {{.*}} [[UAV_DECL]] +// SET: -RootSignatureAttr 0x{{.*}} {{.*}} [[UAV_DECL]] +// OVERRIDE: -RootSignatureAttr 0x{{.*}} {{.*}} [[CBV_DECL]] + +[RootSignature(SampleUAV)] +void uav_main() {} + +// CHECK-LABEL: -FunctionDecl 0x{{.*}} {{.*}} none_main +// NO-OVERRIDE-NONE: -RootSignatureAttr +// SET: -RootSignatureAttr 0x{{.*}} {{.*}} [[CBV_DECL]] +// OVERRIDE-NONE: -RootSignatureAttr + +void none_main() {} diff --git a/clang/test/Driver/dxc_rootsig-define.hlsl b/clang/test/Driver/dxc_rootsig-define.hlsl new file mode 100644 index 0000000000000..40c3e127f94d5 --- /dev/null +++ b/clang/test/Driver/dxc_rootsig-define.hlsl @@ -0,0 +1,33 @@ +// RUN: %clang_dxc -T cs_6_0 -fcgl %s | FileCheck %s --check-prefixes=CHECK,REG +// RUN: %clang_dxc -T cs_6_0 -fcgl -rootsig-define EmptyRS %s | FileCheck %s --check-prefixes=CHECK,EMPTY +// RUN: %clang_dxc -T cs_6_0 -fcgl -rootsig-define CmdRS -D CmdRS='"SRV(t0)"' %s | FileCheck %s --check-prefixes=CHECK,CMD + +// Equivalent clang checks: +// RUN: %clang -target dxil-unknown-shadermodel6.0-compute -S -emit-llvm -o - %s \ +// RUN: | FileCheck %s --check-prefixes=CHECK,REG + +// RUN: %clang -target dxil-unknown-shadermodel6.0-compute -S -emit-llvm -o - %s \ +// RUN: -fdx-rootsignature-define=EmptyRS \ +// RUN: | FileCheck %s --check-prefixes=CHECK,EMPTY + +// RUN: %clang -target dxil-unknown-shadermodel6.0-compute -S -emit-llvm -o - %s \ +// RUN: -fdx-rootsignature-define=CmdRS -D CmdRS='"SRV(t0)"' \ +// RUN: | FileCheck %s --check-prefixes=CHECK,CMD + +#define EmptyRS "" +#define NotEmptyRS "CBV(b0)" + +// CHECK: !dx.rootsignatures = !{![[#ENTRY:]]} +// CHECK: ![[#ENTRY]] = !{ptr @main, ![[#RS:]], i32 2} + +// REG: ![[#RS]] = !{![[#CBV:]]} +// REG: ![[#CBV]] = !{!"RootCBV" + +// EMPTY: ![[#RS]] = !{} + +// CMD: ![[#RS]] = !{![[#SRV:]]} +// CMD: ![[#SRV]] = !{!"RootSRV" + +[shader("compute"), RootSignature(NotEmptyRS)] +[numthreads(1,1,1)] +void main() {}