Compare commits
5 Commits
5ef4893f03
...
dee53e6339
| Author | SHA1 | Date | |
|---|---|---|---|
| dee53e6339 | |||
| 8c646fd920 | |||
| fd05f6b024 | |||
| 891546069c | |||
| 99fee298c7 |
@@ -1,9 +1,7 @@
|
|||||||
use circuit_cas::poly::flat;
|
|
||||||
use circuit_cas::var;
|
use circuit_cas::var;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let poly = (2
|
let poly = (2 * ((var!("x", 1, 5) ^ 5) * (var!("x", 1, 2) ^ 5) * (var!("x", 2, 5) ^ 1)))
|
||||||
* ((var!("x", 1, 5) ^ 5) * (var!("x", 1, 2) ^ 5) * (var!("x", 2, 5) ^ 1)))
|
|
||||||
+ (3 * ((var!("x", 1, 9) ^ 5) * (var!("x", 1, 2) ^ 5) * (var!("x", 2, 5) ^ 1)));
|
+ (3 * ((var!("x", 1, 9) ^ 5) * (var!("x", 1, 2) ^ 5) * (var!("x", 2, 5) ^ 1)));
|
||||||
|
|
||||||
let x = var!("x");
|
let x = var!("x");
|
||||||
@@ -11,13 +9,13 @@ fn main() {
|
|||||||
let z = var!("z");
|
let z = var!("z");
|
||||||
let other = -3 * ((&x ^ 2) * (&y ^ 4));
|
let other = -3 * ((&x ^ 2) * (&y ^ 4));
|
||||||
|
|
||||||
let mono = (&x^2)*(&y^4);
|
let mono = (&x ^ 2) * (&y ^ 4);
|
||||||
|
|
||||||
let inside = (&x^2)*(&y^2)*(&z^1);
|
let inside = (&x ^ 2) * (&y ^ 2) * (&z ^ 1);
|
||||||
|
|
||||||
if mono.contains(&inside){
|
if mono.contains(&inside) {
|
||||||
println!("{inside}\u{2286}{mono}");
|
println!("{inside}\u{2286}{mono}");
|
||||||
}else{
|
} else {
|
||||||
println!("{inside}\u{2284}{mono}");
|
println!("{inside}\u{2284}{mono}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use slotmap::{new_key_type, SlotMap};
|
use slotmap::{SlotMap, new_key_type};
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ops::{Add, Mul};
|
use std::ops::{Add, Mul};
|
||||||
@@ -15,6 +15,17 @@ pub enum Node<V: Var> {
|
|||||||
Prod(NodeId, NodeId),
|
Prod(NodeId, NodeId),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<V: Var> Node<V> {
|
||||||
|
pub fn children(&self) -> impl Iterator<Item = NodeId> {
|
||||||
|
match self {
|
||||||
|
Node::Leaf(_) => [None, None],
|
||||||
|
Node::Sum(l, r) | Node::Prod(l, r) => [Some(*l), Some(*r)],
|
||||||
|
}
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Circuit<V: Var> {
|
pub struct Circuit<V: Var> {
|
||||||
nodes: SlotMap<NodeId, Node<V>>,
|
nodes: SlotMap<NodeId, Node<V>>,
|
||||||
intern: HashMap<Node<V>, NodeId>,
|
intern: HashMap<Node<V>, NodeId>,
|
||||||
@@ -41,6 +52,32 @@ impl<V: Var> Circuit<V> {
|
|||||||
self.nodes.get(id)
|
self.nodes.get(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId {
|
||||||
|
let (l, r) = if left <= right {
|
||||||
|
(left, right)
|
||||||
|
} else {
|
||||||
|
(right, left)
|
||||||
|
};
|
||||||
|
self.node(Node::Sum(l, r))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId {
|
||||||
|
let (l, r) = if left <= right {
|
||||||
|
(left, right)
|
||||||
|
} else {
|
||||||
|
(right, left)
|
||||||
|
};
|
||||||
|
self.node(Node::Prod(l, r))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.nodes.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn children(&self, id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
|
||||||
|
self.nodes.get(id).into_iter().flat_map(Node::children)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn remove(&mut self, id: NodeId) {
|
pub fn remove(&mut self, id: NodeId) {
|
||||||
if let Some(node) = self.nodes.remove(id) {
|
if let Some(node) = self.nodes.remove(id) {
|
||||||
self.intern.remove(&node);
|
self.intern.remove(&node);
|
||||||
@@ -52,15 +89,26 @@ pub struct CircuitNode<V: Var> {
|
|||||||
pub id: NodeId,
|
pub id: NodeId,
|
||||||
circuit: Rc<RefCell<Circuit<V>>>,
|
circuit: Rc<RefCell<Circuit<V>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CircuitExt<V: Var> {
|
pub trait CircuitExt<V: Var> {
|
||||||
fn leaf(&self, v: V) -> CircuitNode<V>;
|
fn leaf(&self, v: V) -> CircuitNode<V>;
|
||||||
|
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> {
|
impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> {
|
||||||
fn leaf(&self, v: V) -> CircuitNode<V> {
|
fn leaf(&self, v: V) -> CircuitNode<V> {
|
||||||
let id = self.borrow_mut().node(Node::Leaf(v));
|
let id = self.borrow_mut().node(Node::Leaf(v));
|
||||||
CircuitNode { id, circuit: self.clone() }
|
CircuitNode {
|
||||||
|
id,
|
||||||
|
circuit: self.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>> {
|
||||||
|
self.borrow().get(id)?;
|
||||||
|
Some(CircuitNode {
|
||||||
|
id,
|
||||||
|
circuit: self.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,8 +116,11 @@ impl<V: Var> Add for CircuitNode<V> {
|
|||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
|
||||||
fn add(self, rhs: Self) -> Self {
|
fn add(self, rhs: Self) -> Self {
|
||||||
let id = self.circuit.borrow_mut().node(Node::Sum(self.id, rhs.id));
|
let id = self.circuit.borrow_mut().add(self.id, rhs.id);
|
||||||
CircuitNode { id, circuit: self.circuit }
|
CircuitNode {
|
||||||
|
id,
|
||||||
|
circuit: self.circuit,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +128,10 @@ impl<V: Var> Mul for CircuitNode<V> {
|
|||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
|
||||||
fn mul(self, rhs: Self) -> Self {
|
fn mul(self, rhs: Self) -> Self {
|
||||||
let id = self.circuit.borrow_mut().node(Node::Prod(self.id, rhs.id));
|
let id = self.circuit.borrow_mut().mul(self.id, rhs.id);
|
||||||
CircuitNode { id, circuit: self.circuit }
|
CircuitNode {
|
||||||
|
id,
|
||||||
|
circuit: self.circuit,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1 +1,6 @@
|
|||||||
pub mod dag;
|
pub mod dag;
|
||||||
|
pub mod quotient;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
|
|||||||
1
src/circuit/quotient.rs
Normal file
1
src/circuit/quotient.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
53
src/circuit/tests.rs
Normal file
53
src/circuit/tests.rs
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
|
||||||
|
use std::cell::RefCell;
|
||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
use super::dag::{Circuit, CircuitExt};
|
||||||
|
use crate::poly::var::StaticVar;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deduplication() {
|
||||||
|
let circuit = Rc::new(RefCell::new(Circuit::new()));
|
||||||
|
|
||||||
|
// Same leaf constructed twice returns the same NodeId
|
||||||
|
let x1 = circuit.leaf(StaticVar::from("x"));
|
||||||
|
let x2 = circuit.leaf(StaticVar::from("x"));
|
||||||
|
assert_eq!(x1.id, x2.id);
|
||||||
|
assert_eq!(circuit.borrow().len(), 1);
|
||||||
|
|
||||||
|
// Same sum constructed twice returns the same NodeId
|
||||||
|
let _y = circuit.leaf(StaticVar::from("y"));
|
||||||
|
let sum1 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
|
||||||
|
let sum2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
|
||||||
|
assert_eq!(sum1.id, sum2.id);
|
||||||
|
assert_eq!(circuit.borrow().len(), 3); // x, y, x+y
|
||||||
|
|
||||||
|
// Shared subexpression: (x + y) * (x + y) reuses the x+y node
|
||||||
|
let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
|
||||||
|
let xy2 = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
|
||||||
|
let _sq = xy * xy2;
|
||||||
|
assert_eq!(circuit.borrow().len(), 4); // x, y, x+y, (x+y)*(x+y)
|
||||||
|
|
||||||
|
// Commutativity: x+y and y+x are the same node
|
||||||
|
let xy = circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y"));
|
||||||
|
let yx = circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("x"));
|
||||||
|
assert_eq!(xy.id, yx.id);
|
||||||
|
|
||||||
|
// Associativity: (x+y)+z and x+(y+z) are distinct nodes
|
||||||
|
let _z = circuit.leaf(StaticVar::from("z"));
|
||||||
|
let xy_z = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")))
|
||||||
|
+ circuit.leaf(StaticVar::from("z"));
|
||||||
|
let x_yz = circuit.leaf(StaticVar::from("x"))
|
||||||
|
+ (circuit.leaf(StaticVar::from("y")) + circuit.leaf(StaticVar::from("z")));
|
||||||
|
assert_ne!(xy_z.id, x_yz.id);
|
||||||
|
|
||||||
|
// Deep shared structure: (x+y)*z appears twice in ((x+y)*z) + ((x+y)*z)
|
||||||
|
let xyz1 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")))
|
||||||
|
* circuit.leaf(StaticVar::from("z"));
|
||||||
|
let xyz2 = (circuit.leaf(StaticVar::from("x")) + circuit.leaf(StaticVar::from("y")))
|
||||||
|
* circuit.leaf(StaticVar::from("z"));
|
||||||
|
assert_eq!(xyz1.id, xyz2.id);
|
||||||
|
let _sum = xyz1 + xyz2;
|
||||||
|
// x, y, z, x+y(==y+x), (x+y)*z, (x+y)+z, y+z, x+(y+z), (x+y)*z+(x+y)*z, sq
|
||||||
|
assert_eq!(circuit.borrow().len(), 10);
|
||||||
|
}
|
||||||
63
src/fmt.rs
63
src/fmt.rs
@@ -1,34 +1,35 @@
|
|||||||
|
pub fn num_to_subscript(s: String) -> String {
|
||||||
pub fn num_to_subscript(s:String)->String{
|
s.chars()
|
||||||
s.chars().map(|c| match c{
|
.map(|c| match c {
|
||||||
'0'=>'\u{2080}',
|
'0' => '\u{2080}',
|
||||||
'1'=>'\u{2081}',
|
'1' => '\u{2081}',
|
||||||
'2'=>'\u{2082}',
|
'2' => '\u{2082}',
|
||||||
'3'=>'\u{2083}',
|
'3' => '\u{2083}',
|
||||||
'4'=>'\u{2084}',
|
'4' => '\u{2084}',
|
||||||
'5'=>'\u{2085}',
|
'5' => '\u{2085}',
|
||||||
'6'=>'\u{2086}',
|
'6' => '\u{2086}',
|
||||||
'7'=>'\u{2087}',
|
'7' => '\u{2087}',
|
||||||
'8'=>'\u{2088}',
|
'8' => '\u{2088}',
|
||||||
'9'=>'\u{2089}',
|
'9' => '\u{2089}',
|
||||||
_=>c
|
_ => c,
|
||||||
}).collect()
|
})
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn num_to_superscript(s:String)->String{
|
pub fn num_to_superscript(s: String) -> String {
|
||||||
s.chars().map(|c| match c{
|
s.chars()
|
||||||
'0'=>'\u{2070}',
|
.map(|c| match c {
|
||||||
'1'=>'\u{20B9}',
|
'0' => '\u{2070}',
|
||||||
'2'=>'\u{00B2}',
|
'1' => '\u{20B9}',
|
||||||
'3'=>'\u{00B3}',
|
'2' => '\u{00B2}',
|
||||||
'4'=>'\u{2074}',
|
'3' => '\u{00B3}',
|
||||||
'5'=>'\u{2075}',
|
'4' => '\u{2074}',
|
||||||
'6'=>'\u{2076}',
|
'5' => '\u{2075}',
|
||||||
'7'=>'\u{2077}',
|
'6' => '\u{2076}',
|
||||||
'8'=>'\u{2078}',
|
'7' => '\u{2077}',
|
||||||
'9'=>'\u{2079}',
|
'8' => '\u{2078}',
|
||||||
_=>c
|
'9' => '\u{2079}',
|
||||||
}).collect()
|
_ => c,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
pub mod poly;
|
|
||||||
pub mod circuit;
|
pub mod circuit;
|
||||||
pub mod fmt;
|
pub mod fmt;
|
||||||
|
pub mod poly;
|
||||||
|
|||||||
249
src/poly/flat.rs
249
src/poly/flat.rs
@@ -1,15 +1,10 @@
|
|||||||
use itertools::Itertools;
|
|
||||||
use std::fmt::{self, Display};
|
|
||||||
|
|
||||||
use std::ops::{Add, BitXor, Mul, Sub};
|
|
||||||
|
|
||||||
use super::var::{StaticVar, Var};
|
|
||||||
use crate::fmt::num_to_superscript;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use super::var::Var;
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq)]
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
pub struct Poly<V: Var> {
|
pub struct Poly<V: Var> {
|
||||||
mono: HashMap<Mono<V>, i32>,
|
pub(crate) mono: HashMap<Mono<V>, i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: Var> Default for Poly<V> {
|
impl<V: Var> Default for Poly<V> {
|
||||||
@@ -20,22 +15,6 @@ impl<V: Var> Default for Poly<V> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: Var> Display for Poly<V> {
|
|
||||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
|
||||||
match self.mono.is_empty() {
|
|
||||||
true => write!(fmt, "∅"),
|
|
||||||
false => write!(
|
|
||||||
fmt,
|
|
||||||
"{}",
|
|
||||||
self.mono
|
|
||||||
.iter()
|
|
||||||
.map(|(m, c)| format!("{}{}", c, m))
|
|
||||||
.join(" + ")
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<V: Var, T: IntoIterator> From<T> for Poly<V>
|
impl<V: Var, T: IntoIterator> From<T> for Poly<V>
|
||||||
where
|
where
|
||||||
Poly<V>: FromIterator<<T as IntoIterator>::Item>,
|
Poly<V>: FromIterator<<T as IntoIterator>::Item>,
|
||||||
@@ -53,7 +32,7 @@ impl<V: Var, U: Into<Mono<V>>> FromIterator<(i32, U)> for Poly<V> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
|
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||||
pub struct Mono<V: Var> {
|
pub struct Mono<V: Var> {
|
||||||
pub term: Vec<(V, u32)>,
|
pub term: Vec<(V, u32)>,
|
||||||
}
|
}
|
||||||
@@ -103,224 +82,12 @@ impl<V: Var, U: Into<V>> FromIterator<(U, u32)> for Mono<V> {
|
|||||||
term.sort();
|
term.sort();
|
||||||
|
|
||||||
// Check duplicate variables
|
// Check duplicate variables
|
||||||
assert!((term[..])
|
assert!(
|
||||||
|
(term[..])
|
||||||
.windows(2)
|
.windows(2)
|
||||||
.all(|window| window[0].0 != window[1].0));
|
.all(|window| window[0].0 != window[1].0)
|
||||||
|
);
|
||||||
|
|
||||||
Mono { term }
|
Mono { term }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: Var> Display for Mono<V> {
|
|
||||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
|
||||||
write!(
|
|
||||||
fmt,
|
|
||||||
"{}",
|
|
||||||
self.term
|
|
||||||
.iter()
|
|
||||||
.map(|(t, p)| match p {
|
|
||||||
1 => format!("{t}"),
|
|
||||||
_ => format!("{t}{}", num_to_superscript(p.to_string())),
|
|
||||||
})
|
|
||||||
.join("")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BitXor<u32> for StaticVar {
|
|
||||||
type Output = Mono<StaticVar>;
|
|
||||||
|
|
||||||
fn bitxor(self, exp: u32) -> Self::Output {
|
|
||||||
[(self, exp)].into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> BitXor<u32> for &'a StaticVar {
|
|
||||||
type Output = Mono<StaticVar>;
|
|
||||||
|
|
||||||
fn bitxor(self, exp: u32) -> Self::Output {
|
|
||||||
[(self.clone(), exp)].into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<V: Var> Mul for Mono<V> {
|
|
||||||
type Output = Self;
|
|
||||||
|
|
||||||
fn mul(self, other: Mono<V>) -> Self::Output {
|
|
||||||
let mut a_term = self.term.into_iter().peekable();
|
|
||||||
let mut b_term = other.term.into_iter().peekable();
|
|
||||||
|
|
||||||
let mut result: Vec<(V, u32)> = Default::default();
|
|
||||||
|
|
||||||
loop {
|
|
||||||
match (a_term.peek(), b_term.peek()) {
|
|
||||||
(Some((a_var, _)), Some((b_var, _))) => {
|
|
||||||
if a_var < b_var {
|
|
||||||
result.push(a_term.next().unwrap());
|
|
||||||
} else if a_var > b_var {
|
|
||||||
result.push(b_term.next().unwrap());
|
|
||||||
} else {
|
|
||||||
let (var, a_exp) = a_term.next().unwrap();
|
|
||||||
let (_, b_exp) = b_term.next().unwrap();
|
|
||||||
result.push((var, a_exp + b_exp));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(Some(a), None) => {
|
|
||||||
result.push(a.clone());
|
|
||||||
a_term.next();
|
|
||||||
}
|
|
||||||
(None, Some(b)) => {
|
|
||||||
result.push(b.clone());
|
|
||||||
b_term.next();
|
|
||||||
}
|
|
||||||
(None, None) => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Mono { term: result }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<V: Var> Mul<Mono<V>> for i32 {
|
|
||||||
type Output = Poly<V>;
|
|
||||||
|
|
||||||
fn mul(self, mono: Mono<V>) -> Self::Output {
|
|
||||||
let mut poly: HashMap<Mono<V>, i32> = Default::default();
|
|
||||||
|
|
||||||
poly.insert(mono, self);
|
|
||||||
|
|
||||||
Poly { mono: poly }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<V: Var> Add for Poly<V> {
|
|
||||||
type Output = Poly<V>;
|
|
||||||
|
|
||||||
fn add(mut self, other: Poly<V>) -> Self::Output {
|
|
||||||
for (mono, coeff) in other.mono {
|
|
||||||
let entry = self.mono.entry(mono).or_insert(0);
|
|
||||||
*entry += coeff;
|
|
||||||
}
|
|
||||||
self.mono.retain(|_, &mut coeff| coeff != 0);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<V: Var> Sub for Poly<V> {
|
|
||||||
type Output = Poly<V>;
|
|
||||||
|
|
||||||
fn sub(mut self, other: Poly<V>) -> Self::Output {
|
|
||||||
for (mono, coeff) in other.mono {
|
|
||||||
let entry = self.mono.entry(mono).or_insert(0);
|
|
||||||
*entry -= coeff;
|
|
||||||
}
|
|
||||||
self.mono.retain(|_, &mut coeff| coeff != 0);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_mono_contains() {
|
|
||||||
let a: Mono<StaticVar> = [("x", 2), ("y", 1)].into();
|
|
||||||
|
|
||||||
// Lower exponent of same variable is contained
|
|
||||||
assert!(a.contains(&Mono::from([("x", 1)])));
|
|
||||||
|
|
||||||
// Higher exponent of same variable is not contained
|
|
||||||
assert!(!a.contains(&Mono::from([("x", 3)])));
|
|
||||||
|
|
||||||
// Identical monomial is contained
|
|
||||||
assert!(a.contains(&Mono::from([("x", 2), ("y", 1)])));
|
|
||||||
|
|
||||||
// Variable absent from self is not contained
|
|
||||||
assert!(!a.contains(&Mono::from([("x", 2), ("z", 1)])));
|
|
||||||
|
|
||||||
// Subset of variables with lower exponents is contained
|
|
||||||
assert!(a.contains(&Mono::from([("x", 1), ("y", 1)])));
|
|
||||||
|
|
||||||
// Single variable with exact exponent is contained
|
|
||||||
assert!(a.contains(&Mono::from([("x", 2)])));
|
|
||||||
|
|
||||||
// Insufficient exponent in self means not contained
|
|
||||||
assert!(!Mono::<StaticVar>::from([("x", 1)]).contains(&Mono::from([("x", 2)])));
|
|
||||||
|
|
||||||
// Missing variable in self means not contained
|
|
||||||
assert!(!Mono::<StaticVar>::from([("x", 1), ("y", 1)]).contains(&Mono::from([("x", 2)])));
|
|
||||||
assert!(!Mono::<StaticVar>::from([("x", 1)]).contains(&Mono::from([("x", 1), ("y", 1)])));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_mono_mul() {
|
|
||||||
// Same variable: exponents add
|
|
||||||
let a: Mono<StaticVar> = [("x", 2)].into();
|
|
||||||
let b: Mono<StaticVar> = [("x", 3)].into();
|
|
||||||
assert_eq!(a * b, Mono::from([("x", 5)]));
|
|
||||||
|
|
||||||
// Disjoint variables: both appear in result
|
|
||||||
let a: Mono<StaticVar> = [("x", 2)].into();
|
|
||||||
let b: Mono<StaticVar> = [("y", 3)].into();
|
|
||||||
assert_eq!(a * b, Mono::from([("x", 2), ("y", 3)]));
|
|
||||||
|
|
||||||
// Mixed: shared and disjoint variables
|
|
||||||
let a: Mono<StaticVar> = [("x", 1), ("y", 2)].into();
|
|
||||||
let b: Mono<StaticVar> = [("y", 1), ("z", 3)].into();
|
|
||||||
assert_eq!(a * b, Mono::from([("x", 1), ("y", 3), ("z", 3)]));
|
|
||||||
|
|
||||||
// Commutativity
|
|
||||||
let a: Mono<StaticVar> = [("x", 2), ("z", 1)].into();
|
|
||||||
let b: Mono<StaticVar> = [("y", 3)].into();
|
|
||||||
assert_eq!(a.clone() * b.clone(), b * a);
|
|
||||||
|
|
||||||
// Multiply by constant monomial (empty term vec = 1)
|
|
||||||
let a: Mono<StaticVar> = [("x", 4)].into();
|
|
||||||
let one: Mono<StaticVar> = Mono { term: vec![] };
|
|
||||||
assert_eq!(a.clone() * one, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_poly_add() {
|
|
||||||
// Distinct monomials are collected as separate terms
|
|
||||||
let a: Poly<StaticVar> = [(1, [("x", 2)]), (2, [("y", 1)])].into();
|
|
||||||
let b: Poly<StaticVar> = [(3, [("z", 1)])].into();
|
|
||||||
let expected: Poly<StaticVar> = [(1, [("x", 2)]), (2, [("y", 1)]), (3, [("z", 1)])].into();
|
|
||||||
assert_eq!(a + b, expected);
|
|
||||||
|
|
||||||
// Coefficients of matching monomials are summed
|
|
||||||
let a: Poly<StaticVar> = [(2, [("x", 1)])].into();
|
|
||||||
let b: Poly<StaticVar> = [(3, [("x", 1)])].into();
|
|
||||||
let expected: Poly<StaticVar> = [(5, [("x", 1)])].into();
|
|
||||||
assert_eq!(a + b, expected);
|
|
||||||
|
|
||||||
// Terms that cancel sum to zero are dropped
|
|
||||||
let a: Poly<StaticVar> = [(1, [("x", 1)])].into();
|
|
||||||
let b: Poly<StaticVar> = [(-1, [("x", 1)])].into();
|
|
||||||
let expected: Poly<StaticVar> = Poly::default();
|
|
||||||
assert_eq!(a + b, expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_poly_sub() {
|
|
||||||
// Distinct monomials are collected as separate terms with negated rhs coefficients
|
|
||||||
let a: Poly<StaticVar> = [(3, [("x", 2)])].into();
|
|
||||||
let b: Poly<StaticVar> = [(1, [("y", 1)])].into();
|
|
||||||
let expected: Poly<StaticVar> = [(3, [("x", 2)]), (-1, [("y", 1)])].into();
|
|
||||||
assert_eq!(a - b, expected);
|
|
||||||
|
|
||||||
// Coefficients of matching monomials are subtracted
|
|
||||||
let a: Poly<StaticVar> = [(5, [("x", 1)])].into();
|
|
||||||
let b: Poly<StaticVar> = [(3, [("x", 1)])].into();
|
|
||||||
let expected: Poly<StaticVar> = [(2, [("x", 1)])].into();
|
|
||||||
assert_eq!(a - b, expected);
|
|
||||||
|
|
||||||
// Subtracting equal polynomials yields zero
|
|
||||||
let a: Poly<StaticVar> = [(4, [("x", 2)]), (1, [("y", 1)])].into();
|
|
||||||
let b: Poly<StaticVar> = [(4, [("x", 2)]), (1, [("y", 1)])].into();
|
|
||||||
assert_eq!(a - b, Poly::default());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
57
src/poly/fmt.rs
Normal file
57
src/poly/fmt.rs
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
use std::fmt::{self, Display, Formatter};
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
use crate::fmt::{num_to_subscript, num_to_superscript};
|
||||||
|
use crate::poly::flat::{Mono, Poly};
|
||||||
|
use crate::poly::var::{StaticVar, Var};
|
||||||
|
|
||||||
|
impl Display for StaticVar {
|
||||||
|
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmt::Error> {
|
||||||
|
let num_indices = self.indices.len();
|
||||||
|
match num_indices {
|
||||||
|
0 => write!(fmt, "{}", self.name),
|
||||||
|
_ => write!(
|
||||||
|
fmt,
|
||||||
|
"{}{}",
|
||||||
|
self.name,
|
||||||
|
self.indices
|
||||||
|
.iter()
|
||||||
|
.map(|x: &u32| num_to_subscript(x.to_string()))
|
||||||
|
.join(",")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: Var> Display for Poly<V> {
|
||||||
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||||
|
match self.mono.is_empty() {
|
||||||
|
true => write!(fmt, "∅"),
|
||||||
|
false => write!(
|
||||||
|
fmt,
|
||||||
|
"{}",
|
||||||
|
self.mono
|
||||||
|
.iter()
|
||||||
|
.map(|(m, c)| format!("{}{}", c, m))
|
||||||
|
.join(" + ")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: Var> Display for Mono<V> {
|
||||||
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||||
|
write!(
|
||||||
|
fmt,
|
||||||
|
"{}",
|
||||||
|
self.term
|
||||||
|
.iter()
|
||||||
|
.map(|(t, p)| match p {
|
||||||
|
1 => format!("{t}"),
|
||||||
|
_ => format!("{t}{}", num_to_superscript(p.to_string())),
|
||||||
|
})
|
||||||
|
.join("")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,2 +1,7 @@
|
|||||||
pub mod flat;
|
pub mod flat;
|
||||||
|
pub mod fmt;
|
||||||
|
pub mod ops;
|
||||||
pub mod var;
|
pub mod var;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|||||||
115
src/poly/ops.rs
Normal file
115
src/poly/ops.rs
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::ops::{Add, BitXor, Mul, Sub};
|
||||||
|
|
||||||
|
use crate::poly::flat::{Mono, Poly};
|
||||||
|
use crate::poly::var::{StaticVar, Var};
|
||||||
|
|
||||||
|
impl BitXor<u32> for StaticVar {
|
||||||
|
type Output = Mono<StaticVar>;
|
||||||
|
|
||||||
|
fn bitxor(self, exp: u32) -> Self::Output {
|
||||||
|
[(self, exp)].into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> BitXor<u32> for &'a StaticVar {
|
||||||
|
type Output = Mono<StaticVar>;
|
||||||
|
|
||||||
|
fn bitxor(self, exp: u32) -> Self::Output {
|
||||||
|
[(self.clone(), exp)].into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: Var> Mul for Mono<V> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn mul(self, other: Mono<V>) -> Self::Output {
|
||||||
|
let mut a_term = self.term.into_iter().peekable();
|
||||||
|
let mut b_term = other.term.into_iter().peekable();
|
||||||
|
|
||||||
|
let mut result: Vec<(V, u32)> = Default::default();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match (a_term.peek(), b_term.peek()) {
|
||||||
|
(Some((a_var, _)), Some((b_var, _))) => {
|
||||||
|
if a_var < b_var {
|
||||||
|
result.push(a_term.next().unwrap());
|
||||||
|
} else if a_var > b_var {
|
||||||
|
result.push(b_term.next().unwrap());
|
||||||
|
} else {
|
||||||
|
let (var, a_exp) = a_term.next().unwrap();
|
||||||
|
let (_, b_exp) = b_term.next().unwrap();
|
||||||
|
result.push((var, a_exp + b_exp));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Some(a), None) => {
|
||||||
|
result.push(a.clone());
|
||||||
|
a_term.next();
|
||||||
|
}
|
||||||
|
(None, Some(b)) => {
|
||||||
|
result.push(b.clone());
|
||||||
|
b_term.next();
|
||||||
|
}
|
||||||
|
(None, None) => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Mono { term: result }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: Var> Mul<Mono<V>> for i32 {
|
||||||
|
type Output = Poly<V>;
|
||||||
|
|
||||||
|
fn mul(self, mono: Mono<V>) -> Self::Output {
|
||||||
|
let mut poly: HashMap<Mono<V>, i32> = Default::default();
|
||||||
|
|
||||||
|
poly.insert(mono, self);
|
||||||
|
|
||||||
|
Poly { mono: poly }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: Var> Mul for Poly<V> {
|
||||||
|
type Output = Poly<V>;
|
||||||
|
|
||||||
|
fn mul(self, other: Poly<V>) -> Self::Output {
|
||||||
|
let mut result = Poly::default();
|
||||||
|
for (m1, c1) in &self.mono {
|
||||||
|
for (m2, c2) in &other.mono {
|
||||||
|
let entry = result.mono.entry(m1.clone() * m2.clone()).or_insert(0i32);
|
||||||
|
*entry += c1 * c2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.mono.retain(|_, &mut c| c != 0);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: Var> Add for Poly<V> {
|
||||||
|
type Output = Poly<V>;
|
||||||
|
|
||||||
|
fn add(mut self, other: Poly<V>) -> Self::Output {
|
||||||
|
for (mono, coeff) in other.mono {
|
||||||
|
let entry = self.mono.entry(mono).or_insert(0);
|
||||||
|
*entry += coeff;
|
||||||
|
}
|
||||||
|
self.mono.retain(|_, &mut coeff| coeff != 0);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: Var> Sub for Poly<V> {
|
||||||
|
type Output = Poly<V>;
|
||||||
|
|
||||||
|
fn sub(mut self, other: Poly<V>) -> Self::Output {
|
||||||
|
for (mono, coeff) in other.mono {
|
||||||
|
let entry = self.mono.entry(mono).or_insert(0);
|
||||||
|
*entry -= coeff;
|
||||||
|
}
|
||||||
|
self.mono.retain(|_, &mut coeff| coeff != 0);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
101
src/poly/tests.rs
Normal file
101
src/poly/tests.rs
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
use super::flat::{Mono, Poly};
|
||||||
|
use super::var::StaticVar;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mono_contains() {
|
||||||
|
let a: Mono<StaticVar> = [("x", 2), ("y", 1)].into();
|
||||||
|
|
||||||
|
// Lower exponent of same variable is contained
|
||||||
|
assert!(a.contains(&Mono::from([("x", 1)])));
|
||||||
|
|
||||||
|
// Higher exponent of same variable is not contained
|
||||||
|
assert!(!a.contains(&Mono::from([("x", 3)])));
|
||||||
|
|
||||||
|
// Identical monomial is contained
|
||||||
|
assert!(a.contains(&Mono::from([("x", 2), ("y", 1)])));
|
||||||
|
|
||||||
|
// Variable absent from self is not contained
|
||||||
|
assert!(!a.contains(&Mono::from([("x", 2), ("z", 1)])));
|
||||||
|
|
||||||
|
// Subset of variables with lower exponents is contained
|
||||||
|
assert!(a.contains(&Mono::from([("x", 1), ("y", 1)])));
|
||||||
|
|
||||||
|
// Single variable with exact exponent is contained
|
||||||
|
assert!(a.contains(&Mono::from([("x", 2)])));
|
||||||
|
|
||||||
|
// Insufficient exponent in self means not contained
|
||||||
|
assert!(!Mono::<StaticVar>::from([("x", 1)]).contains(&Mono::from([("x", 2)])));
|
||||||
|
|
||||||
|
// Missing variable in self means not contained
|
||||||
|
assert!(!Mono::<StaticVar>::from([("x", 1), ("y", 1)]).contains(&Mono::from([("x", 2)])));
|
||||||
|
assert!(!Mono::<StaticVar>::from([("x", 1)]).contains(&Mono::from([("x", 1), ("y", 1)])));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mono_mul() {
|
||||||
|
// Same variable: exponents add
|
||||||
|
let a: Mono<StaticVar> = [("x", 2)].into();
|
||||||
|
let b: Mono<StaticVar> = [("x", 3)].into();
|
||||||
|
assert_eq!(a * b, Mono::from([("x", 5)]));
|
||||||
|
|
||||||
|
// Disjoint variables: both appear in result
|
||||||
|
let a: Mono<StaticVar> = [("x", 2)].into();
|
||||||
|
let b: Mono<StaticVar> = [("y", 3)].into();
|
||||||
|
assert_eq!(a * b, Mono::from([("x", 2), ("y", 3)]));
|
||||||
|
|
||||||
|
// Mixed: shared and disjoint variables
|
||||||
|
let a: Mono<StaticVar> = [("x", 1), ("y", 2)].into();
|
||||||
|
let b: Mono<StaticVar> = [("y", 1), ("z", 3)].into();
|
||||||
|
assert_eq!(a * b, Mono::from([("x", 1), ("y", 3), ("z", 3)]));
|
||||||
|
|
||||||
|
// Commutativity
|
||||||
|
let a: Mono<StaticVar> = [("x", 2), ("z", 1)].into();
|
||||||
|
let b: Mono<StaticVar> = [("y", 3)].into();
|
||||||
|
assert_eq!(a.clone() * b.clone(), b * a);
|
||||||
|
|
||||||
|
// Multiply by constant monomial (empty term vec = 1)
|
||||||
|
let a: Mono<StaticVar> = [("x", 4)].into();
|
||||||
|
let one: Mono<StaticVar> = Mono { term: vec![] };
|
||||||
|
assert_eq!(a.clone() * one, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_poly_add() {
|
||||||
|
// Distinct monomials are collected as separate terms
|
||||||
|
let a: Poly<StaticVar> = [(1, [("x", 2)]), (2, [("y", 1)])].into();
|
||||||
|
let b: Poly<StaticVar> = [(3, [("z", 1)])].into();
|
||||||
|
let expected: Poly<StaticVar> = [(1, [("x", 2)]), (2, [("y", 1)]), (3, [("z", 1)])].into();
|
||||||
|
assert_eq!(a + b, expected);
|
||||||
|
|
||||||
|
// Coefficients of matching monomials are summed
|
||||||
|
let a: Poly<StaticVar> = [(2, [("x", 1)])].into();
|
||||||
|
let b: Poly<StaticVar> = [(3, [("x", 1)])].into();
|
||||||
|
let expected: Poly<StaticVar> = [(5, [("x", 1)])].into();
|
||||||
|
assert_eq!(a + b, expected);
|
||||||
|
|
||||||
|
// Terms that cancel sum to zero are dropped
|
||||||
|
let a: Poly<StaticVar> = [(1, [("x", 1)])].into();
|
||||||
|
let b: Poly<StaticVar> = [(-1, [("x", 1)])].into();
|
||||||
|
let expected: Poly<StaticVar> = Poly::default();
|
||||||
|
assert_eq!(a + b, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_poly_sub() {
|
||||||
|
// Distinct monomials are collected as separate terms with negated rhs coefficients
|
||||||
|
let a: Poly<StaticVar> = [(3, [("x", 2)])].into();
|
||||||
|
let b: Poly<StaticVar> = [(1, [("y", 1)])].into();
|
||||||
|
let expected: Poly<StaticVar> = [(3, [("x", 2)]), (-1, [("y", 1)])].into();
|
||||||
|
assert_eq!(a - b, expected);
|
||||||
|
|
||||||
|
// Coefficients of matching monomials are subtracted
|
||||||
|
let a: Poly<StaticVar> = [(5, [("x", 1)])].into();
|
||||||
|
let b: Poly<StaticVar> = [(3, [("x", 1)])].into();
|
||||||
|
let expected: Poly<StaticVar> = [(2, [("x", 1)])].into();
|
||||||
|
assert_eq!(a - b, expected);
|
||||||
|
|
||||||
|
// Subtracting equal polynomials yields zero
|
||||||
|
let a: Poly<StaticVar> = [(4, [("x", 2)]), (1, [("y", 1)])].into();
|
||||||
|
let b: Poly<StaticVar> = [(4, [("x", 2)]), (1, [("y", 1)])].into();
|
||||||
|
assert_eq!(a - b, Poly::default());
|
||||||
|
}
|
||||||
@@ -1,27 +1,14 @@
|
|||||||
use itertools::Itertools;
|
use std::fmt::{Debug, Display};
|
||||||
use std::fmt::{self, Debug, Display, Formatter};
|
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
|
|
||||||
use crate::fmt::num_to_subscript;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||||
pub struct StaticVar {
|
pub struct StaticVar {
|
||||||
name: &'static str,
|
pub(crate) name: &'static str,
|
||||||
indices: Vec<u32>,
|
pub(crate) indices: Vec<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Var: PartialEq + Eq + PartialOrd + Ord + Clone + Hash + Debug + Display {}
|
pub trait Var: PartialEq + Eq + PartialOrd + Ord + Clone + Hash + Debug + Display {}
|
||||||
|
|
||||||
impl Display for StaticVar {
|
|
||||||
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmt::Error> {
|
|
||||||
let num_indices = self.indices.len();
|
|
||||||
match num_indices {
|
|
||||||
0 => write!(fmt, "{}", self.name),
|
|
||||||
_ => write!(fmt, "{}{}", self.name, self.indices.iter().map(|x| num_to_subscript(x.to_string())).join(",")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Var for StaticVar {}
|
impl Var for StaticVar {}
|
||||||
|
|
||||||
impl From<&'static str> for StaticVar {
|
impl From<&'static str> for StaticVar {
|
||||||
|
|||||||
Reference in New Issue
Block a user