mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 21:42:08 -07:00
after save loss string now shows averaged value since last save
This commit is contained in:
parent
bb02cc1d97
commit
131b2b5c79
2 changed files with 30 additions and 18 deletions
|
@ -41,14 +41,15 @@ def trainerThread (s2c, c2s, args, device_args):
|
|||
device_args=device_args)
|
||||
|
||||
is_reached_goal = model.is_reached_iter_goal()
|
||||
is_upd_save_time_after_train = False
|
||||
|
||||
shared_state = { 'after_save' : False }
|
||||
loss_string = ""
|
||||
save_iter = model.get_iter()
|
||||
def model_save():
|
||||
if not debug and not is_reached_goal:
|
||||
io.log_info ("Saving....", end='\r')
|
||||
model.save()
|
||||
io.log_info(loss_string)
|
||||
is_upd_save_time_after_train = True
|
||||
shared_state['after_save'] = True
|
||||
|
||||
def send_preview():
|
||||
if not debug:
|
||||
|
@ -75,12 +76,32 @@ def trainerThread (s2c, c2s, args, device_args):
|
|||
for i in itertools.count(0,1):
|
||||
if not debug:
|
||||
if not is_reached_goal:
|
||||
loss_string = model.train_one_iter()
|
||||
if is_upd_save_time_after_train:
|
||||
#save resets plaidML programs, so upd last_save_time only after plaidML rebuild them
|
||||
last_save_time = time.time()
|
||||
iter, iter_time = model.train_one_iter()
|
||||
|
||||
loss_history = model.get_loss_history()
|
||||
time_str = time.strftime("[%H:%M:%S]")
|
||||
if iter_time >= 10:
|
||||
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) )
|
||||
else:
|
||||
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) )
|
||||
|
||||
if shared_state['after_save']:
|
||||
shared_state['after_save'] = False
|
||||
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
|
||||
|
||||
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
|
||||
|
||||
for loss_value in mean_loss:
|
||||
loss_string += "[%.4f]" % (loss_value)
|
||||
|
||||
io.log_info (loss_string)
|
||||
|
||||
save_iter = iter
|
||||
else:
|
||||
for loss_value in loss_history[-1]:
|
||||
loss_string += "[%.4f]" % (loss_value)
|
||||
io.log_info (loss_string, end='\r')
|
||||
|
||||
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
|
||||
io.log_info ('Reached target iteration.')
|
||||
model_save()
|
||||
|
@ -88,7 +109,6 @@ def trainerThread (s2c, c2s, args, device_args):
|
|||
io.log_info ('You can use preview now.')
|
||||
|
||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
||||
last_save_time = time.time()
|
||||
model_save()
|
||||
send_preview()
|
||||
|
||||
|
|
|
@ -374,15 +374,7 @@ class ModelBase(object):
|
|||
|
||||
self.iter += 1
|
||||
|
||||
time_str = time.strftime("[%H:%M:%S]")
|
||||
if iter_time >= 10:
|
||||
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, self.iter, '{:0.4f}'.format(iter_time) )
|
||||
else:
|
||||
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, self.iter, int(iter_time*1000) )
|
||||
for (loss_name, loss_value) in losses:
|
||||
loss_string += " %s:%.3f" % (loss_name, loss_value)
|
||||
|
||||
return loss_string
|
||||
return self.iter, iter_time
|
||||
|
||||
def pass_one_iter(self):
|
||||
self.last_sample = self.generate_next_sample()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue