mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
refactoring
This commit is contained in:
parent
97685ce0ae
commit
a858732b1d
6 changed files with 48 additions and 55 deletions
|
@ -286,9 +286,11 @@ class ModelBase(object):
|
|||
|
||||
def save_weights_safe(self, model_filename_list):
|
||||
for model, filename in model_filename_list:
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
model.save_weights( filename + '.tmp' )
|
||||
|
||||
for model, filename in model_filename_list:
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
source_filename = Path(filename+'.tmp')
|
||||
target_filename = Path(filename)
|
||||
if target_filename.exists():
|
||||
|
|
|
@ -8,10 +8,6 @@ from utils.console_utils import *
|
|||
|
||||
class Model(ModelBase):
|
||||
|
||||
encoderH5 = 'encoder.h5'
|
||||
decoder_srcH5 = 'decoder_src.h5'
|
||||
decoder_dstH5 = 'decoder_dst.h5'
|
||||
|
||||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
if is_first_run or ask_override:
|
||||
|
@ -31,9 +27,11 @@ class Model(ModelBase):
|
|||
self.encoder, self.decoder_src, self.decoder_dst = self.Build(ae_input_layer)
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||
weights_to_load = [ [self.encoder , 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
]
|
||||
self.load_weights_safe(weights_to_load)
|
||||
|
||||
self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder_src(self.encoder(ae_input_layer)))
|
||||
self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder_dst(self.encoder(ae_input_layer)))
|
||||
|
@ -59,9 +57,9 @@ class Model(ModelBase):
|
|||
])
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
|
||||
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]] )
|
||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample, generators_list):
|
||||
|
|
|
@ -8,10 +8,6 @@ from utils.console_utils import *
|
|||
|
||||
class Model(ModelBase):
|
||||
|
||||
encoderH5 = 'encoder.h5'
|
||||
decoder_srcH5 = 'decoder_src.h5'
|
||||
decoder_dstH5 = 'decoder_dst.h5'
|
||||
|
||||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
if is_first_run:
|
||||
|
@ -35,9 +31,11 @@ class Model(ModelBase):
|
|||
|
||||
bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build( self.options['lighter_ae'] )
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||
weights_to_load = [ [self.encoder , 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
]
|
||||
self.load_weights_safe(weights_to_load)
|
||||
|
||||
input_src_bgr = Input(bgr_shape)
|
||||
input_src_mask = Input(mask_shape)
|
||||
|
@ -74,9 +72,9 @@ class Model(ModelBase):
|
|||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
|
||||
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]])
|
||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample, generators_list):
|
||||
|
|
|
@ -8,10 +8,6 @@ from utils.console_utils import *
|
|||
|
||||
class Model(ModelBase):
|
||||
|
||||
encoderH5 = 'encoder.h5'
|
||||
decoder_srcH5 = 'decoder_src.h5'
|
||||
decoder_dstH5 = 'decoder_dst.h5'
|
||||
|
||||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
if is_first_run:
|
||||
|
@ -37,9 +33,11 @@ class Model(ModelBase):
|
|||
bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.options['lighter_ae'])
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||
weights_to_load = [ [self.encoder , 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
]
|
||||
self.load_weights_safe(weights_to_load)
|
||||
|
||||
input_src_bgr = Input(bgr_shape)
|
||||
input_src_mask = Input(mask_shape)
|
||||
|
@ -75,9 +73,9 @@ class Model(ModelBase):
|
|||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
|
||||
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]] )
|
||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample, generators_list):
|
||||
|
|
|
@ -8,11 +8,6 @@ from utils.console_utils import *
|
|||
|
||||
class Model(ModelBase):
|
||||
|
||||
encoderH5 = 'encoder.h5'
|
||||
decoderH5 = 'decoder.h5'
|
||||
inter_BH5 = 'inter_B.h5'
|
||||
inter_ABH5 = 'inter_AB.h5'
|
||||
|
||||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
if is_first_run or ask_override:
|
||||
|
@ -32,10 +27,12 @@ class Model(ModelBase):
|
|||
self.encoder, self.decoder, self.inter_B, self.inter_AB = self.Build(ae_input_layer)
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
|
||||
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
|
||||
self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5))
|
||||
weights_to_load = [ [self.encoder, 'encoder.h5'],
|
||||
[self.decoder, 'decoder.h5'],
|
||||
[self.inter_B, 'inter_B.h5'],
|
||||
[self.inter_AB, 'inter_AB.h5']
|
||||
]
|
||||
self.load_weights_safe(weights_to_load)
|
||||
|
||||
code = self.encoder(ae_input_layer)
|
||||
AB = self.inter_AB(code)
|
||||
|
@ -66,11 +63,11 @@ class Model(ModelBase):
|
|||
])
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)],
|
||||
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
|
||||
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)]] )
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder, 'decoder.h5'],
|
||||
[self.inter_B, 'inter_B.h5'],
|
||||
[self.inter_AB, 'inter_AB.h5']] )
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample, generators_list):
|
||||
|
|
|
@ -349,21 +349,21 @@ class SAEModel(ModelBase):
|
|||
#override
|
||||
def onSave(self):
|
||||
if self.options['archi'] == 'liae':
|
||||
ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
|
||||
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
|
||||
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)]
|
||||
ar = [[self.encoder, 'encoder.h5'],
|
||||
[self.inter_B, 'inter_B.h5'],
|
||||
[self.inter_AB, 'inter_AB.h5'],
|
||||
[self.decoder, 'decoder.h5']
|
||||
]
|
||||
if self.options['learn_mask']:
|
||||
ar += [ [self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)] ]
|
||||
else:
|
||||
ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
|
||||
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]
|
||||
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
||||
elif self.options['archi'] == 'df' or self.options['archi'] == 'vg':
|
||||
ar = [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
]
|
||||
if self.options['learn_mask']:
|
||||
ar += [ [self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)],
|
||||
[self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)] ]
|
||||
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
||||
|
||||
self.save_weights_safe(ar)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue