Skip to content

Commit

Permalink
Merge pull request #1236 from green-cabbage/master
Browse files Browse the repository at this point in the history
feat: add DNN example on mltools on top of GNN
  • Loading branch information
lgray authored Jan 1, 2025
2 parents 6185600 + 6c5e826 commit dae7806
Showing 1 changed file with 163 additions and 4 deletions.
167 changes: 163 additions & 4 deletions binder/mltools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example using ParticleNet-like jet variable calculation using PyTorch\n",
"## Example using ParticleNet-like jet variable calculation using PyTorch (GNN)\n",
"\n",
"The example given in this notebook be using [`pytorch`][pytorch] to calculate a\n",
"jet-level discriminant using its constituent particles. An example for how to\n",
Expand Down Expand Up @@ -554,6 +554,165 @@
"dask_results.visualize(optimize_graph=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example using ParticleNet-like jet variable calculation using PyTorch (DNN)\n",
"\n",
"In this example, we will do the same operation as shown above, but using a simple vanilla DNN"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Define the DNN\n",
"\"\"\"\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self, input_shape):\n",
" super(Net, self).__init__()\n",
" self.fc1 = nn.Linear(input_shape, 128)\n",
" self.bn1 = nn.BatchNorm1d(128)\n",
" self.dropout1 = nn.Dropout(0.2)\n",
" self.fc2 = nn.Linear(128, 64)\n",
" self.bn2 = nn.BatchNorm1d(64)\n",
" self.dropout2 = nn.Dropout(0.2)\n",
" self.fc3 = nn.Linear(64, 32)\n",
" self.bn3 = nn.BatchNorm1d(32)\n",
" self.dropout3 = nn.Dropout(0.2)\n",
" self.output = nn.Linear(32, 1)\n",
"\n",
" def forward(self, features):\n",
" x = features\n",
" x = self.fc1(x)\n",
" x = self.bn1(x)\n",
" x = F.tanh(x)\n",
" x = self.dropout1(x)\n",
"\n",
" x = self.fc2(x)\n",
" x = self.bn2(x)\n",
" x = F.tanh(x)\n",
" x = self.dropout2(x)\n",
"\n",
" x = self.fc3(x)\n",
" x = self.bn3(x)\n",
" x = F.tanh(x)\n",
" x = self.dropout3(x)\n",
"\n",
" x = self.output(x)\n",
" output = F.sigmoid(x)\n",
" return output\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Initialize and save the DNN. Normally, we would first train the DNN before saving.\n",
"\"\"\"\n",
"n_feat = 5\n",
"dummy_integer = 100\n",
"model = Net(n_feat)\n",
"model.eval() # put in eval mode for BatchNorm1d and Dropout\n",
"input_arr = torch.rand(dummy_integer, n_feat) # intialize a dummy input to generate torch graph for torch.jit to trace\n",
"torch.jit.trace(model, input_arr).save(\"test_model.pt\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dask.awkward<numpy-call-DNNWrapper, npartitions=1>\n",
"coffea DNN output: [[0.578], [0.578], [0.578], [0.578], ..., [0.578], [0.578], [0.578], [0.578]]\n"
]
}
],
"source": [
"from coffea.ml_tools.torch_wrapper import torch_wrapper\n",
"\n",
"def open_events():\n",
" factory = NanoEventsFactory.from_root(\n",
" {\"file:./pfnano.root\": \"Events\"},\n",
" schemaclass=PFNanoAODSchema,\n",
" )\n",
" return factory.events()\n",
"\n",
"\n",
"\n",
"class DNNWrapper(torch_wrapper):\n",
" def prepare_awkward(self, arr):\n",
" # The input is any awkward array with matching dimension\n",
" # Last time we added our input in a dictionary, but with a simple DNN, just add it to a list\n",
" return [\n",
" ak.values_astype(arr, \"float32\"), #only modification we do is is force float32\n",
" ], {}\n",
"\n",
"\n",
"events = open_events()\n",
"input = ak.concatenate( # Fold 5 event-level variables into a singular array\n",
" [\n",
" events.event[:, np.newaxis],\n",
" events.MET.sumEt[:, np.newaxis],\n",
" events.MET.significance[:, np.newaxis],\n",
" events.event[:, np.newaxis],\n",
" events.event[:, np.newaxis],\n",
" ],\n",
" axis=1,\n",
")\n",
"dwrap = DNNWrapper(\"test_model.pt\")\n",
"output = dwrap(input)\n",
"print(output) # This is the lazy evaluated dask array! Use this directly for histogram filling\n",
"print(f\"coffea DNN output: {output.compute()}\") # Eagerly evaluated resut"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"normal DNN output: [[0.57815915]\n",
" [0.57815915]\n",
" [0.57815915]\n",
" [0.57815915]\n",
" [0.57815915]\n",
" [0.57815915]\n",
" [0.57815915]\n",
" [0.57815915]\n",
" [0.57815915]\n",
" [0.57815915]]\n"
]
}
],
"source": [
"\"\"\"\n",
"Sanity check that the DNN wrapper is giving the same outputs\n",
"\"\"\"\n",
"test_input = torch.from_numpy(ak.to_numpy(input.compute())).float()\n",
"print(f\"normal DNN output: {model(test_input).detach().numpy()}\")\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -648,9 +807,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python [conda env:coffea_latest]",
"language": "python",
"name": "python3"
"name": "conda-env-coffea_latest-py"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -662,7 +821,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.11"
}
},
"nbformat": 4,
Expand Down

0 comments on commit dae7806

Please sign in to comment.