mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE: new options face_style_power, bg_style_power instead of style_power. Zero - means dont use style.
SAE: new option 'lighter_encoder'. Now model settings can be overrided by pressing enter in 2 seconds while model loading (works on Windows). Removed all MultiGPU models, because keras multi_gpu in fact doesn't work.
This commit is contained in:
parent
c3f175862a
commit
48d0123f0b
9 changed files with 158 additions and 139 deletions
12
main.py
12
main.py
|
@ -81,16 +81,10 @@ if __name__ == "__main__":
|
||||||
training_data_dst_dir=arguments.training_data_dst_dir,
|
training_data_dst_dir=arguments.training_data_dst_dir,
|
||||||
model_path=arguments.model_dir,
|
model_path=arguments.model_dir,
|
||||||
model_name=arguments.model_name,
|
model_name=arguments.model_name,
|
||||||
ask_for_session_options = arguments.ask_for_session_options,
|
|
||||||
debug = arguments.debug,
|
debug = arguments.debug,
|
||||||
#**options
|
#**options
|
||||||
session_write_preview_history = arguments.session_write_preview_history,
|
|
||||||
session_target_epoch = arguments.session_target_epoch,
|
|
||||||
session_batch_size = arguments.session_batch_size,
|
|
||||||
save_interval_min = arguments.save_interval_min,
|
|
||||||
choose_worst_gpu = arguments.choose_worst_gpu,
|
choose_worst_gpu = arguments.choose_worst_gpu,
|
||||||
force_best_gpu_idx = arguments.force_best_gpu_idx,
|
force_best_gpu_idx = arguments.force_best_gpu_idx,
|
||||||
multi_gpu = arguments.multi_gpu,
|
|
||||||
force_gpu_idxs = arguments.force_gpu_idxs,
|
force_gpu_idxs = arguments.force_gpu_idxs,
|
||||||
cpu_only = arguments.cpu_only
|
cpu_only = arguments.cpu_only
|
||||||
)
|
)
|
||||||
|
@ -101,14 +95,8 @@ if __name__ == "__main__":
|
||||||
train_parser.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.")
|
train_parser.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.")
|
||||||
train_parser.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model")
|
train_parser.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model")
|
||||||
train_parser.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.")
|
train_parser.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.")
|
||||||
train_parser.add_argument('--ask-for-session-options', action="store_true", dest="ask_for_session_options", default=False, help="Ask to override session options.")
|
|
||||||
train_parser.add_argument('--session-write-preview-history', action="store_true", dest="session_write_preview_history", default=None, help="Enable write preview history for this session.")
|
|
||||||
train_parser.add_argument('--session-target-epoch', type=int, dest="session_target_epoch", default=0, help="Train until target epoch for this session. Default - unlimited. Environment variable to override: DFL_TARGET_EPOCH.")
|
|
||||||
train_parser.add_argument('--session-batch-size', type=int, dest="session_batch_size", default=0, help="Model batch size for this session. Default - auto. Environment variable to override: DFL_BATCH_SIZE.")
|
|
||||||
train_parser.add_argument('--save-interval-min', type=int, dest="save_interval_min", default=10, help="Save interval in minutes. Default 10.")
|
|
||||||
train_parser.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
train_parser.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
||||||
train_parser.add_argument('--force-gpu-idxs', type=str, dest="force_gpu_idxs", default=None, help="Override final GPU idxs. Example: 0,1,2.")
|
train_parser.add_argument('--force-gpu-idxs', type=str, dest="force_gpu_idxs", default=None, help="Override final GPU idxs. Example: 0,1,2.")
|
||||||
train_parser.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="MultiGPU option (if model supports it). It will select only same best(worst) GPU models.")
|
|
||||||
train_parser.add_argument('--choose-worst-gpu', action="store_true", dest="choose_worst_gpu", default=False, help="Choose worst GPU instead of best. Environment variable to force True: DFL_WORST_GPU")
|
train_parser.add_argument('--choose-worst-gpu', action="store_true", dest="choose_worst_gpu", default=False, help="Choose worst GPU instead of best. Environment variable to force True: DFL_WORST_GPU")
|
||||||
train_parser.add_argument('--force-best-gpu-idx', type=int, dest="force_best_gpu_idx", default=-1, help="Force to choose this GPU idx as best(worst).")
|
train_parser.add_argument('--force-best-gpu-idx', type=int, dest="force_best_gpu_idx", default=-1, help="Force to choose this GPU idx as best(worst).")
|
||||||
|
|
||||||
|
|
|
@ -18,14 +18,7 @@ You can implement your own model. Check examples.
|
||||||
class ModelBase(object):
|
class ModelBase(object):
|
||||||
|
|
||||||
#DONT OVERRIDE
|
#DONT OVERRIDE
|
||||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None,
|
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, **in_options):
|
||||||
ask_for_session_options=False,
|
|
||||||
session_write_preview_history = None,
|
|
||||||
session_target_epoch=0,
|
|
||||||
session_batch_size=0,
|
|
||||||
|
|
||||||
debug = False, **in_options
|
|
||||||
):
|
|
||||||
print ("Loading model...")
|
print ("Loading model...")
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
|
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
|
||||||
|
@ -56,50 +49,52 @@ class ModelBase(object):
|
||||||
self.loss_history = model_data['loss_history'] if 'loss_history' in model_data.keys() else []
|
self.loss_history = model_data['loss_history'] if 'loss_history' in model_data.keys() else []
|
||||||
self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None
|
self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None
|
||||||
|
|
||||||
|
ask_override = self.epoch != 0 and input_in_time ("Press enter during 2 seconds to override some model settings.", 2)
|
||||||
|
|
||||||
if self.epoch == 0:
|
if self.epoch == 0:
|
||||||
print ("\nModel first run. Enter model options as default for each run.")
|
print ("\nModel first run. Enter model options as default for each run.")
|
||||||
|
|
||||||
|
if self.epoch == 0 or ask_override:
|
||||||
self.options['write_preview_history'] = input_bool("Write preview history? (y/n ?:help skip:n) : ", False, help_message="Preview history will be writed to <ModelName>_history folder.")
|
self.options['write_preview_history'] = input_bool("Write preview history? (y/n ?:help skip:n) : ", False, help_message="Preview history will be writed to <ModelName>_history folder.")
|
||||||
self.options['target_epoch'] = max(0, input_int("Target epoch (skip:unlimited) : ", 0))
|
|
||||||
self.options['batch_size'] = max(0, input_int("Batch_size (?:help skip:model choice) : ", 0, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
|
||||||
self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
|
|
||||||
self.options['random_flip'] = input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
|
||||||
self.options['src_scale_mod'] = np.clip( input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
|
||||||
#self.options['use_fp16'] = use_fp16 = input_bool("Use float16? (y/n skip:n) : ", False)
|
|
||||||
else:
|
|
||||||
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
|
||||||
self.options['target_epoch'] = self.options.get('target_epoch', 0)
|
|
||||||
self.options['batch_size'] = self.options.get('batch_size', 0)
|
|
||||||
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
|
||||||
self.options['random_flip'] = self.options.get('random_flip', True)
|
|
||||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
|
||||||
#self.options['use_fp16'] = use_fp16 = self.options['use_fp16'] if 'use_fp16' in self.options.keys() else False
|
|
||||||
|
|
||||||
use_fp16 = False #currently models fails with fp16
|
|
||||||
|
|
||||||
if ask_for_session_options:
|
|
||||||
print ("Override options for current session:")
|
|
||||||
session_write_preview_history = input_bool("Write preview history? (y/n skip:default) : ", None )
|
|
||||||
session_target_epoch = input_int("Target epoch (skip:default) : ", 0)
|
|
||||||
session_batch_size = input_int("Batch_size (skip:default) : ", 0)
|
|
||||||
|
|
||||||
if self.options['write_preview_history']:
|
|
||||||
if session_write_preview_history is None:
|
|
||||||
session_write_preview_history = self.options['write_preview_history']
|
|
||||||
else:
|
else:
|
||||||
|
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
||||||
|
|
||||||
|
if self.epoch == 0 or ask_override:
|
||||||
|
self.options['target_epoch'] = max(0, input_int("Target epoch (skip:unlimited) : ", 0))
|
||||||
|
else:
|
||||||
|
self.options['target_epoch'] = self.options.get('target_epoch', 0)
|
||||||
|
|
||||||
|
if self.epoch == 0 or ask_override:
|
||||||
|
default_batch_size = 0 if self.epoch == 0 else self.options['batch_size']
|
||||||
|
self.options['batch_size'] = max(0, input_int("Batch_size (?:help skip:default) : ", default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
||||||
|
else:
|
||||||
|
self.options['batch_size'] = self.options.get('batch_size', 0)
|
||||||
|
|
||||||
|
if self.epoch == 0:
|
||||||
|
self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
|
||||||
|
else:
|
||||||
|
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
||||||
|
|
||||||
|
if self.epoch == 0:
|
||||||
|
self.options['random_flip'] = input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
||||||
|
else:
|
||||||
|
self.options['random_flip'] = self.options.get('random_flip', True)
|
||||||
|
|
||||||
|
if self.epoch == 0:
|
||||||
|
self.options['src_scale_mod'] = np.clip( input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
||||||
|
else:
|
||||||
|
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||||
|
|
||||||
|
self.write_preview_history = self.options['write_preview_history']
|
||||||
|
if not self.options['write_preview_history']:
|
||||||
self.options.pop('write_preview_history')
|
self.options.pop('write_preview_history')
|
||||||
|
|
||||||
if self.options['target_epoch'] != 0:
|
self.target_epoch = self.options['target_epoch']
|
||||||
if session_target_epoch == 0:
|
if self.options['target_epoch'] == 0:
|
||||||
session_target_epoch = self.options['target_epoch']
|
|
||||||
else:
|
|
||||||
self.options.pop('target_epoch')
|
self.options.pop('target_epoch')
|
||||||
|
|
||||||
if self.options['batch_size'] != 0:
|
self.batch_size = self.options['batch_size']
|
||||||
if session_batch_size == 0:
|
|
||||||
session_batch_size = self.options['batch_size']
|
|
||||||
else:
|
|
||||||
self.options.pop('batch_size')
|
|
||||||
|
|
||||||
self.sort_by_yaw = self.options['sort_by_yaw']
|
self.sort_by_yaw = self.options['sort_by_yaw']
|
||||||
if not self.sort_by_yaw:
|
if not self.sort_by_yaw:
|
||||||
self.options.pop('sort_by_yaw')
|
self.options.pop('sort_by_yaw')
|
||||||
|
@ -112,18 +107,17 @@ class ModelBase(object):
|
||||||
if self.src_scale_mod == 0:
|
if self.src_scale_mod == 0:
|
||||||
self.options.pop('src_scale_mod')
|
self.options.pop('src_scale_mod')
|
||||||
|
|
||||||
self.write_preview_history = session_write_preview_history
|
self.onInitializeOptions(self.epoch == 0, ask_override)
|
||||||
self.target_epoch = session_target_epoch
|
|
||||||
self.batch_size = session_batch_size
|
|
||||||
self.onInitializeOptions(self.epoch == 0, ask_for_session_options)
|
|
||||||
|
|
||||||
nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, use_fp16=use_fp16, **in_options) )
|
nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, **in_options) )
|
||||||
self.device_config = nnlib.active_DeviceConfig
|
self.device_config = nnlib.active_DeviceConfig
|
||||||
|
|
||||||
self.created_vram_gb = self.options['created_vram_gb'] if 'created_vram_gb' in self.options.keys() else self.device_config.gpu_total_vram_gb
|
self.created_vram_gb = self.options['created_vram_gb'] if 'created_vram_gb' in self.options.keys() else self.device_config.gpu_total_vram_gb
|
||||||
|
|
||||||
self.onInitialize(**in_options)
|
self.onInitialize(**in_options)
|
||||||
|
|
||||||
|
self.options['batch_size'] = self.batch_size
|
||||||
|
|
||||||
if self.debug or self.batch_size == 0:
|
if self.debug or self.batch_size == 0:
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
|
|
||||||
|
@ -155,16 +149,10 @@ class ModelBase(object):
|
||||||
print ("==")
|
print ("==")
|
||||||
print ("== Model options:")
|
print ("== Model options:")
|
||||||
for key in self.options.keys():
|
for key in self.options.keys():
|
||||||
print ("== |== %s : %s" % (key, self.options[key]) )
|
print ("== |== %s : %s" % (key, self.options[key]) )
|
||||||
print ("== Session options:")
|
|
||||||
if self.write_preview_history:
|
|
||||||
print ("== |== write_preview_history : True ")
|
|
||||||
if self.target_epoch != 0:
|
|
||||||
print ("== |== target_epoch : %s " % (self.target_epoch) )
|
|
||||||
print ("== |== batch_size : %s " % (self.batch_size) )
|
|
||||||
if self.device_config.multi_gpu:
|
if self.device_config.multi_gpu:
|
||||||
print ("== |== multi_gpu : True ")
|
print ("== |== multi_gpu : True ")
|
||||||
|
|
||||||
|
|
||||||
print ("== Running on:")
|
print ("== Running on:")
|
||||||
if self.device_config.cpu_only:
|
if self.device_config.cpu_only:
|
||||||
|
@ -183,7 +171,7 @@ class ModelBase(object):
|
||||||
print ("=========================")
|
print ("=========================")
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def onInitializeOptions(self, is_first_run, ask_for_session_options):
|
def onInitializeOptions(self, is_first_run, ask_override):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
|
@ -231,23 +219,24 @@ class ModelBase(object):
|
||||||
def is_reached_epoch_goal(self):
|
def is_reached_epoch_goal(self):
|
||||||
return self.target_epoch != 0 and self.epoch >= self.target_epoch
|
return self.target_epoch != 0 and self.epoch >= self.target_epoch
|
||||||
|
|
||||||
def to_multi_gpu_model_if_possible (self, models_list):
|
#multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976
|
||||||
if len(self.device_config.gpu_idxs) > 1:
|
#def to_multi_gpu_model_if_possible (self, models_list):
|
||||||
#make batch_size to divide on GPU count without remainder
|
# if len(self.device_config.gpu_idxs) > 1:
|
||||||
self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) )
|
# #make batch_size to divide on GPU count without remainder
|
||||||
if self.batch_size == 0:
|
# self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) )
|
||||||
self.batch_size = 1
|
# if self.batch_size == 0:
|
||||||
self.batch_size *= len(self.device_config.gpu_idxs)
|
# self.batch_size = 1
|
||||||
|
# self.batch_size *= len(self.device_config.gpu_idxs)
|
||||||
result = []
|
#
|
||||||
for model in models_list:
|
# result = []
|
||||||
for i in range( len(model.output_names) ):
|
# for model in models_list:
|
||||||
model.output_names = 'output_%d' % (i)
|
# for i in range( len(model.output_names) ):
|
||||||
result += [ nnlib.keras.utils.multi_gpu_model( model, self.device_config.gpu_idxs ) ]
|
# model.output_names = 'output_%d' % (i)
|
||||||
|
# result += [ nnlib.keras.utils.multi_gpu_model( model, self.device_config.gpu_idxs ) ]
|
||||||
return result
|
#
|
||||||
else:
|
# return result
|
||||||
return models_list
|
# else:
|
||||||
|
# return models_list
|
||||||
|
|
||||||
def get_previews(self):
|
def get_previews(self):
|
||||||
return self.onGetPreview ( self.last_sample )
|
return self.onGetPreview ( self.last_sample )
|
||||||
|
|
|
@ -29,9 +29,6 @@ class Model(ModelBase):
|
||||||
self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder_src(self.encoder(ae_input_layer)))
|
self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder_src(self.encoder(ae_input_layer)))
|
||||||
self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder_dst(self.encoder(ae_input_layer)))
|
self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder_dst(self.encoder(ae_input_layer)))
|
||||||
|
|
||||||
if self.is_training_mode:
|
|
||||||
self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] )
|
|
||||||
|
|
||||||
self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
||||||
self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
||||||
|
|
||||||
|
|
|
@ -32,9 +32,6 @@ class Model(ModelBase):
|
||||||
|
|
||||||
self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
|
self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
|
||||||
|
|
||||||
if self.is_training_mode:
|
|
||||||
self.ae, = self.to_multi_gpu_model_if_possible ( [self.ae,] )
|
|
||||||
|
|
||||||
self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
|
self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
|
||||||
loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] )
|
loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] )
|
||||||
|
|
||||||
|
|
|
@ -33,9 +33,6 @@ class Model(ModelBase):
|
||||||
|
|
||||||
self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
|
self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
|
||||||
|
|
||||||
if self.is_training_mode:
|
|
||||||
self.ae, = self.to_multi_gpu_model_if_possible ( [self.ae,] )
|
|
||||||
|
|
||||||
self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
|
self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
|
||||||
loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] )
|
loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] )
|
||||||
|
|
||||||
|
|
|
@ -34,9 +34,6 @@ class Model(ModelBase):
|
||||||
self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([AB, AB])) )
|
self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([AB, AB])) )
|
||||||
self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([B, AB])) )
|
self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([B, AB])) )
|
||||||
|
|
||||||
if self.is_training_mode:
|
|
||||||
self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] )
|
|
||||||
|
|
||||||
self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
||||||
self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
||||||
|
|
||||||
|
|
|
@ -21,32 +21,40 @@ class SAEModel(ModelBase):
|
||||||
decoder_dstmH5 = 'decoder_dstm.h5'
|
decoder_dstmH5 = 'decoder_dstm.h5'
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onInitializeOptions(self, is_first_run, ask_for_session_options):
|
def onInitializeOptions(self, is_first_run, ask_override):
|
||||||
default_resolution = 128
|
default_resolution = 128
|
||||||
default_archi = 'liae'
|
default_archi = 'liae'
|
||||||
default_style_power = 100
|
default_style_power = 100
|
||||||
default_face_type = 'f'
|
default_face_type = 'f'
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
#first run
|
|
||||||
self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
||||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||||
self.options['learn_face_style'] = input_bool("Learn face style? (y/n skip:y) : ", True)
|
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, but it is not tested on various scenes.").lower()
|
||||||
self.options['learn_bg_style'] = input_bool("Learn background style? (y/n skip:y) : ", True)
|
else:
|
||||||
self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 )
|
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
self.options['archi'] = self.options.get('archi', default_archi)
|
||||||
|
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||||
|
|
||||||
|
if is_first_run or ask_override:
|
||||||
|
self.options['face_style_power'] = np.clip ( input_int("Face style power (0..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0, 100 )
|
||||||
|
else:
|
||||||
|
self.options['face_style_power'] = self.options.get('face_style_power', default_style_power)
|
||||||
|
|
||||||
|
if is_first_run or ask_override:
|
||||||
|
self.options['bg_style_power'] = np.clip ( input_int("Background style power (0..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0, 100 )
|
||||||
|
else:
|
||||||
|
self.options['bg_style_power'] = self.options.get('bg_style_power', default_style_power)
|
||||||
|
|
||||||
|
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||||
|
|
||||||
|
if is_first_run:
|
||||||
self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." )
|
self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." )
|
||||||
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
||||||
else:
|
else:
|
||||||
#not first run
|
|
||||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
|
||||||
self.options['learn_face_style'] = self.options.get('learn_face_style', True)
|
|
||||||
self.options['learn_bg_style'] = self.options.get('learn_bg_style', True)
|
|
||||||
self.options['archi'] = self.options.get('archi', default_archi)
|
|
||||||
self.options['style_power'] = self.options.get('style_power', default_style_power)
|
|
||||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
|
||||||
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
||||||
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||||
|
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onInitialize(self, **in_options):
|
def onInitialize(self, **in_options):
|
||||||
|
@ -68,7 +76,7 @@ class SAEModel(ModelBase):
|
||||||
target_dstm = Input(mask_shape)
|
target_dstm = Input(mask_shape)
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
if self.options['archi'] == 'liae':
|
||||||
self.encoder = modelify(SAEModel.EncFlow() ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.EncFlow(self.options['lighter_encoder']) ) (Input(bgr_shape))
|
||||||
|
|
||||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
|
||||||
|
@ -107,7 +115,7 @@ class SAEModel(ModelBase):
|
||||||
pred_src_dst = self.decoder(warped_src_dst_inter_code)
|
pred_src_dst = self.decoder(warped_src_dst_inter_code)
|
||||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||||
else:
|
else:
|
||||||
self.encoder = modelify(SAEModel.DFEncFlow(dims=ae_dims,lowest_dense_res=resolution // 16) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.DFEncFlow(self.options['lighter_encoder'], dims=ae_dims,lowest_dense_res=resolution // 16) ) (Input(bgr_shape))
|
||||||
|
|
||||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
|
||||||
|
@ -162,17 +170,16 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm
|
psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm
|
||||||
psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm
|
psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm
|
||||||
|
|
||||||
|
|
||||||
style_power = self.options['style_power'] / 100.0
|
|
||||||
|
|
||||||
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
|
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
|
||||||
|
|
||||||
if self.options['learn_face_style']:
|
if self.options['face_style_power'] != 0:
|
||||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked)
|
face_style_power = self.options['face_style_power'] / 100.0
|
||||||
|
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked)
|
||||||
|
|
||||||
if self.options['learn_bg_style']:
|
if self.options['bg_style_power'] != 0:
|
||||||
src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
|
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||||
|
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
if self.options['archi'] == 'liae':
|
||||||
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
|
@ -180,8 +187,7 @@ class SAEModel(ModelBase):
|
||||||
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
|
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
|
||||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, src_train_weights) )
|
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, src_train_weights) )
|
||||||
|
|
||||||
|
|
||||||
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
|
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
if self.options['archi'] == 'liae':
|
||||||
|
@ -190,8 +196,7 @@ class SAEModel(ModelBase):
|
||||||
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
|
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
|
||||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, dst_train_weights) )
|
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, dst_train_weights) )
|
||||||
|
|
||||||
|
|
||||||
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
|
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
if self.options['archi'] == 'liae':
|
||||||
|
@ -317,13 +322,18 @@ class SAEModel(ModelBase):
|
||||||
**in_options)
|
**in_options)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def EncFlow():
|
def EncFlow(light_enc):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
|
||||||
def downscale (dim):
|
def downscale (dim):
|
||||||
def func(x):
|
def func(x):
|
||||||
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
|
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
def downscale_sep (dim):
|
||||||
|
def func(x):
|
||||||
|
return LeakyReLU(0.1)(SeparableConv2D(dim, 5, strides=2, padding='same')(x))
|
||||||
|
return func
|
||||||
|
|
||||||
def upscale (dim):
|
def upscale (dim):
|
||||||
def func(x):
|
def func(x):
|
||||||
|
@ -334,9 +344,15 @@ class SAEModel(ModelBase):
|
||||||
x = input
|
x = input
|
||||||
|
|
||||||
x = downscale(128)(x)
|
x = downscale(128)(x)
|
||||||
x = downscale(256)(x)
|
if not light_enc:
|
||||||
x = downscale(512)(x)
|
x = downscale(256)(x)
|
||||||
x = downscale(1024)(x)
|
x = downscale(512)(x)
|
||||||
|
x = downscale(1024)(x)
|
||||||
|
else:
|
||||||
|
x = downscale_sep(256)(x)
|
||||||
|
x = downscale_sep(512)(x)
|
||||||
|
x = downscale_sep(1024)(x)
|
||||||
|
|
||||||
x = Flatten()(x)
|
x = Flatten()(x)
|
||||||
return x
|
return x
|
||||||
return func
|
return func
|
||||||
|
@ -380,7 +396,7 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFEncFlow(dims=512, lowest_dense_res=8):
|
def DFEncFlow(light_enc, dims=512, lowest_dense_res=8):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
|
||||||
def downscale (dim):
|
def downscale (dim):
|
||||||
|
@ -388,6 +404,11 @@ class SAEModel(ModelBase):
|
||||||
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
|
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
def downscale_sep (dim):
|
||||||
|
def func(x):
|
||||||
|
return LeakyReLU(0.1)(SeparableConv2D(dim, 5, strides=2, padding='same')(x))
|
||||||
|
return func
|
||||||
|
|
||||||
def upscale (dim):
|
def upscale (dim):
|
||||||
def func(x):
|
def func(x):
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||||
|
@ -397,9 +418,14 @@ class SAEModel(ModelBase):
|
||||||
x = input
|
x = input
|
||||||
|
|
||||||
x = downscale(128)(x)
|
x = downscale(128)(x)
|
||||||
x = downscale(256)(x)
|
if not light_enc:
|
||||||
x = downscale(512)(x)
|
x = downscale(256)(x)
|
||||||
x = downscale(1024)(x)
|
x = downscale(512)(x)
|
||||||
|
x = downscale(1024)(x)
|
||||||
|
else:
|
||||||
|
x = downscale_sep(256)(x)
|
||||||
|
x = downscale_sep(512)(x)
|
||||||
|
x = downscale_sep(1024)(x)
|
||||||
|
|
||||||
x = Dense(dims)(Flatten()(x))
|
x = Dense(dims)(Flatten()(x))
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * dims)(x)
|
x = Dense(lowest_dense_res * lowest_dense_res * dims)(x)
|
||||||
|
|
|
@ -72,6 +72,7 @@ Input = keras.layers.Input
|
||||||
Dense = keras.layers.Dense
|
Dense = keras.layers.Dense
|
||||||
Conv2D = keras.layers.Conv2D
|
Conv2D = keras.layers.Conv2D
|
||||||
Conv2DTranspose = keras.layers.Conv2DTranspose
|
Conv2DTranspose = keras.layers.Conv2DTranspose
|
||||||
|
SeparableConv2D = keras.layers.SeparableConv2D
|
||||||
MaxPooling2D = keras.layers.MaxPooling2D
|
MaxPooling2D = keras.layers.MaxPooling2D
|
||||||
BatchNormalization = keras.layers.BatchNormalization
|
BatchNormalization = keras.layers.BatchNormalization
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
def input_int(s, default_value, valid_list=None, help_message=None):
|
def input_int(s, default_value, valid_list=None, help_message=None):
|
||||||
while True:
|
while True:
|
||||||
|
@ -51,4 +54,28 @@ def input_str(s, default_value, valid_list=None, help_message=None):
|
||||||
return inp
|
return inp
|
||||||
except:
|
except:
|
||||||
print (default_value)
|
print (default_value)
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
|
def input_process(stdin_fd, sq, str):
|
||||||
|
sys.stdin = os.fdopen(stdin_fd)
|
||||||
|
try:
|
||||||
|
inp = input (str)
|
||||||
|
sq.put (True)
|
||||||
|
except:
|
||||||
|
sq.put (False)
|
||||||
|
|
||||||
|
def input_in_time (str, max_time_sec):
|
||||||
|
sq = multiprocessing.Queue()
|
||||||
|
p = multiprocessing.Process(target=input_process, args=( sys.stdin.fileno(), sq, str))
|
||||||
|
p.start()
|
||||||
|
t = time.time()
|
||||||
|
inp = False
|
||||||
|
while True:
|
||||||
|
if not sq.empty():
|
||||||
|
inp = sq.get()
|
||||||
|
break
|
||||||
|
if time.time() - t > max_time_sec:
|
||||||
|
break
|
||||||
|
p.terminate()
|
||||||
|
sys.stdin = os.fdopen( sys.stdin.fileno() )
|
||||||
|
return inp
|
Loading…
Add table
Add a link
Reference in a new issue