-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
179 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#version 450 | ||
#extension GL_EXT_shader_explicit_arithmetic_types: enable | ||
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require | ||
|
||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; | ||
|
||
layout(push_constant) uniform PushConstantData { | ||
uint width; | ||
uint height; | ||
|
||
// these are actual coordinates of the first red pixel (unlike everywhere else) | ||
uint first_red_x; | ||
uint first_red_y; | ||
} params; | ||
|
||
layout(set = 0, binding = 0) buffer readonly Source { uint8_t data[]; } source; | ||
layout(set = 0, binding = 1) buffer writeonly Sink { uint8_t data[]; } sink; | ||
|
||
void main() { | ||
uvec2 pos = gl_GlobalInvocationID.xy; | ||
if (pos.x >= params.width || pos.y >= params.height) return; | ||
|
||
uvec2 base_pos = pos * 2; | ||
|
||
uint red_first_x = params.first_red_x; | ||
uint red_first_y = params.first_red_y; | ||
sink.data[(pos.y * params.width + pos.x) * 3 + 0] = source.data[(base_pos.x + red_first_x) + (base_pos.y + red_first_y) * params.width * 2]; | ||
sink.data[(pos.y * params.width + pos.x) * 3 + 2] = source.data[(base_pos.x + ((red_first_x + 1) % 2)) + (base_pos.y + ((red_first_y + 1) % 2)) * params.width * 2]; | ||
sink.data[(pos.y * params.width + pos.x) * 3 + 1] = uint8_t(source.data[(base_pos.x + ((red_first_x + 1) % 2)) + (base_pos.y + ((red_first_y + 0) % 2)) * params.width * 2] / 2 + source.data[(base_pos.x + ((red_first_x + 0) % 2)) + (base_pos.y + ((red_first_y + 1) % 2)) * params.width * 2] / 2); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
use crate::pipeline_processing::{ | ||
buffers::GpuBuffer, | ||
frame::{Frame, FrameInterpretation, Raw, Rgb}, | ||
gpu_util::ensure_gpu_buffer, | ||
node::{Caps, InputProcessingNode, NodeID, ProcessingNode, Request}, | ||
parametrizable::prelude::*, | ||
payload::Payload, | ||
processing_context::ProcessingContext, | ||
}; | ||
use anyhow::{anyhow, Context, Result}; | ||
use async_trait::async_trait; | ||
use std::sync::Arc; | ||
use vulkano::{ | ||
buffer::{BufferUsage, DeviceLocalBuffer}, | ||
command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage::OneTimeSubmit}, | ||
descriptor_set::{persistent::PersistentDescriptorSet, WriteDescriptorSet}, | ||
device::{Device, Queue}, | ||
pipeline::{ComputePipeline, Pipeline, PipelineBindPoint}, | ||
sync::GpuFuture, | ||
DeviceSize, | ||
}; | ||
|
||
// generated by the macro | ||
#[allow(clippy::needless_question_mark)] | ||
mod compute_shader { | ||
vulkano_shaders::shader! { | ||
ty: "compute", | ||
path: "src/nodes_gpu/debayer_loss.glsl" | ||
} | ||
} | ||
|
||
pub struct DebayerResolutionLoss { | ||
device: Arc<Device>, | ||
pipeline: Arc<ComputePipeline>, | ||
queue: Arc<Queue>, | ||
input: InputProcessingNode, | ||
} | ||
|
||
impl Parameterizable for DebayerResolutionLoss { | ||
fn describe_parameters() -> ParametersDescriptor { | ||
ParametersDescriptor::default().with("input", Mandatory(NodeInputParameter)) | ||
} | ||
fn from_parameters( | ||
mut parameters: Parameters, | ||
_is_input_to: &[NodeID], | ||
context: &ProcessingContext, | ||
) -> Result<Self> | ||
where | ||
Self: Sized, | ||
{ | ||
let (device, queues) = context.require_vulkan()?; | ||
let queue = queues.iter().find(|&q| q.family().supports_compute()).unwrap().clone(); | ||
|
||
let shader = compute_shader::load(device.clone()).unwrap(); | ||
let pipeline = ComputePipeline::new( | ||
device.clone(), | ||
shader.entry_point("main").unwrap(), | ||
&(), | ||
None, | ||
|_| {}, | ||
) | ||
.unwrap(); | ||
|
||
Ok(Self { device, pipeline, queue, input: parameters.take("input")? }) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl ProcessingNode for DebayerResolutionLoss { | ||
async fn pull(&self, request: Request) -> Result<Payload> { | ||
let input = self.input.pull(request).await?; | ||
|
||
let (frame, fut) = ensure_gpu_buffer::<Raw>(&input, self.queue.clone()) | ||
.context("Wrong input format for Debayer")?; | ||
|
||
if frame.interp.bit_depth != 8 { | ||
return Err(anyhow!( | ||
"A frame with bit_depth=8 is required. Convert the bit depth of the frame!" | ||
)); | ||
} | ||
|
||
let interp = | ||
Rgb { width: frame.interp.width / 2, height: frame.interp.height / 2, fps: frame.interp.fps }; | ||
let sink_buffer = DeviceLocalBuffer::<[u8]>::array( | ||
self.device.clone(), | ||
interp.required_bytes() as DeviceSize, | ||
BufferUsage { | ||
storage_buffer: true, | ||
storage_texel_buffer: true, | ||
transfer_src: true, | ||
..BufferUsage::none() | ||
}, | ||
std::iter::once(self.queue.family()), | ||
)?; | ||
|
||
let new_width = frame.interp.width / 2; | ||
let new_height = frame.interp.height / 2; | ||
let push_constants = compute_shader::ty::PushConstantData { | ||
width: new_width as u32, | ||
height: new_height as u32, | ||
first_red_x: (!frame.interp.cfa.red_in_first_col) as u32, | ||
first_red_y: (!frame.interp.cfa.red_in_first_row) as u32, | ||
}; | ||
|
||
let layout = self.pipeline.layout().set_layouts()[0].clone(); | ||
let set = PersistentDescriptorSet::new( | ||
layout, | ||
[ | ||
WriteDescriptorSet::buffer(0, frame.storage.untyped()), | ||
WriteDescriptorSet::buffer(1, sink_buffer.clone()), | ||
], | ||
) | ||
.unwrap(); | ||
|
||
let mut builder = AutoCommandBufferBuilder::primary( | ||
self.device.clone(), | ||
self.queue.family(), | ||
OneTimeSubmit, | ||
) | ||
.unwrap(); | ||
builder | ||
.bind_descriptor_sets( | ||
PipelineBindPoint::Compute, | ||
self.pipeline.layout().clone(), | ||
0, | ||
set, | ||
) | ||
.push_constants(self.pipeline.layout().clone(), 0, push_constants) | ||
.bind_pipeline_compute(self.pipeline.clone()) | ||
.dispatch([ | ||
(new_width + 31) as u32 / 32, | ||
(new_height as u32 + 31) / 32, | ||
1, | ||
])?; | ||
let command_buffer = builder.build()?; | ||
|
||
let future = | ||
fut.then_execute(self.queue.clone(), command_buffer)?.then_signal_fence_and_flush()?; | ||
|
||
future.wait(None).unwrap(); | ||
Ok(Payload::from(Frame { interp, storage: GpuBuffer::from(sink_buffer) })) | ||
} | ||
|
||
fn get_caps(&self) -> Caps { self.input.get_caps() } | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters