Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subgroup intrinsics #1151

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Next Next commit
add trait VectorOrScalar, representing either a vector or a scalar type
  • Loading branch information
Firestar99 committed Jun 11, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 9080fc8a456c65c0205ac4ba65a9f352a40470e8
4 changes: 4 additions & 0 deletions crates/spirv-std/src/float.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Traits and helper functions related to floats.

use crate::scalar::VectorOrScalar;
use crate::vector::Vector;
#[cfg(target_arch = "spirv")]
use core::arch::asm;
@@ -71,6 +72,9 @@ struct F32x2 {
x: f32,
y: f32,
}
unsafe impl VectorOrScalar for F32x2 {
type Scalar = f32;
}
unsafe impl Vector<f32, 2> for F32x2 {}

/// Converts an f32 (float) into an f16 (half). The result is a u32, not a u16, due to GPU support
49 changes: 48 additions & 1 deletion crates/spirv-std/src/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,58 @@
//! Traits related to scalars.

/// Abstract trait representing either a vector or a scalar type.
///
/// # Safety
/// Implementing this trait on non-scalar or non-vector types may break assumptions about other
/// unsafe code, and should not be done.
pub unsafe trait VectorOrScalar: Default {
/// Either the scalar component type of the vector or the scalar itself.
type Scalar: Scalar;
}

unsafe impl VectorOrScalar for bool {
type Scalar = bool;
}
unsafe impl VectorOrScalar for f32 {
type Scalar = f32;
}
unsafe impl VectorOrScalar for f64 {
type Scalar = f64;
}
unsafe impl VectorOrScalar for u8 {
type Scalar = u8;
}
unsafe impl VectorOrScalar for u16 {
type Scalar = u16;
}
unsafe impl VectorOrScalar for u32 {
type Scalar = u32;
}
unsafe impl VectorOrScalar for u64 {
type Scalar = u64;
}
unsafe impl VectorOrScalar for i8 {
type Scalar = i8;
}
unsafe impl VectorOrScalar for i16 {
type Scalar = i16;
}
unsafe impl VectorOrScalar for i32 {
type Scalar = i32;
}
unsafe impl VectorOrScalar for i64 {
type Scalar = i64;
}

/// Abstract trait representing a SPIR-V scalar type.
///
/// # Safety
/// Implementing this trait on non-scalar types breaks assumptions of other unsafe code, and should
/// not be done.
pub unsafe trait Scalar: Copy + Default + crate::sealed::Sealed {}
pub unsafe trait Scalar:
VectorOrScalar<Scalar = Self> + Copy + Default + crate::sealed::Sealed
{
}

unsafe impl Scalar for bool {}
unsafe impl Scalar for f32 {}
48 changes: 46 additions & 2 deletions crates/spirv-std/src/vector.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,57 @@
//! Traits related to vectors.

use crate::scalar::{Scalar, VectorOrScalar};
use glam::{Vec3Swizzles, Vec4Swizzles};

unsafe impl VectorOrScalar for glam::Vec2 {
type Scalar = f32;
}
unsafe impl VectorOrScalar for glam::Vec3 {
type Scalar = f32;
}
unsafe impl VectorOrScalar for glam::Vec3A {
type Scalar = f32;
}
unsafe impl VectorOrScalar for glam::Vec4 {
type Scalar = f32;
}

unsafe impl VectorOrScalar for glam::DVec2 {
type Scalar = f64;
}
unsafe impl VectorOrScalar for glam::DVec3 {
type Scalar = f64;
}
unsafe impl VectorOrScalar for glam::DVec4 {
type Scalar = f64;
}

unsafe impl VectorOrScalar for glam::UVec2 {
type Scalar = u32;
}
unsafe impl VectorOrScalar for glam::UVec3 {
type Scalar = u32;
}
unsafe impl VectorOrScalar for glam::UVec4 {
type Scalar = u32;
}

unsafe impl VectorOrScalar for glam::IVec2 {
type Scalar = i32;
}
unsafe impl VectorOrScalar for glam::IVec3 {
type Scalar = i32;
}
unsafe impl VectorOrScalar for glam::IVec4 {
type Scalar = i32;
}

/// Abstract trait representing a SPIR-V vector type.
///
/// # Safety
/// Implementing this trait on non-simd-vector types breaks assumptions of other unsafe code, and
/// should not be done.
pub unsafe trait Vector<T: crate::scalar::Scalar, const N: usize>: Default {}
pub unsafe trait Vector<T: Scalar, const N: usize>: VectorOrScalar<Scalar = T> {}

unsafe impl Vector<f32, 2> for glam::Vec2 {}
unsafe impl Vector<f32, 3> for glam::Vec3 {}
@@ -27,7 +71,7 @@ unsafe impl Vector<i32, 3> for glam::IVec3 {}
unsafe impl Vector<i32, 4> for glam::IVec4 {}

/// Trait that implements slicing of a vector into a scalar or vector of lower dimensions, by
/// ignoring the highter dimensions
/// ignoring the higter dimensions
pub trait VectorTruncateInto<T> {
/// Slices the vector into a lower dimensional type by ignoring the higher components
fn truncate_into(self) -> T;
5 changes: 4 additions & 1 deletion tests/ui/arch/all.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
#![feature(repr_simd)]

use spirv_std::spirv;
use spirv_std::{scalar::Scalar, vector::Vector};
use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};

/// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members
/// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()`
@@ -12,6 +12,9 @@ use spirv_std::{scalar::Scalar, vector::Vector};
/// it (for now at least)
#[repr(simd)]
struct Vec2<T>(T, T);
unsafe impl<T: Scalar> VectorOrScalar for Vec2<T> {
type Scalar = T;
}
unsafe impl<T: Scalar> Vector<T, 2> for Vec2<T> {}

impl<T: Scalar> Default for Vec2<T> {
5 changes: 4 additions & 1 deletion tests/ui/arch/any.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
#![feature(repr_simd)]

use spirv_std::spirv;
use spirv_std::{scalar::Scalar, vector::Vector};
use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};

/// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members
/// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()`
@@ -12,6 +12,9 @@ use spirv_std::{scalar::Scalar, vector::Vector};
/// it (for now at least)
#[repr(simd)]
struct Vec2<T>(T, T);
unsafe impl<T: Scalar> VectorOrScalar for Vec2<T> {
type Scalar = T;
}
unsafe impl<T: Scalar> Vector<T, 2> for Vec2<T> {}

impl<T: Scalar> Default for Vec2<T> {