add tests

This commit is contained in:
2026-04-22 12:13:49 +02:00
parent 99fee298c7
commit 891546069c

View File

@@ -53,11 +53,17 @@ impl<V: Var> Circuit<V> {
}
pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId {
self.node(Node::Sum(left, right))
let (l, r) = if left <= right { (left, right) } else { (right, left) };
self.node(Node::Sum(l, r))
}
pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId {
self.node(Node::Prod(left, right))
let (l, r) = if left <= right { (left, right) } else { (right, left) };
self.node(Node::Prod(l, r))
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn children(&self, id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
@@ -75,7 +81,6 @@ pub struct CircuitNode<V: Var> {
pub id: NodeId,
circuit: Rc<RefCell<Circuit<V>>>,
}
pub trait CircuitExt<V: Var> {
fn leaf(&self, v: V) -> CircuitNode<V>;
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>>;
@@ -97,7 +102,7 @@ impl<V: Var> Add for CircuitNode<V> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().node(Node::Sum(self.id, rhs.id));
let id = self.circuit.borrow_mut().add(self.id, rhs.id);
CircuitNode { id, circuit: self.circuit }
}
}
@@ -106,7 +111,62 @@ impl<V: Var> Mul for CircuitNode<V> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().node(Node::Prod(self.id, rhs.id));
let id = self.circuit.borrow_mut().mul(self.id, rhs.id);
CircuitNode { id, circuit: self.circuit }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::poly::var::StaticVar;
#[test]
fn test_deduplication() {
let circuit = Rc::new(RefCell::new(Circuit::new()));
// Same leaf constructed twice returns the same NodeId
let x1 = circuit.leaf(StaticVar::from("x"));
let x2 = circuit.leaf(StaticVar::from("x"));
assert_eq!(x1.id, x2.id);
assert_eq!(circuit.borrow().len(), 1);
// Same sum constructed twice returns the same NodeId
let _y = circuit.leaf(StaticVar::from("y"));
let sum1 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
let sum2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
assert_eq!(sum1.id, sum2.id);
assert_eq!(circuit.borrow().len(), 3); // x, y, x+y
// Shared subexpression: (x + y) * (x + y) reuses the x+y node
let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
let xy2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
let _sq = xy * xy2;
assert_eq!(circuit.borrow().len(), 4); // x, y, x+y, (x+y)*(x+y)
// Commutativity: x+y and y+x are the same node
let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
let yx = circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("x"));
assert_eq!(xy.id, yx.id);
// Associativity: (x+y)+z and x+(y+z) are distinct nodes
let _z = circuit.leaf(StaticVar::from("z"));
let xy_z = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")))
+ circuit.leaf(StaticVar::from("z"));
let x_yz = circuit.leaf(StaticVar::from("x"))
+ (circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("z")));
assert_ne!(xy_z.id, x_yz.id);
// Deep shared structure: (x+y)*z appears twice in ((x+y)*z) + ((x+y)*z)
let xyz1 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")))
* circuit.leaf(StaticVar::from("z"));
let xyz2 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")))
* circuit.leaf(StaticVar::from("z"));
assert_eq!(xyz1.id, xyz2.id);
let _sum = xyz1 + xyz2;
// x, y, z, x+y(==y+x), (x+y)*z, (x+y)+z, y+z, x+(y+z), (x+y)*z+(x+y)*z, sq
assert_eq!(circuit.borrow().len(), 10);
}
}