DeepFaceLab/mainscripts/Trainer.py
iperov 5ac7e5d7f1 changed help message for pixel loss:
Pixel loss may help to enhance fine details and stabilize face color. Use it only if quality does not improve over time.

SAE:
previous SAE model will not work with this update.
Greatly decreased chance of model collapse.
Increased model accuracy.
Residual blocks now default and this option has been removed.
Improved 'learn mask'.
Added masked preview (switch by space key)

Converter:
fixed rct/lct in seamless mode
added mask mode (6) learned*FAN-prd*FAN-dst

added mask editor, its created for refining dataset for FANSeg model, and not for production, but you can spend your time and test it in regular fakes with face obstructions
2019-04-04 10:22:53 +04:00

307 lines
12 KiB
Python

import sys
import traceback
import queue
import threading
import time
import numpy as np
import itertools
from pathlib import Path
from utils import Path_utils
import imagelib
import cv2
import models
from interact import interact as io
def trainerThread (s2c, c2s, args, device_args):
while True:
try:
start_time = time.time()
training_data_src_path = Path( args.get('training_data_src_dir', '') )
training_data_dst_path = Path( args.get('training_data_dst_dir', '') )
model_path = Path( args.get('model_path', '') )
model_name = args.get('model_name', '')
save_interval_min = 15
debug = args.get('debug', '')
execute_programs = args.get('execute_programs', [])
if not training_data_src_path.exists():
io.log_err('Training data src directory does not exist.')
break
if not training_data_dst_path.exists():
io.log_err('Training data dst directory does not exist.')
break
if not model_path.exists():
model_path.mkdir(exist_ok=True)
model = models.import_model(model_name)(
model_path,
training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path,
debug=debug,
device_args=device_args)
is_reached_goal = model.is_reached_iter_goal()
shared_state = { 'after_save' : False }
loss_string = ""
save_iter = model.get_iter()
def model_save():
if not debug and not is_reached_goal:
io.log_info ("Saving....", end='\r')
model.save()
shared_state['after_save'] = True
def send_preview():
if not debug:
previews = model.get_previews()
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
else:
previews = [( 'debug, press update for new', model.debug_one_iter())]
c2s.put ( {'op':'show', 'previews': previews} )
if model.is_first_run():
model_save()
if model.get_target_iter() != 0:
if is_reached_goal:
io.log_info('Model already trained to target iteration. You can use preview.')
else:
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) )
else:
io.log_info('Starting. Press "Enter" to stop training and save model.')
last_save_time = time.time()
for i in itertools.count(0,1):
if not debug:
cur_time = time.time()
for x in execute_programs:
prog_time, prog = x
if prog_time != 0 and (cur_time - start_time) >= prog_time:
x[0] = 0
try:
exec(prog)
except Exception as e:
print("Unable to execute program: %s" % (prog) )
if not is_reached_goal:
iter, iter_time = model.train_one_iter()
loss_history = model.get_loss_history()
time_str = time.strftime("[%H:%M:%S]")
if iter_time >= 10:
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) )
else:
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) )
if shared_state['after_save']:
shared_state['after_save'] = False
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
for loss_value in mean_loss:
loss_string += "[%.4f]" % (loss_value)
io.log_info (loss_string)
save_iter = iter
else:
for loss_value in loss_history[-1]:
loss_string += "[%.4f]" % (loss_value)
if io.is_colab():
io.log_info ('\r' + loss_string, end='')
else:
io.log_info (loss_string, end='\r')
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
io.log_info ('Reached target iteration.')
model_save()
is_reached_goal = True
io.log_info ('You can use preview now.')
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
model_save()
send_preview()
if i==0:
if is_reached_goal:
model.pass_one_iter()
send_preview()
if debug:
time.sleep(0.005)
while not s2c.empty():
input = s2c.get()
op = input['op']
if op == 'save':
model_save()
elif op == 'preview':
if is_reached_goal:
model.pass_one_iter()
send_preview()
elif op == 'close':
model_save()
i = -1
break
if i == -1:
break
model.finalize()
except Exception as e:
print ('Error: %s' % (str(e)))
traceback.print_exc()
break
c2s.put ( {'op':'close'} )
def main(args, device_args):
io.log_info ("Running trainer.\r\n")
no_preview = args.get('no_preview', False)
s2c = queue.Queue()
c2s = queue.Queue()
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, args, device_args) )
thread.start()
if no_preview:
while True:
if not c2s.empty():
input = c2s.get()
op = input.get('op','')
if op == 'close':
break
io.process_messages(0.1)
else:
wnd_name = "Training preview"
io.named_window(wnd_name)
io.capture_keys(wnd_name)
previews = None
loss_history = None
selected_preview = 0
update_preview = False
is_showing = False
is_waiting_preview = False
show_last_history_iters_count = 0
iter = 0
while True:
if not c2s.empty():
input = c2s.get()
op = input['op']
if op == 'show':
is_waiting_preview = False
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
previews = input['previews'] if 'previews' in input.keys() else None
iter = input['iter'] if 'iter' in input.keys() else 0
if previews is not None:
max_w = 0
max_h = 0
for (preview_name, preview_rgb) in previews:
(h, w, c) = preview_rgb.shape
max_h = max (max_h, h)
max_w = max (max_w, w)
max_size = 800
if max_h > max_size:
max_w = int( max_w / (max_h / max_size) )
max_h = max_size
#make all previews size equal
for preview in previews[:]:
(preview_name, preview_rgb) = preview
(h, w, c) = preview_rgb.shape
if h != max_h or w != max_w:
previews.remove(preview)
previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) )
selected_preview = selected_preview % len(previews)
update_preview = True
elif op == 'close':
break
if update_preview:
update_preview = False
selected_preview_name = previews[selected_preview][0]
selected_preview_rgb = previews[selected_preview][1]
(h,w,c) = selected_preview_rgb.shape
# HEAD
head_lines = [
'[s]:save [enter]:exit',
'[p]:update [space]:next preview [l]:change history range',
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) )
]
head_line_height = 15
head_height = len(head_lines) * head_line_height
head = np.ones ( (head_height,w,c) ) * 0.1
for i in range(0, len(head_lines)):
t = i*head_line_height
b = (i+1)*head_line_height
head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c )
final = head
if loss_history is not None:
if show_last_history_iters_count == 0:
loss_history_to_show = loss_history
else:
loss_history_to_show = loss_history[-show_last_history_iters_count:]
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c)
final = np.concatenate ( [final, lh_img], axis=0 )
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
final = np.clip(final, 0, 1)
io.show_image( wnd_name, (final*255).astype(np.uint8) )
is_showing = True
key_events = io.get_key_events(wnd_name)
key, = key_events[-1] if len(key_events) > 0 else (0,)
if key == ord('\n') or key == ord('\r'):
s2c.put ( {'op': 'close'} )
elif key == ord('s'):
s2c.put ( {'op': 'save'} )
elif key == ord('p'):
if not is_waiting_preview:
is_waiting_preview = True
s2c.put ( {'op': 'preview'} )
elif key == ord('l'):
if show_last_history_iters_count == 0:
show_last_history_iters_count = 5000
elif show_last_history_iters_count == 5000:
show_last_history_iters_count = 10000
elif show_last_history_iters_count == 10000:
show_last_history_iters_count = 50000
elif show_last_history_iters_count == 50000:
show_last_history_iters_count = 100000
elif show_last_history_iters_count == 100000:
show_last_history_iters_count = 0
update_preview = True
elif key == ord(' '):
selected_preview = (selected_preview + 1) % len(previews)
update_preview = True
try:
io.process_messages(0.1)
except KeyboardInterrupt:
s2c.put ( {'op': 'close'} )
io.destroy_all_windows()