Skip to content

Commit

Permalink
Use MSL directly for SDL_GPU (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheSpydog authored Nov 3, 2024
1 parent 1df57d1 commit 32ae391
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1,149 deletions.
113 changes: 62 additions & 51 deletions mojoshader_sdlgpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,21 @@
#ifdef USE_SDL3 /* Private define, for now */

#include <SDL3/SDL.h>

#define SDL_GPU_SHADERCROSS_IMPLEMENTATION
#include "spirv/SDL_gpu_shadercross.h"
#include <spirv/spirv.h>

/* Max entries for each register file type */
#define MAX_REG_FILE_F 8192
#define MAX_REG_FILE_I 2047
#define MAX_REG_FILE_B 2047

/* The destination shader format to use */
static SDL_GPUShaderFormat shader_format =
#ifdef __APPLE__
SDL_GPU_SHADERFORMAT_MSL;
#else
SDL_GPU_SHADERFORMAT_SPIRV;
#endif

typedef struct ShaderEntry
{
uint64_t hash;
Expand Down Expand Up @@ -248,8 +254,7 @@ static uint8_t update_uniform_buffer(

unsigned int MOJOSHADER_sdlGetShaderFormats(void)
{
SDL_ShaderCross_Init();
return SDL_ShaderCross_GetSPIRVShaderFormats();
return shader_format;
} // MOJOSHADER_sdlGetShaderFormats

static bool load_precompiled_blob(MOJOSHADER_sdlContext *ctx)
Expand Down Expand Up @@ -349,8 +354,7 @@ MOJOSHADER_sdlContext *MOJOSHADER_sdlCreateContext(
}
else
{
/* Always use spirv and interop with SDL_gpu_spirvcross */
resultCtx->profile = "spirv";
resultCtx->profile = (shader_format == SDL_GPU_SHADERFORMAT_SPIRV) ? "spirv" : "metal";
}

resultCtx->malloc_fn = m;
Expand Down Expand Up @@ -390,8 +394,6 @@ void MOJOSHADER_sdlDestroyContext(
} // if

ctx->free_fn(ctx, ctx->malloc_data);

SDL_ShaderCross_Quit();
} // MOJOSHADER_sdlDestroyContext

static uint16_t shaderTagCounter = 1;
Expand Down Expand Up @@ -579,7 +581,7 @@ static MOJOSHADER_sdlProgram *compile_blob_program(
return program;
} // compile_blob_program

static MOJOSHADER_sdlProgram *compile_spirv_program(
static MOJOSHADER_sdlProgram *compile_program(
MOJOSHADER_sdlContext *ctx,
MOJOSHADER_sdlShaderData *vshader,
MOJOSHADER_sdlShaderData *pshader,
Expand All @@ -595,56 +597,65 @@ static MOJOSHADER_sdlProgram *compile_spirv_program(
return NULL;
} // if

// We have to patch the SPIR-V output to ensure type consistency. The non-float types are:
// BYTE4 - 5
// SHORT2 - 6
// SHORT4 - 7
int vDataLen = vshader->parseData->output_len - sizeof(SpirvPatchTable);
SpirvPatchTable *vTable = (SpirvPatchTable *) &vshader->parseData->output[vDataLen];
size_t vshaderCodeSize = vshader->parseData->output_len;
size_t pshaderCodeSize = pshader->parseData->output_len;

for (int i = 0; i < vertexAttributeCount; i += 1)
if (shader_format == SDL_GPU_SHADERFORMAT_SPIRV)
{
MOJOSHADER_sdlVertexAttribute *element = &vertexAttributes[i];
uint32 typeDecl, typeLoad;
SpvOp opcodeLoad;
// We have to patch the SPIR-V output to ensure type consistency. The non-float types are:
// BYTE4 - 5
// SHORT2 - 6
// SHORT4 - 7
int vDataLen = vshader->parseData->output_len - sizeof(SpirvPatchTable);
SpirvPatchTable *vTable = (SpirvPatchTable *) &vshader->parseData->output[vDataLen];

if (element->vertexElementFormat >= 5 && element->vertexElementFormat <= 7)
for (int i = 0; i < vertexAttributeCount; i += 1)
{
typeDecl = element->vertexElementFormat == 5 ? vTable->tid_uvec4_p : vTable->tid_ivec4_p;
typeLoad = element->vertexElementFormat == 5 ? vTable->tid_uvec4 : vTable->tid_ivec4;
opcodeLoad = element->vertexElementFormat == 5 ? SpvOpConvertUToF : SpvOpConvertSToF;
}
else
{
typeDecl = vTable->tid_vec4_p;
typeLoad = vTable->tid_vec4;
opcodeLoad = SpvOpCopyObject;
}
MOJOSHADER_sdlVertexAttribute *element = &vertexAttributes[i];
uint32 typeDecl, typeLoad;
SpvOp opcodeLoad;

uint32_t typeDeclOffset = vTable->attrib_type_offsets[element->usage][element->usageIndex];
((uint32_t*)vshader->parseData->output)[typeDeclOffset] = typeDecl;
for (uint32_t j = 0; j < vTable->attrib_type_load_offsets[element->usage][element->usageIndex].num_loads; j += 1)
{
uint32_t typeLoadOffset = vTable->attrib_type_load_offsets[element->usage][element->usageIndex].load_types[j];
uint32_t opcodeLoadOffset = vTable->attrib_type_load_offsets[element->usage][element->usageIndex].load_opcodes[j];
uint32_t *ptr_to_opcode_u32 = &((uint32_t*)vshader->parseData->output)[opcodeLoadOffset];
((uint32_t*)vshader->parseData->output)[typeLoadOffset] = typeLoad;
*ptr_to_opcode_u32 = (*ptr_to_opcode_u32 & 0xFFFF0000) | opcodeLoad;
if (element->vertexElementFormat >= 5 && element->vertexElementFormat <= 7)
{
typeDecl = element->vertexElementFormat == 5 ? vTable->tid_uvec4_p : vTable->tid_ivec4_p;
typeLoad = element->vertexElementFormat == 5 ? vTable->tid_uvec4 : vTable->tid_ivec4;
opcodeLoad = element->vertexElementFormat == 5 ? SpvOpConvertUToF : SpvOpConvertSToF;
}
else
{
typeDecl = vTable->tid_vec4_p;
typeLoad = vTable->tid_vec4;
opcodeLoad = SpvOpCopyObject;
}

uint32_t typeDeclOffset = vTable->attrib_type_offsets[element->usage][element->usageIndex];
((uint32_t*)vshader->parseData->output)[typeDeclOffset] = typeDecl;
for (uint32_t j = 0; j < vTable->attrib_type_load_offsets[element->usage][element->usageIndex].num_loads; j += 1)
{
uint32_t typeLoadOffset = vTable->attrib_type_load_offsets[element->usage][element->usageIndex].load_types[j];
uint32_t opcodeLoadOffset = vTable->attrib_type_load_offsets[element->usage][element->usageIndex].load_opcodes[j];
uint32_t *ptr_to_opcode_u32 = &((uint32_t*)vshader->parseData->output)[opcodeLoadOffset];
((uint32_t*)vshader->parseData->output)[typeLoadOffset] = typeLoad;
*ptr_to_opcode_u32 = (*ptr_to_opcode_u32 & 0xFFFF0000) | opcodeLoad;
}
}
}

MOJOSHADER_spirv_link_attributes(vshader->parseData, pshader->parseData, 0);
MOJOSHADER_spirv_link_attributes(vshader->parseData, pshader->parseData, 0);

vshaderCodeSize -= sizeof(SpirvPatchTable);
pshaderCodeSize -= sizeof(SpirvPatchTable);
}

SDL_zero(createInfo);
createInfo.code = (const Uint8*) vshader->parseData->output;
createInfo.code_size = vshader->parseData->output_len - sizeof(SpirvPatchTable);
createInfo.code_size = vshaderCodeSize;
createInfo.entrypoint = vshader->parseData->mainfn;
createInfo.format = SDL_GPU_SHADERFORMAT_SPIRV;
createInfo.format = shader_format;
createInfo.stage = SDL_GPU_SHADERSTAGE_VERTEX;
createInfo.num_samplers = vshader->samplerSlots;
createInfo.num_uniform_buffers = 1;

program->vertexShader = SDL_ShaderCross_CompileGraphicsShaderFromSPIRV(
program->vertexShader = SDL_CreateGPUShader(
ctx->device,
&createInfo
);
Expand All @@ -657,13 +668,13 @@ static MOJOSHADER_sdlProgram *compile_spirv_program(
} // if

createInfo.code = (const Uint8*) pshader->parseData->output;
createInfo.code_size = pshader->parseData->output_len - sizeof(SpirvPatchTable);
createInfo.code_size = pshaderCodeSize;
createInfo.entrypoint = pshader->parseData->mainfn;
createInfo.format = SDL_GPU_SHADERFORMAT_SPIRV;
createInfo.format = shader_format;
createInfo.stage = SDL_GPU_SHADERSTAGE_FRAGMENT;
createInfo.num_samplers = pshader->samplerSlots;

program->pixelShader = SDL_ShaderCross_CompileGraphicsShaderFromSPIRV(
program->pixelShader = SDL_CreateGPUShader(
ctx->device,
&createInfo
);
Expand All @@ -677,7 +688,7 @@ static MOJOSHADER_sdlProgram *compile_spirv_program(
} // if

return program;
} // compile_spirv_program
} // compile_program

MOJOSHADER_sdlProgram *MOJOSHADER_sdlLinkProgram(
MOJOSHADER_sdlContext *ctx,
Expand Down Expand Up @@ -730,8 +741,8 @@ MOJOSHADER_sdlProgram *MOJOSHADER_sdlLinkProgram(
} // if
else
{
program = compile_spirv_program(ctx, vshader, pshader,
vertexAttributes, vertexAttributeCount);
program = compile_program(ctx, vshader, pshader,
vertexAttributes, vertexAttributeCount);
} // else

if (program == NULL)
Expand Down
2 changes: 1 addition & 1 deletion profiles/mojoshader_profile_metal.c
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ void emit_METAL_finalize(Context *ctx)
output_line(ctx, "};");
pop_output(ctx);

output_line(ctx, "constant %s_Uniforms &uniforms [[buffer(16)]]%s", ctx->mainfn, commas ? "," : "");
output_line(ctx, "constant %s_Uniforms &uniforms [[buffer(0)]]%s", ctx->mainfn, commas ? "," : "");
commas--;
} // if

Expand Down
Loading

0 comments on commit 32ae391

Please sign in to comment.