Skip to content

Commit

Permalink
[spv-out] Cache constant composites (#2257)
Browse files Browse the repository at this point in the history
  • Loading branch information
evahop authored Feb 20, 2023
1 parent 58105a0 commit 60c0fc0
Show file tree
Hide file tree
Showing 5 changed files with 563 additions and 551 deletions.
112 changes: 44 additions & 68 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,36 +711,22 @@ impl<'w> BlockContext<'w> {
.get_constant_scalar(crate::ScalarValue::Float(1.0), width);

if let Some(size) = maybe_size {
let value = LocalType::Value {
let ty = LocalType::Value {
vector_size: Some(size),
kind: crate::ScalarKind::Float,
width,
pointer_space: None,
};

let result_type_id = self.get_type_id(LookupType::Local(value));
}
.into();

self.temp_list.clear();
self.temp_list.resize(size as _, arg1_id);

let id = self.gen_id();
block.body.push(Instruction::constant_composite(
result_type_id,
id,
&self.temp_list,
));
arg1_id = id;
arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);

self.temp_list.clear();
self.temp_list.resize(size as _, arg2_id);
self.temp_list.fill(arg2_id);

let id = self.gen_id();
block.body.push(Instruction::constant_composite(
result_type_id,
id,
&self.temp_list,
));
arg2_id = id;
arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
}

MathOp::Custom(Instruction::ext_inst(
Expand Down Expand Up @@ -893,25 +879,22 @@ impl<'w> BlockContext<'w> {

let (int_type_id, int_id) = match *arg_ty {
crate::TypeInner::Vector { size, width, .. } => {
let ty = self.get_type_id(LookupType::Local(LocalType::Value {
let ty = LocalType::Value {
vector_size: Some(size),
kind: crate::ScalarKind::Sint,
width,
pointer_space: None,
}));
}
.into();

self.temp_list.clear();
self.temp_list
.resize(size as _, self.writer.get_constant_scalar(int, width));

let id = self.gen_id();
block.body.push(Instruction::constant_composite(
ty,
id,
&self.temp_list,
));

(ty, id)
(
self.get_type_id(ty),
self.writer.get_constant_composite(ty, &self.temp_list),
)
}
crate::TypeInner::Scalar { width, .. } => (
self.get_type_id(LookupType::Local(LocalType::Value {
Expand Down Expand Up @@ -1127,22 +1110,18 @@ impl<'w> BlockContext<'w> {
let zero_scalar_id = self.writer.get_constant_scalar(value, src_width);
let zero_id = match src_size {
Some(size) => {
let vector_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
kind: src_kind,
width: src_width,
pointer_space: None,
}));
let components = [zero_scalar_id; 4];

let zero_id = self.gen_id();
block.body.push(Instruction::composite_construct(
vector_type_id,
zero_id,
&components[..size as usize],
));
zero_id
let ty = LocalType::Value {
vector_size: Some(size),
kind: src_kind,
width: src_width,
pointer_space: None,
}
.into();

self.temp_list.clear();
self.temp_list.resize(size as _, zero_scalar_id);

self.writer.get_constant_composite(ty, &self.temp_list)
}
None => zero_scalar_id,
};
Expand All @@ -1168,28 +1147,25 @@ impl<'w> BlockContext<'w> {
let scalar1_id = self.writer.get_constant_scalar(val1, dst_width);
let (accept_id, reject_id) = match src_size {
Some(size) => {
let vector_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
kind,
width: dst_width,
pointer_space: None,
}));
let components0 = [scalar0_id; 4];
let components1 = [scalar1_id; 4];

let vec0_id = self.gen_id();
block.body.push(Instruction::composite_construct(
vector_type_id,
vec0_id,
&components0[..size as usize],
));
let vec1_id = self.gen_id();
block.body.push(Instruction::composite_construct(
vector_type_id,
vec1_id,
&components1[..size as usize],
));
let ty = LocalType::Value {
vector_size: Some(size),
kind,
width: dst_width,
pointer_space: None,
}
.into();

self.temp_list.clear();
self.temp_list.resize(size as _, scalar0_id);

let vec0_id =
self.writer.get_constant_composite(ty, &self.temp_list);

self.temp_list.fill(scalar1_id);

let vec1_id =
self.writer.get_constant_composite(ty, &self.temp_list);

(vec1_id, vec0_id)
}
None => (scalar1_id, scalar0_id),
Expand Down
14 changes: 13 additions & 1 deletion src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,18 @@ impl recyclable::Recyclable for CachedExpressions {
}
}

#[derive(Eq, Hash, PartialEq)]
enum CachedConstant {
Scalar {
value: crate::ScalarValue,
width: crate::Bytes,
},
Composite {
ty: LookupType,
constituent_ids: Vec<Word>,
},
}

#[derive(Clone)]
struct GlobalVariable {
/// ID of the OpVariable that declares the global.
Expand Down Expand Up @@ -589,7 +601,7 @@ pub struct Writer {
lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>,
lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>,
constant_ids: Vec<Word>,
cached_constants: crate::FastHashMap<(crate::ScalarValue, crate::Bytes), Word>,
cached_constants: crate::FastHashMap<CachedConstant, Word>,
global_variables: Vec<GlobalVariable>,
binding_map: BindingMap,

Expand Down
72 changes: 49 additions & 23 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
make_local, Block, BlockContext, CachedExpressions, EntryPointContext, Error, Function,
FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalType, LocalVariable,
make_local, Block, BlockContext, CachedConstant, CachedExpressions, EntryPointContext, Error,
Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalType, LocalVariable,
LogicalLayout, LookupFunctionType, LookupType, LoopContext, Options, PhysicalLayout,
PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
};
Expand Down Expand Up @@ -1101,12 +1101,13 @@ impl Writer {
value: crate::ScalarValue,
width: crate::Bytes,
) -> Word {
if let Some(&id) = self.cached_constants.get(&(value, width)) {
let scalar = CachedConstant::Scalar { value, width };
if let Some(&id) = self.cached_constants.get(&scalar) {
return id;
}
let id = self.id_gen.next();
self.write_constant_scalar(id, &value, width, None);
self.cached_constants.insert((value, width), id);
self.cached_constants.insert(scalar, id);
id
}

Expand Down Expand Up @@ -1180,22 +1181,39 @@ impl Writer {
instruction.to_words(&mut self.logical_layout.declarations);
}

pub(super) fn get_constant_composite(
&mut self,
ty: LookupType,
constituent_ids: &[Word],
) -> Word {
let composite = CachedConstant::Composite {
ty,
constituent_ids: constituent_ids.to_vec(),
};
if let Some(&id) = self.cached_constants.get(&composite) {
return id;
}
let id = self.id_gen.next();
self.write_constant_composite(id, ty, constituent_ids, None);
self.cached_constants.insert(composite, id);
id
}

fn write_constant_composite(
&mut self,
id: Word,
ty: Handle<crate::Type>,
components: &[Handle<crate::Constant>],
) -> Result<(), Error> {
let mut constituent_ids = Vec::with_capacity(components.len());
for constituent in components.iter() {
let constituent_id = self.constant_ids[constituent.index()];
constituent_ids.push(constituent_id);
ty: LookupType,
constituent_ids: &[Word],
debug_name: Option<&String>,
) {
if self.flags.contains(WriterFlags::DEBUG) {
if let Some(name) = debug_name {
self.debugs.push(Instruction::name(id, name));
}
}

let type_id = self.get_type_id(LookupType::Handle(ty));
Instruction::constant_composite(type_id, id, constituent_ids.as_slice())
let type_id = self.get_type_id(ty);
Instruction::constant_composite(type_id, id, constituent_ids)
.to_words(&mut self.logical_layout.declarations);
Ok(())
}

pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
Expand Down Expand Up @@ -1776,19 +1794,27 @@ impl Writer {
self.write_type_declaration_arena(&ir_module.types, handle)?;
}

// the all the composite constants, they rely on types
// then all the composite constants, they rely on types
for (handle, constant) in ir_module.constants.iter() {
match constant.inner {
crate::ConstantInner::Scalar { .. } => continue,
crate::ConstantInner::Composite { ty, ref components } => {
let id = self.id_gen.next();
self.constant_ids[handle.index()] = id;
if self.flags.contains(WriterFlags::DEBUG) {
if let Some(ref name) = constant.name {
self.debugs.push(Instruction::name(id, name));
}
let ty = LookupType::Handle(ty);

let mut constituent_ids = Vec::with_capacity(components.len());
for constituent in components.iter() {
let constituent_id = self.constant_ids[constituent.index()];
constituent_ids.push(constituent_id);
}
self.write_constant_composite(id, ty, components)?;

self.constant_ids[handle.index()] = match constant.name {
Some(ref name) => {
let id = self.id_gen.next();
self.write_constant_composite(id, ty, &constituent_ids, Some(name));
id
}
None => self.get_constant_composite(ty, &constituent_ids),
};
}
}
}
Expand Down
13 changes: 6 additions & 7 deletions tests/out/spv/math-functions.spvasm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 55
; Bound: 54
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
Expand All @@ -20,13 +20,12 @@ OpEntryPoint Vertex %16 "main"
%13 = OpTypeVector %7 2
%14 = OpConstantComposite %13 %6 %6
%17 = OpTypeFunction %2
%40 = OpConstant %7 31
%49 = OpTypeVector %9 2
%25 = OpConstantComposite %12 %5 %5 %5 %5
%26 = OpConstantComposite %12 %3 %3 %3 %3
%29 = OpConstantNull %7
%40 = OpConstant %7 31
%47 = OpConstantComposite %13 %40 %40
%52 = OpConstantComposite %13 %40 %40
%49 = OpTypeVector %9 2
%29 = OpConstantNull %7
%16 = OpFunction %2 None %17
%15 = OpLabel
OpBranch %18
Expand Down Expand Up @@ -58,7 +57,7 @@ OpBranch %18
%48 = OpISub %13 %47 %46
%50 = OpCompositeConstruct %49 %11 %11
%51 = OpExtInst %13 %1 FindUMsb %50
%53 = OpISub %13 %52 %51
%54 = OpBitcast %49 %53
%52 = OpISub %13 %47 %51
%53 = OpBitcast %49 %52
OpReturn
OpFunctionEnd
Loading

0 comments on commit 60c0fc0

Please sign in to comment.