mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 14:24:40 -07:00
Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
06e682a59e
15 changed files with 215 additions and 83 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -291,7 +291,7 @@ def nms(boxes, threshold, method):
|
||||||
w = np.maximum(0.0, xx2-xx1+1)
|
w = np.maximum(0.0, xx2-xx1+1)
|
||||||
h = np.maximum(0.0, yy2-yy1+1)
|
h = np.maximum(0.0, yy2-yy1+1)
|
||||||
inter = w * h
|
inter = w * h
|
||||||
if method is 'Min':
|
if method == 'Min':
|
||||||
o = inter / np.minimum(area[i], area[idx])
|
o = inter / np.minimum(area[i], area[idx])
|
||||||
else:
|
else:
|
||||||
o = inter / (area[i] + area[idx] - inter)
|
o = inter / (area[i] + area[idx] - inter)
|
||||||
|
|
|
@ -1,19 +1,22 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import inspect
|
|
||||||
import pickle
|
|
||||||
import colorsys
|
import colorsys
|
||||||
import imagelib
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from utils import Path_utils
|
|
||||||
from utils import std_utils
|
|
||||||
from utils.cv2_utils import *
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
import cv2
|
||||||
from samplelib import SampleGeneratorBase
|
import numpy as np
|
||||||
from nnlib import nnlib
|
|
||||||
|
import imagelib
|
||||||
from interact import interact as io
|
from interact import interact as io
|
||||||
|
from nnlib import nnlib
|
||||||
|
from samplelib import SampleGeneratorBase
|
||||||
|
from utils import Path_utils, std_utils
|
||||||
|
from utils.cv2_utils import *
|
||||||
|
|
||||||
'''
|
'''
|
||||||
You can implement your own model. Check examples.
|
You can implement your own model. Check examples.
|
||||||
'''
|
'''
|
||||||
|
@ -21,6 +24,7 @@ class ModelBase(object):
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
|
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
|
||||||
|
ask_enable_autobackup=True,
|
||||||
ask_write_preview_history=True,
|
ask_write_preview_history=True,
|
||||||
ask_target_iter=True,
|
ask_target_iter=True,
|
||||||
ask_batch_size=True,
|
ask_batch_size=True,
|
||||||
|
@ -41,7 +45,7 @@ class ModelBase(object):
|
||||||
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] )
|
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] )
|
||||||
self.device_args = device_args
|
self.device_args = device_args
|
||||||
|
|
||||||
self.device_config = nnlib.DeviceConfig(allow_growth=False, **self.device_args)
|
self.device_config = nnlib.DeviceConfig(allow_growth=True, **self.device_args)
|
||||||
|
|
||||||
io.log_info ("Loading model...")
|
io.log_info ("Loading model...")
|
||||||
|
|
||||||
|
@ -84,6 +88,12 @@ class ModelBase(object):
|
||||||
if self.iter == 0:
|
if self.iter == 0:
|
||||||
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
||||||
|
|
||||||
|
if ask_enable_autobackup and (self.iter == 0 or ask_override):
|
||||||
|
default_autobackup = False if self.iter == 0 else self.options.get('autobackup',False)
|
||||||
|
self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " % (yn_str[default_autobackup]) , default_autobackup, help_message="Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01")
|
||||||
|
else:
|
||||||
|
self.options['autobackup'] = self.options.get('autobackup', False)
|
||||||
|
|
||||||
if ask_write_preview_history and (self.iter == 0 or ask_override):
|
if ask_write_preview_history and (self.iter == 0 or ask_override):
|
||||||
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
|
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
|
||||||
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
||||||
|
@ -92,6 +102,8 @@ class ModelBase(object):
|
||||||
|
|
||||||
if (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_support_windows():
|
if (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_support_windows():
|
||||||
choose_preview_history = io.input_bool("Choose image for the preview history? (y/n skip:%s) : " % (yn_str[False]) , False)
|
choose_preview_history = io.input_bool("Choose image for the preview history? (y/n skip:%s) : " % (yn_str[False]) , False)
|
||||||
|
elif (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_colab():
|
||||||
|
choose_preview_history = io.input_bool("Randomly choose new image for preview history? (y/n ?:help skip:%s) : " % (yn_str[False]), False, help_message="Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person")
|
||||||
else:
|
else:
|
||||||
choose_preview_history = False
|
choose_preview_history = False
|
||||||
|
|
||||||
|
@ -128,6 +140,10 @@ class ModelBase(object):
|
||||||
else:
|
else:
|
||||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||||
|
|
||||||
|
self.autobackup = self.options.get('autobackup', False)
|
||||||
|
if not self.autobackup and 'autobackup' in self.options:
|
||||||
|
self.options.pop('autobackup')
|
||||||
|
|
||||||
self.write_preview_history = self.options.get('write_preview_history', False)
|
self.write_preview_history = self.options.get('write_preview_history', False)
|
||||||
if not self.write_preview_history and 'write_preview_history' in self.options:
|
if not self.write_preview_history and 'write_preview_history' in self.options:
|
||||||
self.options.pop('write_preview_history')
|
self.options.pop('write_preview_history')
|
||||||
|
@ -160,8 +176,16 @@ class ModelBase(object):
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
if self.device_args['force_gpu_idx'] == -1:
|
if self.device_args['force_gpu_idx'] == -1:
|
||||||
self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) )
|
self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) )
|
||||||
|
self.autobackups_path = self.model_path / ( '%s_autobackups' % (self.get_model_name()) )
|
||||||
else:
|
else:
|
||||||
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
|
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
|
||||||
|
self.autobackups_path = self.model_path / ( '%d_%s_autobackups' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
|
||||||
|
|
||||||
|
if self.autobackup:
|
||||||
|
self.autobackup_current_hour = time.localtime().tm_hour
|
||||||
|
|
||||||
|
if not self.autobackups_path.exists():
|
||||||
|
self.autobackups_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
if self.write_preview_history or io.is_colab():
|
if self.write_preview_history or io.is_colab():
|
||||||
if not self.preview_history_path.exists():
|
if not self.preview_history_path.exists():
|
||||||
|
@ -206,7 +230,7 @@ class ModelBase(object):
|
||||||
io.destroy_window(wnd_name)
|
io.destroy_window(wnd_name)
|
||||||
else:
|
else:
|
||||||
self.sample_for_preview = self.generate_next_sample()
|
self.sample_for_preview = self.generate_next_sample()
|
||||||
|
self.last_sample = self.sample_for_preview
|
||||||
model_summary_text = []
|
model_summary_text = []
|
||||||
|
|
||||||
model_summary_text += ["===== Model summary ====="]
|
model_summary_text += ["===== Model summary ====="]
|
||||||
|
@ -277,9 +301,13 @@ class ModelBase(object):
|
||||||
def get_model_name(self):
|
def get_model_name(self):
|
||||||
return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
|
return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
|
||||||
|
|
||||||
|
#overridable , return [ [model, filename],... ] list
|
||||||
|
def get_model_filename_list(self):
|
||||||
|
return []
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def get_converter(self):
|
def get_converter(self):
|
||||||
raise NotImplementeError
|
raise NotImplementedError
|
||||||
#return existing or your own converter which derived from base
|
#return existing or your own converter which derived from base
|
||||||
|
|
||||||
def get_target_iter(self):
|
def get_target_iter(self):
|
||||||
|
@ -314,7 +342,8 @@ class ModelBase(object):
|
||||||
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text)
|
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
||||||
|
Path( summary_path ).write_text(self.model_summary_text)
|
||||||
self.onSave()
|
self.onSave()
|
||||||
|
|
||||||
model_data = {
|
model_data = {
|
||||||
|
@ -325,6 +354,44 @@ class ModelBase(object):
|
||||||
}
|
}
|
||||||
self.model_data_path.write_bytes( pickle.dumps(model_data) )
|
self.model_data_path.write_bytes( pickle.dumps(model_data) )
|
||||||
|
|
||||||
|
bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ]
|
||||||
|
bckp_filename_list += [ str(summary_path), str(self.model_data_path) ]
|
||||||
|
|
||||||
|
if self.autobackup:
|
||||||
|
current_hour = time.localtime().tm_hour
|
||||||
|
if self.autobackup_current_hour != current_hour:
|
||||||
|
self.autobackup_current_hour = current_hour
|
||||||
|
|
||||||
|
for i in range(15,0,-1):
|
||||||
|
idx_str = '%.2d' % i
|
||||||
|
next_idx_str = '%.2d' % (i+1)
|
||||||
|
|
||||||
|
idx_backup_path = self.autobackups_path / idx_str
|
||||||
|
next_idx_packup_path = self.autobackups_path / next_idx_str
|
||||||
|
|
||||||
|
if idx_backup_path.exists():
|
||||||
|
if i == 15:
|
||||||
|
Path_utils.delete_all_files(idx_backup_path)
|
||||||
|
else:
|
||||||
|
next_idx_packup_path.mkdir(exist_ok=True)
|
||||||
|
Path_utils.move_all_files (idx_backup_path, next_idx_packup_path)
|
||||||
|
|
||||||
|
if i == 1:
|
||||||
|
idx_backup_path.mkdir(exist_ok=True)
|
||||||
|
for filename in bckp_filename_list:
|
||||||
|
shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) )
|
||||||
|
|
||||||
|
previews = self.get_previews()
|
||||||
|
plist = []
|
||||||
|
for i in range(len(previews)):
|
||||||
|
name, bgr = previews[i]
|
||||||
|
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
|
||||||
|
|
||||||
|
for preview, filepath in plist:
|
||||||
|
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
||||||
|
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||||
|
cv2_imwrite (filepath, img )
|
||||||
|
|
||||||
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
||||||
for model, filename in model_filename_list:
|
for model, filename in model_filename_list:
|
||||||
filename = self.get_strpath_storage_for_file(filename)
|
filename = self.get_strpath_storage_for_file(filename)
|
||||||
|
@ -349,12 +416,16 @@ class ModelBase(object):
|
||||||
print ("Unable to load ", opt_filename)
|
print ("Unable to load ", opt_filename)
|
||||||
|
|
||||||
|
|
||||||
def save_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
def save_weights_safe(self, model_filename_list):
|
||||||
for model, filename in model_filename_list:
|
for model, filename in model_filename_list:
|
||||||
filename = self.get_strpath_storage_for_file(filename)
|
filename = self.get_strpath_storage_for_file(filename)
|
||||||
model.save_weights( filename + '.tmp' )
|
model.save_weights( filename + '.tmp' )
|
||||||
|
|
||||||
rename_list = model_filename_list
|
rename_list = model_filename_list
|
||||||
|
|
||||||
|
"""
|
||||||
|
#unused
|
||||||
|
, optimizer_filename_list=[]
|
||||||
if len(optimizer_filename_list) != 0:
|
if len(optimizer_filename_list) != 0:
|
||||||
opt_filename = self.get_strpath_storage_for_file('opt.h5')
|
opt_filename = self.get_strpath_storage_for_file('opt.h5')
|
||||||
|
|
||||||
|
@ -374,6 +445,7 @@ class ModelBase(object):
|
||||||
rename_list += [('', 'opt.h5')]
|
rename_list += [('', 'opt.h5')]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print ("Unable to save ", opt_filename)
|
print ("Unable to save ", opt_filename)
|
||||||
|
"""
|
||||||
|
|
||||||
for _, filename in rename_list:
|
for _, filename in rename_list:
|
||||||
filename = self.get_strpath_storage_for_file(filename)
|
filename = self.get_strpath_storage_for_file(filename)
|
||||||
|
@ -384,7 +456,6 @@ class ModelBase(object):
|
||||||
target_filename.unlink()
|
target_filename.unlink()
|
||||||
source_filename.rename ( str(target_filename) )
|
source_filename.rename ( str(target_filename) )
|
||||||
|
|
||||||
|
|
||||||
def debug_one_iter(self):
|
def debug_one_iter(self):
|
||||||
images = []
|
images = []
|
||||||
for generator in self.generator_list:
|
for generator in self.generator_list:
|
||||||
|
@ -490,6 +561,8 @@ class ModelBase(object):
|
||||||
|
|
||||||
lh_height = 100
|
lh_height = 100
|
||||||
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
|
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
|
||||||
|
|
||||||
|
if len(loss_history) != 0:
|
||||||
loss_count = len(loss_history[0])
|
loss_count = len(loss_history[0])
|
||||||
lh_len = len(loss_history)
|
lh_len = len(loss_history)
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ class Model(ModelBase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs,
|
super().__init__(*args, **kwargs,
|
||||||
|
ask_enable_autobackup=False,
|
||||||
ask_write_preview_history=False,
|
ask_write_preview_history=False,
|
||||||
ask_target_iter=False,
|
ask_target_iter=False,
|
||||||
ask_sort_by_yaw=False,
|
ask_sort_by_yaw=False,
|
||||||
|
|
|
@ -12,6 +12,7 @@ class Model(ModelBase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs,
|
super().__init__(*args, **kwargs,
|
||||||
|
ask_enable_autobackup=False,
|
||||||
ask_write_preview_history=False,
|
ask_write_preview_history=False,
|
||||||
ask_target_iter=False,
|
ask_target_iter=False,
|
||||||
ask_sort_by_yaw=False,
|
ask_sort_by_yaw=False,
|
||||||
|
|
|
@ -59,11 +59,16 @@ class Model(ModelBase):
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||||
output_sample_types=output_sample_types)
|
output_sample_types=output_sample_types)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
#override
|
||||||
|
def get_model_filename_list(self):
|
||||||
|
return [[self.encoder, 'encoder.h5'],
|
||||||
|
[self.decoder_src, 'decoder_src.h5'],
|
||||||
|
[self.decoder_dst, 'decoder_dst.h5']]
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
self.save_weights_safe( self.get_model_filename_list() )
|
||||||
[self.decoder_src, 'decoder_src.h5'],
|
|
||||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self, sample, generators_list):
|
def onTrainOneIter(self, sample, generators_list):
|
||||||
|
|
|
@ -71,10 +71,14 @@ class Model(ModelBase):
|
||||||
])
|
])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def get_model_filename_list(self):
|
||||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
return [[self.encoder, 'encoder.h5'],
|
||||||
[self.decoder_src, 'decoder_src.h5'],
|
[self.decoder_src, 'decoder_src.h5'],
|
||||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
[self.decoder_dst, 'decoder_dst.h5']]
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onSave(self):
|
||||||
|
self.save_weights_safe( self.get_model_filename_list() )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self, sample, generators_list):
|
def onTrainOneIter(self, sample, generators_list):
|
||||||
|
|
|
@ -72,10 +72,14 @@ class Model(ModelBase):
|
||||||
])
|
])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def get_model_filename_list(self):
|
||||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
return [[self.encoder, 'encoder.h5'],
|
||||||
[self.decoder_src, 'decoder_src.h5'],
|
[self.decoder_src, 'decoder_src.h5'],
|
||||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
[self.decoder_dst, 'decoder_dst.h5']]
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onSave(self):
|
||||||
|
self.save_weights_safe( self.get_model_filename_list() )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self, sample, generators_list):
|
def onTrainOneIter(self, sample, generators_list):
|
||||||
|
|
|
@ -66,11 +66,15 @@ class Model(ModelBase):
|
||||||
])
|
])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def get_model_filename_list(self):
|
||||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
return [[self.encoder, 'encoder.h5'],
|
||||||
[self.decoder, 'decoder.h5'],
|
[self.decoder, 'decoder.h5'],
|
||||||
[self.inter_B, 'inter_B.h5'],
|
[self.inter_B, 'inter_B.h5'],
|
||||||
[self.inter_AB, 'inter_AB.h5']] )
|
[self.inter_AB, 'inter_AB.h5']]
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onSave(self):
|
||||||
|
self.save_weights_safe( self.get_model_filename_list() )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self, sample, generators_list):
|
def onTrainOneIter(self, sample, generators_list):
|
||||||
|
|
|
@ -202,13 +202,17 @@ class RecycleGANModel(ModelBase):
|
||||||
self.G_convert = K.function([real_B0],[fake_A0])
|
self.G_convert = K.function([real_B0],[fake_A0])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def get_model_filename_list(self):
|
||||||
self.save_weights_safe( [[self.GA, 'GA.h5'],
|
return [ [self.GA, 'GA.h5'],
|
||||||
[self.GB, 'GB.h5'],
|
[self.GB, 'GB.h5'],
|
||||||
[self.DA, 'DA.h5'],
|
[self.DA, 'DA.h5'],
|
||||||
[self.DB, 'DB.h5'],
|
[self.DB, 'DB.h5'],
|
||||||
[self.PA, 'PA.h5'],
|
[self.PA, 'PA.h5'],
|
||||||
[self.PB, 'PB.h5'] ])
|
[self.PB, 'PB.h5'] ]
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onSave(self):
|
||||||
|
self.save_weights_safe( self.get_model_filename_list() )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self, generators_samples, generators_list):
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
|
|
|
@ -90,11 +90,19 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
||||||
self.options['apply_random_ct'] = io.input_bool ("Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (yn_str[default_apply_random_ct]), default_apply_random_ct, help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
|
self.options['apply_random_ct'] = io.input_bool ("Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (yn_str[default_apply_random_ct]), default_apply_random_ct, help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
|
||||||
|
|
||||||
|
if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301
|
||||||
|
default_clipgrad = False if is_first_run else self.options.get('clipgrad', False)
|
||||||
|
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping? (y/n, ?:help skip:%s) : " % (yn_str[default_clipgrad]), default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||||
|
else:
|
||||||
|
self.options['clipgrad'] = False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
|
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
|
||||||
|
self.options['clipgrad'] = self.options.get('clipgrad', False)
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['pretrain'] = io.input_bool ("Pretrain the model? (y/n, ?:help skip:n) : ", False, help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: encoder.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.")
|
self.options['pretrain'] = io.input_bool ("Pretrain the model? (y/n, ?:help skip:n) : ", False, help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: encoder.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.")
|
||||||
|
@ -271,8 +279,8 @@ class SAEModel(ModelBase):
|
||||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
|
|
||||||
if 'liae' in self.options['archi']:
|
if 'liae' in self.options['archi']:
|
||||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
|
@ -377,10 +385,7 @@ class SAEModel(ModelBase):
|
||||||
])
|
])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def get_model_filename_list(self):
|
||||||
opt_ar = [ [self.src_dst_opt, 'src_dst_opt'],
|
|
||||||
[self.src_dst_mask_opt, 'src_dst_mask_opt']
|
|
||||||
]
|
|
||||||
ar = []
|
ar = []
|
||||||
if 'liae' in self.options['archi']:
|
if 'liae' in self.options['archi']:
|
||||||
ar += [[self.encoder, 'encoder.h5'],
|
ar += [[self.encoder, 'encoder.h5'],
|
||||||
|
@ -407,9 +412,11 @@ class SAEModel(ModelBase):
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
||||||
|
return ar
|
||||||
|
|
||||||
self.save_weights_safe(ar)
|
#override
|
||||||
|
def onSave(self):
|
||||||
|
self.save_weights_safe( self.get_model_filename_list() )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self, generators_samples, generators_list):
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
|
|
|
@ -28,6 +28,23 @@ def get_image_unique_filestem_paths(dir_path, verbose_print_func=None):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_file_paths(dir_path):
|
||||||
|
dir_path = Path (dir_path)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
if dir_path.exists():
|
||||||
|
return [ x.path for x in list(scandir(str(dir_path))) if x.is_file() ]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_all_dir_names (dir_path):
|
||||||
|
dir_path = Path (dir_path)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
if dir_path.exists():
|
||||||
|
return [ x.name for x in list(scandir(str(dir_path))) if x.is_dir() ]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def get_all_dir_names_startswith (dir_path, startswith):
|
def get_all_dir_names_startswith (dir_path, startswith):
|
||||||
dir_path = Path (dir_path)
|
dir_path = Path (dir_path)
|
||||||
startswith = startswith.lower()
|
startswith = startswith.lower()
|
||||||
|
@ -52,3 +69,15 @@ def get_first_file_by_stem (dir_path, stem, exts=None):
|
||||||
return xp
|
return xp
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def move_all_files (src_dir_path, dst_dir_path):
|
||||||
|
paths = get_file_paths(src_dir_path)
|
||||||
|
for p in paths:
|
||||||
|
p = Path(p)
|
||||||
|
p.rename ( Path(dst_dir_path) / p.name )
|
||||||
|
|
||||||
|
def delete_all_files (dir_path):
|
||||||
|
paths = get_file_paths(dir_path)
|
||||||
|
for p in paths:
|
||||||
|
p = Path(p)
|
||||||
|
p.unlink()
|
Loading…
Add table
Add a link
Reference in a new issue