From 627a2d88f4924cf450e7fcc3a6e5973dee3c5c03 Mon Sep 17 00:00:00 2001 From: asteri Date: Mon, 27 Apr 2026 22:40:55 +0200 Subject: [PATCH] refactor wip --- examples/dag.rs | 9 +-- examples/quotient.rs | 7 +-- src/circuit/dag.rs | 46 +++++++-------- src/circuit/probabilistic.rs | 107 ++++++++++++++++++++--------------- src/circuit/quotient.rs | 99 +++++++++++++++++--------------- src/circuit/tests.rs | 15 ++--- src/circuit/traits.rs | 57 ++++++------------- 7 files changed, 163 insertions(+), 177 deletions(-) diff --git a/examples/dag.rs b/examples/dag.rs index 4ef9222..758c41e 100644 --- a/examples/dag.rs +++ b/examples/dag.rs @@ -1,12 +1,9 @@ -use std::cell::RefCell; -use std::rc::Rc; - -use circuit_cas::circuit::traits::Circuit; -use circuit_cas::circuit::dag::{ProbCircuit, CircuitExt}; +use circuit_cas::circuit::probabilistic::ProbCircuit; +use circuit_cas::circuit::dag::CircuitExt; use circuit_cas::var; fn main() { - let mut circuit = ProbCircuit::new(); + let circuit = ProbCircuit::new(); // Build (x + y) * (x + z) let x = circuit.leaf(var!("x")); diff --git a/examples/quotient.rs b/examples/quotient.rs index 67ed28d..8371884 100644 --- a/examples/quotient.rs +++ b/examples/quotient.rs @@ -1,6 +1,3 @@ -use std::cell::RefCell; -use std::rc::Rc; - use circuit_cas::circuit::quotient::Quotient; use circuit_cas::poly::var::StaticVar; use circuit_cas::var; @@ -15,7 +12,7 @@ fn main() { 1 * ((&x ^ 1) * (&nx ^ 1)) - 1 * (&x ^ 1), ]; - let mut quotient: Quotient = idem.into_iter().collect(); + let quotient: Quotient = idem.into_iter().collect(); - println!("{quotient}"); + println!("{quotient:?}"); } diff --git a/src/circuit/dag.rs b/src/circuit/dag.rs index ac5c335..50e21ef 100644 --- a/src/circuit/dag.rs +++ b/src/circuit/dag.rs @@ -1,7 +1,9 @@ use slotmap::{SlotMap, new_key_type}; +use std::cell::RefCell; use std::collections::HashMap; +use std::rc::Rc; -use super::traits::{Circuit,Node}; +use super::traits::{Circuit, Node}; new_key_type! { pub struct NodeId; } @@ -20,8 +22,8 @@ impl Default for Dag { } } -impl Dag{ - fn node(&mut self, n: N) -> NodeId { +impl Dag { + pub(super) fn node(&mut self, n: N) -> NodeId { if let Some(&id) = self.intern.get(&n) { return id; } @@ -30,44 +32,40 @@ impl Dag{ id } - fn get(&self, id: NodeId) -> Option<&N> { + pub(super) fn get(&self, id: NodeId) -> Option<&N> { self.nodes.get(id) } - fn len(&self) -> usize { + pub(super) fn len(&self) -> usize { self.nodes.len() } - fn children(&self, id: NodeId) -> impl Iterator + '_ { + pub(super) fn children(&self, id: NodeId) -> impl Iterator + '_ { self.nodes.get(id).into_iter().flat_map(Node::children) } - fn remove(&mut self, id: NodeId) { + pub(super) fn remove(&mut self, id: NodeId) { if let Some(node) = self.nodes.remove(id) { self.intern.remove(&node); } } } -pub struct RefNode>{ - pub id:NodeId, - circuit: Rc> +pub struct RefNode { + pub id: NodeId, + pub(super) circuit: Rc>, } -impl RefNode{ - fn leaf(&mut self, variable: C::Var)->Self{ - let mut c = self.circuit.borrow_mut(); - let id = c.leaf(variable); - RefNode { - id, - circuit: self.circuit.clone(), - } - } - fn get_node(&self, id:NodeId)->Option{ +impl RefNode { + pub fn get_node(&self, id: NodeId) -> Option { self.circuit.borrow().get(id)?; - Some(RefNode { - id, - circuit: self.circuit.clone(), - }) + Some(RefNode { id, circuit: self.circuit.clone() }) } } + +pub trait CircuitExt { + type C: Circuit; + type Var; + fn leaf(&self, var: Self::Var) -> RefNode; + fn len(&self) -> usize; +} diff --git a/src/circuit/probabilistic.rs b/src/circuit/probabilistic.rs index b0e7652..51ca2db 100644 --- a/src/circuit/probabilistic.rs +++ b/src/circuit/probabilistic.rs @@ -1,9 +1,10 @@ -use std::ops::{Deref, DerefMut}; +use std::cell::RefCell; +use std::ops::{Add, Deref, DerefMut, Mul}; use std::rc::Rc; use crate::poly::var::Var; -use super::dag::{Dag,NodeId}; -use super::traits::{Circuit,Node,RefNode}; +use super::dag::{CircuitExt, Dag, NodeId, RefNode}; +use super::traits::Node; #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -13,8 +14,7 @@ pub enum PNode { Prod(NodeId, NodeId), } - -impl Node for PNode{ +impl Node for PNode { fn children(&self) -> impl Iterator { match self { Self::Leaf(_) => [None, None], @@ -25,55 +25,68 @@ impl Node for PNode{ } } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug)] pub struct ProbCircuit { dag: Dag>, } -impl Deref for ProbCircuit{ +impl Default for ProbCircuit { + fn default() -> Self { ProbCircuit { dag: Dag::default() } } +} + +impl ProbCircuit { + pub fn new() -> Rc> { + Rc::new(RefCell::new(Self::default())) + } + + pub fn leaf(&mut self, v: V) -> NodeId { self.node(PNode::Leaf(v)) } + + pub fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { + let (l, r) = if l <= r { (l, r) } else { (r, l) }; + self.node(PNode::Sum(l, r)) + } + + pub fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId { + let (l, r) = if l <= r { (l, r) } else { (r, l) }; + self.node(PNode::Prod(l, r)) + } +} + +impl Deref for ProbCircuit { type Target = Dag>; - fn deref(&self)->&Self::Target{ - &self.dag + fn deref(&self) -> &Self::Target { &self.dag } +} + +impl DerefMut for ProbCircuit { + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.dag } +} + +impl CircuitExt for Rc>> { + type C = ProbCircuit; + type Var = V; + + fn leaf(&self, var: V) -> RefNode> { + let id = self.borrow_mut().leaf(var); + RefNode { id, circuit: self.clone() } + } + + fn len(&self) -> usize { self.borrow().len() } +} + +impl Add for RefNode> { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + let id = self.circuit.borrow_mut().add(self.id, rhs.id); + RefNode { id, circuit: self.circuit } } } -impl DerefMut for ProbCircuit{ - fn deref_mut(&mut self)->&mut Self::Target{ - &mut self.dag +impl Mul for RefNode> { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + let id = self.circuit.borrow_mut().mul(self.id, rhs.id); + RefNode { id, circuit: self.circuit } } } - - -impl Rc> { - pub fn var>(variable:T)->RefNode{ - todo!() - } -} - -impl Circuit for ProbCircuit{ - type Node=PNode; - type Var=V; - - fn leaf(&mut self, variable:V)->NodeId{ - self.node(Self::Node::Leaf(variable)) - } - - fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { - let (l, r) = if left <= right { - (left, right) - } else { - (right, left) - }; - self.node(Self::Node::Sum(l, r)) - } - - fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId { - let (l, r) = if left <= right { - (left, right) - } else { - (right, left) - }; - self.node(Self::Node::Prod(l, r)) - } -} - diff --git a/src/circuit/quotient.rs b/src/circuit/quotient.rs index dcd2b56..799c49f 100644 --- a/src/circuit/quotient.rs +++ b/src/circuit/quotient.rs @@ -1,10 +1,12 @@ -use std::ops::{Deref, DerefMut}; +use std::cell::RefCell; +use std::ops::{Add, Deref, DerefMut, Mul}; +use std::rc::Rc; use crate::poly::var::Var; use crate::poly::flat::Poly; use crate::poly::ideal::{Generators, GroebnerBasis, Ideal}; -use super::dag::{Dag,NodeId}; -use super::traits::{Circuit,Node}; +use super::dag::{CircuitExt, Dag, NodeId, RefNode}; +use super::traits::Node; #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -12,15 +14,14 @@ pub enum QNode { Leaf(V), Sum(NodeId, NodeId), Prod(NodeId, NodeId), - DivStep(NodeId, NodeId) + DivStep(NodeId, NodeId), } - -impl Node for QNode{ +impl Node for QNode { fn children(&self) -> impl Iterator { match self { Self::Leaf(_) => [None, None], - Self::Sum(l, r) | Self::Prod(l, r) | Self::DivStep(l,r) => [Some(*l), Some(*r)], + Self::Sum(l, r) | Self::Prod(l, r) | Self::DivStep(l, r) => [Some(*l), Some(*r)], } .into_iter() .flatten() @@ -35,60 +36,68 @@ pub struct QuotientCircuit { impl From> for QuotientCircuit { fn from(basis: Ideal) -> Self { - Self { - basis, - dag: Default::default(), - } + Self { basis, dag: Default::default() } } } impl FromIterator> for QuotientCircuit { fn from_iter>>(iter: T) -> Self { let ideal: Ideal = iter.into_iter().collect(); - Self { - basis: ideal.groebner_basis(), - dag: Default::default(), - } + Self { basis: ideal.groebner_basis(), dag: Default::default() } } } -impl Deref for QuotientCircuit{ +pub type Quotient = QuotientCircuit; + +impl Deref for QuotientCircuit { type Target = Dag>; - fn deref(&self)->&Self::Target{ - &self.dag + fn deref(&self) -> &Self::Target { &self.dag } +} + +impl DerefMut for QuotientCircuit { + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.dag } +} + +impl QuotientCircuit { + pub fn leaf(&mut self, v: V) -> NodeId { self.node(QNode::Leaf(v)) } + + pub fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { + let (l, r) = if l <= r { (l, r) } else { (r, l) }; + self.node(QNode::Sum(l, r)) + } + + pub fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId { + let (l, r) = if l <= r { (l, r) } else { (r, l) }; + self.node(QNode::Prod(l, r)) } } -impl DerefMut for QuotientCircuit{ - fn deref_mut(&mut self)->&mut Self::Target{ - &mut self.dag +impl CircuitExt for Rc>> { + type C = QuotientCircuit; + type Var = V; + + fn leaf(&self, var: V) -> RefNode> { + let id = self.borrow_mut().leaf(var); + RefNode { id, circuit: self.clone() } + } + + fn len(&self) -> usize { self.borrow().len() } +} + +impl Add for RefNode> { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + let id = self.circuit.borrow_mut().add(self.id, rhs.id); + RefNode { id, circuit: self.circuit } } } -impl Circuit for QuotientCircuit{ - type Node=QNode; - type Var=V; +impl Mul for RefNode> { + type Output = Self; - fn leaf(&mut self, variable:V)->NodeId{ - self.node(Self::Node::Leaf(variable)) - } - - fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { - let (l, r) = if left <= right { - (left, right) - } else { - (right, left) - }; - self.node(Self::Node::Sum(l, r)) - } - - fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId { - let (l, r) = if left <= right { - (left, right) - } else { - (right, left) - }; - self.node(Self::Node::Prod(l, r)) + fn mul(self, rhs: Self) -> Self { + let id = self.circuit.borrow_mut().mul(self.id, rhs.id); + RefNode { id, circuit: self.circuit } } } - diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index 158614d..fd86cbb 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -1,32 +1,29 @@ -use std::cell::RefCell; -use std::rc::Rc; - -use super::dag::{ProbCircuit, CircuitExt}; -use super::traits::Circuit; +use super::probabilistic::ProbCircuit; +use super::dag::CircuitExt; use crate::poly::var::StaticVar; #[test] fn test_deduplication() { - let mut circuit = ProbCircuit::new(); + let circuit = ProbCircuit::new(); // Same leaf constructed twice returns the same NodeId let x1 = circuit.leaf(StaticVar::from("x")); let x2 = circuit.leaf(StaticVar::from("x")); assert_eq!(x1.id, x2.id); - assert_eq!(circuit.borrow().len(), 1); + assert_eq!(circuit.len(), 1); // Same sum constructed twice returns the same NodeId let _y = circuit.leaf(StaticVar::from("y")); let sum1 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); let sum2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); assert_eq!(sum1.id, sum2.id); - assert_eq!(circuit.borrow().len(), 3); // x, y, x+y + assert_eq!(circuit.len(), 3); // x, y, x+y // Shared subexpression: (x + y) * (x + y) reuses the x+y node let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); let xy2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); let _sq = xy * xy2; - assert_eq!(circuit.borrow().len(), 4); // x, y, x+y, (x+y)*(x+y) + assert_eq!(circuit.len(), 4); // x, y, x+y, (x+y)*(x+y) // Commutativity: x+y and y+x are the same node let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")); diff --git a/src/circuit/traits.rs b/src/circuit/traits.rs index 75b7bf0..e7d6900 100644 --- a/src/circuit/traits.rs +++ b/src/circuit/traits.rs @@ -1,51 +1,26 @@ use std::hash::Hash; -use std::ops::{Add,Mul}; -use std::{rc::Rc,cell::RefCell}; +use std::ops::DerefMut; -use super::dag::NodeId; +use super::dag::{Dag, NodeId}; -pub trait Node:Clone+PartialEq+Eq+Hash{ +pub trait Node: Clone + PartialEq + Eq + Hash { fn children(&self) -> impl Iterator; } -pub trait Circuit:Clone{ - type Var; - type Node; - - fn node(&mut self, n:Self::Node)->NodeId; - fn remove(&mut self, id:NodeId); - fn get(&self, id:NodeId)->Option<&Self::Node>; +pub trait Circuit: Clone + DerefMut> { + type Node: Node; - fn leaf(&mut self, var: Self::Var)->NodeId; - fn add(&mut self, left: NodeId, right: NodeId)->NodeId; - fn mul(&mut self, left: NodeId, right: NodeId)->NodeId; - - fn len(&self)->usize; - fn children(&self, id:NodeId)->impl Iterator+'_; + fn node(&mut self, n: Self::Node) -> NodeId { (**self).node(n) } + fn remove(&mut self, id: NodeId) { (**self).remove(id) } + fn get(&self, id: NodeId) -> Option<&Self::Node> { (**self).get(id) } + fn len(&self) -> usize { (**self).len() } + fn children(&self, id: NodeId) -> impl Iterator + '_ { (**self).children(id) } } - -impl Add for RefNode { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().add(self.id, rhs.id); - RefNode { - id, - circuit: self.circuit, - } - } +impl Circuit for T +where + T: Clone + DerefMut>, + N: Node, +{ + type Node = N; } - -impl Mul for RefNode { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().mul(self.id, rhs.id); - RefNode { - id, - circuit: self.circuit, - } - } -} -