Skip to content

Commit

Permalink
improve PCFG examples
Browse files Browse the repository at this point in the history
  • Loading branch information
robert-lieck committed Nov 24, 2024
1 parent a0a8307 commit 4fb1c8d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 24 deletions.
23 changes: 20 additions & 3 deletions examples/plot_PCFG.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,31 @@
Abstracted PCFGs
================
This is an example of defining a simple discrete RBN equivalent to a PCFG.
This is an example of defining a simple discrete RBN equivalent to a PCFG using the :class:`AbstractedPCFG
<rbnet.pcfg.AbstractedPCFG>` class.
"""

from rbnet.pcfg import AbstractedPCFG

# %%
# Minimal Example
# ---------------
# We start with a minimal example (also used in :doc:`/auto_examples/plot_discrete_RBN`):

pcfg = AbstractedPCFG(non_terminals="SAB", terminals="ab", start="S", rules=[
("S --> A B", 1), ("S --> B A", 1), # prior + first transition
("A --> B A", 1), ("B --> A B", 1), # non-terminal transitions
("A --> a", 1), ("B --> b", 1) # terminal transition
])

print(pcfg.inside(sequence="aaaa"))
print(pcfg.inside(sequence="bbbb"))
print(pcfg.inside(sequence="aaab"))
print(pcfg.inside_chart[0].pretty())

# %%
# Defining the PCFG
# -----------------
# We use an :class:`AbstractedPCFG <rbnet.pcfg.AbstractedPCFG>`
from rbnet.pcfg import AbstractedPCFG

# %%
# First we define a number of words (terminal symbols) of different categories that sentences can be composed of:
Expand Down
59 changes: 41 additions & 18 deletions examples/plot_discrete_RBN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
Discrete RBNs are equivalent to PCFGs and, compared to continuous RBNs, easier to inspect and instructive as a first
step before moving towards continuous variables.
.. note::
If you are only interested in defining and using a PCFG, see :doc:`/auto_examples/plot_PCFG` for a more
realistic example using the :class:`~rbnet.pcfg.AbstractedPCFG` class.
"""

# %%
Expand Down Expand Up @@ -34,10 +39,12 @@
import numpy as np

# %%
# Defining Variables and Transitions
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# We first define a discrete binary non-terminal variable (:class:`~rbnet.pcfg.DiscreteNonTermVar`) and a corresponding
# :class:`~rbnet.pcfg.DiscretePrior` that always generates this one variable with a uniform distribution over its value
non_term_var = DiscreteNonTermVar(cardinality=2)
prior = DiscretePrior(struc_weights=[1], prior_weights=[[0.5, 0.5]])
prior = DiscretePrior(struc_weights=[1.], prior_weights=[[0.5, 0.5]])

# %%
# For the transitions, we use a :class:`~rbnet.pcfg.DiscreteBinaryNonTerminalTransition` ``p(a, b | c)`` were the
Expand All @@ -54,40 +61,30 @@
term_transition = DiscreteTerminalTransition(weights=weights)

# %%
# Defining the Cell and RBN
# ^^^^^^^^^^^^^^^^^^^^^^^^^
# We can now create a :class:`~rbnet.pcfg.StaticCell` for the non-terminal variable, which chooses the terminal
# transition 50% of the time, and define our :class:`~rbnet.base.SequentialRBN`
cell = StaticCell(variable=non_term_var,
weights=[0.5, 0.5],
transitions=[non_term_transition, term_transition])
rbn = SequentialRBN(cells=[cell], prior=prior)

# %%
# An equivalent RBN can be defined using the :class:`~rbnet.pcfg.AbstractedPCFG` class (see
# :doc:`/auto_examples/plot_PCFG` for a more realistic example using :class:`~rbnet.pcfg.AbstractedPCFG`)
from rbnet.pcfg import AbstractedPCFG

pcfg = AbstractedPCFG(non_terminals="SAB", terminals="ab", start="S", rules=[
("S --> A B", 1), ("S --> B A", 1), # prior + first transition
("A --> B A", 1), ("B --> A B", 1), # non-terminal transitions
("A --> a", 2), ("B --> b", 2) # terminal transition
])


# %%
# Parsing Sequences
# ^^^^^^^^^^^^^^^^^
# It is impossible to generate sequences with all zeros or all ones, because children never have the same value and the
# terminal transition does not change the value. Thus, the marginal likelihood for these sequences, returned by the
# :meth:`~rbnet.base.RBN.inside` method, is always zero
print(rbn.inside(sequence=[[0], [0], [0], [0]]), pcfg.inside(sequence="aaaa"))
print(rbn.inside(sequence=[[1], [1], [1], [1]]), pcfg.inside(sequence="bbbb"))

print(rbn.inside(sequence=[[0], [0], [0], [0]]))
print(rbn.inside(sequence=[[1], [1], [1], [1]]))

# %%
# For other sequences, we see that the marginal likelihood is non-zero, and we can also inspect the parse chart, which
# contains the inside probabilities for the values of the non-terminal variable
print(rbn.inside(sequence=[[0], [0], [0], [1]]), pcfg.inside(sequence="aaab"))
print(rbn.inside(sequence=[[0], [0], [0], [1]]))
print(rbn.inside_chart[0].pretty())
print(pcfg.inside_chart[0].pretty())

# %%
# Note how
Expand All @@ -98,9 +95,35 @@
# of (not) terminating and another factor of ``0.5`` comes from the inside probability of the left child.
# - the marginal likelihood is ``1/2`` of the top-level inside probability, because the prior is uniform over values.

# %%
# Using the :class:`~rbnet.pcfg.AbstractedPCFG` Class
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# An equivalent RBN can be defined using the :class:`~rbnet.pcfg.AbstractedPCFG` class (see
# :doc:`/auto_examples/plot_PCFG` for a more realistic example). Note that the prior transition of an RBN provides
# slightly more freedom than a PCFG, because it defines a distribution over the first symbol, whereas a PCFG always
# starts with the start symbol. Therefore, we need combine the prior and the first non-terminal transition into the
# definition for the start symbol (essentially marginalising out the first symbol generated by the prior in the RBN).
# Internally, the :class:`~rbnet.pcfg.AbstractedPCFG` class defines a deterministic prior that generates the start
# symbol. We get identical inside probabilities to the RBN case above (the first value corresponding to the start
# symbol), but the marginal likelihood is a factor of 2 larger because of the deterministic prior.

from rbnet.pcfg import AbstractedPCFG

pcfg = AbstractedPCFG(non_terminals="SAB", terminals="ab", start="S", rules=[
("S --> A B", 1), ("S --> B A", 1), # prior + first transition
("A --> B A", 1), ("B --> A B", 1), # non-terminal transitions
("A --> a", 1), ("B --> b", 1) # terminal transition
])


print(pcfg.inside(sequence="aaaa"))
print(pcfg.inside(sequence="bbbb"))
print(pcfg.inside(sequence="aaab"))
print(pcfg.inside_chart[0].pretty())

# %%
# Expanded PCFG
# ^^^^^^^^^^^^^
# -------------
# We will now define an RBN by `expanding` the same PCFG used in the example above. When expanding a PCFG to an
# RBN, each non-terminal symbol becomes a separate non-terminal variable in the RBN. The PCFG thus acts as an outer
# skeleton when being expanded, and we are required to additionally define the domain and transitions for the variables.
Expand Down
7 changes: 4 additions & 3 deletions rbnet/pcfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def map_inside_chart(self, precision=None):
field = []
for idx, p in enumerate(probs):
if p > 0:
field.append(f"{self._non_terminals[idx]}|{np.format_float_scientific(p, precision=precision)}")
field.append(f"{self._non_terminals[idx]}:{np.format_float_scientific(p, precision=precision)}")
if len(field) == 1:
new_arr.append(field[0])
else:
Expand All @@ -44,8 +44,9 @@ class AbstractedPCFG(PCFG, pl.LightningModule, ConstrainedModuleMixin):

def __init__(self, non_terminals, terminals, rules, start, prob_rep=LogProb, *args, **kwargs):
"""
An AbstractedPCFG defines an RBN that has only one non-terminal and one terminal variable
with the cardinality of the non-terminal and terminal symbols, respectively, of the PCFG.
An :class:`~AbstractedPCFG` defines an :class:`~rbnet.base.RBN` that has only one non-terminal and one
terminal variable, both being discrete with a cardinality corresponding to the number of non-terminal and
terminal symbols of the PCFG, respectively.
:param non_terminals: list or array of non-terminal symbols
:param terminals: list or array of terminal symbols
Expand Down

0 comments on commit 4fb1c8d

Please sign in to comment.