mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
refactoring
This commit is contained in:
parent
97685ce0ae
commit
a858732b1d
6 changed files with 48 additions and 55 deletions
|
@ -8,10 +8,6 @@ from utils.console_utils import *
|
|||
|
||||
class Model(ModelBase):
|
||||
|
||||
encoderH5 = 'encoder.h5'
|
||||
decoder_srcH5 = 'decoder_src.h5'
|
||||
decoder_dstH5 = 'decoder_dst.h5'
|
||||
|
||||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
if is_first_run:
|
||||
|
@ -35,9 +31,11 @@ class Model(ModelBase):
|
|||
|
||||
bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build( self.options['lighter_ae'] )
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||
weights_to_load = [ [self.encoder , 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
]
|
||||
self.load_weights_safe(weights_to_load)
|
||||
|
||||
input_src_bgr = Input(bgr_shape)
|
||||
input_src_mask = Input(mask_shape)
|
||||
|
@ -74,9 +72,9 @@ class Model(ModelBase):
|
|||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
|
||||
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]])
|
||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample, generators_list):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue