Propagated BN implementation, no CLI support yet.

This commit is contained in:
Jose 2023-02-07 13:19:45 +01:00 committed by GitHub
parent fcd398707f
commit 7439d5003e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 1133 additions and 3 deletions

View file

@ -27,6 +27,7 @@ def trainerThread (s2c, c2s, e,
silent_start=False,
execute_programs = None,
debug=False,
use_bn=False,
**kwargs):
while True:
try:
@ -43,7 +44,22 @@ def trainerThread (s2c, c2s, e,
if not saved_models_path.exists():
saved_models_path.mkdir(exist_ok=True, parents=True)
model = models.import_model(model_class_name)(
if model_class_name != 'SAEHD':
model = models.import_model(model_class_name)(
is_training=True,
saved_models_path=saved_models_path,
training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path,
pretraining_data_path=pretraining_data_path,
pretrained_model_path=pretrained_model_path,
no_preview=no_preview,
force_model_name=force_model_name,
force_gpu_idxs=force_gpu_idxs,
cpu_only=cpu_only,
silent_start=silent_start,
debug=debug)
else:
model = models.import_model(model_class_name, use_bn=use_bn)(
is_training=True,
saved_models_path=saved_models_path,
training_data_src_path=training_data_src_path,