mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
SAE : WARNING, RETRAIN IS REQUIRED !
fixed model sizes from previous update. avoided bug in ML framework(keras) that forces to train the model on random noise. Converter: added blur on the same keys as sharpness Added new model 'TrueFace'. This is a GAN model ported from https://github.com/NVlabs/FUNIT Model produces near zero morphing and high detail face. Model has higher failure rate than other models. Keep src and dst faceset in same lighting conditions.
This commit is contained in:
parent
201b762541
commit
dc11ec32be
26 changed files with 1308 additions and 250 deletions
|
@ -23,7 +23,7 @@ You can implement your own model. Check examples.
|
|||
class ModelBase(object):
|
||||
|
||||
|
||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
|
||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, is_training=False, debug = False, device_args = None,
|
||||
ask_enable_autobackup=True,
|
||||
ask_write_preview_history=True,
|
||||
ask_target_iter=True,
|
||||
|
@ -56,14 +56,8 @@ class ModelBase(object):
|
|||
self.training_data_dst_path = training_data_dst_path
|
||||
self.pretraining_data_path = pretraining_data_path
|
||||
|
||||
self.src_images_paths = None
|
||||
self.dst_images_paths = None
|
||||
self.src_yaw_images_paths = None
|
||||
self.dst_yaw_images_paths = None
|
||||
self.src_data_generator = None
|
||||
self.dst_data_generator = None
|
||||
self.debug = debug
|
||||
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
|
||||
self.is_training_mode = is_training
|
||||
|
||||
self.iter = 0
|
||||
self.options = {}
|
||||
|
@ -412,40 +406,60 @@ class ModelBase(object):
|
|||
cv2_imwrite (filepath, img )
|
||||
|
||||
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
||||
exec(nnlib.code_import_all, locals(), globals())
|
||||
|
||||
loaded = []
|
||||
not_loaded = []
|
||||
for mf in model_filename_list:
|
||||
model, filename = mf
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
|
||||
if Path(filename).exists():
|
||||
loaded += [ mf ]
|
||||
model.load_weights(filename)
|
||||
|
||||
if issubclass(model.__class__, keras.optimizers.Optimizer):
|
||||
opt = model
|
||||
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
fd = pickle.loads(f.read())
|
||||
|
||||
weights = fd.get('weights', None)
|
||||
if weights is not None:
|
||||
opt.set_weights(weights)
|
||||
|
||||
except Exception as e:
|
||||
print ("Unable to load ", filename)
|
||||
|
||||
else:
|
||||
model.load_weights(filename)
|
||||
else:
|
||||
not_loaded += [ mf ]
|
||||
|
||||
if len(optimizer_filename_list) != 0:
|
||||
opt_filename = self.get_strpath_storage_for_file('opt.h5')
|
||||
if Path(opt_filename).exists():
|
||||
try:
|
||||
with open(opt_filename, "rb") as f:
|
||||
d = pickle.loads(f.read())
|
||||
|
||||
for x in optimizer_filename_list:
|
||||
opt, filename = x
|
||||
if filename in d:
|
||||
weights = d[filename].get('weights', None)
|
||||
if weights:
|
||||
opt.set_weights(weights)
|
||||
print("set ok")
|
||||
except Exception as e:
|
||||
print ("Unable to load ", opt_filename)
|
||||
|
||||
|
||||
return loaded, not_loaded
|
||||
|
||||
def save_weights_safe(self, model_filename_list):
|
||||
exec(nnlib.code_import_all, locals(), globals())
|
||||
|
||||
for model, filename in model_filename_list:
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
model.save_weights( filename + '.tmp' )
|
||||
filename = self.get_strpath_storage_for_file(filename) + '.tmp'
|
||||
|
||||
if issubclass(model.__class__, keras.optimizers.Optimizer):
|
||||
opt = model
|
||||
|
||||
try:
|
||||
fd = {}
|
||||
symbolic_weights = getattr(opt, 'weights')
|
||||
if symbolic_weights:
|
||||
fd['weights'] = self.K.batch_get_value(symbolic_weights)
|
||||
|
||||
with open(filename, 'wb') as f:
|
||||
f.write( pickle.dumps(fd) )
|
||||
except Exception as e:
|
||||
print ("Unable to save ", filename)
|
||||
else:
|
||||
model.save_weights( filename)
|
||||
|
||||
rename_list = model_filename_list
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue