From 891546069c7552e9fd24284c2a7de5624f3373c6 Mon Sep 17 00:00:00 2001 From: asteri Date: Wed, 22 Apr 2026 12:13:49 +0200 Subject: [PATCH] add tests --- src/circuit/dag.rs | 70 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 5 deletions(-) diff --git a/src/circuit/dag.rs b/src/circuit/dag.rs index e21d793..8c56cb8 100644 --- a/src/circuit/dag.rs +++ b/src/circuit/dag.rs @@ -53,11 +53,17 @@ impl Circuit { } pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { - self.node(Node::Sum(left, right)) + let (l, r) = if left <= right { (left, right) } else { (right, left) }; + self.node(Node::Sum(l, r)) } pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId { - self.node(Node::Prod(left, right)) + let (l, r) = if left <= right { (left, right) } else { (right, left) }; + self.node(Node::Prod(l, r)) + } + + pub fn len(&self) -> usize { + self.nodes.len() } pub fn children(&self, id: NodeId) -> impl Iterator + '_ { @@ -75,7 +81,6 @@ pub struct CircuitNode { pub id: NodeId, circuit: Rc>>, } - pub trait CircuitExt { fn leaf(&self, v: V) -> CircuitNode; fn get_node(&self, id: NodeId) -> Option>; @@ -97,7 +102,7 @@ impl Add for CircuitNode { type Output = Self; fn add(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().node(Node::Sum(self.id, rhs.id)); + let id = self.circuit.borrow_mut().add(self.id, rhs.id); CircuitNode { id, circuit: self.circuit } } } @@ -106,7 +111,62 @@ impl Mul for CircuitNode { type Output = Self; fn mul(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().node(Node::Prod(self.id, rhs.id)); + let id = self.circuit.borrow_mut().mul(self.id, rhs.id); CircuitNode { id, circuit: self.circuit } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::poly::var::StaticVar; + + #[test] + fn test_deduplication() { + let circuit = Rc::new(RefCell::new(Circuit::new())); + + // Same leaf constructed twice returns the same NodeId + let x1 = circuit.leaf(StaticVar::from("x")); + let x2 = circuit.leaf(StaticVar::from("x")); + assert_eq!(x1.id, x2.id); + assert_eq!(circuit.borrow().len(), 1); + + // Same sum constructed twice returns the same NodeId + let _y = circuit.leaf(StaticVar::from("y")); + let sum1 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); + let sum2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); + assert_eq!(sum1.id, sum2.id); + assert_eq!(circuit.borrow().len(), 3); // x, y, x+y + + // Shared subexpression: (x + y) * (x + y) reuses the x+y node + let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); + let xy2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); + let _sq = xy * xy2; + assert_eq!(circuit.borrow().len(), 4); // x, y, x+y, (x+y)*(x+y) + + // Commutativity: x+y and y+x are the same node + let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); + let yx = circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("x")); + assert_eq!(xy.id, yx.id); + + // Associativity: (x+y)+z and x+(y+z) are distinct nodes + let _z = circuit.leaf(StaticVar::from("z")); + let xy_z = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"))) + + circuit.leaf(StaticVar::from("z")); + let x_yz = circuit.leaf(StaticVar::from("x")) + + (circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("z"))); + assert_ne!(xy_z.id, x_yz.id); + + // Deep shared structure: (x+y)*z appears twice in ((x+y)*z) + ((x+y)*z) + let xyz1 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"))) + * circuit.leaf(StaticVar::from("z")); + let xyz2 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"))) + * circuit.leaf(StaticVar::from("z")); + assert_eq!(xyz1.id, xyz2.id); + let _sum = xyz1 + xyz2; + // x, y, z, x+y(==y+x), (x+y)*z, (x+y)+z, y+z, x+(y+z), (x+y)*z+(x+y)*z, sq + assert_eq!(circuit.borrow().len(), 10); + } +} + +