wip refactor

This commit is contained in:
2026-04-27 22:05:46 +02:00
parent 3a83340c8f
commit 723033f3aa
7 changed files with 234 additions and 134 deletions

View File

@@ -1,11 +1,12 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::rc::Rc; use std::rc::Rc;
use circuit_cas::circuit::dag::{Circuit, CircuitExt}; use circuit_cas::circuit::traits::Circuit;
use circuit_cas::circuit::dag::{ProbCircuit, CircuitExt};
use circuit_cas::var; use circuit_cas::var;
fn main() { fn main() {
let circuit = Rc::new(RefCell::new(Circuit::new())); let mut circuit = ProbCircuit::new();
// Build (x + y) * (x + z) // Build (x + y) * (x + z)
let x = circuit.leaf(var!("x")); let x = circuit.leaf(var!("x"));

View File

@@ -1,57 +1,27 @@
use slotmap::{SlotMap, new_key_type}; use slotmap::{SlotMap, new_key_type};
use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::ops::{Add, Mul};
use std::rc::Rc;
use crate::poly::var::Var; use super::traits::{Circuit,Node};
new_key_type! { pub struct NodeId; } new_key_type! { pub struct NodeId; }
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Node<V: Var> {
Leaf(V),
Scale(NodeId, i32),
Sum(NodeId, NodeId),
Prod(NodeId, NodeId),
}
impl<V: Var> Node<V> {
pub fn children(&self) -> impl Iterator<Item = NodeId> {
match self {
Node::Leaf(_) => [None, None],
Node::Scale(n, _) => [Some(*n), None],
Node::Sum(l, r) | Node::Prod(l, r) => [Some(*l), Some(*r)],
}
.into_iter()
.flatten()
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Circuit<V: Var> { pub struct Dag<N: Node> {
nodes: SlotMap<NodeId, Node<V>>, nodes: SlotMap<NodeId, N>,
intern: HashMap<Node<V>, NodeId>, intern: HashMap<N, NodeId>,
} }
impl<V: Var> Default for Circuit<V> { impl<N: Node> Default for Dag<N> {
fn default() -> Self { fn default() -> Self {
Circuit { Dag {
nodes: Default::default(), nodes: Default::default(),
intern: Default::default(), intern: Default::default(),
} }
} }
} }
impl<V: Var> Circuit<V> { impl<N:Node> Dag<N>{
pub fn new() -> Self { fn node(&mut self, n: N) -> NodeId {
Circuit {
nodes: SlotMap::with_key(),
intern: HashMap::new(),
}
}
pub fn node(&mut self, n: Node<V>) -> NodeId {
if let Some(&id) = self.intern.get(&n) { if let Some(&id) = self.intern.get(&n) {
return id; return id;
} }
@@ -60,90 +30,44 @@ impl<V: Var> Circuit<V> {
id id
} }
pub fn get(&self, id: NodeId) -> Option<&Node<V>> { fn get(&self, id: NodeId) -> Option<&N> {
self.nodes.get(id) self.nodes.get(id)
} }
pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { fn len(&self) -> usize {
let (l, r) = if left <= right {
(left, right)
} else {
(right, left)
};
self.node(Node::Sum(l, r))
}
pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId {
let (l, r) = if left <= right {
(left, right)
} else {
(right, left)
};
self.node(Node::Prod(l, r))
}
pub fn len(&self) -> usize {
self.nodes.len() self.nodes.len()
} }
pub fn children(&self, id: NodeId) -> impl Iterator<Item = NodeId> + '_ { fn children(&self, id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
self.nodes.get(id).into_iter().flat_map(Node::children) self.nodes.get(id).into_iter().flat_map(Node::children)
} }
pub fn remove(&mut self, id: NodeId) { fn remove(&mut self, id: NodeId) {
if let Some(node) = self.nodes.remove(id) { if let Some(node) = self.nodes.remove(id) {
self.intern.remove(&node); self.intern.remove(&node);
} }
} }
} }
pub struct CircuitNode<V: Var> { pub struct RefNode<T:DerefMut<Target=Dag>>{
pub id: NodeId, pub id:NodeId,
circuit: Rc<RefCell<Circuit<V>>>, circuit: Rc<RefCell<T>>
}
pub trait CircuitExt<V: Var> {
fn leaf(&self, v: V) -> CircuitNode<V>;
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>>;
} }
impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> { impl<C:Circuit> RefNode<C>{
fn leaf(&self, v: V) -> CircuitNode<V> { fn leaf(&mut self, variable: C::Var)->Self{
let id = self.borrow_mut().node(Node::Leaf(v)); let mut c = self.circuit.borrow_mut();
CircuitNode { let id = c.leaf(variable);
RefNode {
id, id,
circuit: self.clone(), circuit: self.circuit.clone(),
} }
} }
fn get_node(&self, id:NodeId)->Option<Self>{
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>> { self.circuit.borrow().get(id)?;
self.borrow().get(id)?; Some(RefNode {
Some(CircuitNode {
id, id,
circuit: self.clone(), circuit: self.circuit.clone(),
}) })
} }
} }
impl<V: Var> Add for CircuitNode<V> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().add(self.id, rhs.id);
CircuitNode {
id,
circuit: self.circuit,
}
}
}
impl<V: Var> Mul for CircuitNode<V> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().mul(self.id, rhs.id);
CircuitNode {
id,
circuit: self.circuit,
}
}
}

View File

@@ -1,4 +1,6 @@
pub mod traits;
pub mod dag; pub mod dag;
pub mod probabilistic;
pub mod quotient; pub mod quotient;
#[cfg(test)] #[cfg(test)]

View File

@@ -0,0 +1,79 @@
use std::ops::{Deref, DerefMut};
use std::rc::Rc;
use crate::poly::var::Var;
use super::dag::{Dag,NodeId};
use super::traits::{Circuit,Node,RefNode};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum PNode<V: Var> {
Leaf(V),
Sum(NodeId, NodeId),
Prod(NodeId, NodeId),
}
impl<V:Var> Node for PNode<V>{
fn children(&self) -> impl Iterator<Item = NodeId> {
match self {
Self::Leaf(_) => [None, None],
Self::Sum(l, r) | Self::Prod(l, r) => [Some(*l), Some(*r)],
}
.into_iter()
.flatten()
}
}
#[derive(Clone, Debug, Default)]
pub struct ProbCircuit<V: Var> {
dag: Dag<PNode<V>>,
}
impl<V:Var> Deref for ProbCircuit<V>{
type Target = Dag<PNode<V>>;
fn deref(&self)->&Self::Target{
&self.dag
}
}
impl<V:Var> DerefMut for ProbCircuit<V>{
fn deref_mut(&mut self)->&mut Self::Target{
&mut self.dag
}
}
impl<V: Var> Rc<ProbCircuit<V>> {
pub fn var<T:Into<V>>(variable:T)->RefNode<Self>{
todo!()
}
}
impl<V:Var> Circuit for ProbCircuit<V>{
type Node=PNode<V>;
type Var=V;
fn leaf(&mut self, variable:V)->NodeId{
self.node(Self::Node::Leaf(variable))
}
fn add(&mut self, left: NodeId, right: NodeId) -> NodeId {
let (l, r) = if left <= right {
(left, right)
} else {
(right, left)
};
self.node(Self::Node::Sum(l, r))
}
fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId {
let (l, r) = if left <= right {
(left, right)
} else {
(right, left)
};
self.node(Self::Node::Prod(l, r))
}
}

View File

@@ -1,52 +1,94 @@
use super::dag::{Circuit, Node, NodeId}; use std::ops::{Deref, DerefMut};
use crate::poly::{
flat::Poly,
ideal::{Generators, GroebnerBasis, Ideal},
var::Var,
};
use std::fmt::{self, Display}; use crate::poly::var::Var;
use crate::poly::flat::Poly;
use crate::poly::ideal::{Generators, GroebnerBasis, Ideal};
use super::dag::{Dag,NodeId};
use super::traits::{Circuit,Node};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum QNode<V: Var> {
Leaf(V),
Sum(NodeId, NodeId),
Prod(NodeId, NodeId),
DivStep(NodeId, NodeId)
}
impl<V:Var> Node for QNode<V>{
fn children(&self) -> impl Iterator<Item = NodeId> {
match self {
Self::Leaf(_) => [None, None],
Self::Sum(l, r) | Self::Prod(l, r) | Self::DivStep(l,r) => [Some(*l), Some(*r)],
}
.into_iter()
.flatten()
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Quotient<V: Var> { pub struct QuotientCircuit<V: Var> {
basis: Ideal<V, GroebnerBasis>, basis: Ideal<V, GroebnerBasis>,
circuit: Circuit<V>, dag: Dag<QNode<V>>,
} }
impl<V: Var> From<Ideal<V, GroebnerBasis>> for Quotient<V> { impl<V: Var> From<Ideal<V, GroebnerBasis>> for QuotientCircuit<V> {
fn from(basis: Ideal<V, GroebnerBasis>) -> Self { fn from(basis: Ideal<V, GroebnerBasis>) -> Self {
Quotient { Self {
basis, basis,
circuit: Default::default(), dag: Default::default(),
} }
} }
} }
impl<V: Var> FromIterator<Poly<V>> for Quotient<V> { impl<V: Var> FromIterator<Poly<V>> for QuotientCircuit<V> {
fn from_iter<T: IntoIterator<Item = Poly<V>>>(iter: T) -> Self { fn from_iter<T: IntoIterator<Item = Poly<V>>>(iter: T) -> Self {
let ideal: Ideal<V, Generators> = iter.into_iter().collect(); let ideal: Ideal<V, Generators> = iter.into_iter().collect();
Quotient { Self {
basis: ideal.groebner_basis(), basis: ideal.groebner_basis(),
circuit: Default::default(), dag: Default::default(),
} }
} }
} }
impl<V: Var> Display for Quotient<V> { impl<V:Var> Deref for QuotientCircuit<V>{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { type Target = Dag<QNode<V>>;
write!(fmt, "C/{}", self.basis) fn deref(&self)->&Self::Target{
&self.dag
} }
} }
impl<V: Var> Quotient<V> { impl<V:Var> DerefMut for QuotientCircuit<V>{
pub fn node(&mut self, n: Node<V>) -> NodeId { fn deref_mut(&mut self)->&mut Self::Target{
self.circuit.node(n) &mut self.dag
}
pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId {
self.circuit.add(left, right)
}
pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId {
self.circuit.mul(left, right)
} }
} }
impl<V:Var> Circuit for QuotientCircuit<V>{
type Node=QNode<V>;
type Var=V;
fn leaf(&mut self, variable:V)->NodeId{
self.node(Self::Node::Leaf(variable))
}
fn add(&mut self, left: NodeId, right: NodeId) -> NodeId {
let (l, r) = if left <= right {
(left, right)
} else {
(right, left)
};
self.node(Self::Node::Sum(l, r))
}
fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId {
let (l, r) = if left <= right {
(left, right)
} else {
(right, left)
};
self.node(Self::Node::Prod(l, r))
}
}

View File

@@ -1,12 +1,13 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::rc::Rc; use std::rc::Rc;
use super::dag::{Circuit, CircuitExt}; use super::dag::{ProbCircuit, CircuitExt};
use super::traits::Circuit;
use crate::poly::var::StaticVar; use crate::poly::var::StaticVar;
#[test] #[test]
fn test_deduplication() { fn test_deduplication() {
let circuit = Rc::new(RefCell::new(Circuit::new())); let mut circuit = ProbCircuit::new();
// Same leaf constructed twice returns the same NodeId // Same leaf constructed twice returns the same NodeId
let x1 = circuit.leaf(StaticVar::from("x")); let x1 = circuit.leaf(StaticVar::from("x"));
@@ -48,5 +49,5 @@ fn test_deduplication() {
assert_eq!(xyz1.id, xyz2.id); assert_eq!(xyz1.id, xyz2.id);
let _sum = xyz1 + xyz2; let _sum = xyz1 + xyz2;
// x, y, z, x+y(==y+x), (x+y)*z, (x+y)+z, y+z, x+(y+z), (x+y)*z+(x+y)*z, sq // x, y, z, x+y(==y+x), (x+y)*z, (x+y)+z, y+z, x+(y+z), (x+y)*z+(x+y)*z, sq
assert_eq!(circuit.borrow().len(), 10); assert_eq!(circuit.len(), 10);
} }

51
src/circuit/traits.rs Normal file
View File

@@ -0,0 +1,51 @@
use std::hash::Hash;
use std::ops::{Add,Mul};
use std::{rc::Rc,cell::RefCell};
use super::dag::NodeId;
pub trait Node:Clone+PartialEq+Eq+Hash{
fn children(&self) -> impl Iterator<Item = NodeId>;
}
pub trait Circuit:Clone{
type Var;
type Node;
fn node(&mut self, n:Self::Node)->NodeId;
fn remove(&mut self, id:NodeId);
fn get(&self, id:NodeId)->Option<&Self::Node>;
fn leaf(&mut self, var: Self::Var)->NodeId;
fn add(&mut self, left: NodeId, right: NodeId)->NodeId;
fn mul(&mut self, left: NodeId, right: NodeId)->NodeId;
fn len(&self)->usize;
fn children(&self, id:NodeId)->impl Iterator<Item=NodeId>+'_;
}
impl<C:Circuit> Add for RefNode<C> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().add(self.id, rhs.id);
RefNode {
id,
circuit: self.circuit,
}
}
}
impl<C:Circuit> Mul for RefNode<C> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().mul(self.id, rhs.id);
RefNode {
id,
circuit: self.circuit,
}
}
}