Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WGSL support for slang-test #5174

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions external/slang-tint-headers/slang-tint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include <stddef.h>
#include <stdint.h>

struct tint_CompileRequest
{
const char* wgslCode;
size_t wgslCodeLength;
};

struct tint_CompileResult
{
const uint8_t* buffer;
size_t bufferSize;
const char* error;
};


typedef int (*tint_CompileFunc)(tint_CompileRequest* request, tint_CompileResult* result);

typedef void (*tint_FreeResultFunc)(tint_CompileResult* result);
1 change: 1 addition & 0 deletions include/slang-gfx.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ enum class DeviceType
Metal,
CPU,
CUDA,
WebGPU,
CountOf,
};

Expand Down
4 changes: 3 additions & 1 deletion include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,8 @@ extern "C"
SLANG_METAL_LIB_ASM, ///< Metal library assembly
SLANG_HOST_SHARED_LIBRARY, ///< A shared library/Dll for host code (for hosting CPU/OS)
SLANG_WGSL, ///< WebGPU shading language
SLANG_WGSL_SPIRV_ASM, ///< SPIR-V assembly via WebGPU shading language
SLANG_WGSL_SPIRV, ///< SPIR-V via WebGPU shading language
SLANG_TARGET_COUNT_OF,
};

Expand Down Expand Up @@ -643,7 +645,7 @@ extern "C"
SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler' - includes LLVM and Clang
SLANG_PASS_THROUGH_SPIRV_OPT, ///< SPIRV-opt
SLANG_PASS_THROUGH_METAL, ///< Metal compiler
SLANG_PASS_THROUGH_WGSL, ///< WGSL compiler
SLANG_PASS_THROUGH_TINT, ///< Tint WGSL compiler
SLANG_PASS_THROUGH_COUNT_OF,
};

Expand Down
5 changes: 5 additions & 0 deletions source/compiler-core/slang-artifact-desc-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactKind, SLANG_ARTIFACT_KIND, SLANG_ARTIFACT_KIND_E
x(PTX, KernelLike) \
x(CuBin, KernelLike) \
x(MetalAIR, KernelLike) \
x(WGSL_SPIRV, KernelLike) \
x(CPULike, Base) \
x(UnknownCPU, CPULike) \
x(X86, CPULike) \
Expand Down Expand Up @@ -290,6 +291,8 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL
case SLANG_METAL_LIB: return Desc::make(Kind::Executable, Payload::MetalAIR, Style::Kernel, 0);
case SLANG_METAL_LIB_ASM: return Desc::make(Kind::Assembly, Payload::MetalAIR, Style::Kernel, 0);
case SLANG_WGSL: return Desc::make(Kind::Source, Payload::WGSL, Style::Kernel, 0);
case SLANG_WGSL_SPIRV_ASM: return Desc::make(Kind::Assembly, Payload::WGSL_SPIRV, Style::Kernel, 0);
case SLANG_WGSL_SPIRV: return Desc::make(Kind::Executable, Payload::WGSL_SPIRV, Style::Kernel, 0);
default: break;
}

Expand Down Expand Up @@ -346,6 +349,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL
case Payload::DXBC: return SLANG_DXBC_ASM;
case Payload::PTX: return SLANG_PTX;
case Payload::MetalAIR: return SLANG_METAL_LIB_ASM;
case Payload::WGSL_SPIRV: return SLANG_WGSL_SPIRV_ASM;
default: break;
}
}
Expand Down Expand Up @@ -374,6 +378,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL
case Payload::DXBC: return SLANG_DXBC;
case Payload::PTX: return SLANG_PTX;
case Payload::MetalAIR: return SLANG_METAL_LIB_ASM;
case Payload::WGSL_SPIRV: return SLANG_WGSL_SPIRV;
default: break;
}
}
Expand Down
1 change: 1 addition & 0 deletions source/compiler-core/slang-artifact.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ enum class ArtifactPayload : uint8_t
PTX, ///< PTX. NOTE! PTX is a text format, but is handable to CUDA API.
MetalAIR, ///< Metal AIR
CuBin, ///< CUDA binary
WGSL_SPIRV, ///< SPIR-V derived via WebGPU shading language

CPULike, ///< CPU code

Expand Down
2 changes: 2 additions & 0 deletions source/compiler-core/slang-downstream-compiler-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "slang-glslang-compiler.h"
#include "slang-llvm-compiler.h"
#include "slang-metal-compiler.h"
#include "slang-tint-compiler.h"

namespace Slang
{
Expand Down Expand Up @@ -332,6 +333,7 @@ DownstreamCompilerMatchVersion DownstreamCompilerUtil::getCompiledVersion()
outFuncs[int(SLANG_PASS_THROUGH_LLVM)] = &LLVMDownstreamCompilerUtil::locateCompilers;
outFuncs[int(SLANG_PASS_THROUGH_SPIRV_DIS)] = &SpirvDisDownstreamCompilerUtil::locateCompilers;
outFuncs[int(SLANG_PASS_THROUGH_METAL)] = &MetalDownstreamCompilerUtil::locateCompilers;
outFuncs[int(SLANG_PASS_THROUGH_TINT)] = &TintDownstreamCompilerUtil::locateCompilers;
}

static String _getParentPath(const String& path)
Expand Down
4 changes: 3 additions & 1 deletion source/compiler-core/slang-glslang-compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ SlangResult GlslangDownstreamCompiler::validate(const uint32_t* contents, int co
bool GlslangDownstreamCompiler::canConvert(const ArtifactDesc& from, const ArtifactDesc& to)
{
// Can only disassemble blobs that are SPIR-V
return ArtifactDescUtil::isDisassembly(from, to) && from.payload == ArtifactPayload::SPIRV;
return ArtifactDescUtil::isDisassembly(from, to) && (
(from.payload == ArtifactPayload::SPIRV) ||
(from.payload == ArtifactPayload::WGSL_SPIRV));
}

SlangResult GlslangDownstreamCompiler::convert(IArtifact* from, const ArtifactDesc& to, IArtifact** outArtifact)
Expand Down
164 changes: 164 additions & 0 deletions source/compiler-core/slang-tint-compiler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#include "slang-tint-compiler.h"

#include "slang-artifact-associated-impl.h"

#include "../../external/slang-tint-headers/slang-tint.h"

namespace Slang
{

class TintDownstreamCompiler : public DownstreamCompilerBase
{

public:

// IDownstreamCompiler
virtual SLANG_NO_THROW SlangResult SLANG_MCALL compile(
const CompileOptions& options, IArtifact** outResult) SLANG_OVERRIDE;

virtual SLANG_NO_THROW bool SLANG_MCALL canConvert(
const ArtifactDesc& from, const ArtifactDesc& to) SLANG_OVERRIDE;

virtual SLANG_NO_THROW SlangResult SLANG_MCALL convert(
IArtifact* from, const ArtifactDesc& to, IArtifact** outArtifact)
SLANG_OVERRIDE;

virtual SLANG_NO_THROW bool SLANG_MCALL isFileBased() SLANG_OVERRIDE
{
return false;
}

virtual SLANG_NO_THROW SlangResult SLANG_MCALL getVersionString(
slang::IBlob** outVersionString) SLANG_OVERRIDE;

SlangResult compile(IArtifact *const sourceArtifact, IArtifact** outArtifact);

SlangResult init(ISlangSharedLibrary* library);

protected:

ComPtr<ISlangSharedLibrary> m_sharedLibrary;

private:

tint_CompileFunc m_compile;
tint_FreeResultFunc m_freeResult;
};

SlangResult TintDownstreamCompiler::init(ISlangSharedLibrary* library)
{
tint_CompileFunc compile =
(tint_CompileFunc)library->findFuncByName("tint_compile");
if (compile == nullptr)
{
return SLANG_FAIL;
}

tint_FreeResultFunc freeResult =
(tint_FreeResultFunc)library->findFuncByName("tint_free_result");
if (freeResult == nullptr)
{
return SLANG_FAIL;
}

m_sharedLibrary = library;
m_desc = Desc(SLANG_PASS_THROUGH_TINT);
m_compile = compile;
m_freeResult = freeResult;
return SLANG_OK;
}

SlangResult TintDownstreamCompilerUtil::locateCompilers(
const String& path,
ISlangSharedLibraryLoader* loader,
DownstreamCompilerSet* set)
{
ComPtr<ISlangSharedLibrary> library;
SLANG_RETURN_ON_FAIL(DownstreamCompilerUtil::loadSharedLibrary(
path, loader, nullptr, "slang-tint", library));
SLANG_ASSERT(library);

ComPtr<IDownstreamCompiler> compiler = ComPtr<IDownstreamCompiler>(
new TintDownstreamCompiler());
SLANG_RETURN_ON_FAIL(static_cast<TintDownstreamCompiler*>(
compiler.get())->init(library));

set->addCompiler(compiler);
return SLANG_OK;
}

SlangResult TintDownstreamCompiler::compile(
const CompileOptions& options, IArtifact** outArtifact)
{
IArtifact * sourceArtifact = options.sourceArtifacts[0];
return compile(sourceArtifact, outArtifact);
}

SlangResult TintDownstreamCompiler::compile(
IArtifact *const sourceArtifact, IArtifact** outArtifact)
{
tint_CompileRequest req = {};

if (sourceArtifact == nullptr)
return SLANG_FAIL;

ComPtr<ISlangBlob> sourceBlob;
SLANG_RETURN_FALSE_ON_FAIL(sourceArtifact->loadBlob(
ArtifactKeep::Yes, sourceBlob.writeRef()));

String wgslCode(
(char*)sourceBlob->getBufferPointer(),
(char*)sourceBlob->getBufferPointer() + sourceBlob->getBufferSize());
req.wgslCode = wgslCode.begin();
req.wgslCodeLength = wgslCode.getLength();

tint_CompileResult result = {};
SLANG_DEFER(m_freeResult(&result));
bool compileSucceeded = m_compile(&req, &result) == 0;

ComPtr<ISlangBlob> spirvBlob = RawBlob::create(result.buffer, result.bufferSize);
result.buffer = nullptr;

ComPtr<IArtifact> resultArtifact = ArtifactUtil::createArtifactForCompileTarget(
SlangCompileTarget::SLANG_WGSL_SPIRV);
auto diagnostics = ArtifactDiagnostics::create();
diagnostics->setResult(compileSucceeded ? SLANG_OK : SLANG_FAIL);
ArtifactUtil::addAssociated(resultArtifact, diagnostics);
if (compileSucceeded)
{
resultArtifact->addRepresentationUnknown(spirvBlob);
}
else
{
diagnostics->setRaw(CharSlice(result.error));
diagnostics->requireErrorDiagnostic();
}

*outArtifact = resultArtifact.detach();
return SLANG_OK;
}

bool TintDownstreamCompiler::canConvert(
const ArtifactDesc& from, const ArtifactDesc& to)
{
return (from.payload == ArtifactPayload::WGSL) &&
(to.payload == ArtifactPayload::SPIRV);
}

SlangResult TintDownstreamCompiler::convert(
IArtifact* from, const ArtifactDesc& to, IArtifact** outArtifact)
{
if (!canConvert(from->getDesc(), to))
return SLANG_FAIL;
return compile(from, outArtifact);
}

SlangResult TintDownstreamCompiler::getVersionString(
slang::IBlob** /* outVersionString */)
{
// We just use Tint at whatever version is in our Dawn fork, so nobody should
// depend on the particular version at the moment.
return SLANG_FAIL;
}

}
17 changes: 17 additions & 0 deletions source/compiler-core/slang-tint-compiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "slang-downstream-compiler-util.h"
#include "../core/slang-platform.h"

namespace Slang
{

struct TintDownstreamCompilerUtil
{
static SlangResult locateCompilers(
const String& path,
ISlangSharedLibraryLoader* loader,
DownstreamCompilerSet* set);
};

}
5 changes: 5 additions & 0 deletions source/core/slang-render-api-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace Slang {
{ RenderApiType::Metal, "mtl,metal", ""},
{ RenderApiType::CPU, "cpu", ""},
{ RenderApiType::CUDA, "cuda", "cuda,ptx"},
{ RenderApiType::WebGPU, "wgpu,webgpu", "wgsl"},
};

static int _calcAvailableApis()
Expand Down Expand Up @@ -265,6 +266,10 @@ static bool _canLoadSharedLibrary(const char* libName)
{
#if SLANG_WINDOWS_FAMILY
case RenderApiType::Vulkan: return _canLoadSharedLibrary("vulkan-1") || _canLoadSharedLibrary("vk_swiftshader");
case RenderApiType::WebGPU:
return _canLoadSharedLibrary("webgpu_dawn") &&
_canLoadSharedLibrary("dxcompiler") &&
aleino-nv marked this conversation as resolved.
Show resolved Hide resolved
_canLoadSharedLibrary("dxil");
#elif SLANG_APPLE_FAMILY
case RenderApiType::Vulkan: return true;
case RenderApiType::Metal: return true;
Expand Down
2 changes: 2 additions & 0 deletions source/core/slang-render-api-util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ enum class RenderApiType
Metal,
CPU,
CUDA,
WebGPU,
CountOf,
};

Expand All @@ -31,6 +32,7 @@ struct RenderApiFlag
Metal = 1 << int(RenderApiType::Metal),
CPU = 1 << int(RenderApiType::CPU),
CUDA = 1 << int(RenderApiType::CUDA),
WebGPU = 1 << int(RenderApiType::WebGPU),
AllOf = (1 << int(RenderApiType::CountOf)) - 1 ///< All bits set
};
};
Expand Down
1 change: 1 addition & 0 deletions source/core/slang-type-convert-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace Slang
case SLANG_CPP_PYTORCH_BINDING:return SLANG_SOURCE_LANGUAGE_CPP;
case SLANG_HOST_CPP_SOURCE: return SLANG_SOURCE_LANGUAGE_CPP;
case SLANG_CUDA_SOURCE: return SLANG_SOURCE_LANGUAGE_CUDA;
case SLANG_WGSL: return SLANG_SOURCE_LANGUAGE_WGSL;
default: break;
}
return SLANG_SOURCE_LANGUAGE_UNKNOWN;
Expand Down
3 changes: 3 additions & 0 deletions source/core/slang-type-text-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ static const TypeTextUtil::CompileTargetInfo s_compileTargetInfos[] =
{ SLANG_METAL_LIB, "metallib", "metallib", "Metal Library Bytecode" },
{ SLANG_METAL_LIB_ASM, "metallib-asm" "metallib-asm", "Metal Library Bytecode assembly" },
{ SLANG_WGSL, "wgsl", "wgsl", "WebGPU shading language source" },
{ SLANG_WGSL_SPIRV_ASM, "wgsl-spirv-asm", "wgsl-spirv-asm,wgsl-spirv-assembly", "SPIR-V assembly via WebGPU shading language" },
{ SLANG_WGSL_SPIRV, "wgsl-spirv", "wgsl-spirv", "SPIR-V via WebGPU shading language" },
};

static const NamesDescriptionValue s_languageInfos[] =
Expand Down Expand Up @@ -91,6 +93,7 @@ static const NamesDescriptionValue s_compilerInfos[] =
{ SLANG_PASS_THROUGH_LLVM, "llvm", "LLVM/Clang `slang-llvm`" },
{ SLANG_PASS_THROUGH_SPIRV_OPT, "spirv-opt", "spirv-tools SPIRV optimizer" },
{ SLANG_PASS_THROUGH_METAL, "metal", "Metal shader compiler" },
{ SLANG_PASS_THROUGH_TINT, "tint", "Tint compiler" },
};

static const NamesDescriptionValue s_archiveTypeInfos[] =
Expand Down
4 changes: 2 additions & 2 deletions source/slang/slang-artifact-output-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ namespace Slang
}
return SLANG_FAIL;
}
// Get the downstream compiler that can be used for this target
// Get the downstream disassembler that can be used for this target
// TODO(JS):
// This could perhaps be performed in some other manner if there was more than one way to produce
// disassembly from a binary.

const CodeGenTarget target = (CodeGenTarget)ArtifactDescUtil::getCompileTargetFromDesc(desc);
const CodeGenTarget target = (CodeGenTarget)ArtifactDescUtil::getCompileTargetFromDesc(assemblyDesc);
if (target == CodeGenTarget::Unknown)
{
return SLANG_FAIL;
Expand Down
Loading
Loading