From 829ef5fb00a43e718c7b4cd9bf4b5b728f8c5137 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Fri, 17 Nov 2023 17:16:45 +0000 Subject: [PATCH 1/8] feat: tk2circuit hash --- tket2-py/src/circuit/convert.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index 47993ccd..3813754b 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -6,6 +6,7 @@ use pyo3::{prelude::*, PyTypeInfo}; use derive_more::From; use hugr::{Hugr, HugrView}; use serde::Serialize; +use tket2::circuit::CircuitHash; use tket2::extension::REGISTRY; use tket2::json::TKETDecode; use tket2::passes::CircuitChunks; @@ -73,6 +74,16 @@ impl Tk2Circuit { hugr: tk1.decode()?, }) } + + /// Returns a hash of the circuit. + pub fn hash(&self) -> u64 { + self.hugr.circuit_hash() + } + + /// Hash the circuit + pub fn __hash__(&self) -> isize { + self.hash() as isize + } } impl Tk2Circuit { /// Tries to extract a Tk2Circuit from a python object. From a14a7f09f203493c249abffc992befbbebcedc83 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Mon, 20 Nov 2023 15:32:51 +0000 Subject: [PATCH 2/8] feat: Bindings for the circuit cost --- tket2-py/src/circuit.rs | 8 +- tket2-py/src/circuit/convert.rs | 25 ++++- tket2-py/src/circuit/cost.rs | 168 ++++++++++++++++++++++++++++++++ tket2-py/test/test_bindings.py | 93 ------------------ tket2-py/test/test_circuit.py | 49 ++++++++++ tket2-py/test/test_pass.py | 76 ++++++++++++++- tket2/src/circuit.rs | 7 +- 7 files changed, 326 insertions(+), 100 deletions(-) create mode 100644 tket2-py/src/circuit/cost.rs delete mode 100644 tket2-py/test/test_bindings.py create mode 100644 tket2-py/test/test_circuit.py diff --git a/tket2-py/src/circuit.rs b/tket2-py/src/circuit.rs index a410d31b..175a1d7d 100644 --- a/tket2-py/src/circuit.rs +++ b/tket2-py/src/circuit.rs @@ -2,6 +2,7 @@ #![allow(unused)] pub mod convert; +pub mod cost; use derive_more::{From, Into}; use pyo3::prelude::*; @@ -14,14 +15,17 @@ use tket2::rewrite::CircuitRewrite; use tket_json_rs::circuit_json::SerialCircuit; pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, Tk2Circuit}; +pub use self::cost::PyCircuitCost; +pub use tket2::{Pauli, T2Op}; /// The module definition pub fn module(py: Python) -> PyResult<&PyModule> { let m = PyModule::new(py, "_circuit")?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(validate_hugr, m)?)?; m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?; diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index 3813754b..2e614d13 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -1,6 +1,6 @@ //! Utilities for calling Hugr functions on generic python objects. -use pyo3::exceptions::PyAttributeError; +use pyo3::exceptions::{PyAttributeError, PyValueError}; use pyo3::{prelude::*, PyTypeInfo}; use derive_more::From; @@ -10,10 +10,13 @@ use tket2::circuit::CircuitHash; use tket2::extension::REGISTRY; use tket2::json::TKETDecode; use tket2::passes::CircuitChunks; +use tket2::{Circuit, T2Op}; use tket_json_rs::circuit_json::SerialCircuit; use crate::pattern::rewrite::PyCircuitRewrite; +use super::PyCircuitCost; + /// A manager for tket 2 operations on a tket 1 Circuit. #[pyclass] #[derive(Clone, Debug, PartialEq, From)] @@ -75,6 +78,26 @@ impl Tk2Circuit { }) } + /// Compute the cost of the circuit based on a per-operation cost function. + /// + /// :param cost_fn: A function that takes a `T2Op` and returns an arbitrary cost. + /// The cost must implement `__add__`. + /// :returns: The sum of all operation costs. + pub fn circuit_cost<'py>(&self, cost_fn: &'py PyAny) -> PyResult<&'py PyAny> { + let py = cost_fn.py(); + let circ_cost = self.hugr.circuit_cost(|op| { + let tk2_op: T2Op = op.try_into().map_err(|e| { + PyErr::new::(format!( + "Could not convert circuit operation to a `T2Op`: {e}" + )) + })?; + cost_fn.call1((tk2_op,)).map(|cost| PyCircuitCost { + cost: cost.to_object(py), + }) + })?; + Ok(circ_cost.cost.clone().into_ref(py)) + } + /// Returns a hash of the circuit. pub fn hash(&self) -> u64 { self.hugr.circuit_hash() diff --git a/tket2-py/src/circuit/cost.rs b/tket2-py/src/circuit/cost.rs new file mode 100644 index 00000000..a7b2d6bd --- /dev/null +++ b/tket2-py/src/circuit/cost.rs @@ -0,0 +1,168 @@ +//! + +use std::cmp::Ordering; +use std::iter::Sum; +use std::ops::{Add, AddAssign, Sub}; + +use pyo3::{prelude::*, PyTypeInfo}; +use tket2::circuit::cost::{CircuitCost, CostDelta}; + +/// A generic circuit cost, backed by an arbitrary python object. +#[pyclass] +#[derive(Clone, Debug)] +#[pyo3(name = "CircuitCost")] +pub struct PyCircuitCost { + /// Generic python cost object. + pub cost: PyObject, +} + +#[pymethods] +impl PyCircuitCost { + /// Create a new circuit cost. + #[new] + pub fn new(cost: PyObject) -> Self { + Self { cost } + } +} + +impl Default for PyCircuitCost { + fn default() -> Self { + Python::with_gil(|py| PyCircuitCost { cost: py.None() }) + } +} + +impl Add for PyCircuitCost { + type Output = PyCircuitCost; + + fn add(self, rhs: PyCircuitCost) -> Self::Output { + Python::with_gil(|py| { + let cost = self + .cost + .call_method1(py, "__add__", (rhs.cost,)) + .expect("Could not add circuit cost objects."); + PyCircuitCost { cost } + }) + } +} + +impl AddAssign for PyCircuitCost { + fn add_assign(&mut self, rhs: Self) { + Python::with_gil(|py| { + let cost = self + .cost + .call_method1(py, "__add__", (rhs.cost,)) + .expect("Could not add circuit cost objects."); + self.cost = cost; + }) + } +} + +impl Sub for PyCircuitCost { + type Output = PyCircuitCost; + + fn sub(self, rhs: PyCircuitCost) -> Self::Output { + Python::with_gil(|py| { + let cost = self + .cost + .call_method1(py, "__sub__", (rhs.cost,)) + .expect("Could not subtract circuit cost objects."); + PyCircuitCost { cost } + }) + } +} + +impl Sum for PyCircuitCost { + fn sum>(iter: I) -> Self { + Python::with_gil(|py| { + let mut acc = None; + for c in iter { + match &mut acc { + None => acc = Some(c.cost), + Some(cost) => { + *cost = cost + .call_method1(py, "__add__", (c.cost,)) + .expect("Could not add circuit cost objects.") + } + } + } + PyCircuitCost { + cost: acc.unwrap_or_else(|| py.None()), + } + }) + } +} + +impl PartialEq for PyCircuitCost { + fn eq(&self, other: &Self) -> bool { + Python::with_gil(|py| { + let res = self + .cost + .call_method1(py, "__eq__", (&other.cost,)) + .expect("Could not compare circuit cost objects."); + res.is_true(py) + .expect("Could not compare circuit cost objects.") + }) + } +} + +impl Eq for PyCircuitCost {} + +impl PartialOrd for PyCircuitCost { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PyCircuitCost { + fn cmp(&self, other: &Self) -> Ordering { + Python::with_gil(|py| { + let res = self + .cost + .call_method1(py, "__lt__", (&other.cost,)) + .expect("Could not compare circuit cost objects."); + match res.is_true(py) { + Ok(true) => Ordering::Less, + _ => Ordering::Greater, + } + }) + } +} + +impl CostDelta for PyCircuitCost { + fn as_isize(&self) -> isize { + Python::with_gil(|py| { + let res = self + .cost + .call_method0(py, "__int__") + .expect("Could not convert the circuit cost object to an integer."); + res.extract(py) + .expect("Could not convert the circuit cost object to an integer.") + }) + } +} + +impl CircuitCost for PyCircuitCost { + type CostDelta = PyCircuitCost; + + fn as_usize(&self) -> usize { + self.as_isize() as usize + } + + fn sub_cost(&self, other: &Self) -> Self::CostDelta { + self.clone() - other.clone() + } + + fn add_delta(&self, delta: &Self::CostDelta) -> Self { + self.clone() + delta.clone() + } + + fn div_cost(&self, n: std::num::NonZeroUsize) -> Self { + Python::with_gil(|py| { + let res = self + .cost + .call_method0(py, "__div__") + .expect("Could not divide the circuit cost object."); + Self { cost: res } + }) + } +} diff --git a/tket2-py/test/test_bindings.py b/tket2-py/test/test_bindings.py deleted file mode 100644 index 28aa8c0c..00000000 --- a/tket2-py/test/test_bindings.py +++ /dev/null @@ -1,93 +0,0 @@ -from dataclasses import dataclass -from pytket.circuit import Circuit - -from tket2 import passes -from tket2.passes import greedy_depth_reduce -from tket2.circuit import Tk2Circuit, to_hugr_dot -from tket2.pattern import Rule, RuleMatcher - - -def test_conversion(): - tk1 = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) - tk1_dot = to_hugr_dot(tk1) - - tk2 = Tk2Circuit(tk1) - tk2_dot = to_hugr_dot(tk2) - - assert type(tk2) == Tk2Circuit - assert tk1_dot == tk2_dot - - tk1_back = tk2.to_tket1() - - assert tk1_back == tk1 - assert type(tk1_back) == Circuit - - -@dataclass -class DepthOptimisePass: - def apply(self, circ: Circuit) -> Circuit: - (circ, n_moves) = greedy_depth_reduce(circ) - return circ - - -def test_depth_optimise(): - c = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) - - assert c.depth() == 3 - - c = DepthOptimisePass().apply(c) - - assert c.depth() == 2 - - -def test_chunks(): - c = Circuit(4).CX(0, 2).CX(1, 3).CX(1, 2).CX(0, 3).CX(1, 3) - - assert c.depth() == 3 - - chunks = passes.chunks(c, 2) - circuits = chunks.circuits() - chunks.update_circuit(0, circuits[0]) - c2 = chunks.reassemble() - - assert c2.depth() == 3 - assert type(c2) == Circuit - - # Split and reassemble, with a tket2 circuit - tk2_chunks = passes.chunks(Tk2Circuit(c2), 2) - tk2 = tk2_chunks.reassemble() - - assert type(tk2) == Tk2Circuit - - -def test_cx_rule(): - c = Tk2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2)) - - rule = Rule(Circuit(2).CX(0, 1).CX(0, 1), Circuit(2)) - matcher = RuleMatcher([rule]) - - mtch = matcher.find_match(c) - - c.apply_match(mtch) - - out = c.to_tket1() - - assert out == Circuit(4).CX(0, 2) - - -def test_multiple_rules(): - circ = Tk2Circuit(Circuit(3).CX(0, 1).H(0).H(1).H(2).Z(0).H(0).H(1).H(2)) - - rule1 = Rule(Circuit(1).H(0).Z(0).H(0), Circuit(1).X(0)) - rule2 = Rule(Circuit(1).H(0).H(0), Circuit(1)) - matcher = RuleMatcher([rule1, rule2]) - - match_count = 0 - while match := matcher.find_match(circ): - match_count += 1 - circ.apply_match(match) - - assert match_count == 3 - - out = circ.to_tket1() - assert out == Circuit(3).CX(0, 1).X(0) diff --git a/tket2-py/test/test_circuit.py b/tket2-py/test/test_circuit.py new file mode 100644 index 00000000..b3379268 --- /dev/null +++ b/tket2-py/test/test_circuit.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from pytket.circuit import Circuit + +from tket2.circuit import Tk2Circuit, T2Op, to_hugr_dot + + +@dataclass +class CustomCost: + gate_count: int + h_count: int + + def __add__(self, other): + return CustomCost( + self.gate_count + other.gate_count, self.h_count + other.h_count + ) + + +def test_cost(): + circ = Tk2Circuit(Circuit(4).CX(0, 1).H(1).CX(1, 2).CX(0, 3).H(0)) + + print(circ.circuit_cost(lambda op: int(op == T2Op.CX))) + + assert circ.circuit_cost(lambda op: int(op == T2Op.CX)) == 3 + assert circ.circuit_cost(lambda op: CustomCost(1, op == T2Op.H)) == CustomCost(5, 2) + + +def test_hash(): + circA = Tk2Circuit(Circuit(4).CX(0, 1).CX(1, 2).CX(0, 3)) + circB = Tk2Circuit(Circuit(4).CX(1, 2).CX(0, 1).CX(0, 3)) + circC = Tk2Circuit(Circuit(4).CX(0, 1).CX(0, 3).CX(1, 2)) + + assert hash(circA) != hash(circB) + assert hash(circA) == hash(circC) + + +def test_conversion(): + tk1 = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) + tk1_dot = to_hugr_dot(tk1) + + tk2 = Tk2Circuit(tk1) + tk2_dot = to_hugr_dot(tk2) + + assert type(tk2) == Tk2Circuit + assert tk1_dot == tk2_dot + + tk1_back = tk2.to_tket1() + + assert tk1_back == tk1 + assert type(tk1_back) == Circuit diff --git a/tket2-py/test/test_pass.py b/tket2-py/test/test_pass.py index b0f1284b..1fafea9e 100644 --- a/tket2-py/test/test_pass.py +++ b/tket2-py/test/test_pass.py @@ -1,5 +1,9 @@ from pytket import Circuit, OpType -from tket2.passes import badger_pass +from dataclasses import dataclass + +from tket2.passes import badger_pass, greedy_depth_reduce, chunks +from tket2.circuit import Tk2Circuit +from tket2.pattern import Rule, RuleMatcher def test_simple_badger_pass_no_opt(): @@ -7,3 +11,73 @@ def test_simple_badger_pass_no_opt(): badger = badger_pass(max_threads=1, timeout=0) badger.apply(c) assert c.n_gates_of_type(OpType.CX) == 6 + + +@dataclass +class DepthOptimisePass: + def apply(self, circ: Circuit) -> Circuit: + (circ, n_moves) = greedy_depth_reduce(circ) + return circ + + +def test_depth_optimise(): + c = Circuit(4).CX(0, 2).CX(1, 2).CX(1, 3) + + assert c.depth() == 3 + + c = DepthOptimisePass().apply(c) + + assert c.depth() == 2 + + +def test_chunks(): + c = Circuit(4).CX(0, 2).CX(1, 3).CX(1, 2).CX(0, 3).CX(1, 3) + + assert c.depth() == 3 + + circ_chunks = chunks(c, 2) + circuits = circ_chunks.circuits() + circ_chunks.update_circuit(0, circuits[0]) + c2 = circ_chunks.reassemble() + + assert c2.depth() == 3 + assert type(c2) == Circuit + + # Split and reassemble, with a tket2 circuit + tk2_chunks = chunks(Tk2Circuit(c2), 2) + tk2 = tk2_chunks.reassemble() + + assert type(tk2) == Tk2Circuit + + +def test_cx_rule(): + c = Tk2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2)) + + rule = Rule(Circuit(2).CX(0, 1).CX(0, 1), Circuit(2)) + matcher = RuleMatcher([rule]) + + mtch = matcher.find_match(c) + + c.apply_match(mtch) + + out = c.to_tket1() + + assert out == Circuit(4).CX(0, 2) + + +def test_multiple_rules(): + circ = Tk2Circuit(Circuit(3).CX(0, 1).H(0).H(1).H(2).Z(0).H(0).H(1).H(2)) + + rule1 = Rule(Circuit(1).H(0).Z(0).H(0), Circuit(1).X(0)) + rule2 = Rule(Circuit(1).H(0).H(0), Circuit(1)) + matcher = RuleMatcher([rule1, rule2]) + + match_count = 0 + while match := matcher.find_match(circ): + match_count += 1 + circ.apply_match(match) + + assert match_count == 3 + + out = circ.to_tket1() + assert out == Circuit(3).CX(0, 1).X(0) diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index d8e55577..643902f3 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -5,6 +5,8 @@ pub mod cost; mod hash; pub mod units; +use std::iter::Sum; + pub use command::{Command, CommandIterator}; pub use hash::CircuitHash; use itertools::Either::{Left, Right}; @@ -25,7 +27,6 @@ pub use hugr::ops::OpType; pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; pub use hugr::{Node, Port, Wire}; -use self::cost::CircuitCost; use self::units::{filter, FilteredUnits, Units}; /// An object behaving like a quantum circuit. @@ -135,7 +136,7 @@ pub trait Circuit: HugrView { fn circuit_cost(&self, op_cost: F) -> C where Self: Sized, - C: CircuitCost, + C: Sum, F: Fn(&OpType) -> C, { self.commands().map(|cmd| op_cost(cmd.optype())).sum() @@ -146,7 +147,7 @@ pub trait Circuit: HugrView { #[inline] fn nodes_cost(&self, nodes: impl IntoIterator, op_cost: F) -> C where - C: CircuitCost, + C: Sum, F: Fn(&OpType) -> C, { nodes.into_iter().map(|n| op_cost(self.get_optype(n))).sum() From 60d9c12bed615d984fc01a4b7b4df3a80a832748 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Mon, 20 Nov 2023 17:07:12 +0000 Subject: [PATCH 3/8] feat: Add Tk2Circuit::__copy__ and deepcopy --- tket2-py/src/circuit/convert.rs | 12 +++++++++++- tket2-py/test/test_pass.py | 4 ++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index 2e614d13..63ba9532 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -41,7 +41,7 @@ impl Tk2Circuit { } /// Apply a rewrite on the circuit. - pub fn apply_match(&mut self, rw: PyCircuitRewrite) { + pub fn apply_rewrite(&mut self, rw: PyCircuitRewrite) { rw.rewrite.apply(&mut self.hugr).expect("Apply error."); } @@ -107,6 +107,16 @@ impl Tk2Circuit { pub fn __hash__(&self) -> isize { self.hash() as isize } + + /// Copy the circuit. + pub fn __copy__(&self) -> PyResult { + Ok(self.clone()) + } + + /// Copy the circuit. + pub fn __deepcopy__(&self, _memo: Py) -> PyResult { + Ok(self.clone()) + } } impl Tk2Circuit { /// Tries to extract a Tk2Circuit from a python object. diff --git a/tket2-py/test/test_pass.py b/tket2-py/test/test_pass.py index 1fafea9e..366f34ff 100644 --- a/tket2-py/test/test_pass.py +++ b/tket2-py/test/test_pass.py @@ -58,7 +58,7 @@ def test_cx_rule(): mtch = matcher.find_match(c) - c.apply_match(mtch) + c.apply_rewrite(mtch) out = c.to_tket1() @@ -75,7 +75,7 @@ def test_multiple_rules(): match_count = 0 while match := matcher.find_match(circ): match_count += 1 - circ.apply_match(match) + circ.apply_rewrite(match) assert match_count == 3 From becf8462f261501477a829fb210ccb0090df79f8 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Tue, 21 Nov 2023 16:37:31 +0000 Subject: [PATCH 4/8] Update `Tk2Circuit` doc --- tket2-py/src/circuit/convert.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index a89e7031..ce24c36a 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -17,7 +17,26 @@ use crate::rewrite::PyCircuitRewrite; use super::PyCircuitCost; -/// A manager for tket 2 operations on a tket 1 Circuit. +/// A circuit in tket2 format. +/// +/// This can be freely converted to and from a `pytket.Circuit`. Prefer using +/// this class when applying multiple tket2 operations on a circuit, as it +/// avoids the overhead of converting to and from a `pytket.Circuit` each time. +/// +/// Node indices returned by this class are not stable across conversion to and +/// from a `pytket.Circuit`. +/// +/// # Examples +/// +/// Convert between `pytket.Circuit`s and `Tk2Circuit`s: +/// ```python +/// from pytket import Circuit +/// c = Circuit(2).H(0).CX(0, 1) +/// # Convert to a Tk2Circuit +/// t2c = Tk2Circuit(c) +/// # Convert back to a pytket.Circuit +/// c2 = t2c.to_tket1() +/// ``` #[pyclass] #[derive(Clone, Debug, PartialEq, From)] pub struct Tk2Circuit { From 9fa62bb95f1e9ea4df2958b26e269bf6bb58f62d Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Tue, 21 Nov 2023 16:43:08 +0000 Subject: [PATCH 5/8] Cleanup `Tk2Circuit::circuit_cost` --- tket2-py/src/circuit/convert.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index ce24c36a..c84f8573 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -1,5 +1,6 @@ //! Utilities for calling Hugr functions on generic python objects. +use hugr::ops::OpType; use pyo3::exceptions::{PyAttributeError, PyValueError}; use pyo3::{prelude::*, PyTypeInfo}; @@ -15,7 +16,7 @@ use tket_json_rs::circuit_json::SerialCircuit; use crate::rewrite::PyCircuitRewrite; -use super::PyCircuitCost; +use super::{cost, PyCircuitCost}; /// A circuit in tket2 format. /// @@ -104,17 +105,19 @@ impl Tk2Circuit { /// :returns: The sum of all operation costs. pub fn circuit_cost<'py>(&self, cost_fn: &'py PyAny) -> PyResult<&'py PyAny> { let py = cost_fn.py(); - let circ_cost = self.hugr.circuit_cost(|op| { + let cost_fn = |op: &OpType| -> PyResult { let tk2_op: T2Op = op.try_into().map_err(|e| { PyErr::new::(format!( "Could not convert circuit operation to a `T2Op`: {e}" )) })?; - cost_fn.call1((tk2_op,)).map(|cost| PyCircuitCost { + let cost = cost_fn.call1((tk2_op,))?; + Ok(PyCircuitCost { cost: cost.to_object(py), }) - })?; - Ok(circ_cost.cost.clone().into_ref(py)) + }; + let circ_cost = self.hugr.circuit_cost(cost_fn)?; + Ok(circ_cost.cost.into_ref(py)) } /// Returns a hash of the circuit. From cb6fce3f217e63ea5badc1a4ff19395b9880c1d4 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Tue, 21 Nov 2023 16:51:17 +0000 Subject: [PATCH 6/8] Use fold --- tket2-py/src/circuit/cost.rs | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tket2-py/src/circuit/cost.rs b/tket2-py/src/circuit/cost.rs index a7b2d6bd..edb60ccd 100644 --- a/tket2-py/src/circuit/cost.rs +++ b/tket2-py/src/circuit/cost.rs @@ -74,20 +74,17 @@ impl Sub for PyCircuitCost { impl Sum for PyCircuitCost { fn sum>(iter: I) -> Self { Python::with_gil(|py| { - let mut acc = None; - for c in iter { - match &mut acc { - None => acc = Some(c.cost), - Some(cost) => { - *cost = cost + let cost = iter + .fold(None, |acc: Option, c| { + Some(match acc { + None => c.cost, + Some(cost) => cost .call_method1(py, "__add__", (c.cost,)) - .expect("Could not add circuit cost objects.") - } - } - } - PyCircuitCost { - cost: acc.unwrap_or_else(|| py.None()), - } + .expect("Could not add circuit cost objects."), + }) + }) + .unwrap_or_else(|| py.None()); + PyCircuitCost { cost } }) } } From b06b0bc5ba9dd9e57b485f3c7dce711424cba2ff Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 23 Nov 2023 13:37:00 +0000 Subject: [PATCH 7/8] Update docs on circ cost requirements --- tket2-py/src/circuit/convert.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index 8d849c5e..1980af87 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -101,7 +101,9 @@ impl Tk2Circuit { /// Compute the cost of the circuit based on a per-operation cost function. /// /// :param cost_fn: A function that takes a `Tk2Op` and returns an arbitrary cost. - /// The cost must implement `__add__`. + /// The cost must implement `__add__`, `__sub__`, `__lt__`, + /// `__eq__`, `__int__`, and integer `__div__`. + /// /// :returns: The sum of all operation costs. pub fn circuit_cost<'py>(&self, cost_fn: &'py PyAny) -> PyResult<&'py PyAny> { let py = cost_fn.py(); From 66eb31c8d7b19b3493f77c45bbc3580aea88cd92 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 23 Nov 2023 13:47:00 +0000 Subject: [PATCH 8/8] Fix PyCircuitCost Ord impl --- tket2-py/src/circuit/cost.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tket2-py/src/circuit/cost.rs b/tket2-py/src/circuit/cost.rs index edb60ccd..cbd734cf 100644 --- a/tket2-py/src/circuit/cost.rs +++ b/tket2-py/src/circuit/cost.rs @@ -112,16 +112,18 @@ impl PartialOrd for PyCircuitCost { impl Ord for PyCircuitCost { fn cmp(&self, other: &Self) -> Ordering { - Python::with_gil(|py| { - let res = self - .cost - .call_method1(py, "__lt__", (&other.cost,)) - .expect("Could not compare circuit cost objects."); - match res.is_true(py) { - Ok(true) => Ordering::Less, - _ => Ordering::Greater, + Python::with_gil(|py| -> PyResult { + let res = self.cost.call_method1(py, "__lt__", (&other.cost,))?; + if res.is_true(py)? { + return Ok(Ordering::Less); + } + let res = self.cost.call_method1(py, "__eq__", (&other.cost,))?; + if res.is_true(py)? { + return Ok(Ordering::Equal); } + Ok(Ordering::Greater) }) + .expect("Could not compare circuit cost objects.") } }