improve api and simplify SumProdCircuit creation

This commit is contained in:
2026-04-27 23:15:06 +02:00
parent 627a2d88f4
commit 3fbd7e3773
7 changed files with 132 additions and 132 deletions

View File

@@ -1,21 +1,22 @@
use std::cell::RefCell;
use std::rc::Rc;
use circuit_cas::circuit::probabilistic::ProbCircuit;
use circuit_cas::circuit::dag::CircuitExt;
use circuit_cas::var;
fn main() {
let circuit = ProbCircuit::new();
let circuit: Rc<RefCell<ProbCircuit>> = ProbCircuit::new();
// Build (x + y) * (x + z)
let x = circuit.leaf(var!("x"));
let y = circuit.leaf(var!("y"));
let z = circuit.leaf(var!("z"));
let x = circuit.var("x");
let y = circuit.var("y");
let z = circuit.var("z");
let x_plus_y = circuit.leaf(var!("x")) + y;
let x_plus_z = circuit.leaf(var!("x")) + z;
let x_plus_y = circuit.var("x") + y;
let x_plus_z = circuit.var("x") + z;
let expr = x_plus_y * x_plus_z;
// Deduplication: both x leaves share the same NodeId
let x2 = circuit.leaf(var!("x"));
let x2 = circuit.var("x");
assert_eq!(x.id, x2.id);
println!("(x + y) * (x + z) root node id: {:?}", expr.id);

View File

@@ -1,5 +1,8 @@
use circuit_cas::circuit::quotient::Quotient;
use circuit_cas::poly::var::StaticVar;
use std::cell::RefCell;
use std::rc::Rc;
use circuit_cas::circuit::dag::CircuitExt;
use circuit_cas::circuit::quotient::QuotientCircuit;
use circuit_cas::circuit::traits::Circuit;
use circuit_cas::var;
fn main() {
@@ -12,7 +15,15 @@ fn main() {
1 * ((&x ^ 1) * (&nx ^ 1)) - 1 * (&x ^ 1),
];
let quotient: Quotient<StaticVar> = idem.into_iter().collect();
let quotient: Rc<RefCell<QuotientCircuit>> = idem.into_iter().collect();
println!("{quotient:?}");
// Build x * x̄ + x in the DAG
let xn = quotient.var("x");
let nxn = quotient.var("x\u{0304}");
let prod = xn * nxn;
let xn2 = quotient.var("x");
let expr = prod + xn2;
println!("dag size: {}", quotient.borrow().len());
println!("expr node id: {:?}", expr.id);
}

View File

@@ -56,16 +56,9 @@ pub struct RefNode<C: Circuit> {
pub(super) circuit: Rc<RefCell<C>>,
}
impl<C: Circuit> RefNode<C> {
pub fn get_node(&self, id: NodeId) -> Option<Self> {
self.circuit.borrow().get(id)?;
Some(RefNode { id, circuit: self.circuit.clone() })
}
}
pub trait CircuitExt {
type C: Circuit;
type Var;
fn leaf(&self, var: Self::Var) -> RefNode<Self::C>;
fn len(&self) -> usize;
fn var(&self, v: impl Into<Self::Var>) -> RefNode<Self::C>;
}

View File

@@ -1,15 +1,15 @@
use std::cell::RefCell;
use std::ops::{Add, Deref, DerefMut, Mul};
use std::ops::{Deref, DerefMut};
use std::rc::Rc;
use crate::poly::var::Var;
use super::dag::{CircuitExt, Dag, NodeId, RefNode};
use super::traits::Node;
use crate::poly::var::{StaticVar, Var};
use super::dag::{Dag, NodeId};
use super::traits::{Node, SumProdCircuit};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum PNode<V: Var> {
Leaf(V),
Var(V),
Sum(NodeId, NodeId),
Prod(NodeId, NodeId),
}
@@ -17,7 +17,7 @@ pub enum PNode<V: Var> {
impl<V: Var> Node for PNode<V> {
fn children(&self) -> impl Iterator<Item = NodeId> {
match self {
Self::Leaf(_) => [None, None],
Self::Var(_) => [None, None],
Self::Sum(l, r) | Self::Prod(l, r) => [Some(*l), Some(*r)],
}
.into_iter()
@@ -26,7 +26,7 @@ impl<V: Var> Node for PNode<V> {
}
#[derive(Clone, Debug)]
pub struct ProbCircuit<V: Var> {
pub struct ProbCircuit<V: Var = StaticVar> {
dag: Dag<PNode<V>>,
}
@@ -38,15 +38,19 @@ impl<V: Var> ProbCircuit<V> {
pub fn new() -> Rc<RefCell<Self>> {
Rc::new(RefCell::new(Self::default()))
}
}
pub fn leaf(&mut self, v: V) -> NodeId { self.node(PNode::Leaf(v)) }
impl<V: Var> SumProdCircuit for ProbCircuit<V> {
type Var = V;
pub fn add(&mut self, l: NodeId, r: NodeId) -> NodeId {
fn var(&mut self, v: V) -> NodeId { self.node(PNode::Var(v)) }
fn add(&mut self, l: NodeId, r: NodeId) -> NodeId {
let (l, r) = if l <= r { (l, r) } else { (r, l) };
self.node(PNode::Sum(l, r))
}
pub fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId {
fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId {
let (l, r) = if l <= r { (l, r) } else { (r, l) };
self.node(PNode::Prod(l, r))
}
@@ -61,32 +65,4 @@ impl<V: Var> DerefMut for ProbCircuit<V> {
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.dag }
}
impl<V: Var> CircuitExt for Rc<RefCell<ProbCircuit<V>>> {
type C = ProbCircuit<V>;
type Var = V;
fn leaf(&self, var: V) -> RefNode<ProbCircuit<V>> {
let id = self.borrow_mut().leaf(var);
RefNode { id, circuit: self.clone() }
}
fn len(&self) -> usize { self.borrow().len() }
}
impl<V: Var> Add for RefNode<ProbCircuit<V>> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().add(self.id, rhs.id);
RefNode { id, circuit: self.circuit }
}
}
impl<V: Var> Mul for RefNode<ProbCircuit<V>> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().mul(self.id, rhs.id);
RefNode { id, circuit: self.circuit }
}
}

View File

@@ -1,17 +1,17 @@
use std::cell::RefCell;
use std::ops::{Add, Deref, DerefMut, Mul};
use std::ops::{Deref, DerefMut};
use std::rc::Rc;
use crate::poly::var::Var;
use crate::poly::var::{StaticVar, Var};
use crate::poly::flat::Poly;
use crate::poly::ideal::{Generators, GroebnerBasis, Ideal};
use super::dag::{CircuitExt, Dag, NodeId, RefNode};
use super::traits::Node;
use super::dag::{Dag, NodeId};
use super::traits::{Node, SumProdCircuit};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum QNode<V: Var> {
Leaf(V),
Var(V),
Sum(NodeId, NodeId),
Prod(NodeId, NodeId),
DivStep(NodeId, NodeId),
@@ -20,7 +20,7 @@ pub enum QNode<V: Var> {
impl<V: Var> Node for QNode<V> {
fn children(&self) -> impl Iterator<Item = NodeId> {
match self {
Self::Leaf(_) => [None, None],
Self::Var(_) => [None, None],
Self::Sum(l, r) | Self::Prod(l, r) | Self::DivStep(l, r) => [Some(*l), Some(*r)],
}
.into_iter()
@@ -29,25 +29,33 @@ impl<V: Var> Node for QNode<V> {
}
#[derive(Clone, Debug)]
pub struct QuotientCircuit<V: Var> {
pub struct QuotientCircuit<V: Var = StaticVar> {
basis: Ideal<V, GroebnerBasis>,
dag: Dag<QNode<V>>,
}
impl<V: Var> From<Ideal<V, GroebnerBasis>> for QuotientCircuit<V> {
fn from(basis: Ideal<V, GroebnerBasis>) -> Self {
Self { basis, dag: Default::default() }
}
impl<V: Var> QuotientCircuit<V> {
pub fn from_ideal(basis: Ideal<V, GroebnerBasis>) -> Rc<RefCell<Self>> {
Rc::new(RefCell::new(Self { basis, dag: Default::default() }))
}
impl<V: Var> FromIterator<Poly<V>> for QuotientCircuit<V> {
fn from_iter<T: IntoIterator<Item = Poly<V>>>(iter: T) -> Self {
pub fn from_polys(iter: impl IntoIterator<Item = Poly<V>>) -> Rc<RefCell<Self>> {
let ideal: Ideal<V, Generators> = iter.into_iter().collect();
Self { basis: ideal.groebner_basis(), dag: Default::default() }
Rc::new(RefCell::new(Self { basis: ideal.groebner_basis(), dag: Default::default() }))
}
}
pub type Quotient<V> = QuotientCircuit<V>;
impl<V: Var> From<Ideal<V, GroebnerBasis>> for Rc<RefCell<QuotientCircuit<V>>> {
fn from(basis: Ideal<V, GroebnerBasis>) -> Self {
QuotientCircuit::from_ideal(basis)
}
}
impl<V: Var> FromIterator<Poly<V>> for Rc<RefCell<QuotientCircuit<V>>> {
fn from_iter<T: IntoIterator<Item = Poly<V>>>(iter: T) -> Self {
QuotientCircuit::from_polys(iter)
}
}
impl<V: Var> Deref for QuotientCircuit<V> {
type Target = Dag<QNode<V>>;
@@ -58,46 +66,20 @@ impl<V: Var> DerefMut for QuotientCircuit<V> {
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.dag }
}
impl<V: Var> QuotientCircuit<V> {
pub fn leaf(&mut self, v: V) -> NodeId { self.node(QNode::Leaf(v)) }
impl<V: Var> SumProdCircuit for QuotientCircuit<V> {
type Var = V;
pub fn add(&mut self, l: NodeId, r: NodeId) -> NodeId {
fn var(&mut self, v: V) -> NodeId { self.node(QNode::Var(v)) }
fn add(&mut self, l: NodeId, r: NodeId) -> NodeId {
let (l, r) = if l <= r { (l, r) } else { (r, l) };
self.node(QNode::Sum(l, r))
}
pub fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId {
fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId {
let (l, r) = if l <= r { (l, r) } else { (r, l) };
self.node(QNode::Prod(l, r))
}
}
impl<V: Var> CircuitExt for Rc<RefCell<QuotientCircuit<V>>> {
type C = QuotientCircuit<V>;
type Var = V;
fn leaf(&self, var: V) -> RefNode<QuotientCircuit<V>> {
let id = self.borrow_mut().leaf(var);
RefNode { id, circuit: self.clone() }
}
fn len(&self) -> usize { self.borrow().len() }
}
impl<V: Var> Add for RefNode<QuotientCircuit<V>> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().add(self.id, rhs.id);
RefNode { id, circuit: self.circuit }
}
}
impl<V: Var> Mul for RefNode<QuotientCircuit<V>> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().mul(self.id, rhs.id);
RefNode { id, circuit: self.circuit }
}
}

View File

@@ -1,50 +1,50 @@
use std::cell::RefCell;
use std::rc::Rc;
use super::probabilistic::ProbCircuit;
use super::dag::CircuitExt;
use crate::poly::var::StaticVar;
#[test]
fn test_deduplication() {
let circuit = ProbCircuit::new();
let circuit: Rc<RefCell<ProbCircuit>> = ProbCircuit::new();
// Same leaf constructed twice returns the same NodeId
let x1 = circuit.leaf(StaticVar::from("x"));
let x2 = circuit.leaf(StaticVar::from("x"));
let x1 = circuit.var("x");
let x2 = circuit.var("x");
assert_eq!(x1.id, x2.id);
assert_eq!(circuit.len(), 1);
assert_eq!(circuit.borrow().len(), 1);
// Same sum constructed twice returns the same NodeId
let _y = circuit.leaf(StaticVar::from("y"));
let sum1 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
let sum2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
let _y = circuit.var("y");
let sum1 = circuit.var("x") + circuit.var("y");
let sum2 = circuit.var("x") + circuit.var("y");
assert_eq!(sum1.id, sum2.id);
assert_eq!(circuit.len(), 3); // x, y, x+y
assert_eq!(circuit.borrow().len(), 3); // x, y, x+y
// Shared subexpression: (x + y) * (x + y) reuses the x+y node
let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
let xy2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
let xy = circuit.var("x") + circuit.var("y");
let xy2 = circuit.var("x") + circuit.var("y");
let _sq = xy * xy2;
assert_eq!(circuit.len(), 4); // x, y, x+y, (x+y)*(x+y)
assert_eq!(circuit.borrow().len(), 4); // x, y, x+y, (x+y)*(x+y)
// Commutativity: x+y and y+x are the same node
let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
let yx = circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("x"));
let xy = circuit.var("x") + circuit.var("y");
let yx = circuit.var("y") + circuit.var("x");
assert_eq!(xy.id, yx.id);
// Associativity: (x+y)+z and x+(y+z) are distinct nodes
let _z = circuit.leaf(StaticVar::from("z"));
let xy_z = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")))
+ circuit.leaf(StaticVar::from("z"));
let x_yz = circuit.leaf(StaticVar::from("x"))
+ (circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("z")));
let _z = circuit.var("z");
let xy_z = (circuit.var("x") + circuit.var("y"))
+ circuit.var("z");
let x_yz = circuit.var("x")
+ (circuit.var("y") + circuit.var("z"));
assert_ne!(xy_z.id, x_yz.id);
// Deep shared structure: (x+y)*z appears twice in ((x+y)*z) + ((x+y)*z)
let xyz1 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")))
* circuit.leaf(StaticVar::from("z"));
let xyz2 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")))
* circuit.leaf(StaticVar::from("z"));
let xyz1 = (circuit.var("x") + circuit.var("y"))
* circuit.var("z");
let xyz2 = (circuit.var("x") + circuit.var("y"))
* circuit.var("z");
assert_eq!(xyz1.id, xyz2.id);
let _sum = xyz1 + xyz2;
// x, y, z, x+y(==y+x), (x+y)*z, (x+y)+z, y+z, x+(y+z), (x+y)*z+(x+y)*z, sq
assert_eq!(circuit.len(), 10);
assert_eq!(circuit.borrow().len(), 10);
}

View File

@@ -1,7 +1,9 @@
use std::cell::RefCell;
use std::hash::Hash;
use std::ops::DerefMut;
use std::ops::{Add, DerefMut, Mul};
use std::rc::Rc;
use super::dag::{Dag, NodeId};
use super::dag::{CircuitExt, Dag, NodeId, RefNode};
pub trait Node: Clone + PartialEq + Eq + Hash {
fn children(&self) -> impl Iterator<Item = NodeId>;
@@ -24,3 +26,38 @@ where
{
type Node = N;
}
pub trait SumProdCircuit: Circuit {
type Var;
fn var(&mut self, v: Self::Var) -> NodeId;
fn add(&mut self, l: NodeId, r: NodeId) -> NodeId;
fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId;
}
impl<C: SumProdCircuit> Add for RefNode<C> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().add(self.id, rhs.id);
RefNode { id, circuit: self.circuit }
}
}
impl<C: SumProdCircuit> Mul for RefNode<C> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().mul(self.id, rhs.id);
RefNode { id, circuit: self.circuit }
}
}
impl<C: SumProdCircuit> CircuitExt for Rc<RefCell<C>> {
type C = C;
type Var = C::Var;
fn var(&self, v: impl Into<C::Var>) -> RefNode<C> {
let id = self.borrow_mut().var(v.into());
RefNode { id, circuit: self.clone() }
}
}