From 131b2b5c79bdaa6ca157962219aac390a58c8d20 Mon Sep 17 00:00:00 2001 From: iperov Date: Mon, 25 Mar 2019 18:14:04 +0400 Subject: [PATCH] after save loss string now shows averaged value since last save --- mainscripts/Trainer.py | 38 +++++++++++++++++++++++++++++--------- models/ModelBase.py | 10 +--------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 2c2e563..ce574d4 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -41,14 +41,15 @@ def trainerThread (s2c, c2s, args, device_args): device_args=device_args) is_reached_goal = model.is_reached_iter_goal() - is_upd_save_time_after_train = False + + shared_state = { 'after_save' : False } loss_string = "" + save_iter = model.get_iter() def model_save(): if not debug and not is_reached_goal: io.log_info ("Saving....", end='\r') model.save() - io.log_info(loss_string) - is_upd_save_time_after_train = True + shared_state['after_save'] = True def send_preview(): if not debug: @@ -75,12 +76,32 @@ def trainerThread (s2c, c2s, args, device_args): for i in itertools.count(0,1): if not debug: if not is_reached_goal: - loss_string = model.train_one_iter() - if is_upd_save_time_after_train: - #save resets plaidML programs, so upd last_save_time only after plaidML rebuild them - last_save_time = time.time() + iter, iter_time = model.train_one_iter() + + loss_history = model.get_loss_history() + time_str = time.strftime("[%H:%M:%S]") + if iter_time >= 10: + loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) ) + else: + loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) ) + + if shared_state['after_save']: + shared_state['after_save'] = False + last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274 + + mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0) + + for loss_value in mean_loss: + loss_string += "[%.4f]" % (loss_value) + + io.log_info (loss_string) + + save_iter = iter + else: + for loss_value in loss_history[-1]: + loss_string += "[%.4f]" % (loss_value) + io.log_info (loss_string, end='\r') - io.log_info (loss_string, end='\r') if model.get_target_iter() != 0 and model.is_reached_iter_goal(): io.log_info ('Reached target iteration.') model_save() @@ -88,7 +109,6 @@ def trainerThread (s2c, c2s, args, device_args): io.log_info ('You can use preview now.') if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60: - last_save_time = time.time() model_save() send_preview() diff --git a/models/ModelBase.py b/models/ModelBase.py index 6b4f5ad..b68c350 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -374,15 +374,7 @@ class ModelBase(object): self.iter += 1 - time_str = time.strftime("[%H:%M:%S]") - if iter_time >= 10: - loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, self.iter, '{:0.4f}'.format(iter_time) ) - else: - loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, self.iter, int(iter_time*1000) ) - for (loss_name, loss_value) in losses: - loss_string += " %s:%.3f" % (loss_name, loss_value) - - return loss_string + return self.iter, iter_time def pass_one_iter(self): self.last_sample = self.generate_next_sample()