Skip to content

Commit

Permalink
Implement extensions interface as described in #691
Browse files Browse the repository at this point in the history
  • Loading branch information
cwfitzgerald committed Jun 6, 2020
1 parent 581863a commit f32cb10
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 26 deletions.
1 change: 1 addition & 0 deletions player/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ fn main() {
#[cfg(not(feature = "winit"))]
compatible_surface: None,
},
unsafe { wgt::UnsafeExtensions::allow() },
wgc::instance::AdapterInputs::IdSet(
&[wgc::id::TypedId::zip(0, 0, backend)],
|id| id.backend(),
Expand Down
79 changes: 56 additions & 23 deletions wgpu-core/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,38 @@ pub struct Surface {
#[derive(Debug)]
pub struct Adapter<B: hal::Backend> {
pub(crate) raw: hal::adapter::Adapter<B>,
extensions: wgt::Extensions,
limits: wgt::Limits,
unsafe_extensions: wgt::UnsafeExtensions,
life_guard: LifeGuard,
}

impl<B: hal::Backend> Adapter<B> {
fn new(raw: hal::adapter::Adapter<B>) -> Self {
fn new(raw: hal::adapter::Adapter<B>, unsafe_extensions: wgt::UnsafeExtensions) -> Self {
let adapter_features = raw.physical_device.features();

let mut extensions = wgt::Extensions::default();
extensions.set(
wgt::Extensions::ANISOTROPIC_FILTERING,
adapter_features.contains(hal::Features::SAMPLER_ANISOTROPY),
);
if unsafe_extensions.allowed() {
// Unsafe extensions go here
}

let adapter_limits = raw.physical_device.limits();

let limits = wgt::Limits {
max_bind_groups: (adapter_limits.max_bound_descriptor_sets as u32)
.min(MAX_BIND_GROUPS as u32),
_non_exhaustive: unsafe { wgt::NonExhaustive::new() },
};

Adapter {
raw,
extensions,
limits,
unsafe_extensions,
life_guard: LifeGuard::new(),
}
}
Expand Down Expand Up @@ -251,7 +276,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
self.surfaces.register_identity(id_in, surface, &mut token)
}

pub fn enumerate_adapters(&self, inputs: AdapterInputs<Input<G, AdapterId>>) -> Vec<AdapterId> {
pub fn enumerate_adapters(
&self,
unsafe_extensions: wgt::UnsafeExtensions,
inputs: AdapterInputs<Input<G, AdapterId>>,
) -> Vec<AdapterId> {
let instance = &self.instance;
let mut token = Token::root();
let mut adapters = Vec::new();
Expand All @@ -264,7 +293,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
if let Some(ref inst) = instance.vulkan {
if let Some(id_vulkan) = inputs.find(Backend::Vulkan) {
for raw in inst.enumerate_adapters() {
let adapter = Adapter::new(raw);
let adapter = Adapter::new(raw, unsafe_extensions);
log::info!("Adapter Vulkan {:?}", adapter.raw.info);
adapters.push(backend::Vulkan::hub(self).adapters.register_identity(
id_vulkan.clone(),
Expand All @@ -279,7 +308,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
{
if let Some(id_metal) = inputs.find(Backend::Metal) {
for raw in instance.metal.enumerate_adapters() {
let adapter = Adapter::new(raw);
let adapter = Adapter::new(raw, unsafe_extensions);
log::info!("Adapter Metal {:?}", adapter.raw.info);
adapters.push(backend::Metal::hub(self).adapters.register_identity(
id_metal.clone(),
Expand All @@ -294,7 +323,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
if let Some(ref inst) = instance.dx12 {
if let Some(id_dx12) = inputs.find(Backend::Dx12) {
for raw in inst.enumerate_adapters() {
let adapter = Adapter::new(raw);
let adapter = Adapter::new(raw, unsafe_extensions);
log::info!("Adapter Dx12 {:?}", adapter.raw.info);
adapters.push(backend::Dx12::hub(self).adapters.register_identity(
id_dx12.clone(),
Expand All @@ -307,7 +336,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

if let Some(id_dx11) = inputs.find(Backend::Dx11) {
for raw in instance.dx11.enumerate_adapters() {
let adapter = Adapter::new(raw);
let adapter = Adapter::new(raw, unsafe_extensions);
log::info!("Adapter Dx11 {:?}", adapter.raw.info);
adapters.push(backend::Dx11::hub(self).adapters.register_identity(
id_dx11.clone(),
Expand All @@ -324,6 +353,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
pub fn pick_adapter(
&self,
desc: &RequestAdapterOptions,
unsafe_extensions: wgt::UnsafeExtensions,
inputs: AdapterInputs<Input<G, AdapterId>>,
) -> Option<AdapterId> {
let instance = &self.instance;
Expand Down Expand Up @@ -462,7 +492,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
))]
{
if selected < adapters_vk.len() {
let adapter = Adapter::new(adapters_vk.swap_remove(selected));
let adapter = Adapter::new(adapters_vk.swap_remove(selected), unsafe_extensions);
log::info!("Adapter Vulkan {:?}", adapter.raw.info);
let id = backend::Vulkan::hub(self).adapters.register_identity(
id_vulkan.unwrap(),
Expand All @@ -476,7 +506,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
#[cfg(any(target_os = "ios", target_os = "macos"))]
{
if selected < adapters_mtl.len() {
let adapter = Adapter::new(adapters_mtl.swap_remove(selected));
let adapter = Adapter::new(adapters_mtl.swap_remove(selected), unsafe_extensions);
log::info!("Adapter Metal {:?}", adapter.raw.info);
let id = backend::Metal::hub(self).adapters.register_identity(
id_metal.unwrap(),
Expand All @@ -490,7 +520,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
#[cfg(windows)]
{
if selected < adapters_dx12.len() {
let adapter = Adapter::new(adapters_dx12.swap_remove(selected));
let adapter = Adapter::new(adapters_dx12.swap_remove(selected), unsafe_extensions);
log::info!("Adapter Dx12 {:?}", adapter.raw.info);
let id = backend::Dx12::hub(self).adapters.register_identity(
id_dx12.unwrap(),
Expand All @@ -501,7 +531,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
selected -= adapters_dx12.len();
if selected < adapters_dx11.len() {
let adapter = Adapter::new(adapters_dx11.swap_remove(selected));
let adapter = Adapter::new(adapters_dx11.swap_remove(selected), unsafe_extensions);
log::info!("Adapter Dx11 {:?}", adapter.raw.info);
let id = backend::Dx11::hub(self).adapters.register_identity(
id_dx11.unwrap(),
Expand Down Expand Up @@ -532,14 +562,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let (adapter_guard, _) = hub.adapters.read(&mut token);
let adapter = &adapter_guard[adapter_id];

let features = adapter.raw.physical_device.features();

let mut extensions = wgt::Extensions::default();
extensions.set(
wgt::Extensions::ANISOTROPIC_FILTERING,
features.contains(hal::Features::SAMPLER_ANISOTROPY),
);
extensions
adapter.extensions
}

pub fn adapter_limits<B: GfxBackend>(&self, adapter_id: AdapterId) -> wgt::Limits {
Expand All @@ -548,11 +571,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let (adapter_guard, _) = hub.adapters.read(&mut token);
let adapter = &adapter_guard[adapter_id];

let limits = adapter.raw.physical_device.limits();

wgt::Limits {
max_bind_groups: (limits.max_bound_descriptor_sets as u32).min(MAX_BIND_GROUPS as u32),
}
adapter.limits.clone()
}

pub fn adapter_destroy<B: GfxBackend>(&self, adapter_id: AdapterId) {
Expand Down Expand Up @@ -603,6 +622,20 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
);
}

// Verify all extensions were exposed by the adapter
if !adapter.unsafe_extensions.allowed() {
assert!(
!desc.extensions.intersects(wgt::Extensions::ALL_UNSAFE),
"Cannot enable unsafe extensions without passing UnsafeExtensions::allow() when getting an adapter. Enabled unsafe extensions: {:?}",
desc.extensions & wgt::Extensions::ALL_UNSAFE
)
}
assert!(
adapter.extensions.contains(desc.extensions),
"Cannot enable extensions that adapter doesn't support. Unsupported extensions: {:?}",
desc.extensions - adapter.extensions
);

// Check features needed by extensions
if desc
.extensions
Expand Down
62 changes: 59 additions & 3 deletions wgpu-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,30 @@ impl From<Backend> for BackendBit {
}
}

/// This type is not to be constructed by any users of wgpu. If you construct this type, any semver
/// guarantees made by wgpu are invalidated and a non-breaking change may break your code.
///
/// If you are here trying to construct it, the solution is to use partial construction with the
/// default:
///
/// ```ignore
/// let limits = Limits {
/// max_bind_groups: 2,
/// ..Limits::default()
/// }
/// ```
#[doc(hidden)]
#[derive(Debug, Copy, Clone, Default, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "trace", derive(Serialize))]
#[cfg_attr(feature = "replay", derive(Deserialize))]
pub struct NonExhaustive(());

impl NonExhaustive {
pub unsafe fn new() -> Self {
Self(())
}
}

bitflags::bitflags! {
#[repr(transparent)]
#[derive(Default)]
Expand All @@ -110,7 +134,33 @@ bitflags::bitflags! {
/// but it is not yet implemented.
///
/// https://github.com/gpuweb/gpuweb/issues/696
const ANISOTROPIC_FILTERING = 0x01;
const ANISOTROPIC_FILTERING = 0x0000_0000_0001_0000;
/// Extensions which are part of the upstream webgpu standard
const ALL_WEBGPU = 0x0000_0000_0000_FFFF;
/// Extensions that require activating the unsafe extension flag
const ALL_UNSAFE = 0xFFFF_0000_0000_0000;
/// Extensions that are only available when targeting native (not web)
const ALL_NATIVE = 0xFFFF_FFFF_FFFF_0000;
}
}

#[derive(Debug, Copy, Clone, Default, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "trace", derive(Serialize))]
#[cfg_attr(feature = "replay", derive(Deserialize))]
pub struct UnsafeExtensions {
allow_unsafe: bool,
}
impl UnsafeExtensions {
pub unsafe fn allow() -> Self {
Self { allow_unsafe: true }
}
pub fn disallow() -> Self {
Self {
allow_unsafe: false,
}
}
pub fn allowed(self) -> bool {
self.allow_unsafe
}
}

Expand All @@ -120,11 +170,15 @@ bitflags::bitflags! {
#[cfg_attr(feature = "replay", derive(Deserialize))]
pub struct Limits {
pub max_bind_groups: u32,
pub _non_exhaustive: NonExhaustive,
}

impl Default for Limits {
fn default() -> Self {
Limits { max_bind_groups: 4 }
Limits {
max_bind_groups: 4,
_non_exhaustive: unsafe { NonExhaustive::new() },
}
}
}

Expand Down Expand Up @@ -941,7 +995,7 @@ impl Default for FilterMode {
}
}

#[derive(Clone, Debug, PartialEq)]
#[derive(Default, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "trace", derive(Serialize))]
#[cfg_attr(feature = "replay", derive(Deserialize))]
pub struct SamplerDescriptor<L> {
Expand All @@ -960,6 +1014,7 @@ pub struct SamplerDescriptor<L> {
///
/// Valid values: 1, 2, 4, 8, and 16.
pub anisotropy_clamp: Option<u8>,
pub _non_exhaustive: NonExhaustive,
}

impl<L> SamplerDescriptor<L> {
Expand All @@ -976,6 +1031,7 @@ impl<L> SamplerDescriptor<L> {
lod_max_clamp: self.lod_max_clamp,
compare: self.compare,
anisotropy_clamp: self.anisotropy_clamp,
_non_exhaustive: self._non_exhaustive,
}
}
}
Expand Down

0 comments on commit f32cb10

Please sign in to comment.