Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 104 additions & 47 deletions src/initialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use std::{ffi::CStr, ptr, slice};
use uuid::Uuid;
use wgpu::{
Adapter, Device, DeviceDescriptor, Instance, InstanceDescriptor, Queue, RequestDeviceError,
hal::{DeviceError, InstanceError, api::Vulkan},
hal::{
DeviceError, InstanceError,
api::Vulkan,
vulkan::{CreateDeviceCallbackArgs, CreateInstanceCallbackArgs},
},
};

/// Creates a wgpu [`Instance`] with the extensions required for DLSS.
Expand All @@ -24,25 +28,8 @@ pub fn create_instance(
memory_budget_thresholds: instance_descriptor.memory_budget_thresholds,
backend_options: instance_descriptor.backend_options.clone(),
},
Some(Box::new(|args| {
match required_instance_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_SuperSampling,
args.entry,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.super_resolution_supported = false,
Err(err) => result = Err(err),
};
match required_instance_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_RayReconstruction,
args.entry,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.ray_reconstruction_supported = false,
Err(err) => result = Err(err),
};
Some(Box::new(|mut args| {
result = register_instance_extensions(project_id, &mut args, feature_support);
})),
)?;
result?;
Expand All @@ -51,6 +38,35 @@ pub fn create_instance(
}
}

/// Call this inside of [`wgpu::hal::vulkan::Instance::init_with_callback`] to register wgpu instance extensions
/// necessary for DLSS.
pub fn register_instance_extensions(
project_id: Uuid,
args: &mut CreateInstanceCallbackArgs,
feature_support: &mut FeatureSupport,
) -> Result<(), RegisterInstanceExtensionsError> {
let mut result = Ok(());
match required_instance_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_SuperSampling,
args.entry,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.super_resolution_supported = false,
Err(err) => result = Err(err),
};
match required_instance_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_RayReconstruction,
args.entry,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.ray_reconstruction_supported = false,
Err(err) => result = Err(err),
};
result
}

/// Creates a wgpu [`Device`] and [`Queue`] with the extensions required for DLSS.
///
/// If the current system does not support a given feature, it will set the corresponding variable in `feature_support` to false.
Expand All @@ -66,36 +82,17 @@ pub fn request_device(
let raw_adapter = adapter
.as_hal::<Vulkan>()
.ok_or(InitializationError::UnsupportedBackend)?;
let raw_instance = raw_adapter.shared_instance().raw_instance();
let raw_physical_device = raw_adapter.raw_physical_device();

let mut result = Ok(());
let open_device = raw_adapter.open_with_callback(
device_descriptor.required_features,
&device_descriptor.memory_hints,
Some(Box::new(|args| {
match required_device_extensions(
Some(Box::new(|mut args| {
result = register_device_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_SuperSampling,
&raw_adapter,
raw_instance.handle(),
raw_physical_device,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.super_resolution_supported = false,
Err(err) => result = Err(err),
};
match required_device_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_RayReconstruction,
&raw_adapter,
raw_instance.handle(),
raw_physical_device,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.ray_reconstruction_supported = false,
Err(err) => result = Err(err),
};
&mut args,
&*raw_adapter,
feature_support,
);
})),
)?;
result?;
Expand All @@ -104,11 +101,49 @@ pub fn request_device(
}
}

/// Call this inside of [`wgpu::hal::vulkan::Instance::init_with_callback`] to register wgpu instance extensions
/// necessary for DLSS.
pub fn register_device_extensions(
project_id: Uuid,
args: &mut CreateDeviceCallbackArgs,
raw_adapter: &wgpu::hal::vulkan::Adapter,
feature_support: &mut FeatureSupport,
) -> Result<(), RegisterInstanceExtensionsError> {
let raw_instance = raw_adapter.shared_instance().raw_instance();
let raw_physical_device = raw_adapter.raw_physical_device();
let mut result = Ok(());

match required_device_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_SuperSampling,
raw_adapter,
raw_instance.handle(),
raw_physical_device,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.super_resolution_supported = false,
Err(err) => result = Err(err),
};

match required_device_extensions(
project_id,
NVSDK_NGX_Feature_NVSDK_NGX_Feature_RayReconstruction,
raw_adapter,
raw_instance.handle(),
raw_physical_device,
) {
Ok((extensions, true)) => args.extensions.extend(extensions),
Ok((_, false)) => feature_support.ray_reconstruction_supported = false,
Err(err) => result = Err(err),
};
result
}

fn required_instance_extensions(
project_id: Uuid,
feature_id: NVSDK_NGX_Feature,
entry: &Entry,
) -> Result<(impl Iterator<Item = &'static CStr>, bool), InitializationError> {
) -> Result<(impl Iterator<Item = &'static CStr>, bool), RegisterInstanceExtensionsError> {
with_feature_info(project_id, feature_id, |feature_info| unsafe {
// Get required extension names
let mut required_extensions = ptr::null_mut();
Expand Down Expand Up @@ -142,7 +177,7 @@ fn required_device_extensions(
raw_adapter: &wgpu::hal::vulkan::Adapter,
raw_instance: ash::vk::Instance,
raw_physical_device: PhysicalDevice,
) -> Result<(impl Iterator<Item = &'static CStr>, bool), InitializationError> {
) -> Result<(impl Iterator<Item = &'static CStr>, bool), RegisterInstanceExtensionsError> {
with_feature_info(project_id, feature_id, |feature_info| unsafe {
// Get required extension names
let mut required_extensions = ptr::null_mut();
Expand Down Expand Up @@ -204,3 +239,25 @@ pub enum InitializationError {
#[error("Provided adapter is not using the Vulkan backend")]
UnsupportedBackend,
}

/// Error returned by [`register_instance_extensions`].
#[derive(thiserror::Error, Debug)]
pub enum RegisterInstanceExtensionsError {
#[error(transparent)]
VulkanError(#[from] ash::vk::Result),
#[error(transparent)]
DlssError(#[from] DlssError),
}

impl From<RegisterInstanceExtensionsError> for InitializationError {
fn from(value: RegisterInstanceExtensionsError) -> Self {
match value {
RegisterInstanceExtensionsError::VulkanError(err) => {
InitializationError::VulkanError(err)
}
RegisterInstanceExtensionsError::DlssError(dlss_error) => {
InitializationError::DlssError(dlss_error)
}
}
}
}
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ pub mod ray_reconstruction;
pub mod super_resolution;

#[cfg(not(feature = "mock"))]
pub use initialization::{FeatureSupport, InitializationError, create_instance, request_device};
pub use initialization::{
FeatureSupport, InitializationError, create_instance, register_device_extensions,
register_instance_extensions, request_device,
};
#[cfg(not(feature = "mock"))]
pub use nvsdk_ngx::{DlssError, DlssFeatureFlags, DlssPerfQualityMode};
#[cfg(not(feature = "mock"))]
Expand Down