Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert to passing full path to model in training #873

Merged
merged 1 commit into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions donkeycar/pipeline/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def generate_model_name(self) -> Tuple[str, int]:
else:
this_num = 0
date = time.strftime('%y-%m-%d')
name = 'pilot_' + date + '_' + str(this_num)
return name, this_num
name = f'pilot_{date}_{this_num}.h5'
return os.path.join(self.cfg.MODELS_PATH, name), this_num

def to_df(self) -> pd.DataFrame:
if self.entries:
Expand Down
29 changes: 19 additions & 10 deletions donkeycar/pipeline/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def create_tf_data(self) -> tf.data.Dataset:
def get_model_train_details(cfg: Config, database: PilotDatabase,
model: str = None, model_type: str = None) \
-> Tuple[str, int, str, bool]:
"""
Returns automatic model name if none is given
:param cfg: donkey config
:param database: model database with existing training data
:param model: model path
:param model_type: type of model, like 'linear', 'tflite_linear', etc
:return: tuple of model path, number, training type, and if
tflite is requested
"""
if not model_type:
model_type = cfg.DEFAULT_MODEL_TYPE
train_type = model_type
Expand All @@ -90,12 +99,13 @@ def get_model_train_details(cfg: Config, database: PilotDatabase,
is_tflite = True
model_num = 0
if not model:
model_name, model_num = database.generate_model_name()
model_path, model_num = database.generate_model_name()
else:
model_name, model_ext = os.path.splitext(model)
_, model_ext = os.path.splitext(model)
model_path = model
is_tflite = model_ext == '.tflite'

return model_name, model_num, train_type, is_tflite
return model_path, model_num, train_type, is_tflite


def train(cfg: Config, tub_paths: str, model: str = None,
Expand All @@ -105,10 +115,9 @@ def train(cfg: Config, tub_paths: str, model: str = None,
Train the model
"""
database = PilotDatabase(cfg)
model_name, model_num, train_type, is_tflite = \
model_path, model_num, train_type, is_tflite = \
get_model_train_details(cfg, database, model, model_type)

output_path = os.path.join(cfg.MODELS_PATH, model_name + '.h5')
kl = get_model_by_type(train_type, cfg)
if transfer:
kl.load(transfer)
Expand All @@ -135,7 +144,7 @@ def train(cfg: Config, tub_paths: str, model: str = None,
assert val_size > 0, "Not enough validation data, decrease the batch " \
"size or add more data."

history = kl.train(model_path=output_path,
history = kl.train(model_path=model_path,
train_data=dataset_train,
train_steps=train_size,
batch_size=cfg.BATCH_SIZE,
Expand All @@ -146,14 +155,14 @@ def train(cfg: Config, tub_paths: str, model: str = None,
min_delta=cfg.MIN_DELTA,
patience=cfg.EARLY_STOP_PATIENCE,
show_plot=cfg.SHOW_PLOT)

base_path = os.path.splitext(model_path)[0]
if is_tflite:
tf_lite_model_path = f'{os.path.splitext(output_path)[0]}.tflite'
keras_model_to_tflite(output_path, tf_lite_model_path)
tf_lite_model_path = f'{base_path}.tflite'
keras_model_to_tflite(model_path, tf_lite_model_path)

database_entry = {
'Number': model_num,
'Name': model_name,
'Name': os.path.basename(base_path),
'Type': str(kl),
'Tubs': tub_paths,
'Time': time(),
Expand Down