Skip to content

Commit 0b36119

Browse files
add reference countin for circuit_outputs
1 parent fea770c commit 0b36119

File tree

2 files changed

+53
-32
lines changed

2 files changed

+53
-32
lines changed

src/libfuncs/circuit.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ fn build_eval<'ctx, 'this>(
397397

398398
// Build output struct
399399
let outputs_type_id = &info.branch_signatures()[0].vars[2].ty;
400+
let ref_count = entry.const_int(context, location, 1, 8)?;
400401
let outputs = build_struct_value(
401402
context,
402403
registry,
@@ -405,7 +406,7 @@ fn build_eval<'ctx, 'this>(
405406
helper,
406407
metadata,
407408
outputs_type_id,
408-
&[outputs_ptr, modulus_struct],
409+
&[ref_count, outputs_ptr, modulus_struct],
409410
)?;
410411

411412
helper.br(ok_block, 0, &[add_mod, mul_mod, outputs], location)?;
@@ -919,7 +920,7 @@ fn build_get_output<'ctx, 'this>(
919920
location,
920921
outputs,
921922
llvm::r#type::pointer(context, 0),
922-
0,
923+
1,
923924
)?;
924925
let modulus_struct = entry.extract_value(
925926
context,

src/types/circuit.rs

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use cairo_lang_sierra::{
2121
program_registry::ProgramRegistry,
2222
};
2323
use melior::{
24-
dialect::{func, llvm},
24+
dialect::{arith::CmpiPredicate, func, llvm, scf},
2525
helpers::{ArithBlockExt, BuiltinBlockExt, LlvmBlockExt},
2626
ir::{r#type::IntegerType, Block, BlockLike, Location, Module, Region, Type, Value},
2727
Context,
@@ -302,14 +302,7 @@ pub fn build_circuit_outputs<'ctx>(
302302
metadata: &mut MetadataStorage,
303303
info: WithSelf<InfoOnlyConcreteType>,
304304
) -> Result<Type<'ctx>> {
305-
let Some(GenericArg::Type(circuit_type_id)) = info.info.long_id.generic_args.first() else {
306-
return Err(SierraAssertError::BadTypeInfo.into());
307-
};
308-
let CoreTypeConcrete::Circuit(CircuitTypeConcrete::Circuit(circuit)) =
309-
registry.get_type(circuit_type_id)?
310-
else {
311-
return Err(SierraAssertError::BadTypeInfo.into());
312-
};
305+
let u8_ty = IntegerType::new(context, 8).into();
313306

314307
DupOverridesMeta::register_with(
315308
context,
@@ -322,29 +315,15 @@ pub fn build_circuit_outputs<'ctx>(
322315
let region = Region::new();
323316
let value_ty = registry.build_type(context, module, metadata, info.self_ty())?;
324317
let entry = region.append_block(Block::new(&[(value_ty, location)]));
318+
let k1 = entry.const_int(context, location, 1, 8)?;
325319

326320
let outputs = entry.arg(0)?;
327-
let gates_ptr = entry.extract_value(
328-
context,
329-
location,
330-
outputs,
331-
llvm::r#type::pointer(context, 0),
332-
0,
333-
)?;
321+
let ref_count = entry.extract_value(context, location, outputs, u8_ty, 0)?;
322+
let ref_count_inc = entry.addi(ref_count, k1, location)?;
334323

335-
let u384_integer_layout = get_integer_layout(384);
336-
337-
let new_gates_ptr = build_array_dup(
338-
context,
339-
&entry,
340-
location,
341-
gates_ptr,
342-
circuit.circuit_info.values.len(),
343-
u384_integer_layout,
344-
)?;
324+
entry.insert_value(context, location, outputs, ref_count_inc, 0)?;
345325

346-
let new_outputs = entry.insert_value(context, location, outputs, new_gates_ptr, 0)?;
347-
entry.append_operation(func::r#return(&[outputs, new_outputs], location));
326+
entry.append_operation(func::r#return(&[outputs, outputs], location));
348327

349328
Ok(Some(region))
350329
},
@@ -360,17 +339,57 @@ pub fn build_circuit_outputs<'ctx>(
360339
let region = Region::new();
361340
let value_ty = registry.build_type(context, module, metadata, info.self_ty())?;
362341
let entry = region.append_block(Block::new(&[(value_ty, location)]));
342+
let k1 = entry.const_int(context, location, 1, 8)?;
363343

364344
let outputs = entry.arg(0)?;
365-
let gates_ptr = entry.extract_value(
345+
let ref_count = entry.extract_value(
366346
context,
367347
location,
368348
outputs,
369349
llvm::r#type::pointer(context, 0),
370350
0,
371351
)?;
372352

373-
entry.append_operation(ReallocBindingsMeta::free(context, gates_ptr, location)?);
353+
// Check that the reference counting is different from 1. If it is equeal to 1, then it is shared.
354+
let is_shared = entry.cmpi(context, CmpiPredicate::Ne, ref_count, k1, location)?;
355+
356+
entry.append_operation(scf::r#if(
357+
is_shared,
358+
&[],
359+
{
360+
// If it is shared then decrement the reference counting.
361+
let region = Region::new();
362+
let entry = region.append_block(Block::new(&[]));
363+
let ref_count_dec = entry.subi(ref_count, k1, location)?;
364+
365+
entry.insert_value(context, location, ref_count_dec, outputs, 0)?;
366+
367+
entry.append_operation(scf::r#yield(&[], location));
368+
369+
region
370+
},
371+
{
372+
// If it is not shared then free the memory.
373+
let region = Region::new();
374+
let entry = region.append_block(Block::new(&[]));
375+
376+
let gates_ptr = entry.extract_value(
377+
context,
378+
location,
379+
outputs,
380+
llvm::r#type::pointer(context, 0),
381+
1,
382+
)?;
383+
384+
entry
385+
.append_operation(ReallocBindingsMeta::free(context, gates_ptr, location)?);
386+
entry.append_operation(scf::r#yield(&[], location));
387+
388+
region
389+
},
390+
location,
391+
));
392+
374393
entry.append_operation(func::r#return(&[], location));
375394

376395
Ok(Some(region))
@@ -380,6 +399,7 @@ pub fn build_circuit_outputs<'ctx>(
380399
Ok(llvm::r#type::r#struct(
381400
context,
382401
&[
402+
u8_ty,
383403
llvm::r#type::pointer(context, 0),
384404
build_u384_struct_type(context),
385405
],

0 commit comments

Comments
 (0)