Skip to content

Commit

Permalink
Merge pull request #19 from charles-r-earp/subgroup-threads-range
Browse files Browse the repository at this point in the history
add (min/max)_subgroup_threads to DeviceInfo
  • Loading branch information
charles-r-earp authored Feb 27, 2024
2 parents 0d431b8 + 1f9bf30 commit b967015
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 120 deletions.
4 changes: 2 additions & 2 deletions benches/compute-benches/krnl-cache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__krnl_cache!("0.1.0-alpha", "
abZy8000000@)b6zbTwm5xwZXCOSRsmkD?uWxi5f5pt@:^3xC?mo2{eT./.T5qp:W41Ppv/baw}@s57]EKdji=/slS:)WfsJcw(P=I&4ZZKbXxj^j-}}hywu9uc8Wd!SfS1#KFaS{OsE56A<yXMWYth(rD}{P6M#gB?^TT6qr-{P9zkZ(vMRE/90O8i/}:>NpuYh4XmW8vUQ>Jp!ZLjNjM:!1E]F4ly1^vqf<#Tz$E4Cj2do-}GZ{cBWHL2#X6zY1I^P(HyWa&*?o0b5c1[jmo[7Dp-%G}Ig+ap5Rr39HpD<-ZcwlomB9dSNJ&3o1qhxXvr3?)}}%f.*okV377mtZ2?G]XH*#z>Fa*$U]3P>u#-z6Gc!tf43?2?yOeLuHGH$UWb&/]{x:$:H/st5<ko>GFa-axiFi[&ksG%jyOb.COqynAYWsu}:bB{61k7o2/qqS>EoLZc<KMbV6zUsrGy5oMo6g2qgt0{RDERuWLFFTk6=N}!q#voXHP0Je5+qB{j3#lYu%itKG4KI*rLKDMYftt{>qkt>cvZQ=jB-9boP}XjkcH(ykfO#W8ncAa0ZtHyT)6[E4vRVEx29.BSCfryha:AY[giyTm4IvK:k1s>:lKK4HU6o-)UXk:*)p=iXeY/9tN4@hT^q4r8Vp2JR+I1yiq)1U+Ulfz(Pqe0rGFfFY&gBWv3UbRhx(}d{Y0bl{4*71E<&72<ePT7@%1})a25ecAoV#c!t6Z(mSg9boRe&0waLddJ9>fh?nFEOzw*i}-{T<.?FZNiOIL-z/-4kaWFeNwl?bKP7qb:cjze)<f(0r)NwB&kF5Rx+LkCbvXv%+7%4jp(s5fHA:un9Su7Yj}TksDi^*N$qRK/*kC{y!hE!g!t0@cl<J+O7fLv>wNP?O)g^*Llx/QP1A/pzj=VPDHmf0D%W7X2#O@(w{4VO5Em^WHrKqqTDOpzoV4)#)Y6Y?#J<w>M{Jc:Cq]t&9.N{5RAeH<-W/))Da6N0#i9{Pa0/
KU{PHp3+V]?zzGa6ry&0mB)sC@A<{WZCaXv++PZxZJaKfU<k0FXbuP4}SP%sO0&8>[+Tt@UKr1%)D/O8kdbr+mawXO(F7x?VgTx((Ovav?ySL(vPya%X}uCej&g&3[%uPTjDEPsvp?6^P6dz3xX@R1&Y0l#H$QgZz/+[J++e$<&Ae4K9XkadU@/)V-ibP1F08-o>q}xG[t%4Zd#5l:)JPqVHw8+Amr*#wgcA:)gV^0pZyNHktJ$&k%iFHXl1]9r9YM=@C%180-5@uZ=@ZVIVax$G-]xjf@lUcM3JHeb
abZy8000000@)b6)nxAN5xvP6CRoI-RoNI%x@!Wh7Z+W845*?mWV[jJ5V$)vlrYFjwc*z{h:qJ$uw0Z$MB)Hh*uC#Rc+>(kwxJbp=I&4ZZKbXvB0*<*]/1QOv%?:!rbH#^vNaC8<[:C^>=+XprCHS^n]]r8[1rZvWO2U{n]m!t+0Adc6$vyczK*RuzV}DDM+cJ}Szp5^l{&P{vYVia>2{>H@#c6-?/TjV.T=BKC>k)3@n$9/Jurn07pN@zB17lH?f%7E7vXbQ&[U!wb5c1[DV51bGquBM@IdDaRG7bvIotG{]ru]NceWBFsl<7!yp7RuD%A3354mb!qkryCGWj}/uuq-{*.l/R]iV8sspu-{USb:C.4HbT21%MAWnk/fn37Cwi}s37}&[f/-hS^f[GidlZXWRYOjuwGPc.j-Y?!GM=CV-i:(>[uv[9%3!7xn^?>wVyEnn<Z<KMa]6.$1gl*TF0e17+G:XW+6>#1UUpRcqMd(FQF}$IFw4]XbuQVsAqCaKeW.lS.86Tn[r!U4$:}-:bW>nVTOa+b7Z0!k@C+-C4XZm.x.keJl?7}*f5628Z]T)5h[4vIPDw*3pNSa<=Nha:AY[giz2l%WU.:k1s[ZO1@)96%QXXA?2mAuxJ<&DaiAA1hZ{Wo?(ky4Q-vO&i#R&s6udgCZsQU6i[GT0k3HDUkvxw0QCShy-Hl}#q]L27qD4kAU#m&%4Y)v3x-D@1T2Pt(jL$!t6Z(nPcAeoRe?lvgdK[JBgoi?n6)8)I-?F-{UZ1.eF65t@m0^/-4b7WFcqSl?bKP>kHTAjM(MFv:)N(gdQkz@#Z.73/ptp&qgrf@vSlND{67V3IdUnB]tKws]D6zO[r?EgXF%<=R6Vg]&>?(7qVu>z+M7kFJSL/IYIT(HH89&-PzO=@R*6@Nb4{TB=*!MYNWiV{Ijw/FPW8$Iqw%gJR.H!DXKJ8%a71T>-bA<gLn(WlY&aCF8>F*V3jbIcIBZP?Zuw1h#M#s{Pa0B
K-+c+oZBZR?z8o76ry?Bmk/$L@v-Ka?.Bfm++PZx?$825U<k0FXbv2g}SP%sGpgLO[+S]5Sp8be)D^rtkdbr^L=eJH!viyQw9UVp![87vYXz6y-hXcW@rxVADwZ4I<s+0ND7Hdd+fi%Qj/9vgK?7kH-bvy#c1YV[ZqG=Igk0.DGK.{?2cn9l-{KDY5XzE(@.hg8OeVBJ.N59)UnVmXXqP<V8@uu?7Vt/=eaxSVxuD{O@%h(YsjlGF&G(>9FZw>[18%=r]k5f&-r7u{=@.kYVhiv>-]vI4kLUu@3JHeb
");
1 change: 0 additions & 1 deletion benches/compute-benches/src/krnl_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ pub struct Saxpy {
impl Saxpy {
pub fn run(&mut self) -> Result<()> {
kernels::saxpy::builder()?
.with_threads(256)
.build(self.device.clone())?
.dispatch(
self.x_device.as_slice(),
Expand Down
148 changes: 74 additions & 74 deletions krnl-cache.rs

Large diffs are not rendered by default.

16 changes: 6 additions & 10 deletions krnl-core/src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ pub mod __private {
use super::{ItemKernel, Kernel};

pub struct KernelArgs {
pub global_threads: u32,
pub global_id: u32,
pub groups: u32,
pub group_id: u32,
pub subgroups: u32,
pub subgroup_id: u32,
pub subgroup_threads: u32,
//pub subgroup_threads: u32,
pub subgroup_thread_id: u32,
pub threads: u32,
pub thread_id: u32,
Expand All @@ -21,25 +20,24 @@ pub mod __private {
#[inline]
pub unsafe fn into_kernel(self) -> Kernel {
let Self {
global_threads,
global_id,
groups,
group_id,
subgroups,
subgroup_id,
subgroup_threads,
//subgroup_threads,
subgroup_thread_id,
threads,
thread_id,
} = self;
Kernel {
global_threads,
global_threads: groups * threads,
global_id,
groups,
group_id,
subgroups,
subgroup_id,
subgroup_threads,
//subgroup_threads,
subgroup_thread_id,
threads,
thread_id,
Expand Down Expand Up @@ -102,8 +100,7 @@ pub struct Kernel {
#[doc(hidden)]
#[deprecated(since = "0.0.4", note = "replaced with subgroup_id()")]
pub subgroup_id: u32,
#[allow(unused)]
subgroup_threads: u32,
//subgroup_threads: u32,
#[doc(hidden)]
#[deprecated(since = "0.0.4", note = "replaced with subgroup_thread_id()")]
pub subgroup_thread_id: u32,
Expand Down Expand Up @@ -151,8 +148,7 @@ impl Kernel {
pub fn subgroup_id(&self) -> usize {
self.subgroup_id as usize
}
// TODO: Intel Mesa driver uses variable subgroup size
// Fixed in https://github.com/charles-r-earp/krnl/tree/update-vulkano
// TODO: Potentially implement via subgroup ballot / reduce operation
/*
/// The number of threads per subgroup.
#[inline]
Expand Down
5 changes: 0 additions & 5 deletions krnl-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1375,9 +1375,6 @@ fn kernel_impl(item_tokens: TokenStream2) -> Result<TokenStream2> {
#[spirv(subgroup_id)]
__krnl_subgroup_id: u32,
#[allow(unused)]
#[spirv(subgroup_size)]
__krnl_subgroup_threads: u32,
#[allow(unused)]
#[spirv(subgroup_local_invocation_id)]
__krnl_subgroup_thread_id: u32,
#[allow(unused)]
Expand All @@ -1404,13 +1401,11 @@ fn kernel_impl(item_tokens: TokenStream2) -> Result<TokenStream2> {
#declare_threads
let mut kernel = unsafe {
::krnl_core::kernel::__private::KernelArgs {
global_threads: __krnl_groups.x * __krnl_threads,
global_id: __krnl_global_id.x,
groups: __krnl_groups.x,
group_id: __krnl_group_id.x,
subgroups: __krnl_subgroups,
subgroup_id: __krnl_subgroup_id,
subgroup_threads: __krnl_subgroup_threads,
subgroup_thread_id: __krnl_subgroup_thread_id,
threads: __krnl_threads,
thread_id: __krnl_thread_id.x,
Expand Down
7 changes: 3 additions & 4 deletions krnlc/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ fn add_spec_constant_ops(module: &mut rspirv::dr::Module) {
| Op::UMod
| Op::SRem
| Op::SMod
| Op::ShiftRightLogical
/* | Op::ShiftRightLogical
| Op::ShiftRightArithmetic
| Op::ShiftLeftLogical
| Op::BitwiseOr
Expand All @@ -1261,7 +1261,7 @@ fn add_spec_constant_ops(module: &mut rspirv::dr::Module) {
| Op::LogicalAnd
| Op::LogicalNot
| Op::LogicalEqual
| Op::LogicalNotEqual
| Op::LogicalNotEqual */
| Op::Select
| Op::IEqual
| Op::INotEqual
Expand All @@ -1272,8 +1272,7 @@ fn add_spec_constant_ops(module: &mut rspirv::dr::Module) {
| Op::ULessThanEqual
| Op::SLessThanEqual
| Op::UGreaterThanEqual
| Op::SGreaterThanEqual
| Op::QuantizeToF16
| Op::SGreaterThanEqual /* | Op::QuantizeToF16 */
) {
if let Some(result_id) = inst.result_id {
let mut used_constants = FxHashSet::default();
Expand Down
29 changes: 21 additions & 8 deletions src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@ pub struct DeviceInfo {
vendor_id: u32,
max_groups: u32,
max_threads: u32,
subgroup_threads: u32,
min_subgroup_threads: u32,
max_subgroup_threads: u32,
features: Features,
debug_printf: bool,
}
Expand All @@ -544,14 +545,26 @@ impl DeviceInfo {
pub fn max_threads(&self) -> u32 {
self.max_threads
}
// TODO: Intel Mesa driver uses variable subgroup size
// Fixed in https://github.com/charles-r-earp/krnl/tree/update-vulkano
/*
/// Subgroup threads.
pub fn subgroup_threads(&self) -> u32 {
self.subgroup_threads
/// Min threads per subgroup.
///
/// Power of 2 between 1 and 128.
///
/// For `subgroup_threads` between `min_subgroup_threads`
/// and `max_subgroup_threads`, each subgroup in a group will have
/// `subgroup_threads` threads, unless `threads` per group is not an exact
/// multiple, where the last subgroup will have the remainder of threads.
///```text
/// subgroups * subgroup_threads >= threads
///```
pub fn min_subgroup_threads(&self) -> u32 {
self.min_subgroup_threads
}
/// Max threads per subgroup.
///
/// Power of 2 between 1 and 128.
pub fn max_subgroup_threads(&self) -> u32 {
self.max_subgroup_threads
}
*/
/// Device features.
pub fn features(&self) -> Features {
self.features
Expand Down
14 changes: 13 additions & 1 deletion src/device/vulkan_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ impl DeviceEngine for Engine {
let name = physical_device.properties().device_name.clone();
let optimal_device_extensions = vulkano::device::DeviceExtensions {
khr_vulkan_memory_model: true,
ext_subgroup_size_control: true,
..vulkano::device::DeviceExtensions::empty()
};
let device_extensions = physical_device
Expand All @@ -215,6 +216,7 @@ impl DeviceEngine for Engine {
let optimal_device_features = vulkano::device::Features {
vulkan_memory_model: true,
timeline_semaphore: true,
subgroup_size_control: true,
shader_int8: optimal_features.shader_int8,
shader_int16: optimal_features.shader_int16,
shader_int64: optimal_features.shader_int64,
Expand Down Expand Up @@ -299,14 +301,24 @@ impl DeviceEngine for Engine {
}
let kernels = DashMap::default();
let properties = device.physical_device().properties();
let (min_subgroup_threads, max_subgroup_threads) = if device_features.subgroup_size_control
{
(
properties.min_subgroup_size.unwrap_or(1),
properties.max_subgroup_size.unwrap_or(128),
)
} else {
(1, 128)
};
let info = Arc::new(DeviceInfo {
index,
name,
device_id: properties.device_id,
vendor_id: properties.vendor_id,
max_groups: properties.max_compute_work_group_count[0],
max_threads: properties.max_compute_work_group_size[0],
subgroup_threads: properties.subgroup_size.unwrap(),
min_subgroup_threads,
max_subgroup_threads,
features,
debug_printf,
});
Expand Down
3 changes: 1 addition & 2 deletions src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1162,8 +1162,7 @@ pub mod __private {
let groups = items / threads + u32::from(items % threads != 0);
groups.min(max_groups)
} else {
#[cfg(debug_assertions)]
unreachable!("groups not provided!");
unreachable!("groups not provided!")
};
let debug_printf_panic = if info.debug_printf() {
Some(Arc::new(AtomicBool::default()))
Expand Down
Loading

0 comments on commit b967015

Please sign in to comment.