From 8c646fd9209b8a6f35d0d6181e0de2b3e3cb788b Mon Sep 17 00:00:00 2001 From: asteri Date: Wed, 22 Apr 2026 16:12:56 +0200 Subject: [PATCH] refactor poly --- src/poly/flat.rs | 151 +---------------------------------------------- src/poly/fmt.rs | 59 ++++++++++++++++++ src/poly/mod.rs | 3 + src/poly/ops.rs | 115 ++++++++++++++++++++++++++++++++++++ src/poly/var.rs | 27 ++------- 5 files changed, 184 insertions(+), 171 deletions(-) create mode 100644 src/poly/fmt.rs create mode 100644 src/poly/ops.rs diff --git a/src/poly/flat.rs b/src/poly/flat.rs index d3fce96..818efb9 100644 --- a/src/poly/flat.rs +++ b/src/poly/flat.rs @@ -1,15 +1,10 @@ -use itertools::Itertools; -use std::fmt::{self, Display}; - -use std::ops::{Add, BitXor, Mul, Sub}; - -use super::var::{StaticVar, Var}; -use crate::fmt::num_to_superscript; use std::collections::HashMap; +use super::var::Var; + #[derive(Clone, Debug, PartialEq)] pub struct Poly { - mono: HashMap, i32>, + pub(crate) mono: HashMap, i32>, } impl Default for Poly { @@ -20,21 +15,6 @@ impl Default for Poly { } } -impl Display for Poly { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self.mono.is_empty() { - true => write!(fmt, "∅"), - false => write!( - fmt, - "{}", - self.mono - .iter() - .map(|(m, c)| format!("{}{}", c, m)) - .join(" + ") - ), - } - } -} impl From for Poly where @@ -113,128 +93,3 @@ impl> FromIterator<(U, u32)> for Mono { } } -impl Display for Mono { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!( - fmt, - "{}", - self.term - .iter() - .map(|(t, p)| match p { - 1 => format!("{t}"), - _ => format!("{t}{}", num_to_superscript(p.to_string())), - }) - .join("") - ) - } -} - -impl BitXor for StaticVar { - type Output = Mono; - - fn bitxor(self, exp: u32) -> Self::Output { - [(self, exp)].into() - } -} - -impl<'a> BitXor for &'a StaticVar { - type Output = Mono; - - fn bitxor(self, exp: u32) -> Self::Output { - [(self.clone(), exp)].into() - } -} - -impl Mul for Mono { - type Output = Self; - - fn mul(self, other: Mono) -> Self::Output { - let mut a_term = self.term.into_iter().peekable(); - let mut b_term = other.term.into_iter().peekable(); - - let mut result: Vec<(V, u32)> = Default::default(); - - loop { - match (a_term.peek(), b_term.peek()) { - (Some((a_var, _)), Some((b_var, _))) => { - if a_var < b_var { - result.push(a_term.next().unwrap()); - } else if a_var > b_var { - result.push(b_term.next().unwrap()); - } else { - let (var, a_exp) = a_term.next().unwrap(); - let (_, b_exp) = b_term.next().unwrap(); - result.push((var, a_exp + b_exp)); - } - } - (Some(a), None) => { - result.push(a.clone()); - a_term.next(); - } - (None, Some(b)) => { - result.push(b.clone()); - b_term.next(); - } - (None, None) => { - break; - } - } - } - - Mono { term: result } - } -} - -impl Mul> for i32 { - type Output = Poly; - - fn mul(self, mono: Mono) -> Self::Output { - let mut poly: HashMap, i32> = Default::default(); - - poly.insert(mono, self); - - Poly { mono: poly } - } -} - -impl Mul for Poly { - type Output = Poly; - - fn mul(self, other: Poly) -> Self::Output { - let mut result = Poly::default(); - for (m1, c1) in &self.mono { - for (m2, c2) in &other.mono { - let entry = result.mono.entry(m1.clone() * m2.clone()).or_insert(0); - *entry += c1 * c2; - } - } - result.mono.retain(|_, &mut c| c != 0); - result - } -} - -impl Add for Poly { - type Output = Poly; - - fn add(mut self, other: Poly) -> Self::Output { - for (mono, coeff) in other.mono { - let entry = self.mono.entry(mono).or_insert(0); - *entry += coeff; - } - self.mono.retain(|_, &mut coeff| coeff != 0); - self - } -} - -impl Sub for Poly { - type Output = Poly; - - fn sub(mut self, other: Poly) -> Self::Output { - for (mono, coeff) in other.mono { - let entry = self.mono.entry(mono).or_insert(0); - *entry -= coeff; - } - self.mono.retain(|_, &mut coeff| coeff != 0); - self - } -} diff --git a/src/poly/fmt.rs b/src/poly/fmt.rs new file mode 100644 index 0000000..e16742f --- /dev/null +++ b/src/poly/fmt.rs @@ -0,0 +1,59 @@ +use std::fmt::{self, Display, Formatter}; + +use itertools::Itertools; + +use crate::fmt::{num_to_subscript, num_to_superscript}; +use crate::poly::flat::{Mono, Poly}; +use crate::poly::var::{StaticVar, Var}; + +impl Display for StaticVar { + fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmt::Error> { + let num_indices = self.indices.len(); + match num_indices { + 0 => write!(fmt, "{}", self.name), + _ => write!( + fmt, + "{}{}", + self.name, + self.indices + .iter() + .map(|x: &u32| num_to_subscript(x.to_string())) + .join(",") + ), + } + } +} + +impl Display for Poly { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self.mono.is_empty() { + true => write!(fmt, "∅"), + false => write!( + fmt, + "{}", + self.mono + .iter() + .map(|(m, c)| format!("{}{}", c, m)) + .join(" + ") + ), + } + } +} + + +impl Display for Mono { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!( + fmt, + "{}", + self.term + .iter() + .map(|(t, p)| match p { + 1 => format!("{t}"), + _ => format!("{t}{}", num_to_superscript(p.to_string())), + }) + .join("") + ) + } +} + diff --git a/src/poly/mod.rs b/src/poly/mod.rs index 1105870..8b38f99 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -1,5 +1,8 @@ pub mod flat; pub mod var; +pub mod ops; +pub mod fmt; #[cfg(test)] mod tests; + diff --git a/src/poly/ops.rs b/src/poly/ops.rs new file mode 100644 index 0000000..2e1ec74 --- /dev/null +++ b/src/poly/ops.rs @@ -0,0 +1,115 @@ +use std::collections::HashMap; +use std::ops::{Add, BitXor, Mul, Sub}; + +use crate::poly::flat::{Mono, Poly}; +use crate::poly::var::{StaticVar, Var}; + +impl BitXor for StaticVar { + type Output = Mono; + + fn bitxor(self, exp: u32) -> Self::Output { + [(self, exp)].into() + } +} + +impl<'a> BitXor for &'a StaticVar { + type Output = Mono; + + fn bitxor(self, exp: u32) -> Self::Output { + [(self.clone(), exp)].into() + } +} + +impl Mul for Mono { + type Output = Self; + + fn mul(self, other: Mono) -> Self::Output { + let mut a_term = self.term.into_iter().peekable(); + let mut b_term = other.term.into_iter().peekable(); + + let mut result: Vec<(V, u32)> = Default::default(); + + loop { + match (a_term.peek(), b_term.peek()) { + (Some((a_var, _)), Some((b_var, _))) => { + if a_var < b_var { + result.push(a_term.next().unwrap()); + } else if a_var > b_var { + result.push(b_term.next().unwrap()); + } else { + let (var, a_exp) = a_term.next().unwrap(); + let (_, b_exp) = b_term.next().unwrap(); + result.push((var, a_exp + b_exp)); + } + } + (Some(a), None) => { + result.push(a.clone()); + a_term.next(); + } + (None, Some(b)) => { + result.push(b.clone()); + b_term.next(); + } + (None, None) => { + break; + } + } + } + + Mono { term: result } + } +} + +impl Mul> for i32 { + type Output = Poly; + + fn mul(self, mono: Mono) -> Self::Output { + let mut poly: HashMap, i32> = Default::default(); + + poly.insert(mono, self); + + Poly { mono: poly } + } +} + +impl Mul for Poly { + type Output = Poly; + + fn mul(self, other: Poly) -> Self::Output { + let mut result = Poly::default(); + for (m1, c1) in &self.mono { + for (m2, c2) in &other.mono { + let entry = result.mono.entry(m1.clone() * m2.clone()).or_insert(0i32); + *entry += c1 * c2; + } + } + result.mono.retain(|_, &mut c| c != 0); + result + } +} + +impl Add for Poly { + type Output = Poly; + + fn add(mut self, other: Poly) -> Self::Output { + for (mono, coeff) in other.mono { + let entry = self.mono.entry(mono).or_insert(0); + *entry += coeff; + } + self.mono.retain(|_, &mut coeff| coeff != 0); + self + } +} + +impl Sub for Poly { + type Output = Poly; + + fn sub(mut self, other: Poly) -> Self::Output { + for (mono, coeff) in other.mono { + let entry = self.mono.entry(mono).or_insert(0); + *entry -= coeff; + } + self.mono.retain(|_, &mut coeff| coeff != 0); + self + } +} diff --git a/src/poly/var.rs b/src/poly/var.rs index f345fb4..b2daf72 100644 --- a/src/poly/var.rs +++ b/src/poly/var.rs @@ -1,34 +1,14 @@ -use itertools::Itertools; -use std::fmt::{self, Debug, Display, Formatter}; +use std::fmt::{Debug, Display}; use std::hash::Hash; -use crate::fmt::num_to_subscript; - #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct StaticVar { - name: &'static str, - indices: Vec, + pub(crate) name: &'static str, + pub(crate) indices: Vec, } pub trait Var: PartialEq + Eq + PartialOrd + Ord + Clone + Hash + Debug + Display {} -impl Display for StaticVar { - fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmt::Error> { - let num_indices = self.indices.len(); - match num_indices { - 0 => write!(fmt, "{}", self.name), - _ => write!( - fmt, - "{}{}", - self.name, - self.indices - .iter() - .map(|x| num_to_subscript(x.to_string())) - .join(",") - ), - } - } -} impl Var for StaticVar {} @@ -71,3 +51,4 @@ macro_rules! var { ::circuit_cas::poly::var::StaticVar::from(($name, $idx1, $idx2)) }; } +