Skip to content

Commit

Permalink
Add changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilwoodruff committed Sep 6, 2024
1 parent 1aac85e commit bd8173d
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 60 deletions.
32 changes: 0 additions & 32 deletions docs/Home.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,6 @@
import streamlit as st
from download_prerequisites import download_data

STYLE = """
<style>
header {
display: none !important;
}
footer {
display: none !important;
}
section > div.block-container {
padding-top: 0px !important;
padding-bottom: 0px !important;
}
html, body, [class*="css"] {
font-family: "Roboto Serif", !important;
font-weight: 500;
}
[data-baseweb="slider"] {
padding-left: 10px !important;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
.modebar{
display: none !important;
}
</style>
"""
st.write(STYLE, unsafe_allow_html=True)

download_data()

st.title("PolicyEngine-US-Data")
Expand Down
2 changes: 2 additions & 0 deletions docs/pages/Benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def compare_datasets():
comparison["Dataset"] = dataset.label
comparison_combined = pd.concat([comparison_combined, comparison])

comparison_combined.to_csv("comparisons.csv", index=False)

return comparison_combined


Expand Down
6 changes: 0 additions & 6 deletions docs/pages/Distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,3 @@ def get_bar_chart(variable, filing_status, taxable, count):


st.plotly_chart(get_bar_chart(variable, filing_status, taxable, count))

st.subheader("Household incomes")

st.write(
"PolicyEngine's calibration process alters the distribution of household incomes. In the chart below, you can see the distribution of household incomes in both the Current Population Survey and the Enhanced CPS."
)
2 changes: 1 addition & 1 deletion policyengine_us_data/datasets/cps/cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate(self):

def add_rent(cps: h5py.File, person: DataFrame, household: DataFrame):
is_renting = household.H_TENURE == 2
AVERAGE_RENT = 1_700 * 12
AVERAGE_RENT = 1_300 * 12
# Project down to the first person in the household
person_is_renting = (
household.set_index("H_SEQ").loc[person.PH_SEQ].H_TENURE.values == 2
Expand Down
38 changes: 32 additions & 6 deletions policyengine_us_data/datasets/cps/enhanced_cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def loss(weights):
worst_val = rel_error[torch.argmax(rel_error)].item()
return rel_error.mean(), worst_name, worst_val

optimizer = torch.optim.Adam([weights], lr=1)
optimizer = torch.optim.Adam([weights], lr=1e-2)
from tqdm import trange

iterator = trange(1_000)
iterator = trange(10_000)
for i in iterator:
optimizer.zero_grad()
l, worst_name, worst_val = loss(torch.exp(weights))
Expand Down Expand Up @@ -129,10 +129,36 @@ def generate(self):
optimised_weights = reweight(
original_weights, loss_matrix, targets_array
)
if self.input_dataset.data_format == Dataset.TIME_PERIOD_ARRAYS:
data["household_weight"][year] = optimised_weights
else:
data["household_weight"] = optimised_weights
data["household_weight"][year] = optimised_weights

self.save_dataset(data)


class ReweightedCPS_2024(Dataset):
data_format = Dataset.ARRAYS
file_path = STORAGE_FOLDER / "reweighted_cps_2024.h5"
name = "reweighted_cps_2024"
label = "Reweighted CPS 2024"
input_dataset = CPS_2024
time_period = 2024

def generate(self):
from policyengine_us import Microsimulation

sim = Microsimulation(dataset=self.input_dataset)
data = sim.dataset.load_dataset()
original_weights = sim.calculate("household_weight")
original_weights = original_weights.values + np.random.normal(
1, 0.1, len(original_weights)
)
for year in [2024]:
loss_matrix, targets_array = build_loss_matrix(
self.input_dataset, year
)
optimised_weights = reweight(
original_weights, loss_matrix, targets_array
)
data["household_weight"] = optimised_weights

self.save_dataset(data)

Expand Down
23 changes: 13 additions & 10 deletions policyengine_us_data/utils/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,19 @@ def upload(
**auth_headers,
}

with open(file_path, "rb") as f, tqdm(
total=file_size, unit="B", unit_scale=True, desc=file_name
) as pbar:

def progress_callback(monitor):
pbar.update(monitor.bytes_read - pbar.n)

response = requests.post(
url, headers=headers, data=f, hooks={"response": progress_callback}
)
with open(file_path, "rb") as f:
with tqdm(total=file_size, unit="B", unit_scale=True) as pbar:
response = requests.post(
url,
headers=headers,
data=f,
stream=True,
hooks=dict(
response=lambda r, *args, **kwargs: pbar.update(
len(r.content)
)
),
)

if response.status_code != 201:
raise ValueError(
Expand Down
10 changes: 5 additions & 5 deletions policyengine_us_data/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,17 @@ def build_loss_matrix(dataset: type, time_period):
spm_threshold_agi = pd.read_csv(STORAGE_FOLDER / "spm_threshold_agi.csv")

for _, row in spm_threshold_agi.iterrows():
spm_unit_agi = sim.calculate("adjusted_gross_income", map_to="spm_unit").values
spm_unit_agi = sim.calculate(
"adjusted_gross_income", map_to="spm_unit"
).values
spm_threshold = sim.calculate("spm_unit_spm_threshold").values
in_threshold_range = (
(spm_threshold >= row["lower_spm_threshold"])
* (spm_threshold < row["upper_spm_threshold"])
in_threshold_range = (spm_threshold >= row["lower_spm_threshold"]) * (
spm_threshold < row["upper_spm_threshold"]
)
label = f"census/agi_in_spm_threshold_decile_{int(row['decile'])}"
loss_matrix[label] = sim.map_result(
in_threshold_range * spm_unit_agi, "spm_unit", "household"
)
loss_matrix[[label]].to_csv("test.csv")
targets_array.append(row["adjusted_gross_income"])

label = f"census/count_in_spm_threshold_decile_{int(row['decile'])}"
Expand Down

0 comments on commit bd8173d

Please sign in to comment.