Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: wrapper for callbacks #842

Merged
merged 30 commits into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c45ebd5
ignore metric callback
ditwoo Jun 10, 2020
458f4eb
wrapper callback
ditwoo Jun 18, 2020
8ca14f9
Merge branch 'master' into feature/ignore-callback
ditwoo Jun 18, 2020
03fcd85
utils for callbacks
ditwoo Jun 19, 2020
f0beb4b
separated imports
ditwoo Jun 19, 2020
7ebc264
fixed cycle import
ditwoo Jun 19, 2020
682a683
main cases of usage
ditwoo Jun 23, 2020
fce7626
Merge branch 'master' into feature/ignore-callback
ditwoo Jun 23, 2020
7f1f244
wrapper callback
ditwoo Jun 23, 2020
7c591c6
fixed link
ditwoo Jun 23, 2020
3623e47
eval
ditwoo Jun 24, 2020
c44434b
fixed conflict
ditwoo Jun 24, 2020
3150936
Merge branch 'master' into feature/ignore-callback
ditwoo Jun 25, 2020
d4ac022
tests
ditwoo Jun 25, 2020
27a7550
simplified example
ditwoo Jun 25, 2020
57f7afc
docs
ditwoo Jun 25, 2020
ca75dd1
removed whitespace
ditwoo Jun 25, 2020
808da30
ignoring eval in codestyle
ditwoo Jun 26, 2020
a284eb2
codestyle fix
ditwoo Jun 26, 2020
055c7d4
WPS400
ditwoo Jun 26, 2020
7fb6e8a
docs: loaders
ditwoo Jun 26, 2020
86130a5
epochs argument & tests
ditwoo Jun 26, 2020
d6fc99b
renamed utils function
ditwoo Jun 26, 2020
47e9e97
epochs test
ditwoo Jun 26, 2020
ba9f715
codestyle fixes
ditwoo Jun 26, 2020
0755480
example fix
ditwoo Jun 27, 2020
11edb5f
ignore arguments & tests
ditwoo Jun 28, 2020
9b6ab21
`ControlFlowCallback` + epochs update + tests
ditwoo Jun 28, 2020
857a977
Merge branch 'master' into feature/ignore-callback
ditwoo Jun 28, 2020
8596284
moved `WrapperCallback` to `Callback`
ditwoo Jun 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions bin/tests/check_dl_core_ignore_metric_callback.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#!/usr/bin/env bash

# Cause the script to exit if a single command fails
set -eo pipefail -v


################################ global variables ################################
rm -rf ./tests/logs ./tests/output.txt

EXPDIR=./tests/_tests_contrib_dl_callbacks
LOGDIR=./tests/logs/_tests_contrib_dl_callbacks
CHECKPOINTS=${LOGDIR}/checkpoints
LOGFILE=${CHECKPOINTS}/_metrics.json
EXP_OUTPUT=./tests/output.txt


function check_file_existence {
# $1 - path to file
if [[ ! -f "$1" ]]
then
echo "There is no '$1'!"
exit 1
fi
}


function check_num_files {
# $1 - ls directory
# $2 - expected count
NFILES=$( ls $1 | wc -l )
if [[ $NFILES -ne $2 ]]
then
echo "Different number of files in '$1' - "`
`"expected $2 but actual number is $NFILES!"
exit 1
fi
}


function check_checkpoints {
# $1 - file prefix
# $2 - expected count
check_num_files "${1}.pth" $2
check_num_files "${1}_full.pth" $2
}


function check_line_counts {
# $1 file
# $2 pattern
# $3 expected count
ACTUAL_COUNT=$( grep -c "$2" $1 || true ) # '|| true' for handling pipefail
if [ $ACTUAL_COUNT -ne $3 ]
then
echo "Different number of lines in file '$1' - "`
`"expected $3 (should match '$2') but actual number is $ACTUAL_COUNT!"
exit 1
fi
}

################################ pipeline 00 ################################
# setup: run validation once in 3 epoch
LOG_MSG='pipeline 00'
echo ${LOG_MSG}

LOGDIR=./tests/logs/_tests_dl_callbacks
CHECKPOINTS=${LOGDIR}/checkpoints
LOGFILE=${CHECKPOINTS}/_metrics.json
EXP_OUTPUT=./tests/output.txt

PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \
python3 -c "
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst.dl import (
SupervisedRunner, Callback, CallbackOrder,
IgnoreMetricCallback, AccuracyCallback,
)

# experiment_setup
logdir = '${LOGDIR}'
num_epochs = 10

# data
num_samples, num_features = int(1e4), int(1e1)
X = torch.rand(num_samples, num_features)
y = torch.randint(0, 5, size=[num_samples])
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {
'train': loader,
'valid': loader,
}

# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, 5)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
runner = SupervisedRunner()

# first stage
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
logdir=logdir,
num_epochs=num_epochs,
verbose=True,
main_metric='accuracy01',
callbacks=[
AccuracyCallback(accuracy_args=[1, 3, 5]),
IgnoreMetricCallback(valid=['_criterion'])
]
)
" > ${EXP_OUTPUT}

cat ${EXP_OUTPUT}
# check_line_counts ${EXP_OUTPUT} "(train):" 10
# check_line_counts ${EXP_OUTPUT} "(valid):" 3
# check_line_counts ${EXP_OUTPUT} ".*/train\.[[:digit:]]\.pth" 1

check_file_existence ${LOGFILE}
cat ${LOGFILE}
echo ${LOG_MSG}

check_checkpoints "${CHECKPOINTS}/best" 1
check_checkpoints "${CHECKPOINTS}/last" 1
check_checkpoints "${CHECKPOINTS}/train\.[[:digit:]]" 1
check_num_files ${CHECKPOINTS} 7 # 3x2 checkpoints + metrics.json

rm -rf ${LOGDIR} ${EXP_OUTPUT}
1 change: 1 addition & 0 deletions catalyst/contrib/dl/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .cutmix_callback import CutmixCallback
from .gradnorm_logger import GradNormLogger
from .ignore_metric_callback import IgnoreMetricCallback
from .knn_metric import KNNMetricCallback
from .periodic_loader_callback import PeriodicLoaderCallback
from .perplexity_metric import PerplexityMetricCallback
Expand Down
78 changes: 78 additions & 0 deletions catalyst/contrib/dl/callbacks/ignore_metric_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from catalyst.core.callback import Callback, CallbackOrder
from catalyst.core.runner import IRunner


class IgnoreMetricCallback(Callback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to be sure, could you please update docs with this new callbacks & changelog :)
I now it only draft yet :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

speaking about architecture and implementation,
what do you think about CallbackFilterCallback?
not only with name filtering, but also with some condition one?

callback_filter = CallbackFilterCallback(
  key={"train": "metric-1"}, 
  lambda={"valid": lambda key, value: isinstance(value, MetricCallback)},
  # or, even we can also add something like
  value={"valid": MetricCallback}
)

long story shot, with callback filtering, most of the time you have 2 options:

  1. filter by name (key field)
  2. filter by class (value field)
  3. some very custom condition -> lambda

what do you think about such proposal?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moreover, if we would like to think about filtering in a broader sense....
why do we filter only on loader name? during stage, we have 2 fors - epoch and loader one.
In this case, I think, we could filter not only on loader name, but also on epoch index.

nevertheless, still now sure about user friendly API for this :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, only maybe, something like

callback_filter = CallbackFilterCallback(
  condition=lambda epoch, loader: loader == "train",
  filter=lambda key, value: isinstance(value, MetricCallback),
)

Copy link
Contributor Author

@ditwoo ditwoo Jun 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Scitator what do you think about FilterCallback or SkipCallback or FilterMetricCallback as a name for a class?
(I think that we can generate a better name without repeated callback word)

Also what do you think about use cases for arguments:

callback = FilterCallback(
    loader0="to_filter",  # base case - ignore callback for loader
    loader1=["to_fitler1", "to_filter2"],  # ignore multiple callbacks
    loader2={"to_filter1": [1, 22, 333], "to_filter2": [1, 2, 3]},  # turn off callbacks for epochs
    loader3=lambda epoch, loader_name, loader_obj: loader_obj,  # filter with custom function
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ditwoo I think, FilterCallback would be good choice :)
First of all, we should not concentrate on Metric-based callbacks only.

"""
Ignore metric callbacks for specified loaders.
"""

def __init__(self, **kwargs):
"""

Args:
kwargs: loader and callback names to ignore
"""
super().__init__(order=CallbackOrder.External)
# contains pointers to callbacks
self.callbacks = {}
self.loader_ignore_list = {}
for loader, ignore_list in kwargs.items():
if not isinstance(ignore_list, (str, list, tuple)):
raise TypeError(
"Expected ignore list object is str/List[str]/Tuple[str] "
f"but got {type(ignore_list)}"
)
if isinstance(ignore_list, str):
to_ignore = [ignore_list]
else:
to_ignore = [
str(callback_name) for callback_name in ignore_list
]
self.loader_ignore_list[loader] = to_ignore

def on_stage_start(self, runner: IRunner) -> None:
"""Get information about callbacks used in a stage.

Args:
runner (IRunner): current runner
"""
for name, callback in runner.callbacks.items():
self.callbacks[name] = callback

def _is_correct_loader(
self, loader: str, name: str, callback: Callback
) -> bool:
"""
Check if callback should be used with loader.

Args:
loader (str): loader name
name (str): callback name
callback (Callback): callback object

Returns:
True if callback should be used with passed loader otherwise False
"""
ignore_list = self.loader_ignore_list.get(loader) or []
in_ignore_list = name in ignore_list
is_metric = callback.order in (
CallbackOrder.Metric,
CallbackOrder.MetricAggregation,
)
return not (in_ignore_list and is_metric)

def on_loader_start(self, runner: IRunner) -> None:
"""
Construct list of callbacks for current loader.

Args:
runner (IRunner): current runner
"""
loader = runner.loader_name
filtered_loader_callbacks = {
name: callback
for name, callback in self.callbacks.items()
if self._is_correct_loader(loader, name, callback)
}
runner.callbacks = filtered_loader_callbacks