dump_ckpt

This commit is contained in:
iperov 2021-03-23 15:00:24 +04:00
parent 3d0e18b0ad
commit b333fcea4b
3 changed files with 62 additions and 13 deletions

View file

@ -27,6 +27,7 @@ def trainerThread (s2c, c2s, e,
silent_start=False,
execute_programs = None,
debug=False,
dump_ckpt=False,
**kwargs):
while True:
try:
@ -44,7 +45,7 @@ def trainerThread (s2c, c2s, e,
saved_models_path.mkdir(exist_ok=True, parents=True)
model = models.import_model(model_class_name)(
is_training=True,
is_training=not dump_ckpt,
saved_models_path=saved_models_path,
training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path,
@ -55,9 +56,13 @@ def trainerThread (s2c, c2s, e,
force_gpu_idxs=force_gpu_idxs,
cpu_only=cpu_only,
silent_start=silent_start,
debug=debug,
)
debug=debug)
if dump_ckpt:
e.set()
model.dump_ckpt()
break
is_reached_goal = model.is_reached_iter_goal()
shared_state = { 'after_save' : False }