Skip to content

Add Bayesian instrumental variable estimation #213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
23eff06
[IV 212] Adding base classes required for bayesian instrumental varia…
NathanielF Jun 30, 2023
0a519d8
[IV 212] loading data set properly, tidied notebook example and estim…
NathanielF Jul 2, 2023
28e9d9f
[IV 212] adding the class and example notebook to the documentation
NathanielF Jul 2, 2023
3c70dfe
[IV 212] adding integration test
NathanielF Jul 2, 2023
f97cd82
[IV 212] updating index rst
NathanielF Jul 4, 2023
d30da22
Merge branch 'latest' into feature_instrumental_variables
NathanielF Jul 4, 2023
61d5592
[IV 212] experimenting with nicer plot
NathanielF Jul 4, 2023
c70c3df
[IV 212] changed colors
NathanielF Jul 4, 2023
1bf4a74
[IV 212] added reference
NathanielF Jul 4, 2023
31cf8d7
[IV 212] added more write up and improved plot
NathanielF Jul 6, 2023
c50db34
get bibtex references working
drbenvincent Jul 2, 2023
fe76082
[IV 212] added parameter recovery example
NathanielF Jul 22, 2023
6a6739b
improvements to current banks
drbenvincent Jul 6, 2023
70b6571
minor edits
drbenvincent Jul 7, 2023
7529eaa
attempt to fix failing test
drbenvincent Jul 7, 2023
3e9b475
update to proper sphinx glossary
drbenvincent Jul 13, 2023
1f4e030
add wilkinson notation reference
drbenvincent Jul 13, 2023
df594ca
fix an admonition in brexit notebook
drbenvincent Jul 13, 2023
d76d295
add in some glossary terms to example notebooks
drbenvincent Jul 13, 2023
306bb1e
Update README.md
drbenvincent Jul 13, 2023
5a14848
bump to version 0.0.14
drbenvincent Jul 14, 2023
73e2959
improvements to current banks
drbenvincent Jul 6, 2023
85ae1b5
update to proper sphinx glossary
drbenvincent Jul 13, 2023
916804e
add wilkinson notation reference
drbenvincent Jul 13, 2023
f279a80
add in some glossary terms to example notebooks
drbenvincent Jul 13, 2023
12334bf
trying to resolve merge conflicts
NathanielF Jul 22, 2023
dd22016
[IV 212] addressing some of Ben's comments with notebook
NathanielF Jul 22, 2023
7a28dab
Merge branch 'main' into feature_instrumental_variables
NathanielF Jul 22, 2023
f08b02d
[IV 212] adding an input validation test
NathanielF Jul 22, 2023
04584d0
[IV 212] adding user warning and tidying params
NathanielF Aug 9, 2023
efa3f6d
Merge branch 'main' into feature_instrumental_variables
NathanielF Aug 9, 2023
f737115
[IV 212] pretty print user warning
NathanielF Aug 9, 2023
2e1d692
[IV 212] adding axis labels with Ben's comments
NathanielF Aug 23, 2023
cb68c78
[IV 212] fixing axis label
NathanielF Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions causalpy/data/AJR2001.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
longname,shortnam,logmort0,risk,loggdp,campaign,source0,slave,latitude,neoeuro,asia,africa,other,edes1975,campaignsj,campaignsj2,mortnaval1,logmortnaval1,mortnaval2,logmortnaval2,mortjam,logmortjam,logmortcap250,logmortjam250,wandcafrica,malfal94,wacacontested,mortnaval2250,logmortnaval2250,mortnaval1250,logmortnaval1250
Angola,AGO,5.6347895,5.3600001,7.77,1,0,0,0.1367,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Argentina,ARG,4.232656,6.3899999,9.1300001,1,0,0,0.37779999,0,0,0,0,90.0,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.232656,4.0342407,0,0.0,0,30.5,3.4177268,15.07,2.7127061
Australia,AUS,2.1459312,9.3199997,9.8999996,0,0,0,0.30000001,1,0,0,1,99.0,0,1,8.5500002,2.1459312,8.5500002,2.1459312,8.5500002,2.1459312,2.1459312,2.1459312,0,0.0,0,8.5500002,2.1459312,8.5500002,2.1459312
Burkina Faso,BFA,5.6347895,4.4499998,6.8499999,1,0,0,0.1444,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Bangladesh,BGD,4.2684379,5.1399999,6.8800001,1,1,0,0.2667,0,1,0,0,0.0,1,1,71.410004,4.2684379,71.410004,4.2684379,71.410004,4.2684379,4.2684379,4.2684379,0,0.12008,0,71.410004,4.2684379,71.410004,4.2684379
Bahamas,BHS,4.4426513,7.5,9.29,0,0,0,0.2683,0,0,0,0,10.0,0,0,85.0,4.4426513,85.0,4.4426513,85.0,4.4426513,4.4426513,4.4426513,0,,0,85.0,4.4426513,85.0,4.4426513
Bolivia,BOL,4.2626801,5.6399999,7.9299998,1,0,0,0.18889999,0,0,0,0,30.000002,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.00165,0,93.25,4.535284,,
Brazil,BRA,4.2626801,7.9099998,8.7299995,1,0,0,0.1111,0,0,0,0,55.0,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.2626801,4.0342407,0,0.035999998,0,30.5,3.4177268,15.07,2.7127061
Canada,CAN,2.7788193,9.7299995,9.9899998,0,1,0,0.66670001,1,0,0,0,98.0,0,0,16.1,2.7788193,16.1,2.7788193,16.1,2.7788193,2.7788193,2.7788193,0,0.0,0,16.1,2.7788193,16.1,2.7788193
Chile,CHL,4.232656,7.8200002,9.3400002,1,0,0,0.33329999,0,0,0,0,50.0,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.232656,4.0342407,0,0.0,0,30.5,3.4177268,15.07,2.7127061
Cote d'Ivoire,CIV,6.5042882,7.0,7.4400001,1,0,0,0.0889,0,0,1,0,0.0,1,1,668.0,6.5042882,668.0,6.5042882,668.0,6.5042882,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Cameroon,CMR,5.6347895,6.4499998,7.5,1,0,0,0.066699997,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Congo,COG,5.480639,4.6799998,7.4200001,0,1,1,0.0111,0,0,1,0,0.0,0,0,240.0,5.480639,240.0,5.480639,240.0,5.480639,5.480639,5.480639,1,0.94999999,0,240.0,5.480639,240.0,5.480639
Colombia,COL,4.2626801,7.3200002,8.8100004,1,0,0,0.044399999,0,0,0,0,25.0,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.14637001,0,93.25,4.535284,,
Costa Rica,CRI,4.3579903,7.0500002,8.79,1,0,0,0.1111,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0,0,93.25,4.535284,,
Dominican Re,DOM,4.8675346,6.1799998,8.3599997,0,0,0,0.2111,0,0,0,0,25.0,0,0,130.0,4.8675346,130.0,4.8675346,130.0,4.8675346,4.8675346,4.8675346,0,0.0,0,130.0,4.8675346,130.0,4.8675346
Algeria,DZA,4.3592696,6.5,8.3900003,1,1,0,0.31110001,0,0,1,0,0.0,1,1,78.199997,4.3592696,78.199997,4.3592696,78.199997,4.3592696,4.3592696,4.3592696,0,0.0,0,78.199997,4.3592696,78.199997,4.3592696
Ecuador,ECU,4.2626801,6.5500002,8.4700003,1,0,0,0.0222,0,0,0,0,30.000002,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.11894999,0,93.25,4.535284,,
Egypt,EGY,4.2165623,6.77,7.9499998,1,1,0,0.30000001,0,0,1,0,0.0,1,1,67.800003,4.2165623,67.800003,4.2165623,67.800003,4.2165623,4.2165623,4.2165623,0,0.0,0,67.800003,4.2165623,67.800003,4.2165623
Ethiopia,ETH,3.2580965,5.73,6.1100001,1,1,0,0.0889,0,0,1,0,0.0,1,1,26.0,3.2580965,26.0,3.2580965,26.0,3.2580965,3.2580965,3.2580965,1,0.551,0,26.0,3.2580965,26.0,3.2580965
Gabon,GAB,5.6347895,7.8200002,8.8999996,1,0,0,0.0111,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94050002,1,250.0,5.521461,250.0,5.521461
Ghana,GHA,6.5042882,6.27,7.3699999,1,1,0,0.0889,0,0,1,0,0.0,1,1,668.0,6.5042882,668.0,6.5042882,668.0,6.5042882,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
Guinea,GIN,6.1800165,6.5500002,7.4899998,1,0,0,0.1222,0,0,1,0,0.0,1,1,483.0,6.1800165,483.0,6.1800165,483.0,6.1800165,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Gambia,GMB,7.2930179,8.2700005,7.27,1,1,0,0.1476,0,0,1,0,0.0,1,1,1470.0,7.2930179,1470.0,7.2930179,1470.0,7.2930179,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
Guatemala,GTM,4.2626801,5.1399999,8.29,1,0,0,0.17,0,0,0,0,20.0,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.0036000002,0,93.25,4.535284,,
Guyana,GUY,3.4713452,5.8899999,7.9000001,0,0,0,0.055599999,0,0,0,0,2.0,0,0,32.18,3.4713452,32.18,3.4713452,32.18,3.4713452,3.4713452,3.4713452,0,0.49503002,0,32.18,3.4713452,32.18,3.4713452
Hong Kong,HKG,2.7013612,8.1400003,10.05,0,0,0,0.24609999,0,1,0,0,0.0,1,1,14.9,2.7013612,14.9,2.7013612,14.9,2.7013612,2.7013612,2.7013612,0,0.0,0,14.9,2.7013612,14.9,2.7013612
Honduras,HND,4.3579903,5.3200002,7.6900001,1,0,0,0.16670001,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.012,0,93.25,4.535284,,
Haiti,HTI,4.8675346,3.73,7.1500001,0,0,0,0.2111,0,0,0,0,0.0,0,0,130.0,4.8675346,130.0,4.8675346,130.0,4.8675346,4.8675346,4.8675346,0,1.0,0,130.0,4.8675346,130.0,4.8675346
Indonesia,IDN,5.1357985,7.5900002,7.3299999,1,1,0,0.055599999,0,1,0,0,0.0,1,1,170.0,5.1357985,170.0,5.1357985,170.0,5.1357985,5.1357985,5.1357985,0,0.17873,0,170.0,5.1357985,170.0,5.1357985
India,IND,3.8842406,8.2700005,7.3299999,0,1,0,0.22220001,0,1,0,0,0.0,0,0,48.630001,3.8842406,48.630001,3.8842406,48.630001,3.8842406,3.8842406,3.8842406,0,0.23596001,0,48.630001,3.8842406,48.630001,3.8842406
Jamaica,JAM,4.8675346,7.0900002,8.1899996,0,1,0,0.2017,0,0,0,0,10.0,0,1,130.0,4.8675346,130.0,4.8675346,130.0,4.8675346,4.8675346,4.8675346,0,0.0,0,130.0,4.8675346,130.0,4.8675346
Kenya,KEN,4.9767337,6.0500002,7.0599999,0,1,1,0.0111,0,0,1,0,0.0,0,0,145.0,4.9767337,145.0,4.9767337,145.0,4.9767337,4.9767337,4.9767337,1,0.79799998,0,145.0,4.9767337,145.0,4.9767337
Sri Lanka,LKA,4.2456341,6.0500002,7.73,0,1,0,0.077799998,0,1,0,0,0.0,0,1,69.800003,4.2456341,69.800003,4.2456341,69.800003,4.2456341,4.2456341,4.2456341,0,0.138,0,69.800003,4.2456341,69.800003,4.2456341
Morocco,MAR,4.3592696,7.0900002,8.04,1,0,0,0.3556,0,0,1,0,1.0,1,1,78.199997,4.3592696,78.199997,4.3592696,78.199997,4.3592696,4.3592696,4.3592696,0,0.0,0,78.199997,4.3592696,78.199997,4.3592696
Madagascar,MDG,6.2842088,4.4499998,6.8400002,1,1,0,0.22220001,0,0,1,0,0.0,1,1,536.03998,6.2842088,536.03998,6.2842088,536.03998,6.2842088,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
Mexico,MEX,4.2626801,7.5,8.9399996,1,1,0,0.25560001,0,0,0,0,15.000001,1,1,71.0,4.2626801,71.0,4.2626801,71.0,4.2626801,4.2626801,4.2626801,0,0.00042,0,71.0,4.2626801,71.0,4.2626801
Mali,MLI,7.986165,4.0,6.5700002,1,1,0,0.18889999,0,0,1,0,0.0,1,1,2940.0,7.986165,2940.0,7.986165,2940.0,7.986165,5.521461,5.521461,1,0.94050002,0,250.0,5.521461,250.0,5.521461
Malta,MLT,2.7911651,7.23,9.4300003,0,1,0,0.3944,0,0,0,1,100.0,0,0,16.299999,2.7911651,16.299999,2.7911651,16.299999,2.7911651,2.7911651,2.7911651,0,,0,16.299999,2.7911651,16.299999,2.7911651
Malaysia,MYS,2.8735647,7.9499998,8.8900003,0,1,0,0.025599999,0,1,0,0,0.0,0,1,17.700001,2.8735647,17.700001,2.8735647,17.700001,2.8735647,2.8735647,2.8735647,0,0.23331,0,17.700001,2.8735647,17.700001,2.8735647
Niger,NER,5.9914646,5.0,6.73,1,0,0,0.1778,0,0,1,0,0.0,1,1,400.0,5.9914646,400.0,5.9914646,400.0,5.9914646,5.521461,5.521461,1,0.94050002,1,250.0,5.521461,250.0,5.521461
Nigeria,NGA,7.6029005,5.5500002,6.8099999,1,1,0,0.1111,0,0,1,0,0.0,1,1,2004.0,7.6029005,2004.0,7.6029005,2004.0,7.6029005,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
Nicaragua,NIC,5.0955892,5.23,7.54,1,0,0,0.1444,0,0,0,0,20.0,1,1,,,93.25,4.535284,130.0,4.8675346,5.0955892,4.8675346,0,0.044,0,93.25,4.535284,,
New Zealand,NZL,2.1459312,9.7299995,9.7600002,0,1,0,0.45559999,1,0,0,1,91.699997,1,1,8.5500002,2.1459312,8.5500002,2.1459312,8.5500002,2.1459312,2.1459312,2.1459312,0,0.0,0,8.5500002,2.1459312,8.5500002,2.1459312
Pakistan,PAK,3.6106477,6.0500002,7.3499999,1,0,0,0.33329999,0,1,0,0,0.0,1,1,36.990002,3.6106477,36.990002,3.6106477,36.990002,3.6106477,3.6106477,3.6106477,0,0.53757,0,36.990002,3.6106477,36.990002,3.6106477
Panama,PAN,5.0955892,5.9099998,8.8400002,1,0,0,0.1,0,0,0,0,20.0,1,1,15.07,2.7127061,30.5,3.4177268,130.0,4.8675346,5.0955892,4.8675346,0,0.08004,0,30.5,3.4177268,15.07,2.7127061
Peru,PER,4.2626801,5.77,8.3999996,1,0,0,0.1111,0,0,0,0,30.000002,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.2626801,4.0342407,0,0.00050000002,0,30.5,3.4177268,15.07,2.7127061
Paraguay,PRY,4.3579903,6.9499998,8.21,1,0,0,0.25560001,0,0,0,0,25.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0,0,93.25,4.535284,,
Sudan,SDN,4.4796071,4.0,7.3099999,1,1,0,0.16670001,0,0,1,0,0.0,1,1,88.199997,4.4796071,88.199997,4.4796071,88.199997,4.4796071,4.4796071,4.4796071,1,0.93099999,0,88.199997,4.4796071,88.199997,4.4796071
Senegal,SEN,5.1038828,6.0,7.4000001,0,1,0,0.1556,0,0,1,0,0.0,0,1,164.66,5.1038828,164.66,5.1038828,164.66,5.1038828,5.1038828,5.1038828,1,0.94999999,0,164.66,5.1038828,164.66,5.1038828
Singapore,SGP,2.8735647,9.3199997,10.15,0,0,0,0.0136,0,1,0,0,0.0,0,1,17.700001,2.8735647,17.700001,2.8735647,17.700001,2.8735647,2.8735647,2.8735647,0,0.0,0,17.700001,2.8735647,17.700001,2.8735647
Sierra Leone,SLE,6.1800165,5.8200002,6.25,1,1,0,0.092200004,0,0,1,0,0.0,1,1,483.0,6.1800165,483.0,6.1800165,483.0,6.1800165,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
El Salvador,SLV,4.3579903,5.0,7.9499998,1,0,0,0.15000001,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0,0,93.25,4.535284,,
Togo,TGO,6.5042882,6.9099998,7.2199998,1,0,0,0.0889,0,0,1,0,0.0,1,1,668.0,6.5042882,668.0,6.5042882,668.0,6.5042882,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Trinidad and Tobago,TTO,4.4426513,7.4499998,8.7700005,0,1,0,0.1222,0,0,0,0,40.0,0,1,85.0,4.4426513,85.0,4.4426513,85.0,4.4426513,4.4426513,4.4426513,0,0.0,0,85.0,4.4426513,85.0,4.4426513
Tunisia,TUN,4.1431346,6.4499998,8.4799995,1,1,0,0.37779999,0,0,1,0,0.0,1,1,63.0,4.1431346,63.0,4.1431346,63.0,4.1431346,4.1431346,4.1431346,0,0.0,0,63.0,4.1431346,63.0,4.1431346
Tanzania,TZA,4.9767337,6.6399999,6.25,0,0,1,0.066699997,0,0,1,0,0.0,0,0,145.0,4.9767337,145.0,4.9767337,145.0,4.9767337,4.9767337,4.9767337,1,0.92150003,1,145.0,4.9767337,145.0,4.9767337
Uganda,UGA,5.6347895,4.4499998,6.9699998,1,0,0,0.0111,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Uruguary,URY,4.2626801,7.0,9.0299997,1,0,0,0.36669999,0,0,0,0,90.0,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.0,0,93.25,4.535284,,
USA,USA,2.7080503,10.0,10.22,0,1,0,0.42219999,1,0,0,0,83.600006,0,1,15.0,2.7080503,15.0,2.7080503,15.0,2.7080503,2.7080503,2.7080503,0,0.0,0,15.0,2.7080503,15.0,2.7080503
Venezuela,VEN,4.3579903,7.1399999,9.0699997,1,0,0,0.0889,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0070400001,0,93.25,4.535284,,
Vietnam,VNM,4.9416423,6.4099998,7.2800002,1,1,0,0.1778,0,1,0,0,0.0,1,1,140.0,4.9416423,140.0,4.9416423,140.0,4.9416423,4.9416423,4.9416423,0,0.70109999,0,140.0,4.9416423,140.0,4.9416423
South Africa,ZAF,2.74084,6.8600001,8.8900003,0,1,0,0.3222,0,0,1,0,16.0,0,1,15.5,2.74084,15.5,2.74084,15.5,2.74084,2.74084,2.74084,0,0.1045,0,15.5,2.74084,15.5,2.74084
Zaire,ZAR,5.480639,3.5,6.8699999,0,0,1,0.0,0,0,1,0,0.0,0,0,240.0,5.480639,240.0,5.480639,240.0,5.480639,5.480639,5.480639,1,0.94999999,1,240.0,5.480639,240.0,5.480639
1 change: 1 addition & 0 deletions causalpy/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"sc": {"filename": "synthetic_control.csv"},
"anova1": {"filename": "ancova_generated.csv"},
"geolift1": {"filename": "geolift1.csv"},
"risk": {"filename": "AJR2001.csv"},
}


Expand Down
128 changes: 128 additions & 0 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import seaborn as sns
import xarray as xr
from patsy import build_design_matrices, dmatrices
from sklearn.linear_model import LinearRegression as sk_lin_reg

from causalpy.custom_exceptions import BadIndexException # NOQA
from causalpy.custom_exceptions import DataException, FormulaException
Expand Down Expand Up @@ -883,3 +884,130 @@ def _get_treatment_effect_coeff(self) -> str:
return label

raise NameError("Unable to find coefficient name for the treatment effect")


class InstrumentalVariable(ExperimentalDesign):
"""
A class to analyse instrumental variable style experiments.

:param instruments_data: A pandas dataframe of instruments
for our treatment variable. Should contain
instruments Z, and treatment t
:param data: A pandas dataframe of covariates for fitting
the focal regression of interest. Should contain covariates X
including treatment t and outcome y
:param instruments_formula: A statistical model formula for
the instrumental stage regression
e.g. t ~ 1 + z1 + z2 + z3
:param formula: A statistical model formula for the \n
focal regression e.g. y ~ 1 + t + x1 + x2 + x3
:param model: A PyMC model
:param priors: An optional dictionary of priors for the
mus and sigmas of both regressions. If priors are not
specified we will substitue MLE estimates for the beta
coefficients. Greater control can be achieved
by specifying the priors directly e.g. priors = {
"mus": [0, 0],
"sigmas": [1, 1],
"eta": 2,
"lkj_sd": 2,
}

"""

def __init__(
self,
instruments_data: pd.DataFrame,
data: pd.DataFrame,
instruments_formula: str,
formula: str,
model=None,
priors=None,
**kwargs,
):
super().__init__(model=model, **kwargs)
self.expt_type = "Instrumental Variable Regression"
self.data = data
self.instruments_data = instruments_data
self.formula = formula
self.instruments_formula = instruments_formula
self.model = model
self._input_validation()

y, X = dmatrices(formula, self.data)
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.y, self.X = np.asarray(y), np.asarray(X)
self.outcome_variable_name = y.design_info.column_names[0]

t, Z = dmatrices(instruments_formula, self.instruments_data)
self._t_design_info = t.design_info
self._z_design_info = Z.design_info
self.labels_instruments = Z.design_info.column_names
self.t, self.Z = np.asarray(t), np.asarray(Z)
self.instrument_variable_name = t.design_info.column_names[0]

self.get_naive_OLS_fit()
self.get_2SLS_fit()

# fit the model to the data
COORDS = {"instruments": self.labels_instruments, "covariates": self.labels}
self.coords = COORDS
if priors is None:
priors = {
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
"sigmas": [1, 1],
"eta": 2,
"lkj_sd": 2,
}
self.priors = priors
self.model.fit(
X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors
)

def get_2SLS_fit(self):
first_stage_reg = sk_lin_reg().fit(self.Z, self.t)
fitted_Z_values = first_stage_reg.predict(self.Z)
X2 = self.data.copy(deep=True)
X2[self.instrument_variable_name] = fitted_Z_values
_, X2 = dmatrices(self.formula, X2)
second_stage_reg = sk_lin_reg().fit(X=X2, y=self.y)
betas_first = list(first_stage_reg.coef_[0][1:])
betas_first.insert(0, first_stage_reg.intercept_[0])
betas_second = list(second_stage_reg.coef_[0][1:])
betas_second.insert(0, second_stage_reg.intercept_[0])
self.ols_beta_first_params = betas_first
self.ols_beta_second_params = betas_second
self.first_stage_reg = first_stage_reg
self.second_stage_reg = second_stage_reg

def get_naive_OLS_fit(self):
ols_reg = sk_lin_reg().fit(self.X, self.y)
beta_params = list(ols_reg.coef_[0][1:])
beta_params.insert(0, ols_reg.intercept_[0])
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
self.ols_reg = ols_reg

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
treatment = self.instruments_formula.split("~")[0]
test = treatment.strip() in self.instruments_data.columns
test = test & (treatment.strip() in self.data.columns)
if not test:
raise DataException(
f"""
The treatment variable:
{treatment} must appear in the instrument_data to be used
as an outcome variable and in the data object to be used as a covariate.
"""
)
Z = self.data[treatment.strip()]
check_binary = len(np.unique(Z)) > 2
if check_binary:
warnings.warn(
"""Warning. The treatment variable is not Binary.
This is not necessarily a problem but it violates
the assumption of a simple IV experiment.
The coefficients should be interpreted appropriately."""
)
Loading