Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added required_subgroup_size to PipelineShaderStageCreateInfo #2235

Merged
merged 16 commits into from
Aug 18, 2023
Merged
44 changes: 43 additions & 1 deletion vulkano-shaders/src/entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,49 @@ fn write_shader_execution(execution: &ShaderExecution) -> TokenStream {
)
}
}
ShaderExecution::Compute => quote! { ::vulkano::shader::ShaderExecution::Compute },
ShaderExecution::Compute(execution) => {
use ::quote::ToTokens;
use ::vulkano::shader::{ComputeShaderExecution, LocalSize};

struct LocalSizeToTokens(LocalSize);

impl ToTokens for LocalSizeToTokens {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self.0 {
LocalSize::Literal(literal) => quote! {
::vulkano::shader::LocalSize::Literal(#literal)
},
LocalSize::SpecId(id) => quote! {
::vulkano::shader::LocalSize::SpecId(#id)
},
}
.to_tokens(tokens);
}
}

match execution {
ComputeShaderExecution::LocalSize([x, y, z]) => {
let [x, y, z] = [
LocalSizeToTokens(*x),
LocalSizeToTokens(*y),
LocalSizeToTokens(*z),
];
quote! { ::vulkano::shader::ShaderExecution::Compute(
::vulkano::shader::ComputeShaderExecution::LocalSize([#x, #y, #z])
) }
}
ComputeShaderExecution::LocalSizeId([x, y, z]) => {
let [x, y, z] = [
LocalSizeToTokens(*x),
LocalSizeToTokens(*y),
LocalSizeToTokens(*z),
];
quote! { ::vulkano::shader::ShaderExecution::Compute(
::vulkano::shader::ComputeShaderExecution::LocalSizeId([#x, #y, #z])
) }
}
}
}
ShaderExecution::RayGeneration => {
quote! { ::vulkano::shader::ShaderExecution::RayGeneration }
}
Expand Down
151 changes: 148 additions & 3 deletions vulkano/src/pipeline/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,9 @@ impl ComputePipeline {
if let Some(cache) = &cache {
assert_eq!(device, cache.device().as_ref());
}

create_info
.validate(device)
.map_err(|err| err.add_context("create_info"))?;

Ok(())
}

Expand All @@ -100,12 +98,14 @@ impl ComputePipeline {
let specialization_info_vk;
let specialization_map_entries_vk: Vec<_>;
let mut specialization_data_vk: Vec<u8>;
let required_subgroup_size_create_info;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the variables holding data to be passed to Vulkan end in _vk. Could you include that as well please?


{
let &PipelineShaderStageCreateInfo {
flags,
ref entry_point,
ref specialization_info,
ref required_subgroup_size,
_ne: _,
} = stage;

Expand Down Expand Up @@ -135,7 +135,20 @@ impl ComputePipeline {
data_size: specialization_data_vk.len(),
p_data: specialization_data_vk.as_ptr() as *const _,
};
required_subgroup_size_create_info =
required_subgroup_size.map(|required_subgroup_size| {
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
required_subgroup_size,
..Default::default()
}
});
stage_vk = ash::vk::PipelineShaderStageCreateInfo {
p_next: required_subgroup_size_create_info.as_ref().map_or(
ptr::null(),
|required_subgroup_size_create_info| {
required_subgroup_size_create_info as *const _ as _
},
),
flags: flags.into(),
stage: ShaderStage::from(&entry_point_info.execution).into(),
module: entry_point.module().handle(),
Expand Down Expand Up @@ -336,12 +349,13 @@ impl ComputePipelineCreateInfo {
flags: _,
ref entry_point,
specialization_info: _,
required_subgroup_size: _vk,
_ne: _,
} = &stage;

let entry_point_info = entry_point.info();

if !matches!(entry_point_info.execution, ShaderExecution::Compute) {
if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) {
return Err(Box::new(ValidationError {
context: "stage.entry_point".into(),
problem: "is not a `ShaderStage::Compute` entry point".into(),
Expand Down Expand Up @@ -515,4 +529,135 @@ mod tests {
let data_buffer_content = data_buffer.read().unwrap();
assert_eq!(*data_buffer_content, 0x12345678);
}

#[test]
fn required_subgroup_size() {
// This test checks whether required_subgroup_size works.
// It executes a single compute shader (one invocation) that writes the subgroup size
// to a buffer. The buffer content is then checked for the right value.

let (device, queue) = gfx_dev_and_queue!(subgroup_size_control);

let cs = unsafe {
/*
#version 450

#extension GL_KHR_shader_subgroup_basic: enable

layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;

layout(set = 0, binding = 0) buffer Output {
uint write;
} write;

void main() {
if (gl_GlobalInvocationID.x == 0) {
write.write = gl_SubgroupSize;
}
}
*/
const MODULE: [u32; 246] = [
119734787, 65536, 851978, 30, 0, 131089, 1, 131089, 61, 393227, 1, 1280527431,
1685353262, 808793134, 0, 196622, 0, 1, 458767, 5, 4, 1852399981, 0, 9, 23, 393232,
4, 17, 128, 1, 1, 196611, 2, 450, 655364, 1197427783, 1279741775, 1885560645,
1953718128, 1600482425, 1701734764, 1919509599, 1769235301, 25974, 524292,
1197427783, 1279741775, 1852399429, 1685417059, 1768185701, 1952671090, 6649449,
589828, 1264536647, 1935626824, 1701077352, 1970495346, 1869768546, 1650421877,
1667855201, 0, 262149, 4, 1852399981, 0, 524293, 9, 1197436007, 1633841004,
1986939244, 1952539503, 1231974249, 68, 262149, 18, 1886680399, 29813, 327686, 18,
0, 1953067639, 101, 262149, 20, 1953067639, 101, 393221, 23, 1398762599,
1919378037, 1399879023, 6650473, 262215, 9, 11, 28, 327752, 18, 0, 35, 0, 196679,
18, 3, 262215, 20, 34, 0, 262215, 20, 33, 0, 196679, 23, 0, 262215, 23, 11, 36,
196679, 24, 0, 262215, 29, 11, 25, 131091, 2, 196641, 3, 2, 262165, 6, 32, 0,
262167, 7, 6, 3, 262176, 8, 1, 7, 262203, 8, 9, 1, 262187, 6, 10, 0, 262176, 11, 1,
6, 131092, 14, 196638, 18, 6, 262176, 19, 2, 18, 262203, 19, 20, 2, 262165, 21, 32,
1, 262187, 21, 22, 0, 262203, 11, 23, 1, 262176, 25, 2, 6, 262187, 6, 27, 128,
262187, 6, 28, 1, 393260, 7, 29, 27, 28, 28, 327734, 2, 4, 0, 3, 131320, 5, 327745,
11, 12, 9, 10, 262205, 6, 13, 12, 327850, 14, 15, 13, 10, 196855, 17, 0, 262394,
15, 16, 17, 131320, 16, 262205, 6, 24, 23, 327745, 25, 26, 20, 22, 196670, 26, 24,
131321, 17, 131320, 17, 65789, 65592,
];
let module =
ShaderModule::new(device.clone(), ShaderModuleCreateInfo::new(&MODULE)).unwrap();
module.entry_point("main").unwrap()
};

let properties = device.physical_device().properties();
let subgroup_size = properties.min_subgroup_size.unwrap_or(1);

let pipeline = {
let stage = PipelineShaderStageCreateInfo {
required_subgroup_size: Some(subgroup_size),
..PipelineShaderStageCreateInfo::new(cs)
};
let layout = PipelineLayout::new(
device.clone(),
PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage])
.into_pipeline_layout_create_info(device.clone())
.unwrap(),
)
.unwrap();
ComputePipeline::new(
device.clone(),
None,
ComputePipelineCreateInfo::stage_layout(stage, layout),
)
.unwrap()
};

let memory_allocator = StandardMemoryAllocator::new_default(device.clone());
let data_buffer = Buffer::from_data(
&memory_allocator,
BufferCreateInfo {
usage: BufferUsage::STORAGE_BUFFER,
..Default::default()
},
AllocationCreateInfo {
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE
| MemoryTypeFilter::HOST_RANDOM_ACCESS,
..Default::default()
},
0,
)
.unwrap();

let ds_allocator = StandardDescriptorSetAllocator::new(device.clone());
let set = PersistentDescriptorSet::new(
&ds_allocator,
pipeline.layout().set_layouts().get(0).unwrap().clone(),
[WriteDescriptorSet::buffer(0, data_buffer.clone())],
[],
)
.unwrap();

let cb_allocator = StandardCommandBufferAllocator::new(device.clone(), Default::default());
let mut cbb = AutoCommandBufferBuilder::primary(
&cb_allocator,
queue.queue_family_index(),
CommandBufferUsage::OneTimeSubmit,
)
.unwrap();
cbb.bind_pipeline_compute(pipeline.clone())
.unwrap()
.bind_descriptor_sets(
PipelineBindPoint::Compute,
pipeline.layout().clone(),
0,
set,
)
.unwrap()
.dispatch([128, 1, 1])
.unwrap();
let cb = cbb.build().unwrap();

let future = now(device)
.then_execute(queue, cb)
.unwrap()
.then_signal_fence_and_flush()
.unwrap();
future.wait(None).unwrap();

let data_buffer_content = data_buffer.read().unwrap();
assert_eq!(*data_buffer_content, subgroup_size);
}
}
21 changes: 19 additions & 2 deletions vulkano/src/pipeline/graphics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ impl GraphicsPipeline {
create_info
.validate(device)
.map_err(|err| err.add_context("create_info"))?;

Ok(())
}

Expand Down Expand Up @@ -204,6 +203,8 @@ impl GraphicsPipeline {
specialization_info_vk: ash::vk::SpecializationInfo,
specialization_map_entries_vk: Vec<ash::vk::SpecializationMapEntry>,
specialization_data_vk: Vec<u8>,
required_subgroup_size_create_info:
Option<ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo>,
}

let (mut stages_vk, mut per_stage_vk): (SmallVec<[_; 5]>, SmallVec<[_; 5]>) = stages
Expand All @@ -213,6 +214,7 @@ impl GraphicsPipeline {
flags,
ref entry_point,
ref specialization_info,
ref required_subgroup_size,
_ne: _,
} = stage;

Expand All @@ -235,7 +237,13 @@ impl GraphicsPipeline {
}
})
.collect();

let required_subgroup_size_create_info =
required_subgroup_size.map(|required_subgroup_size| {
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
required_subgroup_size,
..Default::default()
}
});
(
ash::vk::PipelineShaderStageCreateInfo {
flags: flags.into(),
Expand All @@ -255,6 +263,7 @@ impl GraphicsPipeline {
},
specialization_map_entries_vk,
specialization_data_vk,
required_subgroup_size_create_info,
},
)
})
Expand All @@ -267,10 +276,17 @@ impl GraphicsPipeline {
specialization_info_vk,
specialization_map_entries_vk,
specialization_data_vk,
required_subgroup_size_create_info,
},
) in (stages_vk.iter_mut()).zip(per_stage_vk.iter_mut())
{
*stage_vk = ash::vk::PipelineShaderStageCreateInfo {
p_next: required_subgroup_size_create_info.as_ref().map_or(
ptr::null(),
|required_subgroup_size_create_info| {
required_subgroup_size_create_info as *const _ as _
},
),
p_name: name_vk.as_ptr(),
p_specialization_info: specialization_info_vk,
..*stage_vk
Expand Down Expand Up @@ -2423,6 +2439,7 @@ impl GraphicsPipelineCreateInfo {
flags: _,
ref entry_point,
specialization_info: _,
required_subgroup_size: _vk,
_ne: _,
} = stage;

Expand Down
Loading