diff --git a/src/circuit/quotient.rs b/src/circuit/quotient.rs index c676cb2..1328ee5 100644 --- a/src/circuit/quotient.rs +++ b/src/circuit/quotient.rs @@ -1,4 +1,4 @@ -use super::dag::{Circuit, NodeId}; +use super::dag::{Circuit, Node, NodeId}; use crate::poly::{flat::Poly, var::Var}; use itertools::Itertools; @@ -35,10 +35,10 @@ impl Quotient { self.circuit.node(n) } - pub fn add(left: NodeId, right: NodeId) { + pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId { self.circuit.add(left, right) } - pub fn mul(left: NodeId, right: NodeId) { + pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId { self.circuit.mul(left, right) } } diff --git a/src/poly/flat.rs b/src/poly/flat.rs index ceedea7..ceb036e 100644 --- a/src/poly/flat.rs +++ b/src/poly/flat.rs @@ -1,7 +1,46 @@ +use std::cmp::Ordering; use std::collections::HashMap; use super::var::Var; +pub fn lex_cmp(a: &Mono, b: &Mono) -> Ordering { + let mut a_it = a.term.iter().peekable(); + let mut b_it = b.term.iter().peekable(); + + loop { + match (a_it.peek(), b_it.peek()) { + (None, None) => return Ordering::Equal, + (Some((_, a_exp)), None) => { + return if *a_exp > 0 { Ordering::Greater } else { Ordering::Equal }; + } + (None, Some((_, b_exp))) => { + return if *b_exp > 0 { Ordering::Less } else { Ordering::Equal }; + } + (Some((a_var, a_exp)), Some((b_var, b_exp))) => { + if a_var < b_var { + if *a_exp > 0 { + return Ordering::Greater; + } + a_it.next(); + } else if a_var > b_var { + if *b_exp > 0 { + return Ordering::Less; + } + b_it.next(); + } else { + match a_exp.cmp(b_exp) { + Ordering::Equal => { + a_it.next(); + b_it.next(); + } + ord => return ord, + } + } + } + } + } +} + #[derive(Clone, Debug, PartialEq)] pub struct Poly { pub(crate) mono: HashMap, i32>, @@ -15,6 +54,20 @@ impl Default for Poly { } } +impl Poly { + pub fn is_zero(&self) -> bool { + self.mono.is_empty() + } + + /// Returns (leading monomial, leading coefficient) under lex order, or None if zero. + pub fn leading_term_lex(&self) -> Option<(Mono, i32)> { + self.mono + .iter() + .max_by(|(m1, _), (m2, _)| lex_cmp(m1, m2)) + .map(|(m, &c)| (m.clone(), c)) + } +} + impl From for Poly where Poly: FromIterator<::Item>, @@ -62,6 +115,36 @@ impl Mono { true } + + /// Divides self by other. Assumes `self.contains(other)`. + pub fn div(self, other: &Mono) -> Mono { + let mut self_it = self.term.into_iter().peekable(); + let mut other_it = other.term.iter().peekable(); + let mut result: Vec<(V, u32)> = vec![]; + + loop { + match (self_it.peek(), other_it.peek()) { + (None, None) => break, + (Some(_), None) => result.push(self_it.next().unwrap()), + (None, Some(_)) => unreachable!("divisor not contained in dividend"), + (Some((s_var, _)), Some((o_var, _))) => { + if s_var < o_var { + result.push(self_it.next().unwrap()); + } else if s_var > o_var { + unreachable!("divisor not contained in dividend"); + } else { + let (var, s_exp) = self_it.next().unwrap(); + let (_, o_exp) = other_it.next().unwrap(); + if s_exp > *o_exp { + result.push((var, s_exp - o_exp)); + } + } + } + } + } + + Mono { term: result } + } } impl From for Mono diff --git a/src/poly/ops.rs b/src/poly/ops.rs index 2e1ec74..3273c4b 100644 --- a/src/poly/ops.rs +++ b/src/poly/ops.rs @@ -72,6 +72,20 @@ impl Mul> for i32 { } } +impl Mul> for i32 { + type Output = Poly; + + fn mul(self, mut poly: Poly) -> Self::Output { + if self == 0 { + return Poly::default(); + } + for coeff in poly.mono.values_mut() { + *coeff *= self; + } + poly + } +} + impl Mul for Poly { type Output = Poly; @@ -113,3 +127,49 @@ impl Sub for Poly { self } } + +impl Poly { + /// Pseudo-division with remainder. + /// + /// Returns `(d, q, r)` satisfying `lc(divisor)^d * self = q * divisor + r`, + /// where no term of `r` has its monomial divisible by `LM(divisor)` under lex order. + /// + /// Panics if `divisor` is zero. + pub fn div_rem(self, divisor: &Poly) -> (u32, Poly, Poly) { + let (lt_g_mono, lt_g_coeff) = divisor + .leading_term_lex() + .expect("divisor must be nonzero"); + + let mut p = self; + let mut q = Poly::default(); + let mut r = Poly::default(); + let mut d = 0u32; + + while !p.is_zero() { + let (lt_p_mono, lt_p_coeff) = p.leading_term_lex().unwrap(); + + if lt_p_mono.contains(<_g_mono) { + let t_mono = lt_p_mono.div(<_g_mono); + if lt_p_coeff % lt_g_coeff == 0 { + // Exact division: no need to multiply through by lc(g) + let t_poly: Poly = (lt_p_coeff / lt_g_coeff) * t_mono; + p = p - t_poly.clone() * divisor.clone(); + q = q + t_poly; + } else { + // Pseudo-division: multiply through by lc(g) to stay in ℤ + let t_poly: Poly = lt_p_coeff * t_mono; + p = lt_g_coeff * p - t_poly.clone() * divisor.clone(); + q = lt_g_coeff * q + t_poly; + r = lt_g_coeff * r; + d += 1; + } + } else { + let lt_poly: Poly = lt_p_coeff * lt_p_mono; + p = p - lt_poly.clone(); + r = r + lt_poly; + } + } + + (d, q, r) + } +} diff --git a/src/poly/tests.rs b/src/poly/tests.rs index 870ebde..a31b5a8 100644 --- a/src/poly/tests.rs +++ b/src/poly/tests.rs @@ -1,4 +1,4 @@ -use super::flat::{Mono, Poly}; +use super::flat::{lex_cmp, Mono, Poly}; use super::var::StaticVar; #[test] @@ -99,3 +99,130 @@ fn test_poly_sub() { let b: Poly = [(4, [("x", 2)]), (1, [("y", 1)])].into(); assert_eq!(a - b, Poly::default()); } + +#[test] +fn test_lex_cmp() { + use std::cmp::Ordering; + + let x2: Mono = [("x", 2)].into(); + let xy: Mono = [("x", 1), ("y", 1)].into(); + let y2: Mono = [("y", 2)].into(); + let x: Mono = [("x", 1)].into(); + let one: Mono = Mono { term: vec![] }; + + // x² > xy (x exponent 2 vs 1) + assert_eq!(lex_cmp(&x2, &xy), Ordering::Greater); + // x > y² (x has higher priority) + assert_eq!(lex_cmp(&x, &y2), Ordering::Greater); + // xy > y² (x present in xy but not y²) + assert_eq!(lex_cmp(&xy, &y2), Ordering::Greater); + // 1 < x + assert_eq!(lex_cmp(&one, &x), Ordering::Less); + // reflexive + assert_eq!(lex_cmp(&x2, &x2), Ordering::Equal); +} + +#[test] +fn test_mono_div() { + // x² / x = x + let a: Mono = [("x", 2)].into(); + let b: Mono = [("x", 1)].into(); + assert_eq!(a.div(&b), Mono::from([("x", 1)])); + + // x²y / xy = x + let a: Mono = [("x", 2), ("y", 1)].into(); + let b: Mono = [("x", 1), ("y", 1)].into(); + assert_eq!(a.div(&b), Mono::from([("x", 1)])); + + // x²y / y = x² + let a: Mono = [("x", 2), ("y", 1)].into(); + let b: Mono = [("y", 1)].into(); + assert_eq!(a.div(&b), Mono::from([("x", 2)])); + + // x / x = 1 + let a: Mono = [("x", 1)].into(); + let b: Mono = [("x", 1)].into(); + assert_eq!(a.div(&b), Mono { term: vec![] }); +} + +fn make_const_poly(c: i32) -> Poly { + Poly { mono: [(Mono { term: vec![] }, c)].into_iter().collect() } +} + +fn verify_div_rem(f: Poly, g: &Poly, d: u32, q: Poly, r: Poly) { + // lc(g)^d * f == q * g + r + let (_, lc_g) = g.leading_term_lex().unwrap(); + let lhs = lc_g.pow(d) * f; + let rhs = q * g.clone() + r; + assert_eq!(lhs, rhs); +} + +#[test] +fn test_div_rem() { + // x³ / x² = x, r=0, d=0 + let f: Poly = [(1, [("x", 3)])].into(); + let g: Poly = [(1, [("x", 2)])].into(); + let expected_q: Poly = [(1, [("x", 1)])].into(); + let (d, q, r) = f.clone().div_rem(&g); + assert_eq!(q, expected_q); + assert!(r.is_zero()); + verify_div_rem(f, &g, d, q, r); + + // (x³ + x²y) / x² = x + y, r=0 + // f = x²(x + y), g = x² + let f: Poly = [ + (1i32, Mono::from([("x", 3u32)])), + (1i32, Mono::from([("x", 2u32), ("y", 1u32)])), + ].into_iter().collect(); + let g: Poly = [(1, [("x", 2)])].into(); + let expected_q: Poly = [(1, [("x", 1)]), (1, [("y", 1)])].into(); + let (d, q, r) = f.clone().div_rem(&g); + assert_eq!(q, expected_q); + assert!(r.is_zero()); + verify_div_rem(f, &g, d, q, r); + + // (x³ + y) / x² = x, r=y + // LT(x³) divisible by x², LT(y) is not + let f: Poly = [(1, [("x", 3)]), (1, [("y", 1)])].into(); + let g: Poly = [(1, [("x", 2)])].into(); + let expected_q: Poly = [(1, [("x", 1)])].into(); + let expected_r: Poly = [(1, [("y", 1)])].into(); + let (d, q, r) = f.clone().div_rem(&g); + assert_eq!(q, expected_q); + assert_eq!(r, expected_r); + verify_div_rem(f, &g, d, q, r); + + // 3x² / (2x): 2 ∤ 3, needs pseudo-division + // 2¹ · 3x² = 3x · 2x => d=1, q=3x, r=0 + let f: Poly = [(3, [("x", 2)])].into(); + let g: Poly = [(2, [("x", 1)])].into(); + let expected_q: Poly = [(3, [("x", 1)])].into(); + let (d, q, r) = f.clone().div_rem(&g); + assert_eq!(d, 1); + assert_eq!(q, expected_q); + assert!(r.is_zero()); + verify_div_rem(f, &g, d, q, r); + + // 6xy / 2 = 3xy, r=0, d=0 + let f: Poly = [(6, [("x", 1), ("y", 1)])].into(); + let g = make_const_poly(2); + let expected_q: Poly = [(3, [("x", 1), ("y", 1)])].into(); + let (d, q, r) = f.clone().div_rem(&g); + assert_eq!(d, 0); + assert_eq!(q, expected_q); + assert!(r.is_zero()); + verify_div_rem(f, &g, d, q, r); + + // (x² + xy) / (x + y) = x, r=0 + // f = x(x + y), g = x + y + let f: Poly = [ + (1i32, Mono::from([("x", 2u32)])), + (1i32, Mono::from([("x", 1u32), ("y", 1u32)])), + ].into_iter().collect(); + let g: Poly = [(1, [("x", 1)]), (1, [("y", 1)])].into(); + let expected_q: Poly = [(1, [("x", 1)])].into(); + let (d, q, r) = f.clone().div_rem(&g); + assert_eq!(q, expected_q); + assert!(r.is_zero()); + verify_div_rem(f, &g, d, q, r); +}