Skip to content

Commit

Permalink
Improve random state handling (#801)
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum authored May 20, 2019
1 parent f14af3a commit 6416624
Show file tree
Hide file tree
Showing 29 changed files with 230 additions and 126 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
### Changed
- Re-score ambiguous `DeterministicIntentParser` results based on slots [#791](https://github.com/snipsco/snips-nlu/pull/791)
- Accept ambiguous results from `DeterministicIntentParser` when confidence score is above 0.5 [#797](https://github.com/snipsco/snips-nlu/pull/797)
- Moved the NLU random state from the config to the shared resources [#801](https://github.com/snipsco/snips-nlu/pull/801)
- Bumped `scikit-learn` to `>=0.21,<0.22` for `python>=3.5` and `>=0.20<0.21` for `python<3.5` [#801](https://github.com/snipsco/snips-nlu/pull/801)

### Fixed
- Fixed a couple of bugs in the data augmentation which were making the NLU training non-deterministic [#801](https://github.com/snipsco/snips-nlu/pull/801)

## [0.19.6]
### Fixed
Expand Down
20 changes: 20 additions & 0 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,26 @@ the dataset we generated earlier:
engine.fit(dataset)
Note that, by default, training of the NLU engine is non-deterministic:
training and testing multiple times on the same data may produce different
outputs.

Reproducible trainings can be achieved by passing a **random seed** to the
engine:

.. code-block:: python
seed = 42
engine = SnipsNLUEngine(config=CONFIG_EN, random_state=seed)
engine.fit(dataset)
.. note::

Due to a ``scikit-learn`` bug fixed in version ``0.21`` we can't guarantee
any deterministic behavior if you're using a Python version ``<3.5`` since
``scikit-learn>=0.21`` is only available starting from Python ``>=3.5``


Parsing
-------
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@
"future>=0.16,<0.17",
"numpy>=1.15,<1.16",
"scipy>=1.0,<2.0",
"scikit-learn>=0.19,<0.20",
"scikit-learn>=0.21.1,<0.22; python_version>='3.5'",
"scikit-learn>=0.20,<0.21; python_version<'3.5'",
"sklearn-crfsuite>=0.3.6,<0.4",
"semantic_version>=2.6,<3.0",
"snips-nlu-utils>=0.8,<0.9",
"snips-nlu-parsers>=0.2,<0.3",
"num2words>=0.5.6,<0.6",
"plac>=0.9.6,<1.0",
"requests>=2.0,<3.0",
"pathlib==1.0.1; python_version < '3.4'",
"pathlib==1.0.1; python_version<'3.4'",
"pyaml>=17,<18",
"deprecation>=2,<3",
"funcsigs>=1.0,<2.0; python_version < '3.4'"
"funcsigs>=1.0,<2.0; python_version<'3.4'"
]

extras_require = {
Expand Down
1 change: 1 addition & 0 deletions snips_nlu/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
BUILTIN_ENTITY_PARSER = "builtin_entity_parser"
CUSTOM_ENTITY_PARSER = "custom_entity_parser"
MATCHING_STRICTNESS = "matching_strictness"
RANDOM_STATE = "random_state"

# resources
RESOURCES = "resources"
Expand Down
6 changes: 3 additions & 3 deletions snips_nlu/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def get_entities_iterators(intent_entities, language,
add_builtin_entities_examples, random_state):
entities_its = dict()
for entity_name, entity in iteritems(intent_entities):
utterance_values = random_state.permutation(list(entity[UTTERANCES]))
utterance_values = random_state.permutation(sorted(entity[UTTERANCES]))
if add_builtin_entities_examples and is_builtin_entity(entity_name):
entity_examples = get_builtin_entity_examples(entity_name,
language)
entity_examples = get_builtin_entity_examples(
entity_name, language)
# Builtin entity examples must be kept first in the iterator to
# ensure that they are used when augmenting data
iterator_values = entity_examples + list(utterance_values)
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_de.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -140,8 +139,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_en.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -126,8 +125,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
5 changes: 2 additions & 3 deletions snips_nlu/default_configs/config_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None

},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -118,8 +118,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_fr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -118,8 +117,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_it.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -118,8 +117,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
5 changes: 2 additions & 3 deletions snips_nlu/default_configs/config_ja.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None

},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -144,8 +144,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_ko.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -136,8 +135,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
8 changes: 6 additions & 2 deletions snips_nlu/intent_classifier/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _fit_transform_tfidf_vectorizer(self, x, y, dataset):
config=self.config.tfidf_vectorizer_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)
x_tfidf = self.tfidf_vectorizer.fit_transform(x, dataset)

if not self.tfidf_vectorizer.vocabulary:
Expand Down Expand Up @@ -139,7 +141,9 @@ def _fit_cooccurrence_vectorizer(self, x, classes, none_class, dataset):
config=self.config.cooccurrence_vectorizer_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)
x_cooccurrence = self.cooccurrence_vectorizer.fit(
non_null_x, dataset).transform(x)
if not self.cooccurrence_vectorizer.word_pairs:
Expand Down
26 changes: 17 additions & 9 deletions snips_nlu/intent_classifier/log_reg_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

from snips_nlu.common.log_utils import DifferedLoggingMessage, log_elapsed_time
from snips_nlu.common.utils import (
check_persisted_path, check_random_state,
fitted_required, json_string)
check_persisted_path, fitted_required, json_string)
from snips_nlu.constants import LANGUAGE, RES_PROBA
from snips_nlu.dataset import validate_and_format_dataset
from snips_nlu.exceptions import _EmptyDatasetUtterancesError, LoadingError
from snips_nlu.exceptions import LoadingError, _EmptyDatasetUtterancesError
from snips_nlu.intent_classifier.featurizer import Featurizer
from snips_nlu.intent_classifier.intent_classifier import IntentClassifier
from snips_nlu.intent_classifier.log_reg_classifier_utils import (
Expand All @@ -24,11 +23,20 @@

logger = logging.getLogger(__name__)

# We set tol to 1e-3 to silence the following warning with Python 2 (
# scikit-learn 0.20):
#
# FutureWarning: max_iter and tol parameters have been added in SGDClassifier
# in 0.19. If max_iter is set but tol is left unset, the default value for tol
# in 0.19 and 0.20 will be None (which is equivalent to -infinity, so it has no
# effect) but will change in 0.21 to 1e-3. Specify tol to silence this warning.

LOG_REG_ARGS = {
"loss": "log",
"penalty": "l2",
"class_weight": "balanced",
"max_iter": 5,
"max_iter": 1000,
"tol": 1e-3,
"n_jobs": -1
}

Expand Down Expand Up @@ -66,12 +74,11 @@ def fit(self, dataset):
self.fit_builtin_entity_parser_if_needed(dataset)
self.fit_custom_entity_parser_if_needed(dataset)
language = dataset[LANGUAGE]
random_state = check_random_state(self.config.random_seed)

data_augmentation_config = self.config.data_augmentation_config
utterances, classes, intent_list = build_training_data(
dataset, language, data_augmentation_config, self.resources,
random_state)
self.random_state)

self.intent_list = intent_list
if len(self.intent_list) <= 1:
Expand All @@ -81,7 +88,8 @@ def fit(self, dataset):
config=self.config.featurizer_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources
resources=self.resources,
random_state=self.random_state,
)
self.featurizer.language = language

Expand All @@ -94,8 +102,8 @@ def fit(self, dataset):
return self

alpha = get_regularization_factor(dataset)
self.classifier = SGDClassifier(random_state=random_state,
alpha=alpha, **LOG_REG_ARGS)
self.classifier = SGDClassifier(
random_state=self.random_state, alpha=alpha, **LOG_REG_ARGS)
self.classifier.fit(x, classes)
logger.debug("%s", DifferedLoggingMessage(self.log_best_features))
return self
Expand Down
8 changes: 6 additions & 2 deletions snips_nlu/intent_parser/probabilistic_intent_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def fit(self, dataset, force_retrain=True):
self.config.intent_classifier_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)

if force_retrain or not self.intent_classifier.fitted:
self.intent_classifier.fit(dataset)
Expand All @@ -85,7 +87,9 @@ def fit(self, dataset, force_retrain=True):
slot_filler_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)
if force_retrain or not self.slot_fillers[intent_name].fitted:
self.slot_fillers[intent_name].fit(dataset, intent_name)
logger.debug("Fitted slot fillers in %s",
Expand Down
9 changes: 6 additions & 3 deletions snips_nlu/nlu_engine/nlu_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from snips_nlu.entity_parser import CustomEntityParser
from snips_nlu.entity_parser.builtin_entity_parser import (
BuiltinEntityParser, is_builtin_entity)
from snips_nlu.exceptions import InvalidInputError, IntentNotFoundError, \
LoadingError, IncompatibleModelError
from snips_nlu.exceptions import (
InvalidInputError, IntentNotFoundError, LoadingError,
IncompatibleModelError)
from snips_nlu.intent_parser import IntentParser
from snips_nlu.pipeline.configs import NLUEngineConfig
from snips_nlu.pipeline.processing_unit import ProcessingUnit
Expand Down Expand Up @@ -117,7 +118,9 @@ def fit(self, dataset, force_retrain=True):
parser_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)

if force_retrain or not recycled_parser.fitted:
recycled_parser.fit(dataset, force_retrain)
Expand Down
9 changes: 2 additions & 7 deletions snips_nlu/pipeline/configs/intent_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@ class LogRegIntentClassifierConfig(FromDict, ProcessingUnitConfig):
"""Configuration of a :class:`.LogRegIntentClassifier`"""

# pylint: disable=line-too-long
def __init__(self, data_augmentation_config=None, featurizer_config=None,
random_seed=None):
def __init__(self, data_augmentation_config=None, featurizer_config=None):
"""
Args:
data_augmentation_config (:class:`IntentClassifierDataAugmentationConfig`):
Defines the strategy of the underlying data augmentation
featurizer_config (:class:`FeaturizerConfig`): Configuration of the
:class:`.Featurizer` used underneath
random_seed (int, optional): Allows to fix the seed ot have
reproducible trainings
"""
if data_augmentation_config is None:
data_augmentation_config = IntentClassifierDataAugmentationConfig()
Expand All @@ -32,7 +29,6 @@ def __init__(self, data_augmentation_config=None, featurizer_config=None,
self.data_augmentation_config = data_augmentation_config
self._featurizer_config = None
self.featurizer_config = featurizer_config
self.random_seed = random_seed

# pylint: enable=line-too-long

Expand Down Expand Up @@ -83,8 +79,7 @@ def to_dict(self):
"unit_name": self.unit_name,
"data_augmentation_config":
self.data_augmentation_config.to_dict(),
"featurizer_config": self.featurizer_config.to_dict(),
"random_seed": self.random_seed
"featurizer_config": self.featurizer_config.to_dict()
}


Expand Down
Loading

0 comments on commit 6416624

Please sign in to comment.