From 913ae2ce8606d668b26ca8df279b4c1874f14f2f Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Fri, 16 Aug 2024 14:24:32 -0700 Subject: [PATCH] [Validator] enable validator hash by default. (#6853) The changes affect both the internal validator (used within the DXIL compiler) and external validation tools. Now, by default, validator hash is enabled for all validation processes. #6863 was created for tracking the skip hash discussion. This is second step for #6808. Fixes #6857 --- .../dxc/DxilContainer/DxcContainerBuilder.h | 13 +++- lib/DxilContainer/DxcContainerBuilder.cpp | 63 ++++++++++++++- tools/clang/tools/dxa/dxa.cpp | 24 ++++++ tools/clang/tools/dxclib/dxc.cpp | 2 +- tools/clang/tools/dxcompiler/dxclinker.cpp | 2 +- .../clang/tools/dxcompiler/dxcompilerobj.cpp | 2 +- tools/clang/tools/dxcompiler/dxcutil.cpp | 39 ++++------ tools/clang/tools/dxcompiler/dxcutil.h | 2 + tools/clang/tools/dxcvalidator/CMakeLists.txt | 1 + .../clang/tools/dxcvalidator/dxcvalidator.cpp | 76 +++++++++++++++---- .../tools/dxrfallbackcompiler/dxcutil.cpp | 10 --- tools/clang/unittests/HLSL/CMakeLists.txt | 1 + tools/clang/unittests/HLSL/ValidationTest.cpp | 40 +++++++++- 13 files changed, 219 insertions(+), 56 deletions(-) diff --git a/include/dxc/DxilContainer/DxcContainerBuilder.h b/include/dxc/DxilContainer/DxcContainerBuilder.h index 1947e617b2..9a3241525c 100644 --- a/include/dxc/DxilContainer/DxcContainerBuilder.h +++ b/include/dxc/DxilContainer/DxcContainerBuilder.h @@ -11,9 +11,12 @@ #pragma once -#include "dxc/DxilContainer/DxilContainer.h" +// Include Windows header early for DxilHash.h. #include "dxc/Support/Global.h" #include "dxc/Support/WinIncludes.h" + +#include "dxc/DxilContainer/DxilContainer.h" +#include "dxc/DxilHash/DxilHash.h" #include "dxc/Support/microcom.h" #include "dxc/dxcapi.h" #include "llvm/ADT/SmallVector.h" @@ -46,6 +49,7 @@ class DxcContainerBuilder : public IDxcContainerBuilder { m_warning = warning; m_RequireValidation = false; m_HasPrivateData = false; + m_HashFunction = nullptr; } protected: @@ -66,6 +70,13 @@ class DxcContainerBuilder : public IDxcContainerBuilder { const char *m_warning; bool m_RequireValidation; bool m_HasPrivateData; + // Function to compute hash when valid dxil container is built + // This is nullptr if loaded container has invalid hash + HASH_FUNCTION_PROTO *m_HashFunction; + + void DetermineHashFunctionFromContainerContents( + const DxilContainerHeader *ContainerHeader); + void HashAndUpdate(DxilContainerHeader *ContainerHeader); UINT32 ComputeContainerSize(); HRESULT UpdateContainerHeader(AbstractMemoryStream *pStream, diff --git a/lib/DxilContainer/DxcContainerBuilder.cpp b/lib/DxilContainer/DxcContainerBuilder.cpp index 6a5758dfa5..3c10b0e70a 100644 --- a/lib/DxilContainer/DxcContainerBuilder.cpp +++ b/lib/DxilContainer/DxcContainerBuilder.cpp @@ -13,8 +13,6 @@ #include "dxc/DxilContainer/DxilContainer.h" #include "dxc/Support/ErrorCodes.h" #include "dxc/Support/FileIOHelper.h" -#include "dxc/Support/Global.h" -#include "dxc/Support/WinIncludes.h" #include "dxc/Support/dxcapi.impl.h" #include "dxc/Support/microcom.h" #include "dxc/dxcapi.h" @@ -47,6 +45,10 @@ HRESULT STDMETHODCALLTYPE DxcContainerBuilder::Load(IDxcBlob *pSource) { pPartHeader->PartSize, &pBlob)); AddPart(DxilPart(pPartHeader->PartFourCC, pBlob)); } + // Collect hash function. + const DxilContainerHeader *Header = + (DxilContainerHeader *)pSource->GetBufferPointer(); + DetermineHashFunctionFromContainerContents(Header); return S_OK; } CATCH_CPP_RETURN_HRESULT(); @@ -164,9 +166,64 @@ DxcContainerBuilder::SerializeContainer(IDxcOperationResult **ppResult) { {DxcOutputObject::DataOutput(DXC_OUT_OBJECT, pResult, DxcOutNoName), DxcOutputObject::DataOutput(DXC_OUT_ERRORS, pErrorBlob, DxcOutNoName)}, ppResult)); - return S_OK; } CATCH_CPP_RETURN_HRESULT(); + + if (ppResult == nullptr || *ppResult == nullptr) + return S_OK; + + HRESULT HR; + (*ppResult)->GetStatus(&HR); + if (FAILED(HR)) + return HR; + + CComPtr pObject; + IFR((*ppResult)->GetResult(&pObject)); + + // Add Hash. + LPVOID PTR = pObject->GetBufferPointer(); + if (IsDxilContainerLike(PTR, pObject->GetBufferSize())) + HashAndUpdate((DxilContainerHeader *)PTR); + return S_OK; +} + +// Try hashing the source contained in ContainerHeader using retail and debug +// hashing functions. If either of them match the stored result, set the +// HashFunction to the matching variant. If neither match, set it to null. +void DxcContainerBuilder::DetermineHashFunctionFromContainerContents( + const DxilContainerHeader *ContainerHeader) { + DXASSERT(ContainerHeader != nullptr && + IsDxilContainerLike(ContainerHeader, + ContainerHeader->ContainerSizeInBytes), + "otherwise load function should have returned an error."); + constexpr uint32_t HashStartOffset = + offsetof(struct DxilContainerHeader, Version); + auto *DataToHash = (const BYTE *)ContainerHeader + HashStartOffset; + UINT AmountToHash = ContainerHeader->ContainerSizeInBytes - HashStartOffset; + BYTE Result[DxilContainerHashSize]; + ComputeHashRetail(DataToHash, AmountToHash, Result); + if (0 == memcmp(Result, ContainerHeader->Hash.Digest, sizeof(Result))) { + m_HashFunction = ComputeHashRetail; + } else { + ComputeHashDebug(DataToHash, AmountToHash, Result); + if (0 == memcmp(Result, ContainerHeader->Hash.Digest, sizeof(Result))) + m_HashFunction = ComputeHashDebug; + else + m_HashFunction = nullptr; + } +} + +// For Internal hash function. +void DxcContainerBuilder::HashAndUpdate(DxilContainerHeader *ContainerHeader) { + if (m_HashFunction != nullptr) { + DXASSERT(ContainerHeader != nullptr, + "Otherwise serialization should have failed."); + static const UINT32 HashStartOffset = + offsetof(struct DxilContainerHeader, Version); + const BYTE *DataToHash = (const BYTE *)ContainerHeader + HashStartOffset; + UINT AmountToHash = ContainerHeader->ContainerSizeInBytes - HashStartOffset; + m_HashFunction(DataToHash, AmountToHash, ContainerHeader->Hash.Digest); + } } UINT32 DxcContainerBuilder::ComputeContainerSize() { diff --git a/tools/clang/tools/dxa/dxa.cpp b/tools/clang/tools/dxa/dxa.cpp index 2bffe1da3a..db6d1e9b88 100644 --- a/tools/clang/tools/dxa/dxa.cpp +++ b/tools/clang/tools/dxa/dxa.cpp @@ -68,6 +68,9 @@ static cl::opt DumpReflection("dumpreflection", cl::desc("Dump reflection"), cl::init(false)); +static cl::opt DumpHash("dumphash", cl::desc("Dump validation hash"), + cl::init(false)); + class DxaContext { private: @@ -88,6 +91,7 @@ class DxaContext { void DumpRS(); void DumpRDAT(); void DumpReflection(); + void DumpValidationHash(); }; void DxaContext::Assemble() { @@ -466,6 +470,23 @@ void DxaContext::DumpReflection() { printf("%s", ss.str().c_str()); } +void DxaContext::DumpValidationHash() { + CComPtr pSource; + ReadFileIntoBlob(m_dxcSupport, StringRefWide(InputFilename), &pSource); + if (!hlsl::IsValidDxilContainer( + (hlsl::DxilContainerHeader *)pSource->GetBufferPointer(), + pSource->GetBufferSize())) { + printf("Invalid input file, use binary DxilContainer."); + return; + } + hlsl::DxilContainerHeader *pDxilContainerHeader = + (hlsl::DxilContainerHeader *)pSource->GetBufferPointer(); + printf("Validation hash: 0x"); + for (size_t i = 0; i < hlsl::DxilContainerHashSize; i++) { + printf("%02x", pDxilContainerHeader->Hash.Digest[i]); + } +} + using namespace hlsl::options; #ifdef _WIN32 @@ -527,6 +548,9 @@ int main(int argc, const char **argv) { } else if (DumpReflection) { pStage = "Dump Reflection"; context.DumpReflection(); + } else if (DumpHash) { + pStage = "Dump Validation Hash"; + context.DumpValidationHash(); } else { pStage = "Assembling"; context.Assemble(); diff --git a/tools/clang/tools/dxclib/dxc.cpp b/tools/clang/tools/dxclib/dxc.cpp index cdcfe2b3f6..1bcf5d8e3f 100644 --- a/tools/clang/tools/dxclib/dxc.cpp +++ b/tools/clang/tools/dxclib/dxc.cpp @@ -644,7 +644,7 @@ int DxcContext::VerifyRootSignature() { IFT(pContainerBuilder->AddPart(hlsl::DxilFourCC::DFCC_RootSignature, pRootSignature)); CComPtr pOperationResult; - IFT(pContainerBuilder->SerializeContainer(&pOperationResult)); + pContainerBuilder->SerializeContainer(&pOperationResult); HRESULT status = E_FAIL; CComPtr pResult; IFT(pOperationResult->GetStatus(&status)); diff --git a/tools/clang/tools/dxcompiler/dxclinker.cpp b/tools/clang/tools/dxcompiler/dxclinker.cpp index 2446593238..82c9b8e96b 100644 --- a/tools/clang/tools/dxcompiler/dxclinker.cpp +++ b/tools/clang/tools/dxcompiler/dxclinker.cpp @@ -413,7 +413,7 @@ HRESULT STDMETHODCALLTYPE DxcLinker::Link( HRESULT valHR = S_OK; dxcutil::AssembleInputs inputs( std::move(pM), pOutputBlob, DxcGetThreadMallocNoRef(), - SerializeFlags, pOutputStream, opts.DebugFile, &Diag, + SerializeFlags, pOutputStream, 0, opts.DebugFile, &Diag, &ShaderHashContent, pReflectionStream, pRootSigStream, nullptr, nullptr); if (needsValidation) { diff --git a/tools/clang/tools/dxcompiler/dxcompilerobj.cpp b/tools/clang/tools/dxcompiler/dxcompilerobj.cpp index 2af9a3d8fc..db31f59634 100644 --- a/tools/clang/tools/dxcompiler/dxcompilerobj.cpp +++ b/tools/clang/tools/dxcompiler/dxcompilerobj.cpp @@ -1039,7 +1039,7 @@ class DxcCompiler : public IDxcCompiler3, dxcutil::AssembleInputs inputs( std::move(serializeModule), pOutputBlob, m_pMalloc, - SerializeFlags, pOutputStream, opts.GetPDBName(), + SerializeFlags, pOutputStream, 0, opts.GetPDBName(), &compiler.getDiagnostics(), &ShaderHashContent, pReflectionStream, pRootSigStream, pRootSignatureBlob, pPrivateBlob, opts.SelectValidator); diff --git a/tools/clang/tools/dxcompiler/dxcutil.cpp b/tools/clang/tools/dxcompiler/dxcutil.cpp index 193b060937..d3a531d4c6 100644 --- a/tools/clang/tools/dxcompiler/dxcutil.cpp +++ b/tools/clang/tools/dxcompiler/dxcutil.cpp @@ -76,17 +76,18 @@ AssembleInputs::AssembleInputs( std::unique_ptr &&pM, CComPtr &pOutputContainerBlob, IMalloc *pMalloc, hlsl::SerializeDxilFlags SerializeFlags, CComPtr &pModuleBitcode, - llvm::StringRef DebugName, clang::DiagnosticsEngine *pDiag, - hlsl::DxilShaderHash *pShaderHashOut, AbstractMemoryStream *pReflectionOut, - AbstractMemoryStream *pRootSigOut, CComPtr pRootSigBlob, - CComPtr pPrivateBlob, + uint32_t ValidationFlags, llvm::StringRef DebugName, + clang::DiagnosticsEngine *pDiag, hlsl::DxilShaderHash *pShaderHashOut, + AbstractMemoryStream *pReflectionOut, AbstractMemoryStream *pRootSigOut, + CComPtr pRootSigBlob, CComPtr pPrivateBlob, hlsl::options::ValidatorSelection SelectValidator) : pM(std::move(pM)), pOutputContainerBlob(pOutputContainerBlob), pMalloc(pMalloc), SerializeFlags(SerializeFlags), - pModuleBitcode(pModuleBitcode), DebugName(DebugName), pDiag(pDiag), - pShaderHashOut(pShaderHashOut), pReflectionOut(pReflectionOut), - pRootSigOut(pRootSigOut), pRootSigBlob(pRootSigBlob), - pPrivateBlob(pPrivateBlob), SelectValidator(SelectValidator) {} + ValidationFlags(ValidationFlags), pModuleBitcode(pModuleBitcode), + DebugName(DebugName), pDiag(pDiag), pShaderHashOut(pShaderHashOut), + pReflectionOut(pReflectionOut), pRootSigOut(pRootSigOut), + pRootSigBlob(pRootSigBlob), pPrivateBlob(pPrivateBlob), + SelectValidator(SelectValidator) {} void GetValidatorVersion(unsigned *pMajor, unsigned *pMinor, hlsl::options::ValidatorSelection SelectValidator) { @@ -174,18 +175,6 @@ HRESULT ValidateAndAssembleToContainer(AssembleInputs &inputs) { pValidator.QueryInterface(&pValidator2); } - if (bInternalValidator && - inputs.SelectValidator != hlsl::options::ValidatorSelection::Internal) { - if (inputs.pDiag) { - unsigned diagID = inputs.pDiag->getCustomDiagID( - clang::DiagnosticsEngine::Level::Warning, - "DXIL signing library (dxil.dll,libdxil.so) not found. Resulting " - "DXIL will not be " - "signed for use in release environments.\r\n"); - inputs.pDiag->Report(diagID); - } - } - if (bInternalValidator || pValidator2) { // If using the internal validator or external validator supports // IDxcValidator2, we'll use the modules directly. In this case, we'll want @@ -222,13 +211,13 @@ HRESULT ValidateAndAssembleToContainer(AssembleInputs &inputs) { CComPtr pValResult; // Important: in-place edit is required so the blob is reused and thus // dxil.dll can be released. + inputs.ValidationFlags |= DxcValidatorFlags_InPlaceEdit; if (bInternalValidator) { IFT(RunInternalValidator(pValidator, llvmModuleWithDebugInfo.get(), inputs.pOutputContainerBlob, - DxcValidatorFlags_InPlaceEdit, &pValResult)); + inputs.ValidationFlags, &pValResult)); } else { if (pValidator2 && llvmModuleWithDebugInfo) { - // If metadata was stripped, re-serialize the input module. CComPtr pDebugModuleStream; IFT(CreateMemoryStream(DxcGetThreadMallocNoRef(), &pDebugModuleStream)); @@ -241,11 +230,11 @@ HRESULT ValidateAndAssembleToContainer(AssembleInputs &inputs) { debugModule.Size = pDebugModuleStream->GetPtrSize(); IFT(pValidator2->ValidateWithDebug(inputs.pOutputContainerBlob, - DxcValidatorFlags_InPlaceEdit, - &debugModule, &pValResult)); + inputs.ValidationFlags, &debugModule, + &pValResult)); } else { IFT(pValidator->Validate(inputs.pOutputContainerBlob, - DxcValidatorFlags_InPlaceEdit, &pValResult)); + inputs.ValidationFlags, &pValResult)); } } IFT(pValResult->GetStatus(&valHR)); diff --git a/tools/clang/tools/dxcompiler/dxcutil.h b/tools/clang/tools/dxcompiler/dxcutil.h index 580c5942df..45b3d4dc1a 100644 --- a/tools/clang/tools/dxcompiler/dxcutil.h +++ b/tools/clang/tools/dxcompiler/dxcutil.h @@ -47,6 +47,7 @@ struct AssembleInputs { CComPtr &pOutputContainerBlob, IMalloc *pMalloc, hlsl::SerializeDxilFlags SerializeFlags, CComPtr &pModuleBitcode, + uint32_t ValidationFlags = 0, llvm::StringRef DebugName = llvm::StringRef(), clang::DiagnosticsEngine *pDiag = nullptr, hlsl::DxilShaderHash *pShaderHashOut = nullptr, @@ -61,6 +62,7 @@ struct AssembleInputs { IDxcVersionInfo *pVersionInfo = nullptr; IMalloc *pMalloc; hlsl::SerializeDxilFlags SerializeFlags; + uint32_t ValidationFlags = 0; CComPtr &pModuleBitcode; llvm::StringRef DebugName = llvm::StringRef(); clang::DiagnosticsEngine *pDiag; diff --git a/tools/clang/tools/dxcvalidator/CMakeLists.txt b/tools/clang/tools/dxcvalidator/CMakeLists.txt index 3ad0a6bf75..4991c9f97b 100644 --- a/tools/clang/tools/dxcvalidator/CMakeLists.txt +++ b/tools/clang/tools/dxcvalidator/CMakeLists.txt @@ -7,6 +7,7 @@ set( LLVM_LINK_COMPONENTS dxcsupport DXIL DxilContainer + DxilHash DxilValidation Option # option library Support # just for assert and raw streams diff --git a/tools/clang/tools/dxcvalidator/dxcvalidator.cpp b/tools/clang/tools/dxcvalidator/dxcvalidator.cpp index a9d721bbec..b8b71ece62 100644 --- a/tools/clang/tools/dxcvalidator/dxcvalidator.cpp +++ b/tools/clang/tools/dxcvalidator/dxcvalidator.cpp @@ -9,13 +9,14 @@ // // /////////////////////////////////////////////////////////////////////////////// +#include "dxc/Support/WinIncludes.h" #include "llvm/Bitcode/ReaderWriter.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "dxc/DxilContainer/DxilContainer.h" +#include "dxc/DxilHash/DxilHash.h" #include "dxc/DxilValidation/DxilValidation.h" -#include "dxc/Support/WinIncludes.h" #include "dxc/dxcapi.h" #include "dxcvalidator.h" @@ -31,6 +32,34 @@ using namespace llvm; using namespace hlsl; +static void HashAndUpdate(DxilContainerHeader *Container) { + // Compute hash and update stored hash. + // Hash the container from this offset to the end. + static const uint32_t DXBCHashStartOffset = + offsetof(struct DxilContainerHeader, Version); + const unsigned char *DataToHash = + (const unsigned char *)Container + DXBCHashStartOffset; + unsigned AmountToHash = Container->ContainerSizeInBytes - DXBCHashStartOffset; + ComputeHashRetail(DataToHash, AmountToHash, Container->Hash.Digest); +} + +static void HashAndUpdateOrCopy(uint32_t Flags, IDxcBlob *Shader, + IDxcBlob **Hashed) { + if (Flags & DxcValidatorFlags_InPlaceEdit) { + HashAndUpdate((DxilContainerHeader *)Shader->GetBufferPointer()); + *Hashed = Shader; + Shader->AddRef(); + } else { + CComPtr HashedBlobStream; + IFT(CreateMemoryStream(DxcGetThreadMallocNoRef(), &HashedBlobStream)); + unsigned long CB; + IFT(HashedBlobStream->Write(Shader->GetBufferPointer(), + Shader->GetBufferSize(), &CB)); + HashAndUpdate((DxilContainerHeader *)HashedBlobStream->GetPtr()); + IFT(HashedBlobStream.QueryInterface(Hashed)); + } +} + static uint32_t runValidation( IDxcBlob *Shader, uint32_t Flags, // Validation flags. @@ -180,18 +209,39 @@ uint32_t hlsl::validateWithOptDebugModule( ULONG cbWritten; DiagStream->Write(msg.c_str(), msg.size(), &cbWritten); } - // Assemble the result object. - CComPtr pDiagBlob; - hr = DiagStream.QueryInterface(&pDiagBlob); - DXASSERT_NOMSG(SUCCEEDED(hr)); - hr = DxcResult::Create( - validationStatus, DXC_OUT_NONE, - {DxcOutputObject::ErrorOutput( - CP_UTF8, // TODO Support DefaultTextCodePage - (LPCSTR)pDiagBlob->GetBufferPointer(), pDiagBlob->GetBufferSize())}, - Result); - if (FAILED(hr)) - throw hlsl::Exception(hr); + if (Flags & (DxcValidatorFlags_ModuleOnly)) { + // Validating a module only, return DXC_OUT_NONE instead of + // DXC_OUT_OBJECT. + CComPtr pDiagBlob; + hr = DiagStream.QueryInterface(&pDiagBlob); + DXASSERT_NOMSG(SUCCEEDED(hr)); + hr = DxcResult::Create(validationStatus, DXC_OUT_NONE, + {DxcOutputObject::ErrorOutput( + CP_UTF8, // TODO Support DefaultTextCodePage + (LPCSTR)pDiagBlob->GetBufferPointer(), + pDiagBlob->GetBufferSize())}, + Result); + if (FAILED(hr)) + throw hlsl::Exception(hr); + } else { + CComPtr HashedBlob; + // Assemble the result object. + CComPtr DiagBlob; + CComPtr DiagBlobEnconding; + hr = DiagStream.QueryInterface(&DiagBlob); + DXASSERT_NOMSG(SUCCEEDED(hr)); + hr = DxcCreateBlobWithEncodingSet(DiagBlob, CP_UTF8, &DiagBlobEnconding); + if (FAILED(hr)) + throw hlsl::Exception(hr); + HashAndUpdateOrCopy(Flags, Shader, &HashedBlob); + hr = DxcResult::Create( + validationStatus, DXC_OUT_OBJECT, + {DxcOutputObject::DataOutput(DXC_OUT_OBJECT, HashedBlob), + DxcOutputObject::DataOutput(DXC_OUT_ERRORS, DiagBlobEnconding)}, + Result); + if (FAILED(hr)) + throw hlsl::Exception(hr); + } } CATCH_CPP_ASSIGN_HRESULT(); diff --git a/tools/clang/tools/dxrfallbackcompiler/dxcutil.cpp b/tools/clang/tools/dxrfallbackcompiler/dxcutil.cpp index 6ef2f5a373..cb25b50136 100644 --- a/tools/clang/tools/dxrfallbackcompiler/dxcutil.cpp +++ b/tools/clang/tools/dxrfallbackcompiler/dxcutil.cpp @@ -142,16 +142,6 @@ HRESULT ValidateAndAssembleToContainer(AssembleInputs &inputs) { // Warning on internal Validator if (bInternalValidator) { -#if !DISABLE_GET_CUSTOM_DIAG_ID - if (inputs.pDiag) { - unsigned diagID = inputs.pDiag->getCustomDiagID( - clang::DiagnosticsEngine::Level::Warning, - "DXIL signing library (dxil.dll,libdxil.so) not found. Resulting " - "DXIL will not be " - "signed for use in release environments.\r\n"); - inputs.pDiag->Report(diagID); - } -#endif // If using the internal validator, we'll use the modules directly. // In this case, we'll want to make a clone to avoid // SerializeDxilContainerForModule stripping all the debug info. The debug diff --git a/tools/clang/unittests/HLSL/CMakeLists.txt b/tools/clang/unittests/HLSL/CMakeLists.txt index 0fd570a28b..eaba5049b2 100644 --- a/tools/clang/unittests/HLSL/CMakeLists.txt +++ b/tools/clang/unittests/HLSL/CMakeLists.txt @@ -14,6 +14,7 @@ set( LLVM_LINK_COMPONENTS dxilcontainer dxilrootsignature hlsl + dxilhash option bitreader bitwriter diff --git a/tools/clang/unittests/HLSL/ValidationTest.cpp b/tools/clang/unittests/HLSL/ValidationTest.cpp index 2b3bf38e06..8f6236b8a3 100644 --- a/tools/clang/unittests/HLSL/ValidationTest.cpp +++ b/tools/clang/unittests/HLSL/ValidationTest.cpp @@ -10,6 +10,7 @@ #define NOMINMAX +#include "dxc/Support/WinIncludes.h" #include #include #include @@ -17,7 +18,7 @@ #include "dxc/DxilContainer/DxilContainer.h" #include "dxc/DxilContainer/DxilContainerAssembler.h" -#include "dxc/Support/WinIncludes.h" +#include "dxc/DxilHash/DxilHash.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Regex.h" @@ -295,6 +296,7 @@ class ValidationTest : public ::testing::Test { TEST_METHOD(ValidateRootSigContainer) TEST_METHOD(ValidatePrintfNotAllowed) + TEST_METHOD(ValidateWithHash) TEST_METHOD(ValidateVersionNotAllowed) TEST_METHOD(CreateHandleNotAllowedSM66) @@ -4071,6 +4073,42 @@ TEST_F(ValidationTest, ValidatePrintfNotAllowed) { TestCheck(L"..\\CodeGenHLSL\\printf.hlsl"); } +TEST_F(ValidationTest, ValidateWithHash) { + if (m_ver.SkipDxilVersion(1, 8)) + return; + CComPtr pProgram; + CompileSource("float4 main(float a:A, float b:B) : SV_Target { return 1; }", + "ps_6_0", &pProgram); + + CComPtr pValidator; + CComPtr pResult; + unsigned Flags = 0; + VERIFY_SUCCEEDED( + m_dllSupport.CreateInstance(CLSID_DxcValidator, &pValidator)); + // With hash. + VERIFY_SUCCEEDED(pValidator->Validate(pProgram, Flags, &pResult)); + // Make sure the validation was successful. + HRESULT status; + VERIFY_IS_NOT_NULL(pResult); + CComPtr pValidationOutput; + pResult->GetStatus(&status); + VERIFY_SUCCEEDED(status); + pResult->GetResult(&pValidationOutput); + // Make sure the validation output is not null when hashing. + VERIFY_SUCCEEDED(pValidationOutput != nullptr); + + hlsl::DxilContainerHeader *pHeader = + (hlsl::DxilContainerHeader *)pProgram->GetBufferPointer(); + // Validate the hash. + constexpr uint32_t HashStartOffset = + offsetof(struct DxilContainerHeader, Version); + auto *DataToHash = (const BYTE *)pHeader + HashStartOffset; + UINT AmountToHash = pHeader->ContainerSizeInBytes - HashStartOffset; + BYTE Result[DxilContainerHashSize]; + ComputeHashRetail(DataToHash, AmountToHash, Result); + VERIFY_ARE_EQUAL(memcmp(Result, pHeader->Hash.Digest, sizeof(Result)), 0); +} + TEST_F(ValidationTest, ValidateVersionNotAllowed) { if (m_ver.SkipDxilVersion(1, 6)) return;