Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve random state handling #801

Merged
merged 13 commits into from
May 20, 2019
Merged
20 changes: 20 additions & 0 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,26 @@ the dataset we generated earlier:

engine.fit(dataset)

Note that by default, the training of the engine is non-deterministic: if you
train your NLU twice on the same data and test it on the same input, you'll get
different outputs.
Copy link
Contributor

@adrienball adrienball May 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be a bit more optimistic in the formulation:

Note that, by default, the training of the NLU engine is a non-deterministic process: 
training and testing multiple times on the same data may produce different outputs.


If you want to run training in a reproducible way you can pass a random seed to
your engine:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer using a more impersonal form in the documentation, but that's just a suggestion. That would be something like:

Reproducible training and testing can be achieved by passing a 
**random seed** to the engine:


.. code-block:: python

seed = 42
engine = SnipsNLUEngine(config=CONFIG_EN, random_state=seed)
engine.fit(dataset)


.. note::

Due to a ``scikit-learn`` bug fixed in version ``0.21`` we can't guarantee
any deterministic behavior if you're using a Python version ``<3.5`` since
``scikit-learn>=0.21`` is only available starting from Python ``>=3.5``


Parsing
-------
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@
"future>=0.16,<0.17",
"numpy>=1.15,<1.16",
"scipy>=1.0,<2.0",
"scikit-learn>=0.19,<0.20",
"scikit-learn>=0.21.1,<0.22; python_version>='3.5'",
"scikit-learn>=0.20,<0.21; python_version<'3.5'",
"sklearn-crfsuite>=0.3.6,<0.4",
"semantic_version>=2.6,<3.0",
"snips-nlu-utils>=0.8,<0.9",
"snips-nlu-parsers>=0.2,<0.3",
"num2words>=0.5.6,<0.6",
"plac>=0.9.6,<1.0",
"requests>=2.0,<3.0",
"pathlib==1.0.1; python_version < '3.4'",
"pathlib==1.0.1; python_version<'3.4'",
"pyaml>=17,<18",
"deprecation>=2,<3",
"funcsigs>=1.0,<2.0; python_version < '3.4'"
"funcsigs>=1.0,<2.0; python_version<'3.4'"
]

extras_require = {
Expand Down
1 change: 1 addition & 0 deletions snips_nlu/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
BUILTIN_ENTITY_PARSER = "builtin_entity_parser"
CUSTOM_ENTITY_PARSER = "custom_entity_parser"
MATCHING_STRICTNESS = "matching_strictness"
RANDOM_STATE = "random_state"

# resources
RESOURCES = "resources"
Expand Down
6 changes: 3 additions & 3 deletions snips_nlu/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def get_entities_iterators(intent_entities, language,
add_builtin_entities_examples, random_state):
entities_its = dict()
for entity_name, entity in iteritems(intent_entities):
utterance_values = random_state.permutation(list(entity[UTTERANCES]))
utterance_values = random_state.permutation(sorted(entity[UTTERANCES]))
if add_builtin_entities_examples and is_builtin_entity(entity_name):
entity_examples = get_builtin_entity_examples(entity_name,
language)
entity_examples = get_builtin_entity_examples(
entity_name, language)
# Builtin entity examples must be kept first in the iterator to
# ensure that they are used when augmenting data
iterator_values = entity_examples + list(utterance_values)
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_de.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -140,8 +139,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_en.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -126,8 +125,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
5 changes: 2 additions & 3 deletions snips_nlu/default_configs/config_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None

},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -118,8 +118,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_fr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -118,8 +117,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_it.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -118,8 +117,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
5 changes: 2 additions & 3 deletions snips_nlu/default_configs/config_ja.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None

},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -144,8 +144,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
6 changes: 2 additions & 4 deletions snips_nlu/default_configs/config_ko.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@
"min_utterances": 200,
"capitalization_ratio": 0.2,
"add_builtin_entities_examples": True
},
"random_seed": None
}
},
"intent_classifier_config": {
"unit_name": "log_reg_intent_classifier",
Expand Down Expand Up @@ -136,8 +135,7 @@
"unknown_words_replacement_string": None,
"keep_order": True
}
},
"random_seed": None
}
}
}
]
Expand Down
8 changes: 6 additions & 2 deletions snips_nlu/intent_classifier/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _fit_transform_tfidf_vectorizer(self, x, y, dataset):
config=self.config.tfidf_vectorizer_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)
x_tfidf = self.tfidf_vectorizer.fit_transform(x, dataset)

if not self.tfidf_vectorizer.vocabulary:
Expand Down Expand Up @@ -139,7 +141,9 @@ def _fit_cooccurrence_vectorizer(self, x, classes, none_class, dataset):
config=self.config.cooccurrence_vectorizer_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)
x_cooccurrence = self.cooccurrence_vectorizer.fit(
non_null_x, dataset).transform(x)
if not self.cooccurrence_vectorizer.word_pairs:
Expand Down
26 changes: 17 additions & 9 deletions snips_nlu/intent_classifier/log_reg_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

from snips_nlu.common.log_utils import DifferedLoggingMessage, log_elapsed_time
from snips_nlu.common.utils import (
check_persisted_path, check_random_state,
fitted_required, json_string)
check_persisted_path, fitted_required, json_string)
from snips_nlu.constants import LANGUAGE, RES_PROBA
from snips_nlu.dataset import validate_and_format_dataset
from snips_nlu.exceptions import _EmptyDatasetUtterancesError, LoadingError
from snips_nlu.exceptions import LoadingError, _EmptyDatasetUtterancesError
from snips_nlu.intent_classifier.featurizer import Featurizer
from snips_nlu.intent_classifier.intent_classifier import IntentClassifier
from snips_nlu.intent_classifier.log_reg_classifier_utils import (
Expand All @@ -24,11 +23,20 @@

logger = logging.getLogger(__name__)

# We set tol to 1e-3 to silence the following warning with Python 2 (
# scikit-learn 0.20):
#
# FutureWarning: max_iter and tol parameters have been added in SGDClassifier
# in 0.19. If max_iter is set but tol is left unset, the default value for tol
# in 0.19 and 0.20 will be None (which is equivalent to -infinity, so it has no
# effect) but will change in 0.21 to 1e-3. Specify tol to silence this warning.

LOG_REG_ARGS = {
"loss": "log",
"penalty": "l2",
"class_weight": "balanced",
"max_iter": 5,
"max_iter": 1000,
"tol": 1e-3,
"n_jobs": -1
}

Expand Down Expand Up @@ -66,12 +74,11 @@ def fit(self, dataset):
self.fit_builtin_entity_parser_if_needed(dataset)
self.fit_custom_entity_parser_if_needed(dataset)
language = dataset[LANGUAGE]
random_state = check_random_state(self.config.random_seed)

data_augmentation_config = self.config.data_augmentation_config
utterances, classes, intent_list = build_training_data(
dataset, language, data_augmentation_config, self.resources,
random_state)
self.random_state)

self.intent_list = intent_list
if len(self.intent_list) <= 1:
Expand All @@ -81,7 +88,8 @@ def fit(self, dataset):
config=self.config.featurizer_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources
resources=self.resources,
random_state=self.random_state,
)
self.featurizer.language = language

Expand All @@ -94,8 +102,8 @@ def fit(self, dataset):
return self

alpha = get_regularization_factor(dataset)
self.classifier = SGDClassifier(random_state=random_state,
alpha=alpha, **LOG_REG_ARGS)
self.classifier = SGDClassifier(
random_state=self.random_state, alpha=alpha, **LOG_REG_ARGS)
self.classifier.fit(x, classes)
logger.debug("%s", DifferedLoggingMessage(self.log_best_features))
return self
Expand Down
2 changes: 2 additions & 0 deletions snips_nlu/intent_classifier/log_reg_classifier_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ def get_regularization_factor(dataset):

def get_noise_it(noise, mean_length, std_length, random_state):
it = itertools.cycle(noise)
i = 0
while True:
noise_length = int(random_state.normal(mean_length, std_length))
i += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused variable i

# pylint: disable=stop-iteration-return
yield " ".join(next(it) for _ in range(noise_length))
# pylint: enable=stop-iteration-return
Expand Down
8 changes: 6 additions & 2 deletions snips_nlu/intent_parser/probabilistic_intent_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def fit(self, dataset, force_retrain=True):
self.config.intent_classifier_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)

if force_retrain or not self.intent_classifier.fitted:
self.intent_classifier.fit(dataset)
Expand All @@ -85,7 +87,9 @@ def fit(self, dataset, force_retrain=True):
slot_filler_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)
if force_retrain or not self.slot_fillers[intent_name].fitted:
self.slot_fillers[intent_name].fit(dataset, intent_name)
logger.debug("Fitted slot fillers in %s",
Expand Down
9 changes: 6 additions & 3 deletions snips_nlu/nlu_engine/nlu_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from snips_nlu.entity_parser import CustomEntityParser
from snips_nlu.entity_parser.builtin_entity_parser import (
BuiltinEntityParser, is_builtin_entity)
from snips_nlu.exceptions import InvalidInputError, IntentNotFoundError, \
LoadingError, IncompatibleModelError
from snips_nlu.exceptions import (
InvalidInputError, IntentNotFoundError, LoadingError,
IncompatibleModelError)
from snips_nlu.intent_parser import IntentParser
from snips_nlu.pipeline.configs import NLUEngineConfig
from snips_nlu.pipeline.processing_unit import ProcessingUnit
Expand Down Expand Up @@ -117,7 +118,9 @@ def fit(self, dataset, force_retrain=True):
parser_config,
builtin_entity_parser=self.builtin_entity_parser,
custom_entity_parser=self.custom_entity_parser,
resources=self.resources)
resources=self.resources,
random_state=self.random_state,
)

if force_retrain or not recycled_parser.fitted:
recycled_parser.fit(dataset, force_retrain)
Expand Down
9 changes: 2 additions & 7 deletions snips_nlu/pipeline/configs/intent_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@ class LogRegIntentClassifierConfig(FromDict, ProcessingUnitConfig):
"""Configuration of a :class:`.LogRegIntentClassifier`"""

# pylint: disable=line-too-long
def __init__(self, data_augmentation_config=None, featurizer_config=None,
random_seed=None):
def __init__(self, data_augmentation_config=None, featurizer_config=None):
"""
Args:
data_augmentation_config (:class:`IntentClassifierDataAugmentationConfig`):
Defines the strategy of the underlying data augmentation
featurizer_config (:class:`FeaturizerConfig`): Configuration of the
:class:`.Featurizer` used underneath
random_seed (int, optional): Allows to fix the seed ot have
reproducible trainings
"""
if data_augmentation_config is None:
data_augmentation_config = IntentClassifierDataAugmentationConfig()
Expand All @@ -32,7 +29,6 @@ def __init__(self, data_augmentation_config=None, featurizer_config=None,
self.data_augmentation_config = data_augmentation_config
self._featurizer_config = None
self.featurizer_config = featurizer_config
self.random_seed = random_seed

# pylint: enable=line-too-long

Expand Down Expand Up @@ -83,8 +79,7 @@ def to_dict(self):
"unit_name": self.unit_name,
"data_augmentation_config":
self.data_augmentation_config.to_dict(),
"featurizer_config": self.featurizer_config.to_dict(),
"random_seed": self.random_seed
"featurizer_config": self.featurizer_config.to_dict()
}


Expand Down
Loading