add iterator in children nodes

This commit is contained in:
2026-04-22 11:59:40 +02:00
parent 5ef4893f03
commit 99fee298c7

View File

@@ -15,6 +15,17 @@ pub enum Node<V: Var> {
Prod(NodeId, NodeId), Prod(NodeId, NodeId),
} }
impl<V: Var> Node<V> {
pub fn children(&self) -> impl Iterator<Item = NodeId> {
match self {
Node::Leaf(_) => [None, None],
Node::Sum(l, r) | Node::Prod(l, r) => [Some(*l), Some(*r)],
}
.into_iter()
.flatten()
}
}
pub struct Circuit<V: Var> { pub struct Circuit<V: Var> {
nodes: SlotMap<NodeId, Node<V>>, nodes: SlotMap<NodeId, Node<V>>,
intern: HashMap<Node<V>, NodeId>, intern: HashMap<Node<V>, NodeId>,
@@ -41,6 +52,18 @@ impl<V: Var> Circuit<V> {
self.nodes.get(id) self.nodes.get(id)
} }
pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId {
self.node(Node::Sum(left, right))
}
pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId {
self.node(Node::Prod(left, right))
}
pub fn children(&self, id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
self.nodes.get(id).into_iter().flat_map(Node::children)
}
pub fn remove(&mut self, id: NodeId) { pub fn remove(&mut self, id: NodeId) {
if let Some(node) = self.nodes.remove(id) { if let Some(node) = self.nodes.remove(id) {
self.intern.remove(&node); self.intern.remove(&node);
@@ -55,6 +78,7 @@ pub struct CircuitNode<V: Var> {
pub trait CircuitExt<V: Var> { pub trait CircuitExt<V: Var> {
fn leaf(&self, v: V) -> CircuitNode<V>; fn leaf(&self, v: V) -> CircuitNode<V>;
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>>;
} }
impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> { impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> {
@@ -62,6 +86,11 @@ impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> {
let id = self.borrow_mut().node(Node::Leaf(v)); let id = self.borrow_mut().node(Node::Leaf(v));
CircuitNode { id, circuit: self.clone() } CircuitNode { id, circuit: self.clone() }
} }
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>> {
self.borrow().get(id)?;
Some(CircuitNode { id, circuit: self.clone() })
}
} }
impl<V: Var> Add for CircuitNode<V> { impl<V: Var> Add for CircuitNode<V> {