Skip to content

Commit

Permalink
update xai module. add deep shap. (#335)
Browse files Browse the repository at this point in the history
* add xai init file

* update xai module. add deep shap.
  • Loading branch information
lunlu authored Dec 28, 2022
1 parent d3de58e commit 08b2b8a
Show file tree
Hide file tree
Showing 6 changed files with 891 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2022-12-14 15:36+0800\n"
"POT-Creation-Date: 2022-12-27 16:02+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
Expand All @@ -18,16 +18,16 @@ msgstr ""
"Generated-By: Babel 2.9.1\n"

#: ../../source/api/paddlets.xai.post_hoc.shap_explainer.rst:2
#: 12f2a447343e478b8188626bf2a11e4c
#: 99a9d5a03d8b44fc944b4571521aad77
msgid "paddlets.xai.post_hoc.shap_explainer"
msgstr ""

#: cabdfa8a577d4eb4b0cf18db73b64dad of
#: 06ba11ac09294dd6995441794a21234f of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:1
msgid "Bases: :py:class:`~paddlets.xai.post_hoc.explainer_base.BaseExplainer`"
msgstr ""

#: a2629ebff30e4734922a508f4d5db9a1 of
#: 7617ccdfaf2a41819935a97420902244 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:1
msgid ""
"Shap explainer. This class only (currently) supports regression model of "
Expand All @@ -38,9 +38,9 @@ msgstr ""
"Shap解释器. 该类仅支持预测回归模型. 利用shap数值提供模型输入对输出的贡献度."
"对于shap, 可以看`https://github.com/slundberg/shap`."

#: 67369d101ae7457695da3fdb0ecc299b 6d7a21d9038c467b9e289d905255bdd5
#: 7e3a0a8045024e949d3689788f913774 921d10c440b74257bc812a63cb45816d
#: a5f4186aabee4e00be75927af5cbc818 c8e40365143f4aaba9c1723bfb0a1ad8 of
#: 17350d939a7042648427ef189338251d 51288abc8dcd4d3e89f81fa78f56f316
#: 7c9ec6f587394e6f9aa4a0b8a77e1a63 961abca7e2d344ba9f030c11891c306d
#: ba5872677da44aaea8fc582dee0aa7a8 fba144a346354092b8154e35d31f5b36 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.force_plot
Expand All @@ -50,84 +50,86 @@ msgstr ""
msgid "Parameters"
msgstr ""

#: 73a29593a591448492a83fb3d5638733 of
#: d44eeca9f4db4b8cbffc0243b7d65627 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:5
msgid "A model object that supports `predict` function."
msgstr "支持predict功能的模型对象"

#: ccfa1f09ba7a415aa4dbf633c821b200 of
#: 25e430b833dc4fa4a0529cfc15d58b75 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:7
msgid "A TSDataset for training the shap explainer"
msgstr "用于训练shap解释器的TSDataset数据"

#: dce589d8d20340ecbea7ca1301c5f212 of
#: 107a0e8484de4759be0c3773fd2619e8 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:9
msgid "number of instances sampled from the background_data"
msgstr "从背景板数据中采样的样本实例数量"

#: 11d58d7f296c4263b820adaca867ac3e of
#: 5fb929ea5e854222af50e38575c724dd of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:11
msgid "Optionally, the shap method to apply. Now just support `kernel` method."
msgstr "可选,使用的shap方法. 目前仅支持kernel方法"
msgid "The shap method to apply. Optionally, {'kernel', 'deep'}."
msgstr "使用的shap方法. 目前支持{'kernel', 'deep'}."

#: 945accac4cd540cabd1ddd1bfee58ede of
#: c1c36c045a7a440b985657533fcc97c6 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:13
msgid "Task type of the model. Now just support the regression task."
msgid "Task type of the model. Only support the regression task."
msgstr "模型任务类型. 目前仅支持回归任务"

#: f946971270ab403f88e6afb8f5bbf6c2 of
#: e7643f37fa9946e89d71fc9c0d99cb94 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:15
msgid "random seed."
msgstr "随机数种子"

#: fd312d65b60a47698e7ac79171cfa3bb of
#: 6b4f45f7ec154cadb81a4eb4ece03565 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:17
msgid "Only effective when the model is of type PaddleBaseModel."
msgstr "当模型为PaddleBaseModel类型的时候有效,可以加速效率"

#: 44db4cfb935a48009bccb4b0b7e06ee8 of
#: f4291aa05535423ebbc201e50e8c71b7 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer:19
msgid "Optionally, additional keyword arguments passed to `shap_method`."
msgstr "可选,额外参数对象传入shap_method对象"

#: c11e29aced1848388ca1d77049b577ed of
#: 7135492868464c898010027c9b7d8515 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:1
msgid "Calculate the explanatory value of the test sample."
msgstr "计算测试样本的解释性数值"

#: ee6646a3aa3a4f7fa4629a5fdeac706d of
#: 8cfbae9ac2e24514a76eb1e981a92abc of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:3
msgid "test data."
msgstr 测试数据"
msgstr "测试数据"

#: 50cea6ade4444f379a2a6421d720a099 of
#: ca686c89d4fe4059801cb0cdda4684c2 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:5
msgid ""
"Number of times to re-evaluate the model when explaining each prediction."
" More samples lead to lower variance estimates of the SHAP values. "
"Default nsamples=100."
msgstr "解释每个预测时重新评估模型的次数. 数据越大造成shap数值的评估方差越小"
" More samples lead to lower variance estimates of the SHAP values. Only "
"used in `shap_method=kernel`. Default nsamples=100."
msgstr ""
"解释每个预测时重新评估模型的次数. 数据越大造成shap数值的评估方差越小. "
"仅当`shap_method=kernel`时有效. 默认值100"

#: f94f468060a44192898580effbfa0dfe of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:8
#: 65c3718ecfcf479c99bdfe5d432491ce of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:9
msgid "The sample start index of the test data. Default the latest sample."
msgstr "测试样本的开始索引. 模型最后一个样本"

#: cc45c58d93d040c6b604270d7ee940ad of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:10
#: a3015d415e984daf905ca01256f26aa3 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:11
msgid "The sample number of the test data."
msgstr "测试样本的数量,配合sample_start_index使用"

#: 9c243828c6654c22bf399b3a0caa061f of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:12
#: bfebbe963c1e47a4bc35760000ae7fe4 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:13
msgid ""
"Optionally, additional keyword arguments passed to "
"`shap.explainer.shap_values`."
msgstr ""
msgstr "可选, 额外参数传给`shap.explainer.shap_values`."

#: 112bcf11b72446cfbe4928db6a0f2667 1829f5a5f27244e485acaa26618a06e7
#: bfb0f6e07e024501874b6ee8c5fbbb22 c9998f9869c146b1a9b93d72bf6903e9
#: e5f01b823d0746ec9ab85556a73d77ff of
#: 0f32d5e06c7c4d1d948149b68d1a21f6 1025175ca442484bb419fc508b89636e
#: 3dd792f0aea84702b80bb576800e3a34 bebb6d48d79048e6a86a45f97bd868f8
#: f6a78f3c154b4fd197281f0be5a0c114 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.force_plot
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.get_explanation
Expand All @@ -136,37 +138,47 @@ msgstr ""
msgid "Returns"
msgstr ""

#: 91103523898e4640a5aabfb5584efdf9 c4cb9f92c31d404aa95cf070afcda641 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:14
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.get_explanation:8
msgid "np.ndarray object"
msgstr "np.ndarray对象"
#: 2837e1dd881d4b7e9161eae4129c0c58 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.explain:15
msgid ""
"np.ndarray object(out_chunk_len, samples, in_chunk_len + "
"out_chunk_len(known_cov input), feature dims)"
msgstr ""
"np.ndarray对象. 输出形状: (out_chunk_len, 样本数目, in_chunk_len + "
"out_chunk_len(known_cov部分), 特征数目)"

#: a89d7131e4734a83b379b2ede7005cdf of
#: 7b4391b1ff274aa39dc5c61e33a28ef0 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.get_explanation:1
msgid ""
"Get the explanatory output of a certain time point in the prediction "
"length."
msgstr "获取预测长度中特定时间点的解释性输出"

#: 4b600bd8e03442f3a503a4fbd3fcc30d 85d6d02d847246fbaddf2187912c6313
#: bc0af9fdbec24084a6013cfd9d44e853 of
#: 217b2b9913f24e9493f3556f3c5fadc0 37362ea91ae8483988e5131c82ac8720
#: 77b998402b714aa8a2822480839009b0 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.force_plot:3
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.get_explanation:3
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.summary_plot:3
msgid "The certain time point in the prediction length."
msgstr "预测长度中特定时间点"

#: 38c50cdd7bf4443d87566e6a83fd9c33 49625b69a6814a7ab2790f38e64de3a2
#: 6910d109c5d4447a868da5cb783eebc7 a101c4bbfe1a41d19b7098e4a9561d4f of
#: 0ec9e72d33344c2b8d2f80359ed8d0da 4900f5e233284de59a537f467afa9b99
#: b7468442edb44547b3381c4a271a658d dae466e2149a4c07b646b0629c705443 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.force_plot:5
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.get_explanation:5
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.plot:5
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.summary_plot:5
msgid "The sample index of the explanatory value. Default the first sample."
msgstr "解释性数值样本索引. 默认第一个样本"

#: e8a9d4ade706401ea86fa0b4a1f71a96 of
#: a924185dce85466ab9a9b7e3984d6f5b of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.get_explanation:8
msgid ""
"np.ndarray object(in_chunk_len + out_chunk_len(known_cov input), feature "
"dims)"
msgstr "np.ndarray对象. 输出形状:(in_chunk_len + out_chunk_len, 特征数目)"

#: da6d16a5a6754001acda5c760cb8a47c of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.plot:1
msgid ""
"Display the shap value of different dimensions. Such as 'OI'(output time "
Expand All @@ -177,40 +189,40 @@ msgstr ""
"展示不同维度的shap值. 例如'OI'(输出时间维度vs输入时间维度), 'OV'(输出时间维度vs特征维度), "
"'IV'(输入时间维度vs特征维度), 'I'(输入时间维度), 'V'(特征维度)"

#: 497229751d0c4ccc87726cfc96d62021 of
#: 11e2d3a5e7854684948cab9ab688ec3a of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.plot:3
msgid "display method. Optional, {'OI', 'OV', 'IV', 'I', 'V'}."
msgstr "展示方法. 可选 {'OI', 'OV', 'IV', 'I', 'V'}."

#: 24fd35b16d9c4d27b6710b28baddfae3 of
#: 9a9fe87617d841e4973191072dc8f254 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.plot:7
msgid "other parameters."
msgstr "其他参数"

#: 2dee07a9428941cdb41727a997d65dd9 425bdeb31c574b7d9bda5407ab6d705e
#: 9de5e644706d44fba2abcc52eb301827 of
#: 6059f2326edf4d18bd764a4c28fb9dad dcd72110167f411ba1ebd2850d3564dc
#: ed2025df38124a008da2ed88eb56cd75 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.force_plot:9
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.plot:9
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.summary_plot:9
msgid "None"
msgstr ""

#: dd4c8a049a5a43b4932193f4afaa4a46 of
#: 8bfed601eef740b8a03cc0372bc95a45 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.force_plot:1
msgid "Visualize the given SHAP values with an additive force layout."
msgstr "以加性图层的方式可视化shap数值"

#: 69c19d4786694203ac44ad47b8c05a14 of
#: af6c8949b704404e8ce87de334fc55d4 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.force_plot:7
msgid "Optionally, additional keyword arguments passed to `shap.force_plot`."
msgstr "可选, 额外的参数传给shap.force_plot."

#: 1ab32f53f2c642349a5097da409811cc of
#: 9a4767a7b91f4375aa832dd130b20af4 of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.summary_plot:1
msgid "Create a SHAP feature importance based on previously interpreted samples."
msgstr "基于解释性样本构建SHAP特征重要性图"

#: e64d787b6acc4cd9bd06b8a252d20d94 of
#: 7b91ef634d8b4153a4558355e7a8578c of
#: paddlets.xai.post_hoc.shap_explainer.ShapExplainer.summary_plot:7
msgid "Optionally, additional keyword arguments passed to `shap.summary_plot`."
msgstr "可选, 其他参数传给shap.summary_plot."
Expand Down
6 changes: 3 additions & 3 deletions examples/xai_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@
"source": [
"`ShapExplainer`: 帮助用户实现PaddleTS模型与shap解释器之间的链接桥梁,更好的帮助用户理解输出结果性质\n",
"\n",
"- `model`: 目前支持输入模型对象包括:PaddleBaseModel及其派生类、Pipeline及其派生类两种\n",
"- `model`: 目前支持输入模型对象包括:PaddleBaseModel及其派生类、Pipeline及其派生类两种(shap_method='deep'时暂时仅支持PaddleBaseModel及其派生类对象)\n",
"\n",
"- `background_data`: 背景板数据,通常为用户训练模型使用的训练集\n",
"\n",
Expand All @@ -339,7 +339,7 @@
"\n",
"- `seed`: 背景板采用随机数种子\n",
"\n",
"- `use_paddleloader`: 是否使用paddle loader的方式加速训练效率,目前仅支持PaddleBaseModel及其派生类\n",
"- `use_paddleloader`: 是否使用paddle loader的方式加速训练效率,目前仅支持PaddleBaseModel及其派生类且当shap_method='kernel'时生效\n",
"\n",
"- `kwargs`: 其他参数信息用于传入shap_method"
]
Expand All @@ -350,7 +350,7 @@
"metadata": {},
"outputs": [],
"source": [
"#base data process. 保证输入参数keep_index = True, 否则会报错.\n",
"#base data process.\n",
"se = ShapExplainer(pipe, train_data, background_sample_number=10, keep_index=True, use_paddleloader=False)"
]
},
Expand Down
Loading

0 comments on commit 08b2b8a

Please sign in to comment.