mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
manual extractor: increased FPS,
sort by final : now you can specify target number of images, converter: fix seamless mask and exception, huge refactoring
This commit is contained in:
parent
7db469a1da
commit
438213e97c
30 changed files with 1834 additions and 1718 deletions
|
@ -7,31 +7,34 @@ from pathlib import Path
|
|||
from utils import Path_utils
|
||||
from utils import std_utils
|
||||
from utils import image_utils
|
||||
from utils.console_utils import *
|
||||
from utils.cv2_utils import *
|
||||
import numpy as np
|
||||
import cv2
|
||||
from samples import SampleGeneratorBase
|
||||
from nnlib import nnlib
|
||||
from interact import interact as io
|
||||
'''
|
||||
You can implement your own model. Check examples.
|
||||
'''
|
||||
class ModelBase(object):
|
||||
|
||||
#DONT OVERRIDE
|
||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, force_gpu_idx=-1, **in_options):
|
||||
|
||||
if force_gpu_idx == -1:
|
||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None):
|
||||
|
||||
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
|
||||
|
||||
if device_args['force_gpu_idx'] == -1:
|
||||
idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
|
||||
if len(idxs_names_list) > 1:
|
||||
print ("You have multi GPUs in a system: ")
|
||||
io.log_info ("You have multi GPUs in a system: ")
|
||||
for idx, name in idxs_names_list:
|
||||
print ("[%d] : %s" % (idx, name) )
|
||||
io.log_info ("[%d] : %s" % (idx, name) )
|
||||
|
||||
force_gpu_idx = input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] )
|
||||
self.force_gpu_idx = force_gpu_idx
|
||||
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] )
|
||||
self.device_args = device_args
|
||||
|
||||
print ("Loading model...")
|
||||
io.log_info ("Loading model...")
|
||||
|
||||
self.model_path = model_path
|
||||
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
|
||||
|
||||
|
@ -46,9 +49,7 @@ class ModelBase(object):
|
|||
self.dst_data_generator = None
|
||||
self.debug = debug
|
||||
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
|
||||
|
||||
self.supress_std_once = os.environ.get('TF_SUPPRESS_STD', '0') == '1'
|
||||
|
||||
|
||||
self.epoch = 0
|
||||
self.options = {}
|
||||
self.loss_history = []
|
||||
|
@ -61,40 +62,40 @@ class ModelBase(object):
|
|||
self.loss_history = model_data['loss_history'] if 'loss_history' in model_data.keys() else []
|
||||
self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None
|
||||
|
||||
ask_override = self.is_training_mode and self.epoch != 0 and input_in_time ("Press enter in 2 seconds to override model settings.", 2)
|
||||
ask_override = self.is_training_mode and self.epoch != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 2)
|
||||
|
||||
if self.epoch == 0:
|
||||
print ("\nModel first run. Enter model options as default for each run.")
|
||||
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
||||
|
||||
if self.epoch == 0 or ask_override:
|
||||
default_write_preview_history = False if self.epoch == 0 else self.options.get('write_preview_history',False)
|
||||
self.options['write_preview_history'] = input_bool("Write preview history? (y/n ?:help skip:n/default) : ", default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
||||
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:n/default) : ", default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
||||
else:
|
||||
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
||||
|
||||
if self.epoch == 0 or ask_override:
|
||||
self.options['target_epoch'] = max(0, input_int("Target epoch (skip:unlimited/default) : ", 0))
|
||||
self.options['target_epoch'] = max(0, io.input_int("Target epoch (skip:unlimited/default) : ", 0))
|
||||
else:
|
||||
self.options['target_epoch'] = self.options.get('target_epoch', 0)
|
||||
|
||||
if self.epoch == 0 or ask_override:
|
||||
default_batch_size = 0 if self.epoch == 0 else self.options.get('batch_size',0)
|
||||
self.options['batch_size'] = max(0, input_int("Batch_size (?:help skip:0/default) : ", default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
||||
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:0/default) : ", default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
||||
else:
|
||||
self.options['batch_size'] = self.options.get('batch_size', 0)
|
||||
|
||||
if self.epoch == 0:
|
||||
self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
|
||||
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
|
||||
else:
|
||||
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
||||
|
||||
if self.epoch == 0:
|
||||
self.options['random_flip'] = input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
||||
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
||||
else:
|
||||
self.options['random_flip'] = self.options.get('random_flip', True)
|
||||
|
||||
if self.epoch == 0:
|
||||
self.options['src_scale_mod'] = np.clip( input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
||||
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
||||
else:
|
||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||
|
||||
|
@ -116,10 +117,10 @@ class ModelBase(object):
|
|||
|
||||
self.onInitializeOptions(self.epoch == 0, ask_override)
|
||||
|
||||
nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, force_gpu_idx=self.force_gpu_idx, **in_options) )
|
||||
nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, **self.device_args) )
|
||||
self.device_config = nnlib.active_DeviceConfig
|
||||
|
||||
self.onInitialize(**in_options)
|
||||
self.onInitialize()
|
||||
|
||||
self.options['batch_size'] = self.batch_size
|
||||
|
||||
|
@ -128,10 +129,10 @@ class ModelBase(object):
|
|||
|
||||
if self.is_training_mode:
|
||||
if self.write_preview_history:
|
||||
if self.force_gpu_idx == -1:
|
||||
if self.device_args['force_gpu_idx'] == -1:
|
||||
self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) )
|
||||
else:
|
||||
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.force_gpu_idx, self.get_model_name()) )
|
||||
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
|
||||
|
||||
if not self.preview_history_path.exists():
|
||||
self.preview_history_path.mkdir(exist_ok=True)
|
||||
|
@ -141,11 +142,11 @@ class ModelBase(object):
|
|||
Path(filename).unlink()
|
||||
|
||||
if self.generator_list is None:
|
||||
raise Exception( 'You didnt set_training_data_generators()')
|
||||
raise ValueError( 'You didnt set_training_data_generators()')
|
||||
else:
|
||||
for i, generator in enumerate(self.generator_list):
|
||||
if not isinstance(generator, SampleGeneratorBase):
|
||||
raise Exception('training data generator is not subclass of SampleGeneratorBase')
|
||||
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
||||
|
||||
if (self.sample_for_preview is None) or (self.epoch == 0):
|
||||
self.sample_for_preview = self.generate_next_sample()
|
||||
|
@ -181,14 +182,14 @@ class ModelBase(object):
|
|||
model_summary_text += ["========================="]
|
||||
model_summary_text = "\r\n".join (model_summary_text)
|
||||
self.model_summary_text = model_summary_text
|
||||
print(model_summary_text)
|
||||
io.log_info(model_summary_text)
|
||||
|
||||
#overridable
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
pass
|
||||
|
||||
#overridable
|
||||
def onInitialize(self, **in_options):
|
||||
def onInitialize(self):
|
||||
'''
|
||||
initialize your keras models
|
||||
|
||||
|
@ -221,10 +222,9 @@ class ModelBase(object):
|
|||
return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
|
||||
|
||||
#overridable
|
||||
def get_converter(self, **in_options):
|
||||
#return existing or your own converter which derived from base
|
||||
from .ConverterBase import ConverterBase
|
||||
return ConverterBase(self, **in_options)
|
||||
def get_converter(self):
|
||||
raise NotImplementeError
|
||||
#return existing or your own converter which derived from base
|
||||
|
||||
def get_target_epoch(self):
|
||||
return self.target_epoch
|
||||
|
@ -258,17 +258,10 @@ class ModelBase(object):
|
|||
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
||||
|
||||
def save(self):
|
||||
print ("Saving...")
|
||||
io.log_info ("Saving...")
|
||||
|
||||
if self.supress_std_once:
|
||||
supressor = std_utils.suppress_stdout_stderr()
|
||||
supressor.__enter__()
|
||||
|
||||
Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text)
|
||||
self.onSave()
|
||||
|
||||
if self.supress_std_once:
|
||||
supressor.__exit__()
|
||||
|
||||
model_data = {
|
||||
'epoch': self.epoch,
|
||||
|
@ -310,11 +303,7 @@ class ModelBase(object):
|
|||
def generate_next_sample(self):
|
||||
return [next(generator) for generator in self.generator_list]
|
||||
|
||||
def train_one_epoch(self):
|
||||
if self.supress_std_once:
|
||||
supressor = std_utils.suppress_stdout_stderr()
|
||||
supressor.__enter__()
|
||||
|
||||
def train_one_epoch(self):
|
||||
sample = self.generate_next_sample()
|
||||
epoch_time = time.time()
|
||||
losses = self.onTrainOneEpoch(sample, self.generator_list)
|
||||
|
@ -322,11 +311,7 @@ class ModelBase(object):
|
|||
self.last_sample = sample
|
||||
|
||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
||||
|
||||
if self.supress_std_once:
|
||||
supressor.__exit__()
|
||||
self.supress_std_once = False
|
||||
|
||||
|
||||
if self.write_preview_history:
|
||||
if self.epoch % 10 == 0:
|
||||
preview = self.get_static_preview()
|
||||
|
@ -377,10 +362,10 @@ class ModelBase(object):
|
|||
return self.generator_list
|
||||
|
||||
def get_strpath_storage_for_file(self, filename):
|
||||
if self.force_gpu_idx == -1:
|
||||
if self.device_args['force_gpu_idx'] == -1:
|
||||
return str( self.model_path / ( self.get_model_name() + '_' + filename) )
|
||||
else:
|
||||
return str( self.model_path / ( str(self.force_gpu_idx) + '_' + self.get_model_name() + '_' + filename) )
|
||||
return str( self.model_path / ( str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename) )
|
||||
|
||||
def set_vram_batch_requirements (self, d):
|
||||
#example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue