improve var creation for SumProdCircuit

This commit is contained in:
2026-04-28 17:29:11 +02:00
parent 0c4bddf3b0
commit 5caccbaf50
6 changed files with 12 additions and 11 deletions

View File

@@ -6,10 +6,10 @@ use circuit_cas::circuit::dag::CircuitExt;
fn main() { fn main() {
let circuit: Rc<RefCell<ProbCircuit>> = ProbCircuit::new(); let circuit: Rc<RefCell<ProbCircuit>> = ProbCircuit::new();
// Build (x + y) * (x + z) // vars accept anything that implements Into<StaticVar>: &'static str, (&str, u32), (&str, u32, u32)
let x = circuit.var("x"); let x = circuit.var("x");
let y = circuit.var("y"); let y = circuit.var(("y", 1)); // indexed variable y_1
let z = circuit.var("z"); let z = circuit.var(("z", 0, 1)); // doubly-indexed variable z_{0,1}
let x_plus_y = circuit.var("x") + y; let x_plus_y = circuit.var("x") + y;
let x_plus_z = circuit.var("x") + z; let x_plus_z = circuit.var("x") + z;
@@ -19,7 +19,7 @@ fn main() {
let x2 = circuit.var("x"); let x2 = circuit.var("x");
assert_eq!(x.id, x2.id); assert_eq!(x.id, x2.id);
println!("(x + y) * (x + z) root node id: {:?}", expr.id); println!("(x + y_1) * (x + z_{{0,1}}) root node id: {:?}", expr.id);
println!("x node id: {:?}", x.id); println!("x node id: {:?}", x.id);
println!("x deduplicated node id: {:?}", x2.id); println!("x deduplicated node id: {:?}", x2.id);
} }

View File

@@ -19,6 +19,7 @@ fn main() {
let quotient: Rc<RefCell<QuotientCircuit>> = idem.into(); let quotient: Rc<RefCell<QuotientCircuit>> = idem.into();
// Build x * x̄ + x in the DAG // Build x * x̄ + x in the DAG
// var accepts anything that implements Into<StaticVar>: &'static str, (&str, u32), (&str, u32, u32)
let xn = quotient.var("x"); let xn = quotient.var("x");
let nxn = quotient.var("x\u{0304}"); let nxn = quotient.var("x\u{0304}");
let prod = xn * nxn; let prod = xn * nxn;

View File

@@ -60,5 +60,5 @@ pub struct RefNode<C: Circuit> {
pub trait CircuitExt { pub trait CircuitExt {
type C: Circuit; type C: Circuit;
type Var; type Var;
fn var(&self, v: impl Into<Self::Var>) -> RefNode<Self::C>; fn var<T: Into<Self::Var>>(&self, v: T) -> RefNode<Self::C>;
} }

View File

@@ -43,7 +43,7 @@ impl<V: Var> ProbCircuit<V> {
impl<V: Var> SumProdCircuit for ProbCircuit<V> { impl<V: Var> SumProdCircuit for ProbCircuit<V> {
type Var = V; type Var = V;
fn var(&mut self, v: V) -> NodeId { self.node(PNode::Var(v)) } fn var<T: Into<V>>(&mut self, v: T) -> NodeId { self.node(PNode::Var(v.into())) }
fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { fn add(&mut self, l: NodeId, r: NodeId) -> NodeId {
let (l, r) = if l <= r { (l, r) } else { (r, l) }; let (l, r) = if l <= r { (l, r) } else { (r, l) };

View File

@@ -57,7 +57,7 @@ impl<V: Var> DerefMut for QuotientCircuit<V> {
impl<V: Var> SumProdCircuit for QuotientCircuit<V> { impl<V: Var> SumProdCircuit for QuotientCircuit<V> {
type Var = V; type Var = V;
fn var(&mut self, v: V) -> NodeId { self.node(QNode::Var(v)) } fn var<T: Into<V>>(&mut self, v: T) -> NodeId { self.node(QNode::Var(v.into())) }
fn add(&mut self, l: NodeId, r: NodeId) -> NodeId { fn add(&mut self, l: NodeId, r: NodeId) -> NodeId {
let (l, r) = if l <= r { (l, r) } else { (r, l) }; let (l, r) = if l <= r { (l, r) } else { (r, l) };

View File

@@ -29,7 +29,7 @@ where
pub trait SumProdCircuit: Circuit { pub trait SumProdCircuit: Circuit {
type Var; type Var;
fn var(&mut self, v: Self::Var) -> NodeId; fn var<T: Into<Self::Var>>(&mut self, v: T) -> NodeId;
fn add(&mut self, l: NodeId, r: NodeId) -> NodeId; fn add(&mut self, l: NodeId, r: NodeId) -> NodeId;
fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId; fn mul(&mut self, l: NodeId, r: NodeId) -> NodeId;
} }
@@ -56,8 +56,8 @@ impl<C: SumProdCircuit> CircuitExt for Rc<RefCell<C>> {
type C = C; type C = C;
type Var = C::Var; type Var = C::Var;
fn var(&self, v: impl Into<C::Var>) -> RefNode<C> { fn var<T: Into<C::Var>>(&self, v: T) -> RefNode<C> {
let id = self.borrow_mut().var(v.into()); let id = self.borrow_mut().var(v);
RefNode { id, circuit: self.clone() } RefNode { id, circuit: self.clone() }
} }
} }