mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-08 05:51:40 -07:00
after save loss string now shows averaged value since last save
This commit is contained in:
parent
bb02cc1d97
commit
131b2b5c79
2 changed files with 30 additions and 18 deletions
|
@ -41,14 +41,15 @@ def trainerThread (s2c, c2s, args, device_args):
|
||||||
device_args=device_args)
|
device_args=device_args)
|
||||||
|
|
||||||
is_reached_goal = model.is_reached_iter_goal()
|
is_reached_goal = model.is_reached_iter_goal()
|
||||||
is_upd_save_time_after_train = False
|
|
||||||
|
shared_state = { 'after_save' : False }
|
||||||
loss_string = ""
|
loss_string = ""
|
||||||
|
save_iter = model.get_iter()
|
||||||
def model_save():
|
def model_save():
|
||||||
if not debug and not is_reached_goal:
|
if not debug and not is_reached_goal:
|
||||||
io.log_info ("Saving....", end='\r')
|
io.log_info ("Saving....", end='\r')
|
||||||
model.save()
|
model.save()
|
||||||
io.log_info(loss_string)
|
shared_state['after_save'] = True
|
||||||
is_upd_save_time_after_train = True
|
|
||||||
|
|
||||||
def send_preview():
|
def send_preview():
|
||||||
if not debug:
|
if not debug:
|
||||||
|
@ -75,12 +76,32 @@ def trainerThread (s2c, c2s, args, device_args):
|
||||||
for i in itertools.count(0,1):
|
for i in itertools.count(0,1):
|
||||||
if not debug:
|
if not debug:
|
||||||
if not is_reached_goal:
|
if not is_reached_goal:
|
||||||
loss_string = model.train_one_iter()
|
iter, iter_time = model.train_one_iter()
|
||||||
if is_upd_save_time_after_train:
|
|
||||||
#save resets plaidML programs, so upd last_save_time only after plaidML rebuild them
|
|
||||||
last_save_time = time.time()
|
|
||||||
|
|
||||||
|
loss_history = model.get_loss_history()
|
||||||
|
time_str = time.strftime("[%H:%M:%S]")
|
||||||
|
if iter_time >= 10:
|
||||||
|
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) )
|
||||||
|
else:
|
||||||
|
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) )
|
||||||
|
|
||||||
|
if shared_state['after_save']:
|
||||||
|
shared_state['after_save'] = False
|
||||||
|
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
|
||||||
|
|
||||||
|
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
|
||||||
|
|
||||||
|
for loss_value in mean_loss:
|
||||||
|
loss_string += "[%.4f]" % (loss_value)
|
||||||
|
|
||||||
|
io.log_info (loss_string)
|
||||||
|
|
||||||
|
save_iter = iter
|
||||||
|
else:
|
||||||
|
for loss_value in loss_history[-1]:
|
||||||
|
loss_string += "[%.4f]" % (loss_value)
|
||||||
io.log_info (loss_string, end='\r')
|
io.log_info (loss_string, end='\r')
|
||||||
|
|
||||||
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
|
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
|
||||||
io.log_info ('Reached target iteration.')
|
io.log_info ('Reached target iteration.')
|
||||||
model_save()
|
model_save()
|
||||||
|
@ -88,7 +109,6 @@ def trainerThread (s2c, c2s, args, device_args):
|
||||||
io.log_info ('You can use preview now.')
|
io.log_info ('You can use preview now.')
|
||||||
|
|
||||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
||||||
last_save_time = time.time()
|
|
||||||
model_save()
|
model_save()
|
||||||
send_preview()
|
send_preview()
|
||||||
|
|
||||||
|
|
|
@ -374,15 +374,7 @@ class ModelBase(object):
|
||||||
|
|
||||||
self.iter += 1
|
self.iter += 1
|
||||||
|
|
||||||
time_str = time.strftime("[%H:%M:%S]")
|
return self.iter, iter_time
|
||||||
if iter_time >= 10:
|
|
||||||
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, self.iter, '{:0.4f}'.format(iter_time) )
|
|
||||||
else:
|
|
||||||
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, self.iter, int(iter_time*1000) )
|
|
||||||
for (loss_name, loss_value) in losses:
|
|
||||||
loss_string += " %s:%.3f" % (loss_name, loss_value)
|
|
||||||
|
|
||||||
return loss_string
|
|
||||||
|
|
||||||
def pass_one_iter(self):
|
def pass_one_iter(self):
|
||||||
self.last_sample = self.generate_next_sample()
|
self.last_sample = self.generate_next_sample()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue