Skip to content

Commit

Permalink
Merge branch 'master' into BAAL-288/experiment_api_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 authored May 31, 2024
2 parents d5b44b0 + 5970ad1 commit 53074e7
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 84 deletions.
3 changes: 2 additions & 1 deletion docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
* [baal.bayesian](./bayesian.md)
* [baal.active](./dataset_management.md)
* [baal.active.heuristics](./heuristics.md)
* [baal.active.stopping_criteria](./stopping_criteria.md)
* [baal.calibration](./calibration.md)
* [baal.utils](./utils.md)

### :material-file-tree: Compatibility

* [baal.utils.pytorch_lightning] (./compatibility/pytorch-lightning)
* [baal.utils.pytorch_lightning](./compatibility/pytorch-lightning)
* [baal.transformers_trainer_wrapper](./compatibility/huggingface)


Expand Down
35 changes: 35 additions & 0 deletions docs/api/stopping_criteria.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Stopping Criteria

Stopping criterion are used to determine when to stop your active learning experiment.

Their usage are simple, but best put in practice with `ActiveExperiment`.

**Example**
```python
from baal.active.stopping_criteria import LabellingBudgetStoppingCriterion
from baal.active.dataset import ActiveLearningDataset

al_dataset: ActiveLearningDataset = ... # len(al_dataset) == 10
criterion = LabellingBudgetStoppingCriterion(al_dataset, labelling_budget=100)

assert not criterion.should_stop({}, [])

# len(al_dataset) == 60
al_dataset.label_randomly(50)
assert not criterion.should_stop({}, [])

# len(al_dataset) == 110, budget exhausted! We've labelled 100 items.
al_dataset.label_randomly(50)
assert criterion.should_stop({}, [])
```


### API

### baal.active.stopping_criteria

::: baal.active.stopping_criteria.LabellingBudgetStoppingCriterion

::: baal.active.stopping_criteria.LowAverageUncertaintyStoppingCriterion

::: baal.active.stopping_criteria.EarlyStoppingCriterion
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ mkdocs==1.4.0
mkdocs-exclude-search==0.6.4
mkdocs-jupyter==0.21.0
mkdocstrings[python]==0.23.0
Pygments==2.13.0
Pygments==2.13.0
lxml[html_clean]
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ nav:
- api/calibration.md
- api/dataset_management.md
- api/heuristics.md
- api/stopping_criteria.md
- api/modelwrapper.md
- api/utils.md
- Compatibility:
Expand Down
104 changes: 22 additions & 82 deletions notebooks/production/baal_prod_cls.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@
"is_executing": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train: 5174, Valid: 1725, Num. classes : 8\n"
]
}
],
"source": [
"from glob import glob\n",
"import os\n",
Expand All @@ -52,7 +43,8 @@
"classes = os.listdir('/tmp/natural_images')\n",
"train, test = train_test_split(files, random_state=1337) # Split 75% train, 25% validation\n",
"print(f\"Train: {len(train)}, Valid: {len(test)}, Num. classes : {len(classes)}\")\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -79,7 +71,6 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"from baal.active import FileDataset, ActiveLearningDataset\n",
"from torchvision import transforms\n",
Expand All @@ -101,7 +92,8 @@
"# We use -1 to specify that the data is unlabeled.\n",
"test_dataset = FileDataset(test, [-1] * len(test), test_transform)\n",
"active_learning_ds = ActiveLearningDataset(train_dataset, pool_specifics={'transform': test_transform})\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -129,7 +121,6 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, optim\n",
Expand All @@ -149,7 +140,8 @@
"# ModelWrapper is an object similar to keras.Model.\n",
"baal_model = ModelWrapper(model, criterion)\n",
"\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -170,11 +162,11 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"from baal.active.heuristics import BALD\n",
"heuristic = BALD(shuffle_prop=0.1)\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -193,13 +185,13 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"# This function would do the work that a human would do.\n",
"def get_label(img_path):\n",
" return classes.index(img_path.split('/')[-2])\n",
"\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -223,15 +215,6 @@
"is_executing": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num. labeled: 100/5174\n"
]
}
],
"source": [
"import numpy as np\n",
"# 1. Label all the test set and some samples from the training set.\n",
Expand All @@ -246,7 +229,8 @@
"active_learning_ds.label(train_idxs, labels)\n",
"\n",
"print(f\"Num. labeled: {len(active_learning_ds)}/{len(train_dataset)}\")\n"
]
],
"outputs": []
},
{
"cell_type": "code",
Expand All @@ -256,56 +240,19 @@
"is_executing": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:47:48.133213Z [\u001B[32minfo ] Starting training dataset=100 epoch=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.9/site-packages/torch/utils/data/dataloader.py:478: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 1, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" warnings.warn(_create_warning_msg(\n",
"/opt/conda/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:48:07.477011Z [\u001B[32minfo ] Training complete train_loss=2.058176279067993\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:48:07.479793Z [\u001B[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:48:21.277716Z [\u001B[32minfo ] Evaluation complete test_loss=2.0671451091766357\n",
"Metrics: {'test_loss': 2.0671451091766357, 'train_loss': 2.058176279067993}\n"
]
}
],
"source": [
"# 2. Train the model for a few epoch on the training set.\n",
"baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)\n",
"baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)\n",
"\n",
"print(\"Metrics:\", {k:v.avg for k,v in baal_model.metrics.items()})\n"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:48:21.291851Z [\u001B[32minfo ] Start Predict dataset=5074\n"
]
}
],
"source": [
"# 3. Select the K-top uncertain samples according to the heuristic.\n",
"pool = active_learning_ds.pool\n",
Expand All @@ -316,29 +263,22 @@
"predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)\n",
"# We will label the 10 most uncertain samples.\n",
"top_uncertainty = heuristic(predictions)[:10]\n"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(3, 1429), (4, 2971), (2, 1309), (4, 5), (3, 3761), (4, 2708), (6, 4679), (7, 160), (7, 1638), (6, 73)]\n"
]
}
],
"source": [
"# 4. Label those samples.\n",
"oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)\n",
"labels = [get_label(train_dataset.files[idx]) for idx in oracle_indices]\n",
"print(list(zip(labels, oracle_indices)))\n",
"active_learning_ds.label(top_uncertainty, labels)\n",
"\n"
]
],
"outputs": []
},
{
"cell_type": "code",
Expand All @@ -348,7 +288,6 @@
"is_executing": true
}
},
"outputs": [],
"source": [
"# 5. If not done, go back to 2.\n",
"for step in range(5): # 5 Active Learning step!\n",
Expand All @@ -372,7 +311,8 @@
" active_learning_ds.label(top_uncertainty, labels)\n",
" \n",
" "
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -386,14 +326,14 @@
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"torch.save({\n",
" 'active_dataset': active_learning_ds.state_dict(),\n",
" 'model': baal_model.state_dict(),\n",
" 'metrics': {k:v.avg for k,v in baal_model.metrics.items()}\n",
"}, '/tmp/baal_output.pth')\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand Down

0 comments on commit 53074e7

Please sign in to comment.