-
-
Notifications
You must be signed in to change notification settings - Fork 390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: wrapper for callbacks #842
Merged
Scitator
merged 30 commits into
catalyst-team:master
from
ditwoo:feature/ignore-callback
Jun 28, 2020
Merged
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
c45ebd5
ignore metric callback
ditwoo 458f4eb
wrapper callback
ditwoo 8ca14f9
Merge branch 'master' into feature/ignore-callback
ditwoo 03fcd85
utils for callbacks
ditwoo f0beb4b
separated imports
ditwoo 7ebc264
fixed cycle import
ditwoo 682a683
main cases of usage
ditwoo fce7626
Merge branch 'master' into feature/ignore-callback
ditwoo 7f1f244
wrapper callback
ditwoo 7c591c6
fixed link
ditwoo 3623e47
eval
ditwoo c44434b
fixed conflict
ditwoo 3150936
Merge branch 'master' into feature/ignore-callback
ditwoo d4ac022
tests
ditwoo 27a7550
simplified example
ditwoo 57f7afc
docs
ditwoo ca75dd1
removed whitespace
ditwoo 808da30
ignoring eval in codestyle
ditwoo a284eb2
codestyle fix
ditwoo 055c7d4
WPS400
ditwoo 7fb6e8a
docs: loaders
ditwoo 86130a5
epochs argument & tests
ditwoo d6fc99b
renamed utils function
ditwoo 47e9e97
epochs test
ditwoo ba9f715
codestyle fixes
ditwoo 0755480
example fix
ditwoo 11edb5f
ignore arguments & tests
ditwoo 9b6ab21
`ControlFlowCallback` + epochs update + tests
ditwoo 857a977
Merge branch 'master' into feature/ignore-callback
ditwoo 8596284
moved `WrapperCallback` to `Callback`
ditwoo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
#!/usr/bin/env bash | ||
|
||
# Cause the script to exit if a single command fails | ||
set -eo pipefail -v | ||
|
||
|
||
################################ global variables ################################ | ||
rm -rf ./tests/logs ./tests/output.txt | ||
|
||
EXPDIR=./tests/_tests_contrib_dl_callbacks | ||
LOGDIR=./tests/logs/_tests_contrib_dl_callbacks | ||
CHECKPOINTS=${LOGDIR}/checkpoints | ||
LOGFILE=${CHECKPOINTS}/_metrics.json | ||
EXP_OUTPUT=./tests/output.txt | ||
|
||
|
||
function check_file_existence { | ||
# $1 - path to file | ||
if [[ ! -f "$1" ]] | ||
then | ||
echo "There is no '$1'!" | ||
exit 1 | ||
fi | ||
} | ||
|
||
|
||
function check_num_files { | ||
# $1 - ls directory | ||
# $2 - expected count | ||
NFILES=$( ls $1 | wc -l ) | ||
if [[ $NFILES -ne $2 ]] | ||
then | ||
echo "Different number of files in '$1' - "` | ||
`"expected $2 but actual number is $NFILES!" | ||
exit 1 | ||
fi | ||
} | ||
|
||
|
||
function check_checkpoints { | ||
# $1 - file prefix | ||
# $2 - expected count | ||
check_num_files "${1}.pth" $2 | ||
check_num_files "${1}_full.pth" $2 | ||
} | ||
|
||
|
||
function check_line_counts { | ||
# $1 file | ||
# $2 pattern | ||
# $3 expected count | ||
ACTUAL_COUNT=$( grep -c "$2" $1 || true ) # '|| true' for handling pipefail | ||
if [ $ACTUAL_COUNT -ne $3 ] | ||
then | ||
echo "Different number of lines in file '$1' - "` | ||
`"expected $3 (should match '$2') but actual number is $ACTUAL_COUNT!" | ||
exit 1 | ||
fi | ||
} | ||
|
||
################################ pipeline 00 ################################ | ||
# setup: run validation once in 3 epoch | ||
LOG_MSG='pipeline 00' | ||
echo ${LOG_MSG} | ||
|
||
LOGDIR=./tests/logs/_tests_dl_callbacks | ||
CHECKPOINTS=${LOGDIR}/checkpoints | ||
LOGFILE=${CHECKPOINTS}/_metrics.json | ||
EXP_OUTPUT=./tests/output.txt | ||
|
||
PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \ | ||
python3 -c " | ||
import torch | ||
from torch.utils.data import DataLoader, TensorDataset | ||
from catalyst.dl import ( | ||
SupervisedRunner, Callback, CallbackOrder, | ||
IgnoreMetricCallback, AccuracyCallback, | ||
) | ||
|
||
# experiment_setup | ||
logdir = '${LOGDIR}' | ||
num_epochs = 10 | ||
|
||
# data | ||
num_samples, num_features = int(1e4), int(1e1) | ||
X = torch.rand(num_samples, num_features) | ||
y = torch.randint(0, 5, size=[num_samples]) | ||
dataset = TensorDataset(X, y) | ||
loader = DataLoader(dataset, batch_size=32, num_workers=1) | ||
loaders = { | ||
'train': loader, | ||
'valid': loader, | ||
} | ||
|
||
# model, criterion, optimizer, scheduler | ||
model = torch.nn.Linear(num_features, 5) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
optimizer = torch.optim.Adam(model.parameters()) | ||
runner = SupervisedRunner() | ||
|
||
# first stage | ||
runner.train( | ||
model=model, | ||
criterion=criterion, | ||
optimizer=optimizer, | ||
loaders=loaders, | ||
logdir=logdir, | ||
num_epochs=num_epochs, | ||
verbose=True, | ||
main_metric='accuracy01', | ||
callbacks=[ | ||
AccuracyCallback(accuracy_args=[1, 3, 5]), | ||
IgnoreMetricCallback(valid=['_criterion']) | ||
] | ||
) | ||
" > ${EXP_OUTPUT} | ||
|
||
cat ${EXP_OUTPUT} | ||
# check_line_counts ${EXP_OUTPUT} "(train):" 10 | ||
# check_line_counts ${EXP_OUTPUT} "(valid):" 3 | ||
# check_line_counts ${EXP_OUTPUT} ".*/train\.[[:digit:]]\.pth" 1 | ||
|
||
check_file_existence ${LOGFILE} | ||
cat ${LOGFILE} | ||
echo ${LOG_MSG} | ||
|
||
check_checkpoints "${CHECKPOINTS}/best" 1 | ||
check_checkpoints "${CHECKPOINTS}/last" 1 | ||
check_checkpoints "${CHECKPOINTS}/train\.[[:digit:]]" 1 | ||
check_num_files ${CHECKPOINTS} 7 # 3x2 checkpoints + metrics.json | ||
|
||
rm -rf ${LOGDIR} ${EXP_OUTPUT} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from catalyst.core.callback import Callback, CallbackOrder | ||
from catalyst.core.runner import IRunner | ||
|
||
|
||
class IgnoreMetricCallback(Callback): | ||
""" | ||
Ignore metric callbacks for specified loaders. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
""" | ||
|
||
Args: | ||
kwargs: loader and callback names to ignore | ||
""" | ||
super().__init__(order=CallbackOrder.External) | ||
# contains pointers to callbacks | ||
self.callbacks = {} | ||
self.loader_ignore_list = {} | ||
for loader, ignore_list in kwargs.items(): | ||
if not isinstance(ignore_list, (str, list, tuple)): | ||
raise TypeError( | ||
"Expected ignore list object is str/List[str]/Tuple[str] " | ||
f"but got {type(ignore_list)}" | ||
) | ||
if isinstance(ignore_list, str): | ||
to_ignore = [ignore_list] | ||
else: | ||
to_ignore = [ | ||
str(callback_name) for callback_name in ignore_list | ||
] | ||
self.loader_ignore_list[loader] = to_ignore | ||
|
||
def on_stage_start(self, runner: IRunner) -> None: | ||
"""Get information about callbacks used in a stage. | ||
|
||
Args: | ||
runner (IRunner): current runner | ||
""" | ||
for name, callback in runner.callbacks.items(): | ||
self.callbacks[name] = callback | ||
|
||
def _is_correct_loader( | ||
self, loader: str, name: str, callback: Callback | ||
) -> bool: | ||
""" | ||
Check if callback should be used with loader. | ||
|
||
Args: | ||
loader (str): loader name | ||
name (str): callback name | ||
callback (Callback): callback object | ||
|
||
Returns: | ||
True if callback should be used with passed loader otherwise False | ||
""" | ||
ignore_list = self.loader_ignore_list.get(loader) or [] | ||
in_ignore_list = name in ignore_list | ||
is_metric = callback.order in ( | ||
CallbackOrder.Metric, | ||
CallbackOrder.MetricAggregation, | ||
) | ||
return not (in_ignore_list and is_metric) | ||
|
||
def on_loader_start(self, runner: IRunner) -> None: | ||
""" | ||
Construct list of callbacks for current loader. | ||
|
||
Args: | ||
runner (IRunner): current runner | ||
""" | ||
loader = runner.loader_name | ||
filtered_loader_callbacks = { | ||
name: callback | ||
for name, callback in self.callbacks.items() | ||
if self._is_correct_loader(loader, name, callback) | ||
} | ||
runner.callbacks = filtered_loader_callbacks |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to be sure, could you please update docs with this new callbacks & changelog :)
I now it only draft yet :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
speaking about architecture and implementation,
what do you think about
CallbackFilterCallback
?not only with name filtering, but also with some condition one?
long story shot, with callback filtering, most of the time you have 2 options:
key
field)value
field)what do you think about such proposal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moreover, if we would like to think about filtering in a broader sense....
why do we filter only on loader name? during stage, we have 2 fors - epoch and loader one.
In this case, I think, we could filter not only on loader name, but also on epoch index.
nevertheless, still now sure about user friendly API for this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe, only maybe, something like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Scitator what do you think about
FilterCallback
orSkipCallback
orFilterMetricCallback
as a name for a class?(I think that we can generate a better name without repeated callback word)
Also what do you think about use cases for arguments:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ditwoo I think,
FilterCallback
would be good choice :)First of all, we should not concentrate on Metric-based callbacks only.