-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Develop generalization training #2232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
c650e8a
f78d974
643e816
d90031b
a4cf9ea
e0cd1df
33fe8ae
cf0eae1
155cd4e
65884b3
7b21f29
7a1944d
6b3d1c0
fee2b14
626b0b7
a194379
381ff86
e8c1e80
6b82251
436987e
abd4c03
bfb7bd0
390d98e
e5423bc
7ecc6d8
0f302d7
2f1f814
23fbff8
b9a7f2a
1ce970c
56810c0
ea5f4db
68bdb59
4d7c79f
9d1acc5
5063818
11e2c78
165ab57
d5b0838
c3b372c
bb84831
c441877
8f1c7a8
c8493c0
ecdc734
d4c9163
337ef3e
e285039
ddc2e60
e8ad471
e6a1029
30e0727
603c5c8
fb74bbe
3a5f033
ce152ee
b5b97bd
f545ed1
0e703e2
408b32b
03de771
c47ccfd
44dffd7
79f2079
092e370
bb8654f
188dec7
3b63502
d93f79e
f79263b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
episode-length: 1000 | ||
|
||
mass: | ||
sampler-type: "uniform" | ||
min_value: 0.5 | ||
max_value: 10 | ||
seed: 5 | ||
|
||
gravity: | ||
sampler-type: "uniform" | ||
min_value: 7 | ||
max_value: 12 | ||
seed: 5 | ||
|
||
scale: | ||
sampler-type: "uniform" | ||
min_value: 0.75 | ||
max_value: 3 | ||
seed: 5 |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,99 @@ | ||||||
import numpy as np | ||||||
from functools import * | ||||||
from abc import ABC, abstractmethod | ||||||
|
||||||
from .exception import SamplerException | ||||||
|
||||||
class SamplerException(Exception): | ||||||
pass | ||||||
|
||||||
class Sampler(ABC): | ||||||
|
||||||
@abstractmethod | ||||||
def sample_parameter(self, *args, **kwargs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sample_parameter shouldn't take args and kwargs (or should maybe take a context arg if you're designing for the future) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type annotation for return type: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And same annotation for all the implementations. |
||||||
pass | ||||||
|
||||||
|
||||||
class UniformSampler(Sampler): | ||||||
# kwargs acts as a sink for extra unneeded args | ||||||
def __init__(self, min_value, max_value, **kwargs): | ||||||
self.min_value = min_value | ||||||
self.max_value = max_value | ||||||
|
||||||
def sample_parameter(self): | ||||||
return np.random.uniform(self.min_value, self.max_value) | ||||||
|
||||||
|
||||||
class MultiRangeUniformSampler(Sampler): | ||||||
def __init__(self, intervals, **kwargs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should add more comments one the classes. You need to define what the sampler does (Uniform vs normal for instance) as well as the parameters type and function. For example : Intervals is a list of floats, this needs to be specified somewhere. |
||||||
self.intervals = intervals | ||||||
# Measure the length of the intervals | ||||||
self.interval_lengths = list(map(lambda x: abs(x[1] - x[0]), self.intervals)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find list comprehensions more readable than maps and reduces (and they were also removed from python3). I think this works:
and I also think you can keep interval_lengths and cum_interval_length as local variables. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Converted them to local variables |
||||||
# Cumulative size of the intervals | ||||||
self.cum_interval_length = reduce(lambda x,y: x + y, self.interval_lengths, 0) | ||||||
# Assign weights to an interval proportionate to the interval size | ||||||
self.interval_weights = list(map(lambda x: x/self.cum_interval_length, self.interval_lengths)) | ||||||
|
||||||
|
||||||
def sample_parameter(self): | ||||||
cur_min, cur_max = self.intervals[np.random.choice(len(self.intervals), p=self.interval_weights)] | ||||||
return np.random.uniform(cur_min, cur_max) | ||||||
|
||||||
|
||||||
class GaussianSampler(Sampler): | ||||||
def __init__(self, mean, var, **kwargs): | ||||||
self.mean = mean | ||||||
self.var = var | ||||||
|
||||||
def sample_parameter(self): | ||||||
return np.random.normal(self.mean, self.var) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. variance =/= standard deviation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, let me fix the naming |
||||||
|
||||||
|
||||||
# To introduce new sampling methods, just need to 'register' them to this sampler factory | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment should be formatted consistently with the other comments of the repo. |
||||||
class SamplerFactory: | ||||||
NAME_TO_CLASS = { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Damn I like this. |
||||||
"uniform": UniformSampler, | ||||||
"gaussian": GaussianSampler, | ||||||
"multirange_uniform": MultiRangeUniformSampler, | ||||||
} | ||||||
|
||||||
@staticmethod | ||||||
def register_sampler(name, sampler_cls): | ||||||
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls | ||||||
|
||||||
@staticmethod | ||||||
def init_sampler_class(name, param_dict): | ||||||
if name not in SamplerFactory.NAME_TO_CLASS: | ||||||
raise SamplerException( | ||||||
name + " sampler is not registered in the SamplerFactory." | ||||||
" Use the register_sample method to register the string" | ||||||
" associated to your sampler in the SamplerFactory." | ||||||
) | ||||||
sampler_cls = SamplerFactory.NAME_TO_CLASS[name] | ||||||
return sampler_cls(**param_dict) | ||||||
|
||||||
|
||||||
class SamplerManager: | ||||||
def __init__(self, reset_param_dict): | ||||||
self.reset_param_dict = reset_param_dict | ||||||
self.samplers = {} | ||||||
if reset_param_dict == None: | ||||||
return | ||||||
for param_name, cur_param_dict in self.reset_param_dict.items(): | ||||||
if "sampler-type" not in cur_param_dict: | ||||||
raise SamplerException( | ||||||
"'sampler_type' argument hasn't been supplied for the {0} parameter".format(param_name) | ||||||
) | ||||||
sampler_name = cur_param_dict.pop("sampler-type") | ||||||
param_sampler = SamplerFactory.init_sampler_class(sampler_name, cur_param_dict) | ||||||
|
||||||
self.samplers[param_name] = param_sampler | ||||||
|
||||||
def sample_all(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
res = {} | ||||||
if self.samplers == {}: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this is what you want here; it's definitely more pythonic to do |
||||||
pass | ||||||
else: | ||||||
for param_name, param_sampler in list(self.samplers.items()): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't need the |
||||||
res[param_name] = param_sampler.sample_parameter() | ||||||
return res |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,3 +23,6 @@ class MetaCurriculumError(TrainerError): | |
""" | ||
Any error related to the configuration of a metacurriculum. | ||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra line ? |
||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,8 @@ | |
from mlagents.trainers.exception import TrainerError | ||
from mlagents.trainers import MetaCurriculumError, MetaCurriculum | ||
from mlagents.envs import UnityEnvironment | ||
from mlagents.envs.exception import UnityEnvironmentException | ||
from mlagents.envs.sampler_class import SamplerManager | ||
from mlagents.envs.exception import UnityEnvironmentException, SamplerException | ||
from mlagents.envs.base_unity_environment import BaseUnityEnvironment | ||
from mlagents.envs.subprocess_environment import SubprocessUnityEnvironment | ||
|
||
|
@@ -52,6 +53,8 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue): | |
fast_simulation = not bool(run_options["--slow"]) | ||
no_graphics = run_options["--no-graphics"] | ||
trainer_config_path = run_options["<trainer-config-path>"] | ||
sampler_file_path = run_options["--sampler"] if run_options ["--sampler"] != "None" else None | ||
|
||
# Recognize and use docker volume if one is passed as an argument | ||
if not docker_target_name: | ||
model_path = "./models/{run_id}-{sub_id}".format(run_id=run_id, sub_id=sub_id) | ||
|
@@ -73,6 +76,21 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue): | |
docker_target_name=docker_target_name | ||
) | ||
|
||
sampler = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think |
||
lesson_length = None | ||
if sampler_file_path is not None: | ||
sampler = load_config(sampler_file_path) | ||
if ("episode-length") in sampler: | ||
lesson_length = sampler["episode-length"] | ||
del sampler["episode-length"] | ||
else: | ||
raise SamplerException( | ||
"Episode Length was not specified in the sampler file." | ||
" Please specify it with the 'episode-length' key in the sampler config file." | ||
) | ||
sampler_manager = SamplerManager(sampler) | ||
|
||
|
||
trainer_config = load_config(trainer_config_path) | ||
env_factory = create_environment_factory( | ||
env_path, | ||
|
@@ -84,6 +102,7 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue): | |
env = SubprocessUnityEnvironment(env_factory, num_envs) | ||
maybe_meta_curriculum = try_create_meta_curriculum(curriculum_folder, env) | ||
|
||
|
||
# Create controller and begin training. | ||
tc = TrainerController( | ||
model_path, | ||
|
@@ -98,6 +117,8 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue): | |
env.external_brains, | ||
run_seed, | ||
fast_simulation, | ||
sampler_manager, | ||
lesson_length, | ||
) | ||
|
||
# Signal that environment has been launched. | ||
|
@@ -242,22 +263,23 @@ def main(): | |
mlagents-learn --help | ||
|
||
Options: | ||
--env=<file> Name of the Unity executable [default: None]. | ||
--curriculum=<directory> Curriculum json directory for environment [default: None]. | ||
--keep-checkpoints=<n> How many model checkpoints to keep [default: 5]. | ||
--lesson=<n> Start learning from this lesson [default: 0]. | ||
--load Whether to load the model or randomly initialize [default: False]. | ||
--run-id=<path> The directory name for model and summary statistics [default: ppo]. | ||
--num-runs=<n> Number of concurrent training sessions [default: 1]. | ||
--save-freq=<n> Frequency at which to save model [default: 50000]. | ||
--seed=<n> Random seed used for training [default: -1]. | ||
--slow Whether to run the game at training speed [default: False]. | ||
--train Whether to train model, or only run inference [default: False]. | ||
--base-port=<n> Base port for environment communication [default: 5005]. | ||
--num-envs=<n> Number of parallel environments to use for training [default: 1] | ||
--docker-target-name=<dt> Docker volume to store training-specific files [default: None]. | ||
--no-graphics Whether to run the environment in no-graphics mode [default: False]. | ||
--debug Whether to run ML-Agents in debug mode with detailed logging [default: False]. | ||
--env=<file> Name of the Unity executable [default: None]. | ||
--curriculum=<directory> Curriculum json directory for environment [default: None]. | ||
--keep-checkpoints=<n> How many model checkpoints to keep [default: 5]. | ||
--lesson=<n> Start learning from this lesson [default: 0]. | ||
--load Whether to load the model or randomly initialize [default: False]. | ||
--run-id=<path> The directory name for model and summary statistics [default: ppo]. | ||
--num-runs=<n> Number of concurrent training sessions [default: 1]. | ||
--save-freq=<n> Frequency at which to save model [default: 50000]. | ||
--seed=<n> Random seed used for training [default: -1]. | ||
--slow Whether to run the game at training speed [default: False]. | ||
--train Whether to train model, or only run inference [default: False]. | ||
--base-port=<n> Base port for environment communication [default: 5005]. | ||
--num-envs=<n> Number of parallel environments to use for training [default: 1] | ||
--docker-target-name=<dt> Docker volume to store training-specific files [default: None]. | ||
--no-graphics Whether to run the environment in no-graphics mode [default: False]. | ||
--sampler=<directory> Reset parameter yaml directory for sampling of environment reset parameters [default: None]. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a directory or a file ? In the code, it looks like a file. |
||
--debug Whether to run ML-Agents in debug mode with detailed logging [default: False]. | ||
""" | ||
|
||
options = docopt(_USAGE) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -22,7 +22,9 @@ | |||||
from mlagents.trainers.meta_curriculum import MetaCurriculum | ||||||
|
||||||
|
||||||
|
||||||
class TrainerController(object): | ||||||
# Type of reset_param_dict unspecified as typing library does not support heterogeneous dictionary types | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be able to use Line 9 in c025086
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Trailing comment from old code, no longer relevant here, but type information is good to know. |
||||||
def __init__( | ||||||
self, | ||||||
model_path: str, | ||||||
|
@@ -37,6 +39,8 @@ def __init__( | |||||
external_brains: Dict[str, BrainParameters], | ||||||
training_seed: int, | ||||||
fast_simulation: bool, | ||||||
sampler_manager, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type annotation here: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added type annotation |
||||||
lesson_length: Optional[int], | ||||||
): | ||||||
""" | ||||||
:param model_path: Path to save the model. | ||||||
|
@@ -50,6 +54,8 @@ def __init__( | |||||
:param lesson: Start learning from this lesson. | ||||||
:param external_brains: dictionary of external brain names to BrainInfo objects. | ||||||
:param training_seed: Seed to use for Numpy and Tensorflow random number generation. | ||||||
:param sampler_manager: SamplerManager object which stores information about samplers to use for the reset parameters. | ||||||
:param lesson_length: Specifies number of steps after which reset parameters are resampled. | ||||||
""" | ||||||
|
||||||
self.model_path = model_path | ||||||
|
@@ -72,6 +78,8 @@ def __init__( | |||||
self.fast_simulation = fast_simulation | ||||||
np.random.seed(self.seed) | ||||||
tf.set_random_seed(self.seed) | ||||||
self.sampler_manager = sampler_manager | ||||||
self.lesson_length = lesson_length | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit uncomfortable with multiple things called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same as the episode_length in the yaml file, not sure if episode length would be appropriate either. Any thoughts about lesson_duration instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would use something that doesn't involve the word Whatever you choose, it should be the same in the yaml and in the code. |
||||||
|
||||||
def _get_measure_vals(self): | ||||||
if self.meta_curriculum: | ||||||
|
@@ -90,6 +98,7 @@ def _get_measure_vals(self): | |||||
measure_val = np.mean(self.trainers[brain_name].reward_buffer) | ||||||
brain_names_to_measure_vals[brain_name] = measure_val | ||||||
return brain_names_to_measure_vals | ||||||
|
||||||
else: | ||||||
return None | ||||||
|
||||||
|
@@ -167,13 +176,16 @@ def initialize_trainers(self, trainer_config: Dict[str, Dict[str, str]]): | |||||
self.run_id, | ||||||
) | ||||||
elif trainer_parameters_dict[brain_name]["trainer"] == "ppo": | ||||||
# Find lesson length based on the form of learning | ||||||
if self.meta_curriculum: | ||||||
lesson_length = self.meta_curriculum.brains_to_curriculums[ | ||||||
brain_name].min_lesson_length | ||||||
else: | ||||||
lesson_length = 0 | ||||||
|
||||||
self.trainers[brain_name] = PPOTrainer( | ||||||
self.external_brains[brain_name], | ||||||
self.meta_curriculum.brains_to_curriculums[ | ||||||
brain_name | ||||||
].min_lesson_length | ||||||
if self.meta_curriculum | ||||||
else 0, | ||||||
lesson_length, | ||||||
trainer_parameters_dict[brain_name], | ||||||
self.train_model, | ||||||
self.load_model, | ||||||
|
@@ -203,20 +215,30 @@ def _create_model_path(model_path): | |||||
"permissions are set correctly.".format(model_path) | ||||||
) | ||||||
|
||||||
@staticmethod | ||||||
def _check_reset_params(reset_params, new_config): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just checking, where is this method called? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code trailing from use of lesson_controller; not needed anymore so removed from next update. |
||||||
for k in new_config: | ||||||
if (k in reset_params) and (isinstance(config[k], (int, float))): | ||||||
continue | ||||||
elif not isinstance(new_config[k], (int, float)): | ||||||
raise UnityEnvironmentException( | ||||||
"The parameter '{0}'' generated by the sampler doesn't exist in this environment.".format( | ||||||
k | ||||||
) | ||||||
) | ||||||
|
||||||
def _reset_env(self, env: BaseUnityEnvironment): | ||||||
"""Resets the environment. | ||||||
|
||||||
Returns: | ||||||
A Data structure corresponding to the initial reset state of the | ||||||
environment. | ||||||
""" | ||||||
if self.meta_curriculum is not None: | ||||||
return env.reset( | ||||||
train_mode=self.fast_simulation, | ||||||
config=self.meta_curriculum.get_config(), | ||||||
) | ||||||
else: | ||||||
return env.reset(train_mode=self.fast_simulation) | ||||||
sampled_reset_param = self.sampler_manager.sample_all() | ||||||
ervteng marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
new_meta_curriculum_config = (self.meta_curriculum.get_config() | ||||||
if self.meta_curriculum else {}) | ||||||
sampled_reset_param.update(new_meta_curriculum_config) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this mean meta_curriculum would override sampler_manager if provided? I'd just make sure this is clear in the documentation, or say that you have to use one or the other. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only overrides if they conflict on a key. One of them has to take priority, not sure which one it should be though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. meta_curriculum does take priority over generalization training as this made most sense. |
||||||
return env.reset(train_mode = self.fast_simulation, config = sampled_reset_param) | ||||||
|
||||||
def start_learning(self, env: BaseUnityEnvironment, trainer_config): | ||||||
# TODO: Should be able to start learning at different lesson numbers | ||||||
|
@@ -263,6 +285,22 @@ def start_learning(self, env: BaseUnityEnvironment, trainer_config): | |||||
self._write_training_metrics() | ||||||
self._export_graph() | ||||||
|
||||||
def end_trainer_episodes(self, env, lessons_incremented): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type annotations.
Suggested change
|
||||||
curr_info = self._reset_env(env) | ||||||
for brain_name, trainer in self.trainers.items(): | ||||||
trainer.end_episode() | ||||||
for brain_name, changed in lessons_incremented.items(): | ||||||
if changed: | ||||||
self.trainers[brain_name].reward_buffer.clear() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're doing a generalization reset, do we need to clear the reward buffers? I don't think that will currently happen. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confirming, no need to change reward buffers in the case of generalization reset as generalization doesn't make use of reward buffers to decide on reset There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed offline, please add a comment to this effect; it's not obvious from the code. |
||||||
return curr_info | ||||||
|
||||||
def check_empty_sampler_manager(self): | ||||||
""" | ||||||
If self.samplers is empty, then bool of it returns false, indicating | ||||||
there is no sampler manager. | ||||||
""" | ||||||
return not bool(self.sampler_manager.samplers) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about making this a method or property on SamplerManager instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on making There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made the check part of the SamplerManager class |
||||||
|
||||||
def take_step(self, env: BaseUnityEnvironment, curr_info: AllBrainInfo): | ||||||
if self.meta_curriculum: | ||||||
# Get the sizes of the reward buffers. | ||||||
|
@@ -279,14 +317,12 @@ def take_step(self, env: BaseUnityEnvironment, curr_info: AllBrainInfo): | |||||
|
||||||
# If any lessons were incremented or the environment is | ||||||
# ready to be reset | ||||||
if self.meta_curriculum and any(lessons_incremented.values()): | ||||||
curr_info = self._reset_env(env) | ||||||
for brain_name, trainer in self.trainers.items(): | ||||||
trainer.end_episode() | ||||||
for brain_name, changed in lessons_incremented.items(): | ||||||
if changed: | ||||||
self.trainers[brain_name].reward_buffer.clear() | ||||||
|
||||||
if ( ((self.meta_curriculum) and any(lessons_incremented.values())) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is way too complicated. If you're actually checking what the comment says, then split it up into something like
(also watch the modulo by zero) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a check to ensure that global_step isn't 0 to ensure modulo safety |
||||||
or ( (not self.check_empty_sampler_manager()) and (self.global_step % self.lesson_length == 0) | ||||||
and (self.global_step != 0)) ): | ||||||
curr_info = self.end_trainer_episodes(env, lessons_incremented) | ||||||
|
||||||
|
||||||
# Decide and take an action | ||||||
take_action_vector = {} | ||||||
take_action_memories = {} | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW this should inherit from
UnityException
as with all of the other exception classesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instance of dead code I was using for testing; removed in the next update