diff --git a/examples/dag.rs b/examples/dag.rs index a5d188c..4ef9222 100644 --- a/examples/dag.rs +++ b/examples/dag.rs @@ -1,11 +1,12 @@ use std::cell::RefCell; use std::rc::Rc; -use circuit_cas::circuit::dag::{Circuit, CircuitExt}; +use circuit_cas::circuit::traits::Circuit; +use circuit_cas::circuit::dag::{ProbCircuit, CircuitExt}; use circuit_cas::var; fn main() { - let circuit = Rc::new(RefCell::new(Circuit::new())); + let mut circuit = ProbCircuit::new(); // Build (x + y) * (x + z) let x = circuit.leaf(var!("x")); diff --git a/src/circuit/dag.rs b/src/circuit/dag.rs index dc2d4cf..ac5c335 100644 --- a/src/circuit/dag.rs +++ b/src/circuit/dag.rs @@ -1,57 +1,27 @@ use slotmap::{SlotMap, new_key_type}; -use std::cell::RefCell; use std::collections::HashMap; -use std::ops::{Add, Mul}; -use std::rc::Rc; -use crate::poly::var::Var; +use super::traits::{Circuit,Node}; new_key_type! { pub struct NodeId; } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum Node { - Leaf(V), - Scale(NodeId, i32), - Sum(NodeId, NodeId), - Prod(NodeId, NodeId), -} - -impl Node { - pub fn children(&self) -> impl Iterator { - match self { - Node::Leaf(_) => [None, None], - Node::Scale(n, _) => [Some(*n), None], - Node::Sum(l, r) | Node::Prod(l, r) => [Some(*l), Some(*r)], - } - .into_iter() - .flatten() - } -} - #[derive(Clone, Debug)] -pub struct Circuit { - nodes: SlotMap>, - intern: HashMap, NodeId>, +pub struct Dag { + nodes: SlotMap, + intern: HashMap, } -impl Default for Circuit { +impl Default for Dag { fn default() -> Self { - Circuit { + Dag { nodes: Default::default(), intern: Default::default(), } } } -impl Circuit { - pub fn new() -> Self { - Circuit { - nodes: SlotMap::with_key(), - intern: HashMap::new(), - } - } - - pub fn node(&mut self, n: Node) -> NodeId { +impl Dag{ + fn node(&mut self, n: N) -> NodeId { if let Some(&id) = self.intern.get(&n) { return id; } @@ -60,90 +30,44 @@ impl Circuit { id } - pub fn get(&self, id: NodeId) -> Option<&Node> { + fn get(&self, id: NodeId) -> Option<&N> { self.nodes.get(id) } - pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { - let (l, r) = if left <= right { - (left, right) - } else { - (right, left) - }; - self.node(Node::Sum(l, r)) - } - - pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId { - let (l, r) = if left <= right { - (left, right) - } else { - (right, left) - }; - self.node(Node::Prod(l, r)) - } - - pub fn len(&self) -> usize { + fn len(&self) -> usize { self.nodes.len() } - pub fn children(&self, id: NodeId) -> impl Iterator + '_ { + fn children(&self, id: NodeId) -> impl Iterator + '_ { self.nodes.get(id).into_iter().flat_map(Node::children) } - pub fn remove(&mut self, id: NodeId) { + fn remove(&mut self, id: NodeId) { if let Some(node) = self.nodes.remove(id) { self.intern.remove(&node); } } } -pub struct CircuitNode { - pub id: NodeId, - circuit: Rc>>, -} -pub trait CircuitExt { - fn leaf(&self, v: V) -> CircuitNode; - fn get_node(&self, id: NodeId) -> Option>; +pub struct RefNode>{ + pub id:NodeId, + circuit: Rc> } -impl CircuitExt for Rc>> { - fn leaf(&self, v: V) -> CircuitNode { - let id = self.borrow_mut().node(Node::Leaf(v)); - CircuitNode { +impl RefNode{ + fn leaf(&mut self, variable: C::Var)->Self{ + let mut c = self.circuit.borrow_mut(); + let id = c.leaf(variable); + RefNode { id, - circuit: self.clone(), + circuit: self.circuit.clone(), } } - - fn get_node(&self, id: NodeId) -> Option> { - self.borrow().get(id)?; - Some(CircuitNode { + fn get_node(&self, id:NodeId)->Option{ + self.circuit.borrow().get(id)?; + Some(RefNode { id, - circuit: self.clone(), + circuit: self.circuit.clone(), }) } } - -impl Add for CircuitNode { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().add(self.id, rhs.id); - CircuitNode { - id, - circuit: self.circuit, - } - } -} - -impl Mul for CircuitNode { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - let id = self.circuit.borrow_mut().mul(self.id, rhs.id); - CircuitNode { - id, - circuit: self.circuit, - } - } -} diff --git a/src/circuit/mod.rs b/src/circuit/mod.rs index e0e3241..efe1130 100644 --- a/src/circuit/mod.rs +++ b/src/circuit/mod.rs @@ -1,4 +1,6 @@ +pub mod traits; pub mod dag; +pub mod probabilistic; pub mod quotient; #[cfg(test)] diff --git a/src/circuit/probabilistic.rs b/src/circuit/probabilistic.rs new file mode 100644 index 0000000..b0e7652 --- /dev/null +++ b/src/circuit/probabilistic.rs @@ -0,0 +1,79 @@ +use std::ops::{Deref, DerefMut}; +use std::rc::Rc; + +use crate::poly::var::Var; +use super::dag::{Dag,NodeId}; +use super::traits::{Circuit,Node,RefNode}; + + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum PNode { + Leaf(V), + Sum(NodeId, NodeId), + Prod(NodeId, NodeId), +} + + +impl Node for PNode{ + fn children(&self) -> impl Iterator { + match self { + Self::Leaf(_) => [None, None], + Self::Sum(l, r) | Self::Prod(l, r) => [Some(*l), Some(*r)], + } + .into_iter() + .flatten() + } +} + +#[derive(Clone, Debug, Default)] +pub struct ProbCircuit { + dag: Dag>, +} + +impl Deref for ProbCircuit{ + type Target = Dag>; + fn deref(&self)->&Self::Target{ + &self.dag + } +} + +impl DerefMut for ProbCircuit{ + fn deref_mut(&mut self)->&mut Self::Target{ + &mut self.dag + } +} + + +impl Rc> { + pub fn var>(variable:T)->RefNode{ + todo!() + } +} + +impl Circuit for ProbCircuit{ + type Node=PNode; + type Var=V; + + fn leaf(&mut self, variable:V)->NodeId{ + self.node(Self::Node::Leaf(variable)) + } + + fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { + let (l, r) = if left <= right { + (left, right) + } else { + (right, left) + }; + self.node(Self::Node::Sum(l, r)) + } + + fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId { + let (l, r) = if left <= right { + (left, right) + } else { + (right, left) + }; + self.node(Self::Node::Prod(l, r)) + } +} + diff --git a/src/circuit/quotient.rs b/src/circuit/quotient.rs index f5c86c4..dcd2b56 100644 --- a/src/circuit/quotient.rs +++ b/src/circuit/quotient.rs @@ -1,52 +1,94 @@ -use super::dag::{Circuit, Node, NodeId}; -use crate::poly::{ - flat::Poly, - ideal::{Generators, GroebnerBasis, Ideal}, - var::Var, -}; +use std::ops::{Deref, DerefMut}; -use std::fmt::{self, Display}; +use crate::poly::var::Var; +use crate::poly::flat::Poly; +use crate::poly::ideal::{Generators, GroebnerBasis, Ideal}; +use super::dag::{Dag,NodeId}; +use super::traits::{Circuit,Node}; + + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum QNode { + Leaf(V), + Sum(NodeId, NodeId), + Prod(NodeId, NodeId), + DivStep(NodeId, NodeId) +} + + +impl Node for QNode{ + fn children(&self) -> impl Iterator { + match self { + Self::Leaf(_) => [None, None], + Self::Sum(l, r) | Self::Prod(l, r) | Self::DivStep(l,r) => [Some(*l), Some(*r)], + } + .into_iter() + .flatten() + } +} #[derive(Clone, Debug)] -pub struct Quotient { +pub struct QuotientCircuit { basis: Ideal, - circuit: Circuit, + dag: Dag>, } -impl From> for Quotient { +impl From> for QuotientCircuit { fn from(basis: Ideal) -> Self { - Quotient { + Self { basis, - circuit: Default::default(), + dag: Default::default(), } } } -impl FromIterator> for Quotient { +impl FromIterator> for QuotientCircuit { fn from_iter>>(iter: T) -> Self { let ideal: Ideal = iter.into_iter().collect(); - Quotient { + Self { basis: ideal.groebner_basis(), - circuit: Default::default(), + dag: Default::default(), } } } -impl Display for Quotient { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(fmt, "C/{}", self.basis) +impl Deref for QuotientCircuit{ + type Target = Dag>; + fn deref(&self)->&Self::Target{ + &self.dag } } -impl Quotient { - pub fn node(&mut self, n: Node) -> NodeId { - self.circuit.node(n) - } - - pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { - self.circuit.add(left, right) - } - pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId { - self.circuit.mul(left, right) +impl DerefMut for QuotientCircuit{ + fn deref_mut(&mut self)->&mut Self::Target{ + &mut self.dag } } + +impl Circuit for QuotientCircuit{ + type Node=QNode; + type Var=V; + + fn leaf(&mut self, variable:V)->NodeId{ + self.node(Self::Node::Leaf(variable)) + } + + fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { + let (l, r) = if left <= right { + (left, right) + } else { + (right, left) + }; + self.node(Self::Node::Sum(l, r)) + } + + fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId { + let (l, r) = if left <= right { + (left, right) + } else { + (right, left) + }; + self.node(Self::Node::Prod(l, r)) + } +} + diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index 201dfa7..158614d 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -1,12 +1,13 @@ use std::cell::RefCell; use std::rc::Rc; -use super::dag::{Circuit, CircuitExt}; +use super::dag::{ProbCircuit, CircuitExt}; +use super::traits::Circuit; use crate::poly::var::StaticVar; #[test] fn test_deduplication() { - let circuit = Rc::new(RefCell::new(Circuit::new())); + let mut circuit = ProbCircuit::new(); // Same leaf constructed twice returns the same NodeId let x1 = circuit.leaf(StaticVar::from("x")); @@ -48,5 +49,5 @@ fn test_deduplication() { assert_eq!(xyz1.id, xyz2.id); let _sum = xyz1 + xyz2; // x, y, z, x+y(==y+x), (x+y)*z, (x+y)+z, y+z, x+(y+z), (x+y)*z+(x+y)*z, sq - assert_eq!(circuit.borrow().len(), 10); + assert_eq!(circuit.len(), 10); } diff --git a/src/circuit/traits.rs b/src/circuit/traits.rs new file mode 100644 index 0000000..75b7bf0 --- /dev/null +++ b/src/circuit/traits.rs @@ -0,0 +1,51 @@ +use std::hash::Hash; +use std::ops::{Add,Mul}; +use std::{rc::Rc,cell::RefCell}; + +use super::dag::NodeId; + +pub trait Node:Clone+PartialEq+Eq+Hash{ + fn children(&self) -> impl Iterator; +} + +pub trait Circuit:Clone{ + type Var; + type Node; + + fn node(&mut self, n:Self::Node)->NodeId; + fn remove(&mut self, id:NodeId); + fn get(&self, id:NodeId)->Option<&Self::Node>; + + fn leaf(&mut self, var: Self::Var)->NodeId; + fn add(&mut self, left: NodeId, right: NodeId)->NodeId; + fn mul(&mut self, left: NodeId, right: NodeId)->NodeId; + + fn len(&self)->usize; + fn children(&self, id:NodeId)->impl Iterator+'_; +} + + +impl Add for RefNode { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + let id = self.circuit.borrow_mut().add(self.id, rhs.id); + RefNode { + id, + circuit: self.circuit, + } + } +} + +impl Mul for RefNode { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + let id = self.circuit.borrow_mut().mul(self.id, rhs.id); + RefNode { + id, + circuit: self.circuit, + } + } +} +