From 5caccbaf50d0177e548b4f3b03233369c91e6ede Mon Sep 17 00:00:00 2001 From: asteri Date: Tue, 28 Apr 2026 17:29:11 +0200 Subject: [PATCH] improve var creation for SumProdCircuit --- examples/dag.rs | 10 +++++----- examples/quotient.rs | 1 + src/circuit/dag.rs | 2 +- src/circuit/probabilistic.rs | 2 +- src/circuit/quotient.rs | 2 +- src/circuit/traits.rs | 6 +++--- 6 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/dag.rs b/examples/dag.rs index a5443e6..2530dc0 100644 --- a/examples/dag.rs +++ b/examples/dag.rs @@ -6,10 +6,10 @@ use circuit_cas::circuit::dag::CircuitExt; fn main() { let circuit: Rc> = ProbCircuit::new(); - // Build (x + y) * (x + z) - let x = circuit.var("x"); - let y = circuit.var("y"); - let z = circuit.var("z"); + // vars accept anything that implements Into: &'static str, (&str, u32), (&str, u32, u32) + let x = circuit.var("x"); + let y = circuit.var(("y", 1)); // indexed variable y_1 + let z = circuit.var(("z", 0, 1)); // doubly-indexed variable z_{0,1} let x_plus_y = circuit.var("x") + y; let x_plus_z = circuit.var("x") + z; @@ -19,7 +19,7 @@ fn main() { let x2 = circuit.var("x"); assert_eq!(x.id, x2.id); - println!("(x + y) * (x + z) root node id: {:?}", expr.id); + println!("(x + y_1) * (x + z_{{0,1}}) root node id: {:?}", expr.id); println!("x node id: {:?}", x.id); println!("x deduplicated node id: {:?}", x2.id); } diff --git a/examples/quotient.rs b/examples/quotient.rs index 66c09f9..c045c6b 100644 --- a/examples/quotient.rs +++ b/examples/quotient.rs @@ -19,6 +19,7 @@ fn main() { let quotient: Rc> = idem.into(); // Build x * x̄ + x in the DAG + // var accepts anything that implements Into: &'static str, (&str, u32), (&str, u32, u32) let xn = quotient.var("x"); let nxn = quotient.var("x\u{0304}"); let prod = xn * nxn; diff --git a/src/circuit/dag.rs b/src/circuit/dag.rs index 49fe59e..000703f 100644 --- a/src/circuit/dag.rs +++ b/src/circuit/dag.rs @@ -60,5 +60,5 @@ pub struct RefNode { pub trait CircuitExt { type C: Circuit; type Var; - fn var(&self, v: impl Into) -> RefNode; + fn var>(&self, v: T) -> RefNode; } diff --git a/src/circuit/probabilistic.rs b/src/circuit/probabilistic.rs index 4e92dd3..a15c621 100644 --- a/src/circuit/probabilistic.rs +++ b/src/circuit/probabilistic.rs @@ -43,7 +43,7 @@ impl ProbCircuit { impl SumProdCircuit for ProbCircuit { type Var = V; - fn var(&mut self, v: V) -> NodeId { self.node(PNode::Var(v)) } + fn var>(&mut self, v: T) -> NodeId { self.node(PNode::Var(v.into())) } fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { let (l, r) = if l <= r { (l, r) } else { (r, l) }; diff --git a/src/circuit/quotient.rs b/src/circuit/quotient.rs index d216f33..b88e04e 100644 --- a/src/circuit/quotient.rs +++ b/src/circuit/quotient.rs @@ -57,7 +57,7 @@ impl DerefMut for QuotientCircuit { impl SumProdCircuit for QuotientCircuit { type Var = V; - fn var(&mut self, v: V) -> NodeId { self.node(QNode::Var(v)) } + fn var>(&mut self, v: T) -> NodeId { self.node(QNode::Var(v.into())) } fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { let (l, r) = if l <= r { (l, r) } else { (r, l) }; diff --git a/src/circuit/traits.rs b/src/circuit/traits.rs index 60f5e88..aa79281 100644 --- a/src/circuit/traits.rs +++ b/src/circuit/traits.rs @@ -29,7 +29,7 @@ where pub trait SumProdCircuit: Circuit { type Var; - fn var(&mut self, v: Self::Var) -> NodeId; + fn var>(&mut self, v: T) -> NodeId; fn add(&mut self, l: NodeId, r: NodeId) -> NodeId; fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId; } @@ -56,8 +56,8 @@ impl CircuitExt for Rc> { type C = C; type Var = C::Var; - fn var(&self, v: impl Into) -> RefNode { - let id = self.borrow_mut().var(v.into()); + fn var>(&self, v: T) -> RefNode { + let id = self.borrow_mut().var(v); RefNode { id, circuit: self.clone() } } }