DFL-2.0 initial branch commit

This commit is contained in:
Colombo 2020-01-21 18:43:39 +04:00
parent 52a67a61b3
commit 38b85108b3
154 changed files with 5251 additions and 9414 deletions

View file

@ -6,29 +6,31 @@ import time
import numpy as np
import itertools
from pathlib import Path
from utils import Path_utils
import imagelib
from core import pathex
from core import imagelib
import cv2
import models
from interact import interact as io
from core.interact import interact as io
def trainerThread (s2c, c2s, e, args, device_args):
def trainerThread (s2c, c2s, e,
model_class_name = None,
saved_models_path = None,
training_data_src_path = None,
training_data_dst_path = None,
pretraining_data_path = None,
pretrained_model_path = None,
no_preview=False,
force_model_name=None,
force_gpu_idxs=None,
cpu_only=None,
execute_programs = None,
debug=False,
**kwargs):
while True:
try:
start_time = time.time()
training_data_src_path = Path( args.get('training_data_src_dir', '') )
training_data_dst_path = Path( args.get('training_data_dst_dir', '') )
pretraining_data_path = args.get('pretraining_data_dir', '')
pretraining_data_path = Path(pretraining_data_path) if pretraining_data_path is not None else None
model_path = Path( args.get('model_path', '') )
model_name = args.get('model_name', '')
save_interval_min = 15
debug = args.get('debug', '')
execute_programs = args.get('execute_programs', [])
no_preview = args.get('no_preview', False)
if not training_data_src_path.exists():
io.log_err('Training data src directory does not exist.')
@ -38,18 +40,22 @@ def trainerThread (s2c, c2s, e, args, device_args):
io.log_err('Training data dst directory does not exist.')
break
if not model_path.exists():
model_path.mkdir(exist_ok=True)
if not saved_models_path.exists():
saved_models_path.mkdir(exist_ok=True)
model = models.import_model(model_name)(
model_path,
model = models.import_model(model_class_name)(
is_training=True,
saved_models_path=saved_models_path,
training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path,
pretraining_data_path=pretraining_data_path,
is_training=True,
debug=debug,
pretrained_model_path=pretrained_model_path,
no_preview=no_preview,
device_args=device_args)
force_model_name=force_model_name,
force_gpu_idxs=force_gpu_idxs,
cpu_only=cpu_only,
debug=debug,
)
is_reached_goal = model.is_reached_iter_goal()
@ -71,10 +77,6 @@ def trainerThread (s2c, c2s, e, args, device_args):
c2s.put ( {'op':'show', 'previews': previews} )
e.set() #Set the GUI Thread as Ready
if model.is_first_run():
model_save()
if model.get_target_iter() != 0:
if is_reached_goal:
io.log_info('Model already trained to target iteration. You can use preview.')
@ -108,6 +110,12 @@ def trainerThread (s2c, c2s, e, args, device_args):
print("Unable to execute program: %s" % (prog) )
if not is_reached_goal:
if model.get_iter() == 0:
io.log_info("")
io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.")
io.log_info("")
iter, iter_time = model.train_one_iter()
loss_history = model.get_loss_history()
@ -119,8 +127,8 @@ def trainerThread (s2c, c2s, e, args, device_args):
if shared_state['after_save']:
shared_state['after_save'] = False
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
last_save_time = time.time()
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
for loss_value in mean_loss:
@ -137,7 +145,10 @@ def trainerThread (s2c, c2s, e, args, device_args):
io.log_info ('\r' + loss_string, end='')
else:
io.log_info (loss_string, end='\r')
if model.get_iter() == 1:
model_save()
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
io.log_info ('Reached target iteration.')
model_save()
@ -185,16 +196,16 @@ def trainerThread (s2c, c2s, e, args, device_args):
def main(args, device_args):
def main(**kwargs):
io.log_info ("Running trainer.\r\n")
no_preview = args.get('no_preview', False)
no_preview = kwargs.get('no_preview', False)
s2c = queue.Queue()
c2s = queue.Queue()
e = threading.Event()
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args) )
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e), kwargs=kwargs )
thread.start()
e.wait() #Wait for inital load to occur.