diff --git a/examples/dag.rs b/examples/dag.rs index 758c41e..a5443e6 100644 --- a/examples/dag.rs +++ b/examples/dag.rs @@ -1,21 +1,22 @@ +use std::cell::RefCell; +use std::rc::Rc; use circuit_cas::circuit::probabilistic::ProbCircuit; use circuit_cas::circuit::dag::CircuitExt; -use circuit_cas::var; fn main() { - let circuit = ProbCircuit::new(); + let circuit: Rc> = ProbCircuit::new(); // Build (x + y) * (x + z) - let x = circuit.leaf(var!("x")); - let y = circuit.leaf(var!("y")); - let z = circuit.leaf(var!("z")); + let x = circuit.var("x"); + let y = circuit.var("y"); + let z = circuit.var("z"); - let x_plus_y = circuit.leaf(var!("x")) + y; - let x_plus_z = circuit.leaf(var!("x")) + z; + let x_plus_y = circuit.var("x") + y; + let x_plus_z = circuit.var("x") + z; let expr = x_plus_y * x_plus_z; // Deduplication: both x leaves share the same NodeId - let x2 = circuit.leaf(var!("x")); + let x2 = circuit.var("x"); assert_eq!(x.id, x2.id); println!("(x + y) * (x + z) root node id: {:?}", expr.id); diff --git a/examples/quotient.rs b/examples/quotient.rs index 8371884..617b486 100644 --- a/examples/quotient.rs +++ b/examples/quotient.rs @@ -1,5 +1,8 @@ -use circuit_cas::circuit::quotient::Quotient; -use circuit_cas::poly::var::StaticVar; +use std::cell::RefCell; +use std::rc::Rc; +use circuit_cas::circuit::dag::CircuitExt; +use circuit_cas::circuit::quotient::QuotientCircuit; +use circuit_cas::circuit::traits::Circuit; use circuit_cas::var; fn main() { @@ -12,7 +15,15 @@ fn main() { 1 * ((&x ^ 1) * (&nx ^ 1)) - 1 * (&x ^ 1), ]; - let quotient: Quotient = idem.into_iter().collect(); + let quotient: Rc> = idem.into_iter().collect(); - println!("{quotient:?}"); + // Build x * x̄ + x in the DAG + let xn = quotient.var("x"); + let nxn = quotient.var("x\u{0304}"); + let prod = xn * nxn; + let xn2 = quotient.var("x"); + let expr = prod + xn2; + + println!("dag size: {}", quotient.borrow().len()); + println!("expr node id: {:?}", expr.id); } diff --git a/src/circuit/dag.rs b/src/circuit/dag.rs index 50e21ef..49fe59e 100644 --- a/src/circuit/dag.rs +++ b/src/circuit/dag.rs @@ -56,16 +56,9 @@ pub struct RefNode { pub(super) circuit: Rc>, } -impl RefNode { - pub fn get_node(&self, id: NodeId) -> Option { - self.circuit.borrow().get(id)?; - Some(RefNode { id, circuit: self.circuit.clone() }) - } -} pub trait CircuitExt { type C: Circuit; type Var; - fn leaf(&self, var: Self::Var) -> RefNode; - fn len(&self) -> usize; + fn var(&self, v: impl Into) -> RefNode; } diff --git a/src/circuit/probabilistic.rs b/src/circuit/probabilistic.rs index 51ca2db..4e92dd3 100644 --- a/src/circuit/probabilistic.rs +++ b/src/circuit/probabilistic.rs @@ -1,15 +1,15 @@ use std::cell::RefCell; -use std::ops::{Add, Deref, DerefMut, Mul}; +use std::ops::{Deref, DerefMut}; use std::rc::Rc; -use crate::poly::var::Var; -use super::dag::{CircuitExt, Dag, NodeId, RefNode}; -use super::traits::Node; +use crate::poly::var::{StaticVar, Var}; +use super::dag::{Dag, NodeId}; +use super::traits::{Node, SumProdCircuit}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum PNode { - Leaf(V), + Var(V), Sum(NodeId, NodeId), Prod(NodeId, NodeId), } @@ -17,7 +17,7 @@ pub enum PNode { impl Node for PNode { fn children(&self) -> impl Iterator { match self { - Self::Leaf(_) => [None, None], + Self::Var(_) => [None, None], Self::Sum(l, r) | Self::Prod(l, r) => [Some(*l), Some(*r)], } .into_iter() @@ -26,7 +26,7 @@ impl Node for PNode { } #[derive(Clone, Debug)] -pub struct ProbCircuit { +pub struct ProbCircuit { dag: Dag>, } @@ -38,15 +38,19 @@ impl ProbCircuit { pub fn new() -> Rc> { Rc::new(RefCell::new(Self::default())) } +} - pub fn leaf(&mut self, v: V) -> NodeId { self.node(PNode::Leaf(v)) } +impl SumProdCircuit for ProbCircuit { + type Var = V; - pub fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { + fn var(&mut self, v: V) -> NodeId { self.node(PNode::Var(v)) } + + fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { let (l, r) = if l <= r { (l, r) } else { (r, l) }; self.node(PNode::Sum(l, r)) } - pub fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId { + fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId { let (l, r) = if l <= r { (l, r) } else { (r, l) }; self.node(PNode::Prod(l, r)) } @@ -61,32 +65,4 @@ impl DerefMut for ProbCircuit { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.dag } } -impl CircuitExt for Rc>> { - type C = ProbCircuit; - type Var = V; - fn leaf(&self, var: V) -> RefNode> { - let id = self.borrow_mut().leaf(var); - RefNode { id, circuit: self.clone() } - } - - fn len(&self) -> usize { self.borrow().len() } -} - -impl Add for RefNode> { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().add(self.id, rhs.id); - RefNode { id, circuit: self.circuit } - } -} - -impl Mul for RefNode> { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().mul(self.id, rhs.id); - RefNode { id, circuit: self.circuit } - } -} diff --git a/src/circuit/quotient.rs b/src/circuit/quotient.rs index 799c49f..e39593c 100644 --- a/src/circuit/quotient.rs +++ b/src/circuit/quotient.rs @@ -1,17 +1,17 @@ use std::cell::RefCell; -use std::ops::{Add, Deref, DerefMut, Mul}; +use std::ops::{Deref, DerefMut}; use std::rc::Rc; -use crate::poly::var::Var; +use crate::poly::var::{StaticVar, Var}; use crate::poly::flat::Poly; use crate::poly::ideal::{Generators, GroebnerBasis, Ideal}; -use super::dag::{CircuitExt, Dag, NodeId, RefNode}; -use super::traits::Node; +use super::dag::{Dag, NodeId}; +use super::traits::{Node, SumProdCircuit}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum QNode { - Leaf(V), + Var(V), Sum(NodeId, NodeId), Prod(NodeId, NodeId), DivStep(NodeId, NodeId), @@ -20,7 +20,7 @@ pub enum QNode { impl Node for QNode { fn children(&self) -> impl Iterator { match self { - Self::Leaf(_) => [None, None], + Self::Var(_) => [None, None], Self::Sum(l, r) | Self::Prod(l, r) | Self::DivStep(l, r) => [Some(*l), Some(*r)], } .into_iter() @@ -29,25 +29,33 @@ impl Node for QNode { } #[derive(Clone, Debug)] -pub struct QuotientCircuit { +pub struct QuotientCircuit { basis: Ideal, dag: Dag>, } -impl From> for QuotientCircuit { - fn from(basis: Ideal) -> Self { - Self { basis, dag: Default::default() } +impl QuotientCircuit { + pub fn from_ideal(basis: Ideal) -> Rc> { + Rc::new(RefCell::new(Self { basis, dag: Default::default() })) } -} -impl FromIterator> for QuotientCircuit { - fn from_iter>>(iter: T) -> Self { + pub fn from_polys(iter: impl IntoIterator>) -> Rc> { let ideal: Ideal = iter.into_iter().collect(); - Self { basis: ideal.groebner_basis(), dag: Default::default() } + Rc::new(RefCell::new(Self { basis: ideal.groebner_basis(), dag: Default::default() })) } } -pub type Quotient = QuotientCircuit; +impl From> for Rc>> { + fn from(basis: Ideal) -> Self { + QuotientCircuit::from_ideal(basis) + } +} + +impl FromIterator> for Rc>> { + fn from_iter>>(iter: T) -> Self { + QuotientCircuit::from_polys(iter) + } +} impl Deref for QuotientCircuit { type Target = Dag>; @@ -58,46 +66,20 @@ impl DerefMut for QuotientCircuit { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.dag } } -impl QuotientCircuit { - pub fn leaf(&mut self, v: V) -> NodeId { self.node(QNode::Leaf(v)) } +impl SumProdCircuit for QuotientCircuit { + type Var = V; - pub fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { + fn var(&mut self, v: V) -> NodeId { self.node(QNode::Var(v)) } + + fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { let (l, r) = if l <= r { (l, r) } else { (r, l) }; self.node(QNode::Sum(l, r)) } - pub fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId { + fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId { let (l, r) = if l <= r { (l, r) } else { (r, l) }; self.node(QNode::Prod(l, r)) } } -impl CircuitExt for Rc>> { - type C = QuotientCircuit; - type Var = V; - fn leaf(&self, var: V) -> RefNode> { - let id = self.borrow_mut().leaf(var); - RefNode { id, circuit: self.clone() } - } - - fn len(&self) -> usize { self.borrow().len() } -} - -impl Add for RefNode> { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().add(self.id, rhs.id); - RefNode { id, circuit: self.circuit } - } -} - -impl Mul for RefNode> { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().mul(self.id, rhs.id); - RefNode { id, circuit: self.circuit } - } -} diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index fd86cbb..397e8d7 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -1,50 +1,50 @@ +use std::cell::RefCell; +use std::rc::Rc; use super::probabilistic::ProbCircuit; use super::dag::CircuitExt; -use crate::poly::var::StaticVar; - #[test] fn test_deduplication() { - let circuit = ProbCircuit::new(); + let circuit: Rc> = ProbCircuit::new(); // Same leaf constructed twice returns the same NodeId - let x1 = circuit.leaf(StaticVar::from("x")); - let x2 = circuit.leaf(StaticVar::from("x")); + let x1 = circuit.var("x"); + let x2 = circuit.var("x"); assert_eq!(x1.id, x2.id); - assert_eq!(circuit.len(), 1); + assert_eq!(circuit.borrow().len(), 1); // Same sum constructed twice returns the same NodeId - let _y = circuit.leaf(StaticVar::from("y")); - let sum1 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); - let sum2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); + let _y = circuit.var("y"); + let sum1 = circuit.var("x") + circuit.var("y"); + let sum2 = circuit.var("x") + circuit.var("y"); assert_eq!(sum1.id, sum2.id); - assert_eq!(circuit.len(), 3); // x, y, x+y + assert_eq!(circuit.borrow().len(), 3); // x, y, x+y // Shared subexpression: (x + y) * (x + y) reuses the x+y node - let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); - let xy2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); + let xy = circuit.var("x") + circuit.var("y"); + let xy2 = circuit.var("x") + circuit.var("y"); let _sq = xy * xy2; - assert_eq!(circuit.len(), 4); // x, y, x+y, (x+y)*(x+y) + assert_eq!(circuit.borrow().len(), 4); // x, y, x+y, (x+y)*(x+y) // Commutativity: x+y and y+x are the same node - let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); - let yx = circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("x")); + let xy = circuit.var("x") + circuit.var("y"); + let yx = circuit.var("y") + circuit.var("x"); assert_eq!(xy.id, yx.id); // Associativity: (x+y)+z and x+(y+z) are distinct nodes - let _z = circuit.leaf(StaticVar::from("z")); - let xy_z = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"))) - + circuit.leaf(StaticVar::from("z")); - let x_yz = circuit.leaf(StaticVar::from("x")) - + (circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("z"))); + let _z = circuit.var("z"); + let xy_z = (circuit.var("x") + circuit.var("y")) + + circuit.var("z"); + let x_yz = circuit.var("x") + + (circuit.var("y") + circuit.var("z")); assert_ne!(xy_z.id, x_yz.id); // Deep shared structure: (x+y)*z appears twice in ((x+y)*z) + ((x+y)*z) - let xyz1 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"))) - * circuit.leaf(StaticVar::from("z")); - let xyz2 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"))) - * circuit.leaf(StaticVar::from("z")); + let xyz1 = (circuit.var("x") + circuit.var("y")) + * circuit.var("z"); + let xyz2 = (circuit.var("x") + circuit.var("y")) + * circuit.var("z"); assert_eq!(xyz1.id, xyz2.id); let _sum = xyz1 + xyz2; // x, y, z, x+y(==y+x), (x+y)*z, (x+y)+z, y+z, x+(y+z), (x+y)*z+(x+y)*z, sq - assert_eq!(circuit.len(), 10); + assert_eq!(circuit.borrow().len(), 10); } diff --git a/src/circuit/traits.rs b/src/circuit/traits.rs index e7d6900..60f5e88 100644 --- a/src/circuit/traits.rs +++ b/src/circuit/traits.rs @@ -1,7 +1,9 @@ +use std::cell::RefCell; use std::hash::Hash; -use std::ops::DerefMut; +use std::ops::{Add, DerefMut, Mul}; +use std::rc::Rc; -use super::dag::{Dag, NodeId}; +use super::dag::{CircuitExt, Dag, NodeId, RefNode}; pub trait Node: Clone + PartialEq + Eq + Hash { fn children(&self) -> impl Iterator; @@ -24,3 +26,38 @@ where { type Node = N; } + +pub trait SumProdCircuit: Circuit { + type Var; + fn var(&mut self, v: Self::Var) -> NodeId; + fn add(&mut self, l: NodeId, r: NodeId) -> NodeId; + fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId; +} + +impl Add for RefNode { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + let id = self.circuit.borrow_mut().add(self.id, rhs.id); + RefNode { id, circuit: self.circuit } + } +} + +impl Mul for RefNode { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + let id = self.circuit.borrow_mut().mul(self.id, rhs.id); + RefNode { id, circuit: self.circuit } + } +} + +impl CircuitExt for Rc> { + type C = C; + type Var = C::Var; + + fn var(&self, v: impl Into) -> RefNode { + let id = self.borrow_mut().var(v.into()); + RefNode { id, circuit: self.clone() } + } +}