Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sampling seed to standardize records subsample in test_pufcsv.py #869

Merged
merged 1 commit into from
Aug 22, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions taxcalc/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,15 @@ def __init__(self,
rtol=0.0, atol=0.001):
raise ValueError(msg.format('e02100'))
# read extrapolation blowup factors and sample weights
self.BF = None
self._read_blowup(blowup_factors)
self.WT = None
self._read_weights(weights)
# weights must be same size as tax record data
if not self.WT.empty and self.dim != len(self.WT):
frac = float(self.dim) / len(self.WT)
self.WT = self.WT.iloc[self.index]
self.WT = self.WT / frac

# specify current_year and FLPDYR values
if isinstance(start_year, int):
self._current_year = start_year
Expand All @@ -229,7 +230,7 @@ def __init__(self,
msg = 'start_year is not an integer'
raise ValueError(msg)
# consider applying initial-year blowup factors
if self.BF.empty is False and self.current_year == Records.PUF_YEAR:
if not self.BF.empty and self.current_year == Records.PUF_YEAR:
self._extrapolate_in_puf_year()
# construct sample weights for current_year
wt_colname = 'WT{}'.format(self.current_year)
Expand Down
76 changes: 38 additions & 38 deletions taxcalc/tests/test_pufcsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,24 @@
import pandas as pd


@pytest.mark.requires_pufcsv
def test_sample():
"""
Test if reading in a sample of the data produces a reasonable estimate
relative to the full data set
"""
# Full dataset
clp = Policy()
puf = Records(data=PUFCSV_PATH)
calc = Calculator(policy=clp, records=puf)
adt = multiyear_diagnostic_table(calc, num_years=10)

# Sample sample dataset
clp2 = Policy()
tax_data_full = pd.read_csv(PUFCSV_PATH)
tax_data = tax_data_full.sample(frac=0.02)
puf_sample = Records(data=tax_data)
calc_sample = Calculator(policy=clp2, records=puf_sample)
adt_sample = multiyear_diagnostic_table(calc_sample, num_years=10)

# Get the final combined tax liability for the budget period
# in the sample and the full dataset and make sure they are close
full_tax_liability = adt.loc["Combined liability ($b)"]
sample_tax_liability = adt_sample.loc["Combined liability ($b)"]
max_val = max(full_tax_liability.max(), sample_tax_liability.max())
rel_diff = max(abs(full_tax_liability - sample_tax_liability)) / max_val

# Fail on greater than 5% releative difference in any budget year
assert rel_diff < 0.05


@pytest.mark.requires_pufcsv
def test_agg():
"""
Test Tax-Calculator aggregate taxes with no policy reform using puf.csv
Test Tax-Calculator aggregate taxes with no policy reform using
the full-sample puf.csv and a two-percent sub-sample of puf.csv
"""
# pylint: disable=too-many-locals
# pylint: disable=too-many-locals,too-many-statements
nyrs = 10
# create a Policy object (clp) containing current-law policy parameters
clp = Policy()
# create a Records object (puf) containing puf.csv input records
puf = Records(data=PUFCSV_PATH)
# create a Records object (rec) containing all puf.csv input records
rec = Records(data=PUFCSV_PATH)
# create a Calculator object using clp policy and puf records
calc = Calculator(policy=clp, records=puf)
calc = Calculator(policy=clp, records=rec)
calc_start_year = calc.current_year
# create aggregate diagnostic table (adt) as a Pandas DataFrame object
adt = multiyear_diagnostic_table(calc, 10)
adt = multiyear_diagnostic_table(calc, nyrs)
taxes_fullsample = adt.loc["Combined liability ($b)"]
# convert adt results to a string with a trailing EOL character
adtstr = adt.to_string() + '\n'
# generate differences between actual and expected results
Expand All @@ -96,14 +69,41 @@ def test_agg():
new_filename = '{}{}'.format(AGGRES_PATH[:-10], 'actual.txt')
with open(new_filename, 'w') as new_file:
new_file.write(adtstr)
msg = 'PUFCSV AGG RESULTS DIFFER\n'
msg = 'PUFCSV AGG RESULTS DIFFER FOR FULL-SAMPLE\n'
msg += '-------------------------------------------------\n'
msg += '--- NEW RESULTS IN pufcsv_agg_actual.txt FILE ---\n'
msg += '--- if new OK, copy pufcsv_agg_actual.txt to ---\n'
msg += '--- pufcsv_agg_expect.txt ---\n'
msg += '--- and rerun test. ---\n'
msg += '-------------------------------------------------\n'
raise ValueError(msg)
# create aggregate diagnostic table using sub sample of records
fullsample = pd.read_csv(PUFCSV_PATH)
rn_seed = 80 # to ensure two-percent sub-sample is always the same
subsample = fullsample.sample(frac=0.02, random_state=rn_seed)
rec_subsample = Records(data=subsample)
calc_subsample = Calculator(policy=Policy(), records=rec_subsample)
adt_subsample = multiyear_diagnostic_table(calc_subsample, num_years=nyrs)
# compare combined tax liability from full and sub samples for each year
taxes_subsample = adt_subsample.loc["Combined liability ($b)"]
reltol = 0.01 # maximum allowed relative difference in tax liability
if not np.allclose(taxes_subsample, taxes_fullsample,
atol=0.0, rtol=reltol):
msg = 'PUFCSV AGG RESULTS DIFFER IN SUB-SAMPLE AND FULL-SAMPLE\n'
msg += 'WHEN reltol = {:.4f}\n'.format(reltol)
it_sub = np.nditer(taxes_subsample, flags=['f_index'])
it_all = np.nditer(taxes_fullsample, flags=['f_index'])
while not it_sub.finished:
cyr = it_sub.index + calc_start_year
tax_sub = float(it_sub[0])
tax_all = float(it_all[0])
reldiff = abs(tax_sub - tax_all) / abs(tax_all)
if reldiff > reltol:
msgstr = ' year,sub,full,reldif= {}\t{:.2f}\t{:.2f}\t{:.4f}\n'
msg += msgstr.format(cyr, tax_sub, tax_all, reldiff)
it_sub.iternext()
it_all.iternext()
raise ValueError(msg)


MTR_TAX_YEAR = 2013
Expand Down