mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-15 01:23:44 -07:00
DFL-2.0 initial branch commit
This commit is contained in:
parent
52a67a61b3
commit
38b85108b3
154 changed files with 5251 additions and 9414 deletions
|
@ -6,29 +6,31 @@ import time
|
|||
import numpy as np
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from utils import Path_utils
|
||||
import imagelib
|
||||
from core import pathex
|
||||
from core import imagelib
|
||||
import cv2
|
||||
import models
|
||||
from interact import interact as io
|
||||
from core.interact import interact as io
|
||||
|
||||
def trainerThread (s2c, c2s, e, args, device_args):
|
||||
def trainerThread (s2c, c2s, e,
|
||||
model_class_name = None,
|
||||
saved_models_path = None,
|
||||
training_data_src_path = None,
|
||||
training_data_dst_path = None,
|
||||
pretraining_data_path = None,
|
||||
pretrained_model_path = None,
|
||||
no_preview=False,
|
||||
force_model_name=None,
|
||||
force_gpu_idxs=None,
|
||||
cpu_only=None,
|
||||
execute_programs = None,
|
||||
debug=False,
|
||||
**kwargs):
|
||||
while True:
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
training_data_src_path = Path( args.get('training_data_src_dir', '') )
|
||||
training_data_dst_path = Path( args.get('training_data_dst_dir', '') )
|
||||
|
||||
pretraining_data_path = args.get('pretraining_data_dir', '')
|
||||
pretraining_data_path = Path(pretraining_data_path) if pretraining_data_path is not None else None
|
||||
|
||||
model_path = Path( args.get('model_path', '') )
|
||||
model_name = args.get('model_name', '')
|
||||
save_interval_min = 15
|
||||
debug = args.get('debug', '')
|
||||
execute_programs = args.get('execute_programs', [])
|
||||
no_preview = args.get('no_preview', False)
|
||||
|
||||
if not training_data_src_path.exists():
|
||||
io.log_err('Training data src directory does not exist.')
|
||||
|
@ -38,18 +40,22 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
io.log_err('Training data dst directory does not exist.')
|
||||
break
|
||||
|
||||
if not model_path.exists():
|
||||
model_path.mkdir(exist_ok=True)
|
||||
if not saved_models_path.exists():
|
||||
saved_models_path.mkdir(exist_ok=True)
|
||||
|
||||
model = models.import_model(model_name)(
|
||||
model_path,
|
||||
model = models.import_model(model_class_name)(
|
||||
is_training=True,
|
||||
saved_models_path=saved_models_path,
|
||||
training_data_src_path=training_data_src_path,
|
||||
training_data_dst_path=training_data_dst_path,
|
||||
pretraining_data_path=pretraining_data_path,
|
||||
is_training=True,
|
||||
debug=debug,
|
||||
pretrained_model_path=pretrained_model_path,
|
||||
no_preview=no_preview,
|
||||
device_args=device_args)
|
||||
force_model_name=force_model_name,
|
||||
force_gpu_idxs=force_gpu_idxs,
|
||||
cpu_only=cpu_only,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
is_reached_goal = model.is_reached_iter_goal()
|
||||
|
||||
|
@ -71,10 +77,6 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
c2s.put ( {'op':'show', 'previews': previews} )
|
||||
e.set() #Set the GUI Thread as Ready
|
||||
|
||||
|
||||
if model.is_first_run():
|
||||
model_save()
|
||||
|
||||
if model.get_target_iter() != 0:
|
||||
if is_reached_goal:
|
||||
io.log_info('Model already trained to target iteration. You can use preview.')
|
||||
|
@ -108,6 +110,12 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
print("Unable to execute program: %s" % (prog) )
|
||||
|
||||
if not is_reached_goal:
|
||||
|
||||
if model.get_iter() == 0:
|
||||
io.log_info("")
|
||||
io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.")
|
||||
io.log_info("")
|
||||
|
||||
iter, iter_time = model.train_one_iter()
|
||||
|
||||
loss_history = model.get_loss_history()
|
||||
|
@ -119,8 +127,8 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
|
||||
if shared_state['after_save']:
|
||||
shared_state['after_save'] = False
|
||||
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
|
||||
|
||||
last_save_time = time.time()
|
||||
|
||||
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
|
||||
|
||||
for loss_value in mean_loss:
|
||||
|
@ -137,7 +145,10 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
io.log_info ('\r' + loss_string, end='')
|
||||
else:
|
||||
io.log_info (loss_string, end='\r')
|
||||
|
||||
|
||||
if model.get_iter() == 1:
|
||||
model_save()
|
||||
|
||||
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
|
||||
io.log_info ('Reached target iteration.')
|
||||
model_save()
|
||||
|
@ -185,16 +196,16 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
|
||||
|
||||
|
||||
def main(args, device_args):
|
||||
def main(**kwargs):
|
||||
io.log_info ("Running trainer.\r\n")
|
||||
|
||||
no_preview = args.get('no_preview', False)
|
||||
no_preview = kwargs.get('no_preview', False)
|
||||
|
||||
s2c = queue.Queue()
|
||||
c2s = queue.Queue()
|
||||
|
||||
e = threading.Event()
|
||||
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args) )
|
||||
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e), kwargs=kwargs )
|
||||
thread.start()
|
||||
|
||||
e.wait() #Wait for inital load to occur.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue