Skip to content

Commit

Permalink
Merge pull request #116 from mggg/plot_options
Browse files Browse the repository at this point in the history
Plot options
  • Loading branch information
karink520 authored Jul 8, 2024
2 parents 1c07ec2 + 53c9c7c commit 83d3fa6
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 132 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
run: |
./scripts/test.sh
# Auto-publish when version is increased
publish-job:
# Only try to publish if:
Expand All @@ -45,4 +46,4 @@ jobs:
pypi-token: ${{ secrets.PYPI_API_TOKEN }}
gh-token: ${{ secrets.GITHUB_TOKEN }}
parse-changelog: false
pkg-name: pyei
pkg-name: pyei
56 changes: 48 additions & 8 deletions pyei/goodmans_er.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,16 @@ def summary(self):
{self.voting_prefs_complement_est_:.3f}
"""

def plot(self):
"""Plot the linear regression with confidence interval"""
def plot(self, **sns_regplot_args):
"""Plot the linear regression with 95% confidence interval
Notes:
------
Can pass additional plot arguments through to seaborn.regplot, e.g.
to change the line color:
line_kws=dict(color="red")
scatter_kws={"s": 50}
"""
fig, ax = plt.subplots()
ax.axis("square")
ax.grid(visible=True, which="major")
Expand All @@ -111,6 +119,7 @@ def plot(self):
ax=ax,
ci=95,
truncate=False,
**sns_regplot_args,
)
return fig, ax

Expand All @@ -121,7 +130,10 @@ class GoodmansERBayes(TwoByTwoEIBaseBayes):
"""

def __init__(
self, model_name="goodman_er_bayes", weighted_by_pop=False, **additional_model_params
self,
model_name="goodman_er_bayes",
weighted_by_pop=False,
**additional_model_params,
):
"""
Optional arguments:
Expand Down Expand Up @@ -238,19 +250,47 @@ def compute_credible_int_for_line(self, x_vals=np.linspace(0, 1, 100)):

return x_vals, means, lower_bounds, upper_bounds

def plot(self):
def plot(self, scatter_kws=None, line_kws=None):
"""Plot regression line of votes_fraction vs. group_fraction, with scatter plot and
equal-tailed 95% credible interval for the line"""
equal-tailed 95% credible interval for the line"
Parameters:
-----------
scatter_kws : dict (None)
Keyword arguments to be passed to matplotlib.Axes.scatter
line_kws : dict (None)
Keyword arguments to be passed to matplotlib.Axes.plot.
Note that the color of the ccredible interval shading is set to match
the color of the line itself (but the shading has higher alpha)
Notes:
------
Examples of additional kwargs. Scatter_colors is list of colors of length num_precincts
scatter_kws={"c": scatter_colors, "color": None, "s": 20},
line_kws={"color":"black", "lw": 1}
"""
# TODO: consider renaming these plots for goodman, to disambiguate with TwoByTwoEI.plot()
# TODO: accept axis argument
x_vals, means, lower_bounds, upper_bounds = self.compute_credible_int_for_line()
_, ax = plt.subplots()

if scatter_kws is None:
scatter_kws = {}
if line_kws is None:
line_kws = {}
scatter_kws.setdefault("color", "steelblue")
scatter_kws.setdefault("alpha", 0.8)
line_kws.setdefault("color", "steelblue")

ax.axis("square")
ax.set_xlabel(f"Fraction in group {self.demographic_group_name}")
ax.set_ylabel(f"Fraction voting for {self.candidate_name}")
ax.scatter(self.demographic_group_fraction, self.votes_fraction, alpha=0.8)
ax.plot(x_vals, means)
ax.fill_between(x_vals, lower_bounds, upper_bounds, color="steelblue", alpha=0.2)
ax.scatter(
self.demographic_group_fraction,
self.votes_fraction,
**scatter_kws,
)
ax.plot(x_vals, means, **line_kws)
ax.fill_between(x_vals, lower_bounds, upper_bounds, alpha=0.2, color=line_kws["color"])
ax.grid()
ax.set_xlim((0, 1))
ax.set_ylim((0, 1))
Expand Down
189 changes: 105 additions & 84 deletions pyei/intro_notebooks/Plotting_with_PyEI.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyei/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def to_netcdf(ei_object, filepath):
mode = "a"
for attr in ["demographic_group_fractions", "votes_fractions"]: # array atts
data = xr.DataArray(getattr(ei_object, attr), name=attr)
data.load()
data.to_netcdf(filepath, mode=mode, group=attr, engine="netcdf4")
data.close()

Expand All @@ -83,7 +84,7 @@ def from_netcdf(filepath):
with sim_trace and most other atrributes set as they would
be when fit. Note sim_model is not saved/loaded
"""
idata = az.from_netcdf(filepath)
idata = az.from_netcdf(filepath, engine="netcdf4")

attrs_dict = idata.posterior.attrs # pylint: disable=no-member
attr_list = list(idata.posterior.attrs.keys()) # pylint: disable=no-member
Expand Down
45 changes: 35 additions & 10 deletions pyei/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ def plot_single_ridgeplot(


def plot_single_histogram(
ax, group_prefs, colors, alpha, z_init, trans # pylint: disable=redefined-outer-name
ax,
group_prefs,
colors, # pylint: disable=redefined-outer-name
alpha,
z_init,
trans, # pylint: disable=redefined-outer-name
):
"""Helper function for plot_precincts that plots a single precinct histogram(s)
(i.e.,for a single precinct for a given candidate.)
Expand Down Expand Up @@ -315,7 +320,7 @@ def plot_boxplots(
legend = group_names
support = "for"
if axes is None:
fig, axes = plt.subplots(num_candidates, figsize=FIGSIZE)
_, axes = plt.subplots(num_candidates, figsize=FIGSIZE)

elif plot_by == "group":
num_plots = num_groups
Expand All @@ -325,10 +330,10 @@ def plot_boxplots(
legend = candidate_names
support = "among"
if axes is None:
fig, axes = plt.subplots(num_groups, figsize=FIGSIZE)
_, axes = plt.subplots(num_groups, figsize=FIGSIZE)
else:
raise ValueError("plot_by must be 'group' or 'candidate' (default: 'candidate')")
fig.subplots_adjust(hspace=1)
plt.gcf().subplots_adjust(hspace=1)

for plot_idx in range(num_plots):
samples_df = pd.DataFrame(
Expand Down Expand Up @@ -454,13 +459,19 @@ def plot_precinct_scatterplot(ei_runs, run_names, candidate, demographic_group="
# Set group names and candidates in case runs are TwoByTwoEI
if not hasattr(ei_runs[0], "demographic_group_names"): # then is TwoByTwoEI
demographic_group_names1 = list(ei_runs[0].group_names_for_display())
candidate_names1 = [ei_runs[0].candidate_name, "not " + ei_runs[0].candidate_name]
candidate_names1 = [
ei_runs[0].candidate_name,
"not " + ei_runs[0].candidate_name,
]
else:
demographic_group_names1 = ei_runs[0].demographic_group_names
candidate_names1 = ei_runs[0].candidate_names
if not hasattr(ei_runs[1], "demographic_group_names"): # then it is TwoByTwoEI
demographic_group_names2 = list(ei_runs[1].group_names_for_display())
candidate_names2 = [ei_runs[1].candidate_name, "not " + ei_runs[1].candidate_name]
candidate_names2 = [
ei_runs[1].candidate_name,
"not " + ei_runs[1].candidate_name,
]
else:
demographic_group_names2 = ei_runs[1].demographic_group_names
candidate_names2 = ei_runs[1].candidate_names
Expand Down Expand Up @@ -493,7 +504,8 @@ def plot_precinct_scatterplot(ei_runs, run_names, candidate, demographic_group="
)
sns.lineplot(x=[0, 1], y=[0, 1], alpha=0.5, color="grey")
ax.set_title(
f"{run_names[0]} vs. {run_names[1]}\n Predicted support for {candidate}", fontsize=TITLESIZE
f"{run_names[0]} vs. {run_names[1]}\n Predicted support for {candidate}",
fontsize=TITLESIZE,
)
ax.set_xlabel(f"Support for {candidate} (from {run_names[0]})", fontsize=FONTSIZE)
ax.set_ylabel(f"Support for {candidate} (from {run_names[1]})", fontsize=FONTSIZE)
Expand Down Expand Up @@ -572,6 +584,7 @@ def plot_polarization_kde(
candidate_name,
show_threshold=False,
ax=None,
color="steelblue",
):
"""
Plots a kde for the differences in voting preferences between two groups
Expand All @@ -591,6 +604,8 @@ def plot_polarization_kde(
show_threshold: bool
if true, add a vertical line at the threshold on the plot and display the associated
tail probability
color: str
specifies a color for matplotlib to be used in the histogram/kde
Returns
-------
Expand All @@ -607,7 +622,7 @@ def plot_polarization_kde(
element="step",
stat="density",
label=groups[0] + " - " + groups[1],
color="steelblue",
color=color,
linewidth=0,
)
ax.set_ylabel("Density", fontsize=FONTSIZE)
Expand Down Expand Up @@ -875,7 +890,13 @@ def plot_intervals_all_precincts(


def tomography_plot(
group_fraction, votes_fraction, demographic_group_name, candidate_name, ax=None
group_fraction,
votes_fraction,
demographic_group_name,
candidate_name,
ax=None,
c="b",
**plot_kwargs,
):
"""Tomography plot (basic), applicable for 2x2 ei
Expand All @@ -893,6 +914,10 @@ def tomography_plot(
Name of candidate or voting outcome of interest
ax : Matplotlib axis object or None, optional
Default=None
c : specifies a color for Matplotlib, optional
Default="b"
**plot_kwargs
Additional keyword arguments to be passed to matplotlib.Axes.plot()
Returns
-------
Expand All @@ -911,5 +936,5 @@ def tomography_plot(
ax.set_ylabel(f"voter pref of non-{demographic_group_name} for {candidate_name}")
for i in range(num_precincts):
b_2 = (votes_fraction[i] - b_1 * group_fraction[i]) / (1 - group_fraction[i])
ax.plot(b_1, b_2, c="b")
ax.plot(b_1, b_2, c=c, **plot_kwargs)
return ax
68 changes: 57 additions & 11 deletions pyei/r_by_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,18 @@ def fit( # pylint: disable=too-many-branches

if self.model_name == "multinomial-dirichlet":
self.sim_model = ei_multinom_dirichlet(
group_fractions, votes_fractions, precinct_pops, **self.additional_model_params
group_fractions,
votes_fractions,
precinct_pops,
**self.additional_model_params,
)

elif self.model_name == "multinomial-dirichlet-modified":
self.sim_model = ei_multinom_dirichlet_modified(
group_fractions, votes_fractions, precinct_pops, **self.additional_model_params
group_fractions,
votes_fractions,
precinct_pops,
**self.additional_model_params,
)

elif self.model_name == "greiner-quinn":
Expand All @@ -199,7 +205,10 @@ def fit( # pylint: disable=too-many-branches
)
elif self.model_name == "greiner-quinn":
self.sim_trace = pyei_greiner_quinn_sample(
group_fractions, votes_fractions, precinct_pops, **other_sampling_args
group_fractions,
votes_fractions,
precinct_pops,
**other_sampling_args,
) #

self.calculate_summary()
Expand Down Expand Up @@ -298,7 +307,11 @@ def calculate_turnout_adjusted_summary(self, non_candidate_names):
# # compute credible intervals
percentiles = [2.5, 97.5]
self.turnout_adjusted_credible_interval_95_mean_voting_prefs = np.zeros(
(self.num_groups_and_num_candidates[0], self.num_groups_and_num_candidates[1] - 1, 2)
(
self.num_groups_and_num_candidates[0],
self.num_groups_and_num_candidates[1] - 1,
2,
)
)
for row in range(self.num_groups_and_num_candidates[0]):
for col in range(self.num_groups_and_num_candidates[1] - 1):
Expand Down Expand Up @@ -350,7 +363,11 @@ def calculate_summary(self):
# compute credible intervals
percentiles = [2.5, 97.5]
self.credible_interval_95_mean_voting_prefs = np.zeros(
(self.num_groups_and_num_candidates[0], self.num_groups_and_num_candidates[1], 2)
(
self.num_groups_and_num_candidates[0],
self.num_groups_and_num_candidates[1],
2,
)
)
for row in range(self.num_groups_and_num_candidates[0]):
for col in range(self.num_groups_and_num_candidates[1]):
Expand Down Expand Up @@ -768,10 +785,16 @@ def candidate_of_choice_polarization_report(self, verbose=True, non_candidate_na
f"{self.demographic_group_names[dem2]} voters differ."
)
candidate_differ_rate_dict[
(self.demographic_group_names[dem1], self.demographic_group_names[dem2])
(
self.demographic_group_names[dem1],
self.demographic_group_names[dem2],
)
] = differ_frac
candidate_differ_rate_dict[
(self.demographic_group_names[dem2], self.demographic_group_names[dem1])
(
self.demographic_group_names[dem2],
self.demographic_group_names[dem1],
)
] = differ_frac
return candidate_differ_rate_dict

Expand Down Expand Up @@ -844,7 +867,13 @@ def plot_kdes(self, plot_by="candidate", non_candidate_names=None, axes=None):
)

def plot_margin_kde(
self, group, candidates, threshold=None, percentile=None, show_threshold=False, ax=None
self,
group,
candidates,
threshold=None,
percentile=None,
show_threshold=False,
ax=None,
):
"""
Plot kde of the margin between two candidates among the given demographic group.
Expand Down Expand Up @@ -886,7 +915,14 @@ def plot_margin_kde(
)

def plot_polarization_kde(
self, groups, candidate, threshold=None, percentile=None, show_threshold=False, ax=None
self,
groups,
candidate,
threshold=None,
percentile=None,
show_threshold=False,
ax=None,
color="steelblue",
):
"""Plot kde of differences between voting preferences
Expand All @@ -907,9 +943,12 @@ def plot_polarization_kde(
must be None
show_threshold: bool
ax : matplotlib Axis object
color : str (optional)
Specifies a color for matplotlib to be used in the histogram/kde.
default="steelblue"
Returns:
--------ß
--------
matplotlib axis object
"""
return_interval = threshold is None
Expand All @@ -931,7 +970,14 @@ def plot_polarization_kde(
thresholds = [threshold]

return plot_polarization_kde(
samples, thresholds, percentile, groups, candidate, show_threshold, ax
samples,
thresholds,
percentile,
groups,
candidate,
show_threshold,
ax,
color=color,
)

def plot_intervals_by_precinct(self, group_name, candidate_name):
Expand Down
Loading

0 comments on commit 83d3fa6

Please sign in to comment.