-
-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #464 from stan-dev/external-fns
Allow user to supply a c++ header file
- Loading branch information
Showing
17 changed files
with
354 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ dependencies: | |
- python=3.7 | ||
- ipykernel | ||
- ipython | ||
- ipywidgets | ||
- numpy>=1.15 | ||
- pandas | ||
- xarray | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,5 @@ | |
*.hpp | ||
*.exe | ||
*.csv | ||
.ipynb_checkpoints | ||
.ipynb_checkpoints | ||
!make_odds.hpp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Advanced Topic: Using External C++ Functions\n", | ||
"\n", | ||
"This is based on the relevant portion of the CmdStan documentation [here](https://mc-stan.org/docs/cmdstan-guide/using-external-cpp-code.html)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Consider the following Stan model, based on the bernoulli example." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {"nbsphinx": "hidden"}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"try:\n", | ||
" os.remove('bernoulli_external')\n", | ||
"except:\n", | ||
" pass" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from cmdstanpy import CmdStanModel\n", | ||
"model_external = CmdStanModel(stan_file='bernoulli_external.stan', compile=False)\n", | ||
"print(model_external.code())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"As you can see, it features a function declaration for `make_odds`, but no definition. If we try to compile this, we will get an error. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_external.compile()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Even enabling the `--allow_undefined` flag to stanc3 will not allow this model to be compiled quite yet." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_external.compile(stanc_options={'allow_undefined':True})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To resolve this, we need to both tell the Stan compiler an undefined function is okay **and** let C++ know what it should be. \n", | ||
"\n", | ||
"We can provide a definition in a C++ header file by using the `user_header` argument to either the CmdStanModel constructor or the `compile` method. \n", | ||
"\n", | ||
"This will enables the `allow_undefined` flag automatically." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_external.compile(user_header='make_odds.hpp')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can then run this model and inspect the output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"fit = model_external.sample(data={'N':10, 'y':[0,1,0,0,0,0,0,0,0,1]})\n", | ||
"fit.stan_variable('odds')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The contents of this header file are a bit complicated unless you are familiar with the C++ internals of Stan, so they are presented without comment:\n", | ||
"\n", | ||
"```c++\n", | ||
"#include <boost/math/tools/promotion.hpp>\n", | ||
"#include <ostream>\n", | ||
"\n", | ||
"namespace bernoulli_model_namespace {\n", | ||
" template <typename T0__> inline typename\n", | ||
" boost::math::tools::promote_args<T0__>::type \n", | ||
" make_odds(const T0__& theta, std::ostream* pstream__) {\n", | ||
" return theta / (1 - theta); \n", | ||
" }\n", | ||
"}\n", | ||
"```" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "8765ce46b013071999fc1966b52035a7309a0da7551e066cc0f0fa23e83d4f60" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3.9.5 64-bit ('stan': conda)", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.5" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.