8000 fix issues in logic circuits by n28div · Pull Request #333 · april-tools/cirkit · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fix issues in logic circuits #333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cirkit/templates/logic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from .graph import ConjunctionNode as ConjunctionNode
from .graph import DisjunctionNode as DisjunctionNode
from .graph import LiteralNode as LiteralNode
from .graph import LogicCircuitNode as LogicCircuitNode
from .graph import LogicGraph as LogicGraph
from .graph import LogicalCircuit as LogicalCircuit
from .graph import LogicalCircuitNode as LogicalCircuitNode
from .graph import NegatedLiteralNode as NegatedLiteralNode
from .graph import TopNode as TopNode
from .sdd import SDD as SDD
167 changes: 93 additions & 74 deletions cirkit/templates/logic/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@
from cirkit.symbolic.parameters import Parameter, ParameterFactory, TensorParameter
from cirkit.templates.logic.utils import default_literal_input_factory
from cirkit.templates.utils import InputLayerFactory
from cirkit.utils.algorithms import RootedDiAcyclicGraph
from cirkit.utils.algorithms import RootedDiAcyclicGraph, graph_nodes_outgoings
from cirkit.utils.scope import Scope


class LogicCircuitNode(ABC):
class LogicalCircuitNode(ABC):
"""The abstract base class for nodes in logic circuits."""


class TopNode(LogicCircuitNode):
class TopNode(LogicalCircuitNode):
"""The top node representing True in the logic circuit."""


class BottomNode(LogicCircuitNode):
class BottomNode(LogicalCircuitNode):
"""The bottom node representing False in the logic circuit."""


class LogicInputNode(LogicCircuitNode):
class LogicalInputNode(LogicalCircuitNode):
"""The abstract base class for input nodes in logic circuits."""

def __init__(self, literal: int) -> None:
Expand Down Expand Up @@ -56,85 +56,106 @@ def __repr__(self) -> str:
return f"{type(self).__name__}@0x{id(self):x}({self.literal})"


class LiteralNode(LogicInputNode):
class LiteralNode(LogicalInputNode):
"""A literal in the logical circuit."""


class NegatedLiteralNode(LogicInputNode):
class NegatedLiteralNode(LogicalInputNode):
"""A negated literal in the logical circuit."""


class ConjunctionNode(LogicCircuitNode):
class ConjunctionNode(LogicalCircuitNode):
"""A conjunction in the logical circuit."""


class DisjunctionNode(LogicCircuitNode):
class DisjunctionNode(LogicalCircuitNode):
"""A conjunction in the logical circuit."""


class LogicGraph(RootedDiAcyclicGraph[LogicCircuitNode]):
class LogicalCircuit(RootedDiAcyclicGraph[LogicalCircuitNode]):
def __init__(
self,
nodes: Sequence[LogicCircuitNode],
in_nodes: dict[LogicCircuitNode, Sequence[LogicCircuitNode]],
outputs: Sequence[LogicCircuitNode],
nodes: Sequence[LogicalCircuitNode],
in_nodes: dict[LogicalCircuitNode, Sequence[LogicalCircuitNode]],
outputs: Sequence[LogicalCircuitNode],
) -> None:
"""A Logical circuit represented as a rooted acyclic graph.

Args:
nodes (Sequence[LogicalCircuitNode]): The list of nodes in the logic graph.
in_nodes (dict[LogicalCircuitNode, Sequence[LogicalCircuitNode]]):
A dictionary containing the list of inputs to each layer.
outputs (Sequence[LogicalCircuitNode]):
The output layers of the circuit.
"""
if len(outputs) != 1:
assert ValueError("A logic graphs can only have one output!")
super().__init__(nodes, in_nodes, outputs)

def simplify(self) -> "LogicGraph":
"""
Simplify a graph by removed trivial nodes and propagating the result.
def prune(self):
"""Prune the current graph by applying unit propagation.

Returns:
LogicGraph: The simplified graph, where all bottom and top nodes have
been removed through simplification.
Prune a graph in place by applying unit propagation to conjunction and disjunctions.
See https://en.wikipedia.org/wiki/Unit_propagation.
Nodes that are not used as input to other nodes and are not among the output nodes
are removed too.
"""
in_nodes = dict(self.nodes_inputs)
root = next(self.outputs)

absorbing_element = lambda n: BottomNode if isinstance(n, ConjunctionNode) else TopNode
null_element = lambda n: TopNode if isinstance(n, ConjunctionNode) else BottomNode

absorbed_nodes = [
n
for n, children in in_nodes.items()
if any([isinstance(child, absorbing_element(n)) for child in children])
]
def absorb_node(node):
if isinstance(node, (ConjunctionNode, DisjunctionNode)):
children = [absorb_node(c) for c in self.node_inputs(node)]

# if the node contains the absorbing element, then it is replaced
# altogether
if any(isinstance(c, absorbing_element(node)) for c in children):
return absorbing_element(node)()

return node

# update the graph
# apply node absorbion and remove null elements from conjunctions and disjunctions
in_nodes = {}
for n, children in self._in_nodes.items():
absorbed = absorb_node(n)

if not isinstance(absorbed, (TopNode, BottomNode)):
in_nodes[n] = [
c
for c in [absorb_node(c) for c in children]
if not isinstance(c, null_element(n))
]

# remove nodes that are not used as input to any other node if they are not the output node
out_nodes = graph_nodes_outgoings(self.nodes, lambda n: in_nodes.get(n, []))
in_nodes = {
n: [
child
for child in children
if not isinstance(child, null_element(n)) and child not in absorbed_nodes
]
n: children
for n, children in in_nodes.items()
if n not in absorbed_nodes
if len(out_nodes.get(n, [])) > 0 or n in self._outputs
}

nodes = list(set(itertools.chain(*in_nodes.values())).union(in_nodes.keys()))

return LogicGraph(nodes=nodes, in_nodes=in_nodes, outputs=[root])
# re initialize the graph
self.__init__(nodes, in_nodes, list(self.outputs))

@property
def inputs(self) -> Iterator[LogicCircuitNode]:
return (cast(LogicCircuitNode, node) for node in super().inputs)
def inputs(self) -> Iterator[LogicalCircuitNode]:
return (cast(LogicalCircuitNode, node) for node in super().inputs)

@property
def outputs(self) -> Iterator[LogicCircuitNode]:
return (cast(LogicCircuitNode, node) for node in super().outputs)
def outputs(self) -> Iterator[LogicalCircuitNode]:
return (cast(LogicalCircuitNode, node) for node in super().outputs)

@cached_property
def num_variables(self) -> int:
return len({i.literal for i in self.inputs if isinstance(i, LogicInputNode)})
return len({i.literal for i in self.inputs if isinstance(i, LogicalInputNode)})

def node_scope(self, node: LogicCircuitNode) -> Scope:
def node_scope(self, node: LogicalCircuitNode) -> Scope:
"""Compute the scope of a node.

Args:
node (LogicCircuitNode): The node for which the scope is computed.
node (LogicalCircuitNode): The node for which the scope is computed.

Returns:
Scope: The scope of the node.
Expand All @@ -153,31 +174,30 @@ def node_scope(self, node: LogicCircuitNode) -> Scope:

return scope

def smooth(self) -> "LogicGraph":
"""Construct a new smooth graph from this current graph.
def smooth(self):
"""Convert the current graph to a smooth graph in place.
see https://yoojungchoi.github.io/files/ProbCirc20.pdf and
https://proceedings.neurips.cc/paper/2019/file/940392f5f32a7ade1cc201767cf83e31-Paper.pdf
for more information.

Returns:
LogicGraph: A new logic graph that is smooth.
LogicalCircuit: A new logic graph that is smooth.
"""
literal_map: dict[tuple[int, bool], LogicCircuitNode] = {
literal_map: dict[tuple[int, bool], LogicalCircuitNode] = {
(node.literal, isinstance(node, LiteralNode)): node
for node in self.nodes
if isinstance(node, (LiteralNode, NegatedLiteralNode))
}
# smoothing map keeps track of the disjunctions created for smoothing purposes
smoothing_map: dict[int, DisjunctionNode] = {}
disjunctions = filter(lambda x: isinstance(x, DisjunctionNode), self.nodes)

in_nodes: dict[LogicCircuitNode, list[LogicCircuitNode]] = dict(self.nodes_inputs)
disjunctions = [n for n in self.nodes if isinstance(n, DisjunctionNode)]

in_nodes = self._in_nodes
for d in disjunctions:
d_scope = self.node_scope(d)

for input_to_d in in_nodes[d]:
to_add_for_smoothing: list[LogicCircuitNode] = []
for input_to_d in self.node_inputs(d):
to_add_for_smoothing: list[LogicalCircuitNode] = []
missing_literals = d_scope.difference(self.node_scope(input_to_d))

if len(missing_literals) > 0:
Expand Down Expand Up @@ -209,7 +229,7 @@ def smooth(self) -> "LogicGraph":
in_nodes[d].insert(0, ad_hoc)

nodes = list(set(itertools.chain(*in_nodes.values())).union(in_nodes.keys()))
return LogicGraph(nodes, in_nodes, self._outputs)
self.__init__(nodes, in_nodes, self._outputs)

def build_circuit(
self,
Expand All @@ -226,27 +246,30 @@ def build_circuit(

Args:
literal_input_factory: A factory that builds an input layer for literals.
negated_literal_input_factory: A factory that builds an input layer for negated literals.
weight_factory: The factory to construct the weight of sum layers. It can be None,
or a parameter factory, i.e., a map from a shape to a symbolic parameter.
If None is used, the default weight factory uses non-trainable unitary parameters,
which instantiate a regular boolean logic graph.
negated_literal_input_factory:
A factory that builds an input layer for negated literals.
weight_factory: The factory to construct the weight of sum layers.
It can be None, or a parameter factory, i.e., a map from a shape to
a symbolic parameter.
If None is used, the default weight factory uses non-trainable unitary
parameters, which instantiate a regular boolean logic graph.
num_channels: The number of channels for each variable.
enforce_smoothness: Enforces smoothness of the circuit to support efficient marginalization.
enforce_smoothness:
Enforces smoothness of the circuit to support efficient marginalization.

Returns:
Circuit: A symbolic circuit.

Raises:
ValueError: If only one of literal_input_factory and negated_literal_input_factory is specified.
ValueError: If only one of literal_input_factory and
negated_literal_input_factory are specified.
"""
if enforce_smoothness:
simplified_graph = self.smooth().simplify()
else:
simplified_graph = self.simplify()
self.smooth()
self.prune()

in_layers: dict[Layer, Sequence[Layer]] = {}
node_to_layer: dict[LogicCircuitNode, Layer] = {}
node_to_layer: dict[LogicalCircuitNode, Layer] = {}

if (literal_input_factory is None) ^ (negated_literal_input_factory is None):
raise ValueError(
Expand All @@ -267,7 +290,7 @@ def weight_factory(n: tuple[int]) -> Parameter:
return Parameter.from_input(TensorParameter(*n, initializer=initializer))

# map each input literal to a symbolic input layer
for i in simplified_graph.inputs:
for i in self.inputs:
match i:
case LiteralNode():
node_to_layer[i] = literal_input_factory(
Expand All @@ -278,25 +301,21 @@ def weight_factory(n: tuple[int]) -> Parameter:
Scope([i.literal]), num_units=1, num_channels=num_channels
)

for node in simplified_graph.topological_ordering():
for node in self.topological_ordering():
match node:
case ConjunctionNode():
product_node = HadamardLayer(1, arity=len(simplified_graph.node_inputs(node)))
in_layers[product_node] = [
node_to_layer[i] for i in simplified_graph.node_inputs(node)
]
product_node = HadamardLayer(1, arity=len(self.node_inputs(node)))
in_layers[product_node] = [node_to_layer[i] for i in self.node_inputs(node)]
node_to_layer[node] = product_node
case DisjunctionNode():
sum_node = SumLayer(
1,
1,
arity=len(simplified_graph.node_inputs(node)),
arity=len(self.node_inputs(node)),
weight_factory=weight_factory,
)
in_layers[sum_node] = [
node_to_layer[i] for i in simplified_graph.node_inputs(node)
]
in_layers[sum_node] = [node_to_layer[i] for i in self.node_inputs(node)]
node_to_layer[node] = sum_node

layers = list(set(itertools.chain(*in_layers.values())).union(in_layers.keys()))
return Circuit(num_channels, layers, in_layers, [node_to_layer[simplified_graph.output]])
return Circuit(num_channels, layers, in_layers, [node_to_layer[self.output]])
14 changes: 7 additions & 7 deletions cirkit/templates/logic/sdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
ConjunctionNode,
DisjunctionNode,
LiteralNode,
LogicCircuitNode,
LogicGraph,
LogicalCircuit,
LogicalCircuitNode,
NegatedLiteralNode,
TopNode,
)
Expand All @@ -27,7 +27,7 @@ def sliding_window(iterable, n):
yield tuple(window)


class SDD(LogicGraph):
class SDD(LogicalCircuit):
@staticmethod
def load(filename: str) -> "SDD":
"""Load the SDD from a file.
Expand All @@ -46,14 +46,14 @@ def load(filename: str) -> "SDD":
filename (str): The file name for loading.

Returns:
LogicGraph: The loaded logic graph.
LogicalCircuit: The loaded logic graph.
"""
tag_re = re.compile(r"^(c|sdd|F|T|L|D)")
line_re = re.compile(r"(-?\d+)")

nodes_map: dict[int, LogicCircuitNode] = {}
literal_map: dict[tuple[int, bool], LogicCircuitNode] = {}
in_nodes: dict[LogicCircuitNode, list[LogicCircuitNode]] = defaultdict(list)
nodes_map: dict[int, LogicalCircuitNode] = {}
literal_map: dict[tuple[int, bool], LogicalCircuitNode] = {}
in_nodes: dict[LogicalCircuitNode, list[LogicalCircuitNode]] = defaultdict(list)

with open(filename, encoding="utf-8") as f:
for line in f.readlines():
Expand Down
Loading
Loading
0