after save loss string now shows averaged value since last save

This commit is contained in:
iperov 2019-03-25 18:14:04 +04:00
parent bb02cc1d97
commit 131b2b5c79
2 changed files with 30 additions and 18 deletions

View file

@ -41,14 +41,15 @@ def trainerThread (s2c, c2s, args, device_args):
device_args=device_args)
is_reached_goal = model.is_reached_iter_goal()
is_upd_save_time_after_train = False
shared_state = { 'after_save' : False }
loss_string = ""
save_iter = model.get_iter()
def model_save():
if not debug and not is_reached_goal:
io.log_info ("Saving....", end='\r')
model.save()
io.log_info(loss_string)
is_upd_save_time_after_train = True
shared_state['after_save'] = True
def send_preview():
if not debug:
@ -75,12 +76,32 @@ def trainerThread (s2c, c2s, args, device_args):
for i in itertools.count(0,1):
if not debug:
if not is_reached_goal:
loss_string = model.train_one_iter()
if is_upd_save_time_after_train:
#save resets plaidML programs, so upd last_save_time only after plaidML rebuild them
last_save_time = time.time()
iter, iter_time = model.train_one_iter()
loss_history = model.get_loss_history()
time_str = time.strftime("[%H:%M:%S]")
if iter_time >= 10:
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) )
else:
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) )
if shared_state['after_save']:
shared_state['after_save'] = False
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
for loss_value in mean_loss:
loss_string += "[%.4f]" % (loss_value)
io.log_info (loss_string)
save_iter = iter
else:
for loss_value in loss_history[-1]:
loss_string += "[%.4f]" % (loss_value)
io.log_info (loss_string, end='\r')
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
io.log_info ('Reached target iteration.')
model_save()
@ -88,7 +109,6 @@ def trainerThread (s2c, c2s, args, device_args):
io.log_info ('You can use preview now.')
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
last_save_time = time.time()
model_save()
send_preview()

View file

@ -374,15 +374,7 @@ class ModelBase(object):
self.iter += 1
time_str = time.strftime("[%H:%M:%S]")
if iter_time >= 10:
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, self.iter, '{:0.4f}'.format(iter_time) )
else:
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, self.iter, int(iter_time*1000) )
for (loss_name, loss_value) in losses:
loss_string += " %s:%.3f" % (loss_name, loss_value)
return loss_string
return self.iter, iter_time
def pass_one_iter(self):
self.last_sample = self.generate_next_sample()