Skip to content

Commit

Permalink
Merge pull request #618 from Sergio0694/dev/graphics-resources-loading
Browse files Browse the repository at this point in the history
Rework IShader APIs to load constant buffer and resources
  • Loading branch information
Sergio0694 authored Oct 15, 2023
2 parents befe904 + be6a39f commit c31c63c
Show file tree
Hide file tree
Showing 32 changed files with 652 additions and 1,026 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,43 @@ partial class IShaderGenerator
partial class LoadDispatchData
{
/// <summary>
/// Writes the <c>LoadDispatchData</c> method.
/// Writes the <c>LoadConstantBuffer</c> method.
/// </summary>
/// <param name="info">The input <see cref="ShaderInfo"/> instance with gathered shader info.</param>
/// <param name="writer">The <see cref="IndentedTextWriter"/> instance to write into.</param>
public static void WriteSyntax(ShaderInfo info, IndentedTextWriter writer)
public static void WriteLoadConstantBufferSyntax(ShaderInfo info, IndentedTextWriter writer)
{
writer.WriteLine("/// <inheritdoc/>");
writer.WriteGeneratedAttributes(GeneratorName);
writer.WriteLine("[global::System.Runtime.CompilerServices.SkipLocalsInit]");
writer.WriteLine("readonly void global::ComputeSharp.__Internals.IShader.LoadDispatchData<TLoader>(ref TLoader loader, global::ComputeSharp.GraphicsDevice device, int x, int y, int z)");
writer.WriteLine("readonly void global::ComputeSharp.__Internals.IShader.LoadConstantBuffer<TLoader>(ref TLoader loader, int x, int y, int z)");

using (writer.WriteBlock())
{
writer.WriteLine($"global::System.Span<uint> span0 = stackalloc uint[{info.ConstantBufferSizeInBytes}];");
writer.WriteLineIf($"global::System.Span<ulong> span1 = stackalloc ulong[{info.ResourceCount}];", info.ResourceCount > 0);
writer.WriteLine("ref uint r0 = ref span0[0];");
writer.WriteLineIf("ref ulong r1 = ref span1[0];", info.ResourceCount > 0);
writer.WriteLine($"global::System.Span<byte> span = stackalloc byte[{info.ConstantBufferSizeInBytes}];");

// Append the statements for the dispatch ranges
writer.WriteLine("span0[0] = (uint)x;");
writer.WriteLine("span0[1] = (uint)y;");
writer.WriteLineIf("span0[2] = (uint)z;", !info.IsPixelShaderLike);
writer.WriteLine("global::System.Runtime.CompilerServices.Unsafe.As<byte, uint>(ref span[0]) = (uint)x;");
writer.WriteLine("global::System.Runtime.CompilerServices.Unsafe.As<byte, uint>(ref span[4]) = (uint)y;");
writer.WriteLineIf("global::System.Runtime.CompilerServices.Unsafe.As<byte, uint>(ref span[8]) = (uint)z;", !info.IsPixelShaderLike);

// Generate loading statements for each captured field
foreach (FieldInfo fieldInfo in info.Fields)
{
switch (fieldInfo)
{
case FieldInfo.Resource resource:

// Validate the resource and get a handle for it
writer.WriteLine(
$"global::System.Runtime.CompilerServices.Unsafe.Add(ref r1, {resource.Offset}) = " +
$"global::ComputeSharp.__Internals.GraphicsResourceHelper.ValidateAndGetGpuDescriptorHandle({resource.FieldName}, device);");
break;
case FieldInfo.Primitive { TypeName: "System.Boolean" } primitive:

// Read a boolean value and cast it to Bool first, which will apply the correct size expansion
writer.WriteLine(
$"global::System.Runtime.CompilerServices.Unsafe.As<uint, global::ComputeSharp.Bool>(ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint){primitive.Offset})) = " +
$"global::System.Runtime.CompilerServices.Unsafe.As<byte, global::ComputeSharp.Bool>(ref span[{primitive.Offset}]) = " +
$"(global::ComputeSharp.Bool){string.Join(".", primitive.FieldPath)};");
break;
case FieldInfo.Primitive primitive:

// Read a primitive value and serialize it into the target buffer
writer.WriteLine(
$"global::System.Runtime.CompilerServices.Unsafe.As<uint, global::{primitive.TypeName}>(ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint){primitive.Offset})) = " +
$"global::System.Runtime.CompilerServices.Unsafe.As<byte, global::{primitive.TypeName}>(ref span[{primitive.Offset}]) = " +
$"{string.Join(".", primitive.FieldPath)};");
break;

Expand All @@ -82,17 +72,40 @@ public static void WriteSyntax(ShaderInfo info, IndentedTextWriter writer)
for (int j = 0; j < matrix.Rows; j++)
{
writer.WriteLine(
$"global::System.Runtime.CompilerServices.Unsafe.As<uint, {rowTypeName}>(ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint){matrix.Offsets[j]})) = " +
$"global::System.Runtime.CompilerServices.Unsafe.As<byte, {rowTypeName}>(ref span[{matrix.Offsets[j]}]) = " +
$"global::System.Runtime.CompilerServices.Unsafe.Add(ref {rowLocalName}, {j});");
}

break;
}
}

// Load the prepared buffers
writer.WriteLine("loader.LoadCapturedValues(span0);");
writer.WriteLineIf("loader.LoadCapturedResources(span1);", info.ResourceCount > 0);
// Load the prepared buffer
writer.WriteLine("loader.LoadConstantBuffer(span);");
}
}

/// <summary>
/// Writes the <c>LoadGraphicsResources</c> method.
/// </summary>
/// <param name="info">The input <see cref="ShaderInfo"/> instance with gathered shader info.</param>
/// <param name="writer">The <see cref="IndentedTextWriter"/> instance to write into.</param>
public static void WriteLoadGraphicsResourcesSyntax(ShaderInfo info, IndentedTextWriter writer)
{
writer.WriteLine("/// <inheritdoc/>");
writer.WriteGeneratedAttributes(GeneratorName);
writer.WriteLine("readonly void global::ComputeSharp.__Internals.IShader.LoadGraphicsResources<TLoader>(ref TLoader loader)");

using (writer.WriteBlock())
{
// Generate loading statements for each captured resource
foreach (FieldInfo fieldInfo in info.Fields)
{
if (fieldInfo is FieldInfo.Resource resource)
{
writer.WriteLine($"loader.LoadGraphicsResource({resource.FieldName}, {resource.Offset});");
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ private static partial class LoadDispatchData
/// </summary>
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="structDeclarationSymbol">The current shader type being explored.</param>
/// <param name="constantBufferSizeInBytes">The size of the shader constant buffer.</param>
/// <param name="isPixelShaderLike">Whether the compute shader is "pixel shader like", ie. outputting a pixel into a target texture.</param>
/// <param name="constantBufferSizeInBytes">The size of the shader constant buffer.</param>
/// <param name="resourceCount">The total number of captured resources in the shader.</param>
/// <returns>The sequence of <see cref="FieldInfo"/> instances for all captured resources and values.</returns>
public static ImmutableArray<FieldInfo> GetInfo(
Expand Down Expand Up @@ -108,7 +108,7 @@ from fieldSymbol in currentTypeSymbol.GetMembers().OfType<IFieldSymbol>()
// the local variables is padded to a multiple of a 32 bit value. This is necessary to
// enable loading all the dispatch data after reinterpreting it to a sequence of values
// of size 32 bits (via SetComputeRoot32BitConstants), without reading out of bounds.
constantBufferSizeInBytes = AlignmentHelper.Pad(rawDataOffsetAsBox.Value, sizeof(int)) / sizeof(int);
constantBufferSizeInBytes = AlignmentHelper.Pad(rawDataOffsetAsBox.Value, sizeof(int));

// A shader root signature has a maximum size of 64 DWORDs, so 256 bytes.
// Loaded values in the root signature have the following costs:
Expand All @@ -118,9 +118,9 @@ from fieldSymbol in currentTypeSymbol.GetMembers().OfType<IFieldSymbol>()
// So here we check whether the current signature respects that constraint,
// and emit a build error otherwise. For more info on this, see the docs here:
// https://docs.microsoft.com/windows/win32/direct3d12/root-signature-limits.
int rootSignatureDwordSize = constantBufferSizeInBytes + resourceCount;
int root32BitConstantCount = (constantBufferSizeInBytes / sizeof(int)) + resourceCount;

if (rootSignatureDwordSize > 64)
if (root32BitConstantCount > 64)
{
diagnostics.Add(ShaderDispatchDataSizeExceeded, structDeclarationSymbol, structDeclarationSymbol);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ private static partial class LoadDispatchMetadata
/// <summary>
/// Gets the data related to the shader metadata for a given shader type.
/// </summary>
/// <param name="root32BitConstantCount">The total number of needed 32 bit constants in the shader root signature.</param>
/// <param name="isImplicitTextureUsed">Indicates whether the current shader uses an implicit texture.</param>
/// <param name="isSamplerUsed">Whether the static sampler is used by the shader.</param>
/// <param name="capturedFields">The sequence of captured fields for the shader.</param>
/// <returns>The metadata info for the shader.</returns>
public static ImmutableArray<ResourceDescriptor> GetInfo(
int root32BitConstantCount,
bool isImplicitTextureUsed,
bool isSamplerUsed,
ImmutableArray<FieldInfo> capturedFields)
Expand Down
8 changes: 4 additions & 4 deletions src/ComputeSharp.SourceGenerators/IShaderGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
diagnostics,
typeSymbol,
isPixelShaderLike,
out int root32BitConstantCount,
out int constantBufferSizeInBytes,
out int resourceCount);

token.ThrowIfCancellationRequested();
Expand Down Expand Up @@ -92,7 +92,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

// GetDispatchMetadata() info
ImmutableArray<ResourceDescriptor> resourceDescriptors = LoadDispatchMetadata.GetInfo(
root32BitConstantCount,
isImplicitTextureUsed,
isSamplerUsed,
fieldInfos);
Expand Down Expand Up @@ -123,7 +122,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
ThreadsZ: threadsZ,
IsPixelShaderLike: isPixelShaderLike,
IsSamplerUsed: isSamplerUsed,
ConstantBufferSizeInBytes: root32BitConstantCount,
ConstantBufferSizeInBytes: constantBufferSizeInBytes,
ResourceCount: resourceCount,
Fields: fieldInfos,
ResourceDescriptors: resourceDescriptors,
Expand Down Expand Up @@ -153,7 +152,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
declaredMembers.Add(LoadDispatchMetadata.WriteSyntax);
declaredMembers.Add(BuildHlslSource.WriteSyntax);
declaredMembers.Add(LoadBytecode.WriteHlslBytecodeSyntax);
declaredMembers.Add(LoadDispatchData.WriteSyntax);
declaredMembers.Add(LoadDispatchData.WriteLoadConstantBufferSyntax);
declaredMembers.Add(LoadDispatchData.WriteLoadGraphicsResourcesSyntax);

using ImmutableArrayBuilder<IndentedTextWriter.Callback<ShaderInfo>> additionalTypes = new();
using ImmutableHashSetBuilder<string> usingDirectives = new();
Expand Down
2 changes: 1 addition & 1 deletion src/ComputeSharp.SourceGenerators/Models/ShaderInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace ComputeSharp.SourceGenerators.Models;
/// <param name="ThreadsZ">The thread ids value for the Z axis.</param>
/// <param name="IsPixelShaderLike">Whether the compute shader is "pixel shader like", ie. outputting a pixel into a target texture.</param>
/// <param name="IsSamplerUsed">Whether or not the static sampler is used.</param>
/// /// <param name="ConstantBufferSizeInBytes">The size of the shader constant buffer.</param>
/// <param name="ConstantBufferSizeInBytes">The size of the shader constant buffer.</param>
/// <param name="ResourceCount">The total number of captured resources in the shader.</param>
/// <param name="Fields">The description on shader instance fields.</param>
/// <param name="ResourceDescriptors">The sequence of resource descriptors for the shader.</param>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using System;

namespace ComputeSharp.__Internals;

/// <summary>
/// A type representing a data loader for the constant buffer of a compute shader.
/// </summary>
public interface IConstantBufferLoader
{
/// <summary>
/// Loads the constant buffer of a compute shader.
/// </summary>
/// <param name="data">The constant buffer for the compute shader (the size must be a multiple of the size of a DWORD value).</param>
/// <exception cref="ArgumentException">Thrown if the size of <paramref name="data"/> is not a multiple of the size of a DWORD value).</exception>
void LoadConstantBuffer(ReadOnlySpan<byte> data);
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;

namespace ComputeSharp.__Internals;

/// <summary>
/// A type representing a loader object for <see cref="IGraphicsResource"/> values used in compute shaders.
/// </summary>
public interface IGraphicsResourceLoader
{
/// <summary>
/// Loads a resource used by the shader to be dispatched.
/// </summary>
/// <param name="resource">The input resource to be loaded.</param>
/// <param name="index">The index to use to bind the resource to the shader.</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="resource"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentException">Thrown if <paramref name="resource"/> is not a valid resource object.</exception>
/// <exception cref="GraphicsDeviceMismatchException">Thrown if the resource isn't associated with the same device used to dispatch the shader.</exception>
void LoadGraphicsResource(IGraphicsResource resource, uint index);
}
19 changes: 12 additions & 7 deletions src/ComputeSharp/Core/Interfaces/__Internals/IShader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface IShader
/// <remarks>
/// Constant buffer data is bound to shaders via 32 bit root constants, and loaded before dispatching via
/// <see href="https://learn.microsoft.com/windows/win32/api/d3d12/nf-d3d12-id3d12graphicscommandlist-setcomputeroot32bitconstants"><c>ID3D12GraphicsCommandList::SetComputeRoot32BitConstants</c></see>,
/// so the size must be a multiple of 4.
/// so the size must be a multiple of the size of a DWORD value (ie. 4 bytes).
/// </remarks>
int ConstantBufferSize { get; }

Expand Down Expand Up @@ -60,16 +60,21 @@ public interface IShader
ReadOnlyMemory<byte> HlslBytecode { get; }

/// <summary>
/// Loads the dispatch data for the shader.
/// Loads the constant buffer of a given input shader.
/// </summary>
/// <typeparam name="TLoader">The type of data loader being used.</typeparam>
/// <param name="loader">The <typeparamref name="TLoader"/> instance to use to load the data.</param>
/// <param name="device">The device the shader is being dispatched on.</param>
/// <param name="x">The number of iterations to run on the X axis.</param>
/// <param name="y">The number of iterations to run on the Y axis.</param>
/// <param name="z">The number of iterations to run on the Z axis.</param>
[EditorBrowsable(EditorBrowsableState.Never)]
[Obsolete("This method is not intended to be called directly by user code")]
void LoadDispatchData<TLoader>(ref TLoader loader, GraphicsDevice device, int x, int y, int z)
where TLoader : struct, IDispatchDataLoader;
void LoadConstantBuffer<TLoader>(ref TLoader loader, int x, int y, int z)
where TLoader : struct, IConstantBufferLoader;

/// <summary>
/// Loads the graphics resources for the shader.
/// </summary>
/// <typeparam name="TLoader">The type of graphics resource loader being used.</typeparam>
/// <param name="loader">The <typeparamref name="TLoader"/> instance to use to load the data.</param>
void LoadGraphicsResources<TLoader>(ref TLoader loader)
where TLoader : struct, IGraphicsResourceLoader;
}
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,11 @@ public static unsafe void FillUnorderedAccessView(
/// </summary>
/// <param name="d3D12GraphicsCommandList">The <see cref="ID3D12GraphicsCommandList"/> instance in use.</param>
/// <param name="data">The input buffer with the constant data to bind.</param>
public static void SetComputeRoot32BitConstants(this ref ID3D12GraphicsCommandList d3D12GraphicsCommandList, ReadOnlySpan<uint> data)
public static void SetComputeRoot32BitConstants(this ref ID3D12GraphicsCommandList d3D12GraphicsCommandList, ReadOnlySpan<byte> data)
{
fixed (uint* p = data)
fixed (byte* p = data)
{
d3D12GraphicsCommandList.SetComputeRoot32BitConstants(0, (uint)data.Length, p, 0);
d3D12GraphicsCommandList.SetComputeRoot32BitConstants(0, (uint)data.Length / sizeof(uint), p, 0);
}
}
}
Loading

0 comments on commit c31c63c

Please sign in to comment.