move tests and implement dag

This commit is contained in:
2026-04-22 15:59:19 +02:00
parent 891546069c
commit fd05f6b024
10 changed files with 203 additions and 156 deletions

View File

@@ -2,8 +2,7 @@ use circuit_cas::poly::flat;
use circuit_cas::var; use circuit_cas::var;
fn main() { fn main() {
let poly = (2 let poly = (2 * ((var!("x", 1, 5) ^ 5) * (var!("x", 1, 2) ^ 5) * (var!("x", 2, 5) ^ 1)))
* ((var!("x", 1, 5) ^ 5) * (var!("x", 1, 2) ^ 5) * (var!("x", 2, 5) ^ 1)))
+ (3 * ((var!("x", 1, 9) ^ 5) * (var!("x", 1, 2) ^ 5) * (var!("x", 2, 5) ^ 1))); + (3 * ((var!("x", 1, 9) ^ 5) * (var!("x", 1, 2) ^ 5) * (var!("x", 2, 5) ^ 1)));
let x = var!("x"); let x = var!("x");

View File

@@ -1,4 +1,4 @@
use slotmap::{new_key_type, SlotMap}; use slotmap::{SlotMap, new_key_type};
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::ops::{Add, Mul}; use std::ops::{Add, Mul};
@@ -53,12 +53,20 @@ impl<V: Var> Circuit<V> {
} }
pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId {
let (l, r) = if left <= right { (left, right) } else { (right, left) }; let (l, r) = if left <= right {
(left, right)
} else {
(right, left)
};
self.node(Node::Sum(l, r)) self.node(Node::Sum(l, r))
} }
pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId { pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId {
let (l, r) = if left <= right { (left, right) } else { (right, left) }; let (l, r) = if left <= right {
(left, right)
} else {
(right, left)
};
self.node(Node::Prod(l, r)) self.node(Node::Prod(l, r))
} }
@@ -89,12 +97,18 @@ pub trait CircuitExt<V: Var> {
impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> { impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> {
fn leaf(&self, v: V) -> CircuitNode<V> { fn leaf(&self, v: V) -> CircuitNode<V> {
let id = self.borrow_mut().node(Node::Leaf(v)); let id = self.borrow_mut().node(Node::Leaf(v));
CircuitNode { id, circuit: self.clone() } CircuitNode {
id,
circuit: self.clone(),
}
} }
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>> { fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>> {
self.borrow().get(id)?; self.borrow().get(id)?;
Some(CircuitNode { id, circuit: self.clone() }) Some(CircuitNode {
id,
circuit: self.clone(),
})
} }
} }
@@ -103,7 +117,10 @@ impl<V: Var> Add for CircuitNode<V> {
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().add(self.id, rhs.id); let id = self.circuit.borrow_mut().add(self.id, rhs.id);
CircuitNode { id, circuit: self.circuit } CircuitNode {
id,
circuit: self.circuit,
}
} }
} }
@@ -112,7 +129,10 @@ impl<V: Var> Mul for CircuitNode<V> {
fn mul(self, rhs: Self) -> Self { fn mul(self, rhs: Self) -> Self {
let id = self.circuit.borrow_mut().mul(self.id, rhs.id); let id = self.circuit.borrow_mut().mul(self.id, rhs.id);
CircuitNode { id, circuit: self.circuit } CircuitNode {
id,
circuit: self.circuit,
}
} }
} }
@@ -168,5 +188,3 @@ mod tests {
assert_eq!(circuit.borrow().len(), 10); assert_eq!(circuit.borrow().len(), 10);
} }
} }

View File

@@ -1 +1,2 @@
pub mod dag; pub mod dag;
pub mod quotient;

1
src/circuit/quotient.rs Normal file
View File

@@ -0,0 +1 @@

View File

@@ -1,6 +1,6 @@
pub fn num_to_subscript(s: String) -> String { pub fn num_to_subscript(s: String) -> String {
s.chars().map(|c| match c{ s.chars()
.map(|c| match c {
'0' => '\u{2080}', '0' => '\u{2080}',
'1' => '\u{2081}', '1' => '\u{2081}',
'2' => '\u{2082}', '2' => '\u{2082}',
@@ -11,12 +11,14 @@ pub fn num_to_subscript(s:String)->String{
'7' => '\u{2087}', '7' => '\u{2087}',
'8' => '\u{2088}', '8' => '\u{2088}',
'9' => '\u{2089}', '9' => '\u{2089}',
_=>c _ => c,
}).collect() })
.collect()
} }
pub fn num_to_superscript(s: String) -> String { pub fn num_to_superscript(s: String) -> String {
s.chars().map(|c| match c{ s.chars()
.map(|c| match c {
'0' => '\u{2070}', '0' => '\u{2070}',
'1' => '\u{20B9}', '1' => '\u{20B9}',
'2' => '\u{00B2}', '2' => '\u{00B2}',
@@ -27,8 +29,7 @@ pub fn num_to_superscript(s:String)->String{
'7' => '\u{2077}', '7' => '\u{2077}',
'8' => '\u{2078}', '8' => '\u{2078}',
'9' => '\u{2079}', '9' => '\u{2079}',
_=>c _ => c,
}).collect() })
.collect()
} }

View File

@@ -1,3 +1,3 @@
pub mod poly;
pub mod circuit; pub mod circuit;
pub mod fmt; pub mod fmt;
pub mod poly;

View File

@@ -53,7 +53,7 @@ impl<V: Var, U: Into<Mono<V>>> FromIterator<(i32, U)> for Poly<V> {
} }
} }
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] #[derive(Clone, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Mono<V: Var> { pub struct Mono<V: Var> {
pub term: Vec<(V, u32)>, pub term: Vec<(V, u32)>,
} }
@@ -103,9 +103,11 @@ impl<V: Var, U: Into<V>> FromIterator<(U, u32)> for Mono<V> {
term.sort(); term.sort();
// Check duplicate variables // Check duplicate variables
assert!((term[..]) assert!(
(term[..])
.windows(2) .windows(2)
.all(|window| window[0].0 != window[1].0)); .all(|window| window[0].0 != window[1].0)
);
Mono { term } Mono { term }
} }
@@ -195,6 +197,22 @@ impl<V: Var> Mul<Mono<V>> for i32 {
} }
} }
impl<V: Var> Mul for Poly<V> {
type Output = Poly<V>;
fn mul(self, other: Poly<V>) -> Self::Output {
let mut result = Poly::default();
for (m1, c1) in &self.mono {
for (m2, c2) in &other.mono {
let entry = result.mono.entry(m1.clone() * m2.clone()).or_insert(0);
*entry += c1 * c2;
}
}
result.mono.retain(|_, &mut c| c != 0);
result
}
}
impl<V: Var> Add for Poly<V> { impl<V: Var> Add for Poly<V> {
type Output = Poly<V>; type Output = Poly<V>;
@@ -220,107 +238,3 @@ impl<V: Var> Sub for Poly<V> {
self self
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mono_contains() {
let a: Mono<StaticVar> = [("x", 2), ("y", 1)].into();
// Lower exponent of same variable is contained
assert!(a.contains(&Mono::from([("x", 1)])));
// Higher exponent of same variable is not contained
assert!(!a.contains(&Mono::from([("x", 3)])));
// Identical monomial is contained
assert!(a.contains(&Mono::from([("x", 2), ("y", 1)])));
// Variable absent from self is not contained
assert!(!a.contains(&Mono::from([("x", 2), ("z", 1)])));
// Subset of variables with lower exponents is contained
assert!(a.contains(&Mono::from([("x", 1), ("y", 1)])));
// Single variable with exact exponent is contained
assert!(a.contains(&Mono::from([("x", 2)])));
// Insufficient exponent in self means not contained
assert!(!Mono::<StaticVar>::from([("x", 1)]).contains(&Mono::from([("x", 2)])));
// Missing variable in self means not contained
assert!(!Mono::<StaticVar>::from([("x", 1), ("y", 1)]).contains(&Mono::from([("x", 2)])));
assert!(!Mono::<StaticVar>::from([("x", 1)]).contains(&Mono::from([("x", 1), ("y", 1)])));
}
#[test]
fn test_mono_mul() {
// Same variable: exponents add
let a: Mono<StaticVar> = [("x", 2)].into();
let b: Mono<StaticVar> = [("x", 3)].into();
assert_eq!(a * b, Mono::from([("x", 5)]));
// Disjoint variables: both appear in result
let a: Mono<StaticVar> = [("x", 2)].into();
let b: Mono<StaticVar> = [("y", 3)].into();
assert_eq!(a * b, Mono::from([("x", 2), ("y", 3)]));
// Mixed: shared and disjoint variables
let a: Mono<StaticVar> = [("x", 1), ("y", 2)].into();
let b: Mono<StaticVar> = [("y", 1), ("z", 3)].into();
assert_eq!(a * b, Mono::from([("x", 1), ("y", 3), ("z", 3)]));
// Commutativity
let a: Mono<StaticVar> = [("x", 2), ("z", 1)].into();
let b: Mono<StaticVar> = [("y", 3)].into();
assert_eq!(a.clone() * b.clone(), b * a);
// Multiply by constant monomial (empty term vec = 1)
let a: Mono<StaticVar> = [("x", 4)].into();
let one: Mono<StaticVar> = Mono { term: vec![] };
assert_eq!(a.clone() * one, a);
}
#[test]
fn test_poly_add() {
// Distinct monomials are collected as separate terms
let a: Poly<StaticVar> = [(1, [("x", 2)]), (2, [("y", 1)])].into();
let b: Poly<StaticVar> = [(3, [("z", 1)])].into();
let expected: Poly<StaticVar> = [(1, [("x", 2)]), (2, [("y", 1)]), (3, [("z", 1)])].into();
assert_eq!(a + b, expected);
// Coefficients of matching monomials are summed
let a: Poly<StaticVar> = [(2, [("x", 1)])].into();
let b: Poly<StaticVar> = [(3, [("x", 1)])].into();
let expected: Poly<StaticVar> = [(5, [("x", 1)])].into();
assert_eq!(a + b, expected);
// Terms that cancel sum to zero are dropped
let a: Poly<StaticVar> = [(1, [("x", 1)])].into();
let b: Poly<StaticVar> = [(-1, [("x", 1)])].into();
let expected: Poly<StaticVar> = Poly::default();
assert_eq!(a + b, expected);
}
#[test]
fn test_poly_sub() {
// Distinct monomials are collected as separate terms with negated rhs coefficients
let a: Poly<StaticVar> = [(3, [("x", 2)])].into();
let b: Poly<StaticVar> = [(1, [("y", 1)])].into();
let expected: Poly<StaticVar> = [(3, [("x", 2)]), (-1, [("y", 1)])].into();
assert_eq!(a - b, expected);
// Coefficients of matching monomials are subtracted
let a: Poly<StaticVar> = [(5, [("x", 1)])].into();
let b: Poly<StaticVar> = [(3, [("x", 1)])].into();
let expected: Poly<StaticVar> = [(2, [("x", 1)])].into();
assert_eq!(a - b, expected);
// Subtracting equal polynomials yields zero
let a: Poly<StaticVar> = [(4, [("x", 2)]), (1, [("y", 1)])].into();
let b: Poly<StaticVar> = [(4, [("x", 2)]), (1, [("y", 1)])].into();
assert_eq!(a - b, Poly::default());
}
}

View File

@@ -1,2 +1,5 @@
pub mod flat; pub mod flat;
pub mod var; pub mod var;
#[cfg(test)]
mod tests;

102
src/poly/tests.rs Normal file
View File

@@ -0,0 +1,102 @@
use super::flat::{Mono, Poly};
use super::var::StaticVar;
#[test]
fn test_mono_contains() {
let a: Mono<StaticVar> = [("x", 2), ("y", 1)].into();
// Lower exponent of same variable is contained
assert!(a.contains(&Mono::from([("x", 1)])));
// Higher exponent of same variable is not contained
assert!(!a.contains(&Mono::from([("x", 3)])));
// Identical monomial is contained
assert!(a.contains(&Mono::from([("x", 2), ("y", 1)])));
// Variable absent from self is not contained
assert!(!a.contains(&Mono::from([("x", 2), ("z", 1)])));
// Subset of variables with lower exponents is contained
assert!(a.contains(&Mono::from([("x", 1), ("y", 1)])));
// Single variable with exact exponent is contained
assert!(a.contains(&Mono::from([("x", 2)])));
// Insufficient exponent in self means not contained
assert!(!Mono::<StaticVar>::from([("x", 1)]).contains(&Mono::from([("x", 2)])));
// Missing variable in self means not contained
assert!(!Mono::<StaticVar>::from([("x", 1), ("y", 1)]).contains(&Mono::from([("x", 2)])));
assert!(!Mono::<StaticVar>::from([("x", 1)]).contains(&Mono::from([("x", 1), ("y", 1)])));
}
#[test]
fn test_mono_mul() {
// Same variable: exponents add
let a: Mono<StaticVar> = [("x", 2)].into();
let b: Mono<StaticVar> = [("x", 3)].into();
assert_eq!(a * b, Mono::from([("x", 5)]));
// Disjoint variables: both appear in result
let a: Mono<StaticVar> = [("x", 2)].into();
let b: Mono<StaticVar> = [("y", 3)].into();
assert_eq!(a * b, Mono::from([("x", 2), ("y", 3)]));
// Mixed: shared and disjoint variables
let a: Mono<StaticVar> = [("x", 1), ("y", 2)].into();
let b: Mono<StaticVar> = [("y", 1), ("z", 3)].into();
assert_eq!(a * b, Mono::from([("x", 1), ("y", 3), ("z", 3)]));
// Commutativity
let a: Mono<StaticVar> = [("x", 2), ("z", 1)].into();
let b: Mono<StaticVar> = [("y", 3)].into();
assert_eq!(a.clone() * b.clone(), b * a);
// Multiply by constant monomial (empty term vec = 1)
let a: Mono<StaticVar> = [("x", 4)].into();
let one: Mono<StaticVar> = Mono { term: vec![] };
assert_eq!(a.clone() * one, a);
}
#[test]
fn test_poly_add() {
// Distinct monomials are collected as separate terms
let a: Poly<StaticVar> = [(1, [("x", 2)]), (2, [("y", 1)])].into();
let b: Poly<StaticVar> = [(3, [("z", 1)])].into();
let expected: Poly<StaticVar> = [(1, [("x", 2)]), (2, [("y", 1)]), (3, [("z", 1)])].into();
assert_eq!(a + b, expected);
// Coefficients of matching monomials are summed
let a: Poly<StaticVar> = [(2, [("x", 1)])].into();
let b: Poly<StaticVar> = [(3, [("x", 1)])].into();
let expected: Poly<StaticVar> = [(5, [("x", 1)])].into();
assert_eq!(a + b, expected);
// Terms that cancel sum to zero are dropped
let a: Poly<StaticVar> = [(1, [("x", 1)])].into();
let b: Poly<StaticVar> = [(-1, [("x", 1)])].into();
let expected: Poly<StaticVar> = Poly::default();
assert_eq!(a + b, expected);
}
#[test]
fn test_poly_sub() {
// Distinct monomials are collected as separate terms with negated rhs coefficients
let a: Poly<StaticVar> = [(3, [("x", 2)])].into();
let b: Poly<StaticVar> = [(1, [("y", 1)])].into();
let expected: Poly<StaticVar> = [(3, [("x", 2)]), (-1, [("y", 1)])].into();
assert_eq!(a - b, expected);
// Coefficients of matching monomials are subtracted
let a: Poly<StaticVar> = [(5, [("x", 1)])].into();
let b: Poly<StaticVar> = [(3, [("x", 1)])].into();
let expected: Poly<StaticVar> = [(2, [("x", 1)])].into();
assert_eq!(a - b, expected);
// Subtracting equal polynomials yields zero
let a: Poly<StaticVar> = [(4, [("x", 2)]), (1, [("y", 1)])].into();
let b: Poly<StaticVar> = [(4, [("x", 2)]), (1, [("y", 1)])].into();
assert_eq!(a - b, Poly::default());
}

View File

@@ -17,7 +17,15 @@ impl Display for StaticVar {
let num_indices = self.indices.len(); let num_indices = self.indices.len();
match num_indices { match num_indices {
0 => write!(fmt, "{}", self.name), 0 => write!(fmt, "{}", self.name),
_ => write!(fmt, "{}{}", self.name, self.indices.iter().map(|x| num_to_subscript(x.to_string())).join(",")), _ => write!(
fmt,
"{}{}",
self.name,
self.indices
.iter()
.map(|x| num_to_subscript(x.to_string()))
.join(",")
),
} }
} }
} }