|  | 
|  | 1 | +use crate::{ | 
|  | 2 | +    feature_info::with_feature_info, | 
|  | 3 | +    nvsdk_ngx::{ | 
|  | 4 | +        DlssError, NVSDK_NGX_VULKAN_GetFeatureDeviceExtensionRequirements, | 
|  | 5 | +        NVSDK_NGX_VULKAN_GetFeatureInstanceExtensionRequirements, check_ngx_result, | 
|  | 6 | +    }, | 
|  | 7 | +}; | 
|  | 8 | +use ash::{Entry, vk::PhysicalDevice}; | 
|  | 9 | +use std::{ffi::CStr, ptr, slice}; | 
|  | 10 | +use uuid::Uuid; | 
|  | 11 | +use wgpu::{ | 
|  | 12 | +    Adapter, Device, DeviceDescriptor, Instance, InstanceDescriptor, Queue, RequestDeviceError, | 
|  | 13 | +    hal::{DeviceError, InstanceError, api::Vulkan}, | 
|  | 14 | +}; | 
|  | 15 | + | 
|  | 16 | +/// Creates a wgpu [`Instance`] with the extensions required for DLSS. | 
|  | 17 | +/// | 
|  | 18 | +/// If the system does not support DLSS, it will set `dlss_supported` to false. | 
|  | 19 | +pub fn create_instance( | 
|  | 20 | +    project_id: Uuid, | 
|  | 21 | +    instance_descriptor: &InstanceDescriptor, | 
|  | 22 | +    dlss_supported: &mut bool, | 
|  | 23 | +) -> Result<Instance, InitializationError> { | 
|  | 24 | +    unsafe { | 
|  | 25 | +        let mut result = Ok(()); | 
|  | 26 | +        let raw_instance = wgpu::hal::vulkan::Instance::init_with_callback( | 
|  | 27 | +            &wgpu::hal::InstanceDescriptor { | 
|  | 28 | +                name: "wgpu", | 
|  | 29 | +                flags: instance_descriptor.flags, | 
|  | 30 | +                memory_budget_thresholds: instance_descriptor.memory_budget_thresholds, | 
|  | 31 | +                backend_options: instance_descriptor.backend_options.clone(), | 
|  | 32 | +            }, | 
|  | 33 | +            Some(Box::new(|args| { | 
|  | 34 | +                match required_instance_extensions(project_id, args.entry) { | 
|  | 35 | +                    Ok((extensions, true)) => args.extensions.extend(extensions), | 
|  | 36 | +                    Ok((_, false)) => *dlss_supported = false, | 
|  | 37 | +                    Err(err) => result = Err(err), | 
|  | 38 | +                } | 
|  | 39 | +            })), | 
|  | 40 | +        )?; | 
|  | 41 | +        result?; | 
|  | 42 | + | 
|  | 43 | +        Ok(Instance::from_hal::<Vulkan>(raw_instance)) | 
|  | 44 | +    } | 
|  | 45 | +} | 
|  | 46 | + | 
|  | 47 | +/// Creates a wgpu [`Device`] and [`Queue`] with the extensions required for DLSS. | 
|  | 48 | +/// | 
|  | 49 | +/// If the system does not support DLSS, it will set `dlss_supported` to false. | 
|  | 50 | +/// | 
|  | 51 | +/// The provided [`Adapter`] must be using the Vulkan backend. | 
|  | 52 | +pub fn request_device( | 
|  | 53 | +    project_id: Uuid, | 
|  | 54 | +    adapter: &Adapter, | 
|  | 55 | +    device_descriptor: &DeviceDescriptor, | 
|  | 56 | +    dlss_supported: &mut bool, | 
|  | 57 | +) -> Result<(Device, Queue), InitializationError> { | 
|  | 58 | +    unsafe { | 
|  | 59 | +        let raw_adapter = adapter | 
|  | 60 | +            .as_hal::<Vulkan>() | 
|  | 61 | +            .ok_or(InitializationError::UnsupportedBackend)?; | 
|  | 62 | +        let raw_instance = raw_adapter.shared_instance().raw_instance(); | 
|  | 63 | +        let raw_physical_device = raw_adapter.raw_physical_device(); | 
|  | 64 | + | 
|  | 65 | +        let mut result = Ok(()); | 
|  | 66 | +        let open_device = raw_adapter.open_with_callback( | 
|  | 67 | +            device_descriptor.required_features, | 
|  | 68 | +            &device_descriptor.memory_hints, | 
|  | 69 | +            Some(Box::new(|args| { | 
|  | 70 | +                match required_device_extensions( | 
|  | 71 | +                    project_id, | 
|  | 72 | +                    &raw_adapter, | 
|  | 73 | +                    raw_instance.handle(), | 
|  | 74 | +                    raw_physical_device, | 
|  | 75 | +                ) { | 
|  | 76 | +                    Ok((extensions, true)) => args.extensions.extend(extensions), | 
|  | 77 | +                    Ok((_, false)) => *dlss_supported = false, | 
|  | 78 | +                    Err(err) => result = Err(err), | 
|  | 79 | +                } | 
|  | 80 | +            })), | 
|  | 81 | +        )?; | 
|  | 82 | +        result?; | 
|  | 83 | + | 
|  | 84 | +        Ok(adapter.create_device_from_hal::<Vulkan>(open_device, device_descriptor)?) | 
|  | 85 | +    } | 
|  | 86 | +} | 
|  | 87 | + | 
|  | 88 | +fn required_instance_extensions( | 
|  | 89 | +    project_id: Uuid, | 
|  | 90 | +    entry: &Entry, | 
|  | 91 | +) -> Result<(impl Iterator<Item = &'static CStr>, bool), InitializationError> { | 
|  | 92 | +    with_feature_info(project_id, |feature_info| unsafe { | 
|  | 93 | +        // Get required extension names | 
|  | 94 | +        let mut required_extensions = ptr::null_mut(); | 
|  | 95 | +        let mut required_extension_count = 0; | 
|  | 96 | +        check_ngx_result(NVSDK_NGX_VULKAN_GetFeatureInstanceExtensionRequirements( | 
|  | 97 | +            feature_info, | 
|  | 98 | +            &mut required_extension_count, | 
|  | 99 | +            &mut required_extensions, | 
|  | 100 | +        ))?; | 
|  | 101 | +        let required_extensions = | 
|  | 102 | +            slice::from_raw_parts(required_extensions, required_extension_count as usize); | 
|  | 103 | +        let required_extensions = required_extensions | 
|  | 104 | +            .iter() | 
|  | 105 | +            .map(|extension| CStr::from_ptr(extension.extension_name.as_ptr())); | 
|  | 106 | + | 
|  | 107 | +        // Check that the required extensions are supported | 
|  | 108 | +        let supported_extensions = entry.enumerate_instance_extension_properties(None)?; | 
|  | 109 | +        let extensions_supported = required_extensions.clone().all(|required_extension| { | 
|  | 110 | +            supported_extensions | 
|  | 111 | +                .iter() | 
|  | 112 | +                .any(|extension| extension.extension_name_as_c_str() == Ok(required_extension)) | 
|  | 113 | +        }); | 
|  | 114 | + | 
|  | 115 | +        Ok((required_extensions, extensions_supported)) | 
|  | 116 | +    }) | 
|  | 117 | +} | 
|  | 118 | + | 
|  | 119 | +fn required_device_extensions( | 
|  | 120 | +    project_id: Uuid, | 
|  | 121 | +    raw_adapter: &wgpu::hal::vulkan::Adapter, | 
|  | 122 | +    raw_instance: ash::vk::Instance, | 
|  | 123 | +    raw_physical_device: PhysicalDevice, | 
|  | 124 | +) -> Result<(impl Iterator<Item = &'static CStr>, bool), InitializationError> { | 
|  | 125 | +    with_feature_info(project_id, |feature_info| unsafe { | 
|  | 126 | +        // Get required extension names | 
|  | 127 | +        let mut required_extensions = ptr::null_mut(); | 
|  | 128 | +        let mut required_extension_count = 0; | 
|  | 129 | +        check_ngx_result(NVSDK_NGX_VULKAN_GetFeatureDeviceExtensionRequirements( | 
|  | 130 | +            raw_instance, | 
|  | 131 | +            raw_physical_device, | 
|  | 132 | +            feature_info, | 
|  | 133 | +            &mut required_extension_count, | 
|  | 134 | +            &mut required_extensions, | 
|  | 135 | +        ))?; | 
|  | 136 | +        let required_extensions = | 
|  | 137 | +            slice::from_raw_parts(required_extensions, required_extension_count as usize); | 
|  | 138 | +        let required_extensions = required_extensions | 
|  | 139 | +            .iter() | 
|  | 140 | +            .map(|extension| CStr::from_ptr(extension.extension_name.as_ptr())); | 
|  | 141 | + | 
|  | 142 | +        // Check that the required extensions are supported | 
|  | 143 | +        let extensions_supported = required_extensions.clone().all(|required_extension| { | 
|  | 144 | +            raw_adapter | 
|  | 145 | +                .physical_device_capabilities() | 
|  | 146 | +                .supports_extension(required_extension) | 
|  | 147 | +        }); | 
|  | 148 | + | 
|  | 149 | +        Ok((required_extensions, extensions_supported)) | 
|  | 150 | +    }) | 
|  | 151 | +} | 
|  | 152 | + | 
|  | 153 | +/// Error returned by [`request_device`]. | 
|  | 154 | +#[derive(thiserror::Error, Debug)] | 
|  | 155 | +pub enum InitializationError { | 
|  | 156 | +    #[error(transparent)] | 
|  | 157 | +    InstanceError(#[from] InstanceError), | 
|  | 158 | +    #[error(transparent)] | 
|  | 159 | +    RequestDeviceError(#[from] RequestDeviceError), | 
|  | 160 | +    #[error(transparent)] | 
|  | 161 | +    DeviceError(#[from] DeviceError), | 
|  | 162 | +    #[error(transparent)] | 
|  | 163 | +    VulkanError(#[from] ash::vk::Result), | 
|  | 164 | +    #[error(transparent)] | 
|  | 165 | +    DlssError(#[from] DlssError), | 
|  | 166 | +    #[error("Provided adapter is not using the Vulkan backend")] | 
|  | 167 | +    UnsupportedBackend, | 
|  | 168 | +} | 
0 commit comments