mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: both decoders.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.
312 lines
12 KiB
Python
312 lines
12 KiB
Python
import sys
|
|
import traceback
|
|
import queue
|
|
import threading
|
|
import time
|
|
import numpy as np
|
|
import itertools
|
|
from pathlib import Path
|
|
from utils import Path_utils
|
|
import imagelib
|
|
import cv2
|
|
import models
|
|
from interact import interact as io
|
|
|
|
def trainerThread (s2c, c2s, args, device_args):
|
|
while True:
|
|
try:
|
|
start_time = time.time()
|
|
|
|
training_data_src_path = Path( args.get('training_data_src_dir', '') )
|
|
training_data_dst_path = Path( args.get('training_data_dst_dir', '') )
|
|
pretraining_data_path = Path( args.get('pretraining_data_dir', '') )
|
|
model_path = Path( args.get('model_path', '') )
|
|
model_name = args.get('model_name', '')
|
|
save_interval_min = 15
|
|
debug = args.get('debug', '')
|
|
execute_programs = args.get('execute_programs', [])
|
|
|
|
if not training_data_src_path.exists():
|
|
io.log_err('Training data src directory does not exist.')
|
|
break
|
|
|
|
if not training_data_dst_path.exists():
|
|
io.log_err('Training data dst directory does not exist.')
|
|
break
|
|
|
|
if not model_path.exists():
|
|
model_path.mkdir(exist_ok=True)
|
|
|
|
model = models.import_model(model_name)(
|
|
model_path,
|
|
training_data_src_path=training_data_src_path,
|
|
training_data_dst_path=training_data_dst_path,
|
|
pretraining_data_path=pretraining_data_path,
|
|
debug=debug,
|
|
device_args=device_args)
|
|
|
|
is_reached_goal = model.is_reached_iter_goal()
|
|
|
|
shared_state = { 'after_save' : False }
|
|
loss_string = ""
|
|
save_iter = model.get_iter()
|
|
def model_save():
|
|
if not debug and not is_reached_goal:
|
|
io.log_info ("Saving....", end='\r')
|
|
model.save()
|
|
shared_state['after_save'] = True
|
|
|
|
def send_preview():
|
|
if not debug:
|
|
previews = model.get_previews()
|
|
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
|
|
else:
|
|
previews = [( 'debug, press update for new', model.debug_one_iter())]
|
|
c2s.put ( {'op':'show', 'previews': previews} )
|
|
|
|
|
|
if model.is_first_run():
|
|
model_save()
|
|
|
|
if model.get_target_iter() != 0:
|
|
if is_reached_goal:
|
|
io.log_info('Model already trained to target iteration. You can use preview.')
|
|
else:
|
|
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) )
|
|
else:
|
|
io.log_info('Starting. Press "Enter" to stop training and save model.')
|
|
|
|
last_save_time = time.time()
|
|
|
|
for i in itertools.count(0,1):
|
|
if not debug:
|
|
cur_time = time.time()
|
|
|
|
for x in execute_programs:
|
|
prog_time, prog = x
|
|
if prog_time != 0 and (cur_time - start_time) >= prog_time:
|
|
x[0] = 0
|
|
try:
|
|
exec(prog)
|
|
except Exception as e:
|
|
print("Unable to execute program: %s" % (prog) )
|
|
|
|
if not is_reached_goal:
|
|
iter, iter_time = model.train_one_iter()
|
|
|
|
loss_history = model.get_loss_history()
|
|
time_str = time.strftime("[%H:%M:%S]")
|
|
if iter_time >= 10:
|
|
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) )
|
|
else:
|
|
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) )
|
|
|
|
if shared_state['after_save']:
|
|
shared_state['after_save'] = False
|
|
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
|
|
|
|
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
|
|
|
|
for loss_value in mean_loss:
|
|
loss_string += "[%.4f]" % (loss_value)
|
|
|
|
io.log_info (loss_string)
|
|
|
|
save_iter = iter
|
|
else:
|
|
for loss_value in loss_history[-1]:
|
|
loss_string += "[%.4f]" % (loss_value)
|
|
|
|
if io.is_colab():
|
|
io.log_info ('\r' + loss_string, end='')
|
|
else:
|
|
io.log_info (loss_string, end='\r')
|
|
|
|
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
|
|
io.log_info ('Reached target iteration.')
|
|
model_save()
|
|
is_reached_goal = True
|
|
io.log_info ('You can use preview now.')
|
|
|
|
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
|
model_save()
|
|
send_preview()
|
|
|
|
if i==0:
|
|
if is_reached_goal:
|
|
model.pass_one_iter()
|
|
send_preview()
|
|
|
|
if debug:
|
|
time.sleep(0.005)
|
|
|
|
while not s2c.empty():
|
|
input = s2c.get()
|
|
op = input['op']
|
|
if op == 'save':
|
|
model_save()
|
|
elif op == 'preview':
|
|
if is_reached_goal:
|
|
model.pass_one_iter()
|
|
send_preview()
|
|
elif op == 'close':
|
|
model_save()
|
|
i = -1
|
|
break
|
|
|
|
if i == -1:
|
|
break
|
|
|
|
|
|
|
|
model.finalize()
|
|
|
|
except Exception as e:
|
|
print ('Error: %s' % (str(e)))
|
|
traceback.print_exc()
|
|
break
|
|
c2s.put ( {'op':'close'} )
|
|
|
|
|
|
|
|
def main(args, device_args):
|
|
io.log_info ("Running trainer.\r\n")
|
|
|
|
no_preview = args.get('no_preview', False)
|
|
|
|
s2c = queue.Queue()
|
|
c2s = queue.Queue()
|
|
|
|
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, args, device_args) )
|
|
thread.start()
|
|
|
|
if no_preview:
|
|
while True:
|
|
if not c2s.empty():
|
|
input = c2s.get()
|
|
op = input.get('op','')
|
|
if op == 'close':
|
|
break
|
|
try:
|
|
io.process_messages(0.1)
|
|
except KeyboardInterrupt:
|
|
s2c.put ( {'op': 'close'} )
|
|
else:
|
|
wnd_name = "Training preview"
|
|
io.named_window(wnd_name)
|
|
io.capture_keys(wnd_name)
|
|
|
|
previews = None
|
|
loss_history = None
|
|
selected_preview = 0
|
|
update_preview = False
|
|
is_showing = False
|
|
is_waiting_preview = False
|
|
show_last_history_iters_count = 0
|
|
iter = 0
|
|
while True:
|
|
if not c2s.empty():
|
|
input = c2s.get()
|
|
op = input['op']
|
|
if op == 'show':
|
|
is_waiting_preview = False
|
|
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
|
|
previews = input['previews'] if 'previews' in input.keys() else None
|
|
iter = input['iter'] if 'iter' in input.keys() else 0
|
|
if previews is not None:
|
|
max_w = 0
|
|
max_h = 0
|
|
for (preview_name, preview_rgb) in previews:
|
|
(h, w, c) = preview_rgb.shape
|
|
max_h = max (max_h, h)
|
|
max_w = max (max_w, w)
|
|
|
|
max_size = 800
|
|
if max_h > max_size:
|
|
max_w = int( max_w / (max_h / max_size) )
|
|
max_h = max_size
|
|
|
|
#make all previews size equal
|
|
for preview in previews[:]:
|
|
(preview_name, preview_rgb) = preview
|
|
(h, w, c) = preview_rgb.shape
|
|
if h != max_h or w != max_w:
|
|
previews.remove(preview)
|
|
previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) )
|
|
selected_preview = selected_preview % len(previews)
|
|
update_preview = True
|
|
elif op == 'close':
|
|
break
|
|
|
|
if update_preview:
|
|
update_preview = False
|
|
|
|
selected_preview_name = previews[selected_preview][0]
|
|
selected_preview_rgb = previews[selected_preview][1]
|
|
(h,w,c) = selected_preview_rgb.shape
|
|
|
|
# HEAD
|
|
head_lines = [
|
|
'[s]:save [enter]:exit',
|
|
'[p]:update [space]:next preview [l]:change history range',
|
|
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) )
|
|
]
|
|
head_line_height = 15
|
|
head_height = len(head_lines) * head_line_height
|
|
head = np.ones ( (head_height,w,c) ) * 0.1
|
|
|
|
for i in range(0, len(head_lines)):
|
|
t = i*head_line_height
|
|
b = (i+1)*head_line_height
|
|
head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c )
|
|
|
|
final = head
|
|
|
|
if loss_history is not None:
|
|
if show_last_history_iters_count == 0:
|
|
loss_history_to_show = loss_history
|
|
else:
|
|
loss_history_to_show = loss_history[-show_last_history_iters_count:]
|
|
|
|
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c)
|
|
final = np.concatenate ( [final, lh_img], axis=0 )
|
|
|
|
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
|
|
final = np.clip(final, 0, 1)
|
|
|
|
io.show_image( wnd_name, (final*255).astype(np.uint8) )
|
|
is_showing = True
|
|
|
|
key_events = io.get_key_events(wnd_name)
|
|
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
|
|
|
|
if key == ord('\n') or key == ord('\r'):
|
|
s2c.put ( {'op': 'close'} )
|
|
elif key == ord('s'):
|
|
s2c.put ( {'op': 'save'} )
|
|
elif key == ord('p'):
|
|
if not is_waiting_preview:
|
|
is_waiting_preview = True
|
|
s2c.put ( {'op': 'preview'} )
|
|
elif key == ord('l'):
|
|
if show_last_history_iters_count == 0:
|
|
show_last_history_iters_count = 5000
|
|
elif show_last_history_iters_count == 5000:
|
|
show_last_history_iters_count = 10000
|
|
elif show_last_history_iters_count == 10000:
|
|
show_last_history_iters_count = 50000
|
|
elif show_last_history_iters_count == 50000:
|
|
show_last_history_iters_count = 100000
|
|
elif show_last_history_iters_count == 100000:
|
|
show_last_history_iters_count = 0
|
|
update_preview = True
|
|
elif key == ord(' '):
|
|
selected_preview = (selected_preview + 1) % len(previews)
|
|
update_preview = True
|
|
|
|
try:
|
|
io.process_messages(0.1)
|
|
except KeyboardInterrupt:
|
|
s2c.put ( {'op': 'close'} )
|
|
|
|
io.destroy_all_windows()
|