Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'transforms' argument to plot_cap #594

Merged
merged 2 commits into from
Dec 2, 2022

Conversation

tomicapretto
Copy link
Collaborator

@tomicapretto tomicapretto commented Nov 25, 2022

This PR adds an optional argument called transforms to plot_cap(). This is useful when the model is fitted on transformed data and we want the visualization in the original scale. See the example

import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from bambi.plots import plot_cap

rng = np.random.default_rng(1234)
x = np.abs(rng.normal(size=100, scale=5)) + 5
y = np.exp(2 + 0.3 * x + rng.normal(size=100, scale=0.5))

data = pd.DataFrame({"x": x, "y": y})
data["log_x"] = np.log(x)
data["log_y"] = np.log(y)
model = bmb.Model("log_y ~ 1 + log_x", data)
idata = model.fit()

# Plot CAP on transformed scale. Result is a straight line
fig, ax = plt.subplots()
ax.scatter(data["log_x"], data["log_y"], color="C4", alpha=0.6);
plot_cap(model, idata, "log_x", ax=ax);

image

# Plot CAP on the original scale, not necessarily a straight line
transforms = {"log_y": np.exp, "log_x": np.exp}
fig, ax = plt.subplots()
ax.scatter(data["x"], data["y"], color="C4", alpha=0.6);
plot_cap(model, idata, "log_x", transforms=transforms, ax=ax);

image

Update We can also use models that contain inline transformations. Notice the plot is created in the untransformed space. If we want the transformed space, we need to ask for it. While it may look counter-intuitive in the beginning, this is the most sensible approach I think because it works well when the predictor on the horizontal scale is included in more than a single term.

fig, ax = plt.subplots()
ax.scatter(np.log(data["x"]), np.log(data["y"]), color="C4", alpha=0.6);
plot_cap(model, idata, "x", transforms={"x": np.log}, ax=ax);

image

transforms = {"log(y)": np.exp}
fig, ax = plt.subplots()
ax.scatter(data["x"], data["y"], color="C4", alpha=0.6);
plot_cap(model, idata, "x", transforms=transforms, ax=ax);

image

Closes #588 and #590

@codecov-commenter
Copy link

codecov-commenter commented Nov 25, 2022

Codecov Report

Merging #594 (763316a) into main (e61f0d5) will decrease coverage by 0.34%.
The diff coverage is 0.00%.

@@            Coverage Diff             @@
##             main     #594      +/-   ##
==========================================
- Coverage   85.74%   85.39%   -0.35%     
==========================================
  Files          38       38              
  Lines        2932     2944      +12     
==========================================
  Hits         2514     2514              
- Misses        418      430      +12     
Impacted Files Coverage Δ
bambi/plots/plot_cap.py 0.00% <0.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@tomicapretto
Copy link
Collaborator Author

I think this can be merged. Since it's not that trivial I would like a second pair of eyes @aloctavodia. If you're OK with it, let's merge.

@canyon289
Copy link
Collaborator

If I understand correctly this PR has been replaced by #596?

@tomicapretto
Copy link
Collaborator Author

If I understand correctly this PR has been replaced by #596?

This is myself still not knowing git very well haha! Being on this branch, I did git checkout -b the_other_branch because I wanted to have these changes applied as well. Let's say this PR is Part 1, and the other PR is meant to be Part 2. But because of how I used git, the other branch is Part 1 + Part 2. Do you know how one can do what I wanted to do in the beginning?

@canyon289
Copy link
Collaborator

canyon289 commented Nov 27, 2022 via email

@tomicapretto tomicapretto merged commit 3af8328 into bambinos:main Dec 2, 2022
@tomicapretto tomicapretto deleted the add_transforms_plot_cap branch December 2, 2022 01:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add transforms argument to plot_cap
3 participants