mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
added ability to save optimizers states which work with K.function,
added custom Adam that can save 'iterations' param
This commit is contained in:
parent
e50dc0d748
commit
e4637336ef
3 changed files with 133 additions and 26 deletions
|
@ -84,7 +84,6 @@ class SAEModel(ModelBase):
|
|||
def onInitialize(self):
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
SAEModel.initialize_nn_functions()
|
||||
|
||||
self.set_vram_batch_requirements({1.5:4})
|
||||
|
||||
resolution = self.options['resolution']
|
||||
|
@ -111,6 +110,7 @@ class SAEModel(ModelBase):
|
|||
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||
|
||||
|
||||
weights_to_load = []
|
||||
if self.options['archi'] == 'liae':
|
||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
|
@ -223,8 +223,13 @@ class SAEModel(ModelBase):
|
|||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
self.load_weights_safe(weights_to_load)
|
||||
|
||||
self.src_dst_opt, \
|
||||
self.src_dst_mask_opt = self.load_weights_safe(
|
||||
weights_to_load,
|
||||
[ [Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), 'src_dst_opt'],
|
||||
[Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), 'src_dst_mask_opt']
|
||||
])
|
||||
|
||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
@ -262,9 +267,6 @@ class SAEModel(ModelBase):
|
|||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
|
||||
if self.is_training_mode:
|
||||
def optimizer():
|
||||
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
if self.options['learn_mask']:
|
||||
|
@ -307,7 +309,7 @@ class SAEModel(ModelBase):
|
|||
feed += target_dst_ar[::-1]
|
||||
feed += target_dstm_ar[::-1]
|
||||
|
||||
self.src_dst_train = K.function (feed,[src_loss,dst_loss], optimizer().get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
|
||||
self.src_dst_train = K.function (feed,[src_loss,dst_loss], self.src_dst_opt.get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
|
||||
|
||||
if self.options['learn_mask']:
|
||||
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[-1]-pred_src_srcm[-1])) for i in range(len(target_srcm_ar)) ])
|
||||
|
@ -317,7 +319,7 @@ class SAEModel(ModelBase):
|
|||
feed += target_srcm_ar[::-1]
|
||||
feed += target_dstm_ar[::-1]
|
||||
|
||||
self.src_dst_mask_train = K.function (feed,[src_mask_loss, dst_mask_loss], optimizer().get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
|
||||
self.src_dst_mask_train = K.function (feed,[src_mask_loss, dst_mask_loss], self.src_dst_mask_opt.get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1], pred_src_dstm[-1]])
|
||||
|
@ -353,8 +355,12 @@ class SAEModel(ModelBase):
|
|||
])
|
||||
#override
|
||||
def onSave(self):
|
||||
opt_ar = [ [self.src_dst_opt, 'src_dst_opt'],
|
||||
[self.src_dst_mask_opt, 'src_dst_mask_opt']
|
||||
]
|
||||
ar = []
|
||||
if self.options['archi'] == 'liae':
|
||||
ar = [[self.encoder, 'encoder.h5'],
|
||||
ar += [[self.encoder, 'encoder.h5'],
|
||||
[self.inter_B, 'inter_B.h5'],
|
||||
[self.inter_AB, 'inter_AB.h5'],
|
||||
[self.decoder, 'decoder.h5']
|
||||
|
@ -362,15 +368,15 @@ class SAEModel(ModelBase):
|
|||
if self.options['learn_mask']:
|
||||
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
||||
elif self.options['archi'] == 'df' or self.options['archi'] == 'vg':
|
||||
ar = [[self.encoder, 'encoder.h5'],
|
||||
ar += [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
]
|
||||
if self.options['learn_mask']:
|
||||
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
||||
|
||||
self.save_weights_safe(ar)
|
||||
|
||||
self.save_weights_safe(ar, opt_ar)
|
||||
|
||||
|
||||
#override
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue