add iterator in children nodes
This commit is contained in:
@@ -15,6 +15,17 @@ pub enum Node<V: Var> {
|
||||
Prod(NodeId, NodeId),
|
||||
}
|
||||
|
||||
impl<V: Var> Node<V> {
|
||||
pub fn children(&self) -> impl Iterator<Item = NodeId> {
|
||||
match self {
|
||||
Node::Leaf(_) => [None, None],
|
||||
Node::Sum(l, r) | Node::Prod(l, r) => [Some(*l), Some(*r)],
|
||||
}
|
||||
.into_iter()
|
||||
.flatten()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Circuit<V: Var> {
|
||||
nodes: SlotMap<NodeId, Node<V>>,
|
||||
intern: HashMap<Node<V>, NodeId>,
|
||||
@@ -41,6 +52,18 @@ impl<V: Var> Circuit<V> {
|
||||
self.nodes.get(id)
|
||||
}
|
||||
|
||||
pub fn add(&mut self, left: NodeId, right: NodeId) -> NodeId {
|
||||
self.node(Node::Sum(left, right))
|
||||
}
|
||||
|
||||
pub fn mul(&mut self, left: NodeId, right: NodeId) -> NodeId {
|
||||
self.node(Node::Prod(left, right))
|
||||
}
|
||||
|
||||
pub fn children(&self, id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
|
||||
self.nodes.get(id).into_iter().flat_map(Node::children)
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, id: NodeId) {
|
||||
if let Some(node) = self.nodes.remove(id) {
|
||||
self.intern.remove(&node);
|
||||
@@ -55,6 +78,7 @@ pub struct CircuitNode<V: Var> {
|
||||
|
||||
pub trait CircuitExt<V: Var> {
|
||||
fn leaf(&self, v: V) -> CircuitNode<V>;
|
||||
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>>;
|
||||
}
|
||||
|
||||
impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> {
|
||||
@@ -62,6 +86,11 @@ impl<V: Var> CircuitExt<V> for Rc<RefCell<Circuit<V>>> {
|
||||
let id = self.borrow_mut().node(Node::Leaf(v));
|
||||
CircuitNode { id, circuit: self.clone() }
|
||||
}
|
||||
|
||||
fn get_node(&self, id: NodeId) -> Option<CircuitNode<V>> {
|
||||
self.borrow().get(id)?;
|
||||
Some(CircuitNode { id, circuit: self.clone() })
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: Var> Add for CircuitNode<V> {
|
||||
|
||||
Reference in New Issue
Block a user