mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
revert back Adam
This commit is contained in:
parent
e4637336ef
commit
ee8dbcbc35
3 changed files with 56 additions and 66 deletions
|
@ -223,13 +223,6 @@ class SAEModel(ModelBase):
|
|||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
self.src_dst_opt, \
|
||||
self.src_dst_mask_opt = self.load_weights_safe(
|
||||
weights_to_load,
|
||||
[ [Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), 'src_dst_opt'],
|
||||
[Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), 'src_dst_mask_opt']
|
||||
])
|
||||
|
||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
@ -267,6 +260,9 @@ class SAEModel(ModelBase):
|
|||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
|
||||
if self.is_training_mode:
|
||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
if self.options['learn_mask']:
|
||||
|
@ -325,14 +321,17 @@ class SAEModel(ModelBase):
|
|||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1], pred_src_dstm[-1]])
|
||||
else:
|
||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
|
||||
|
||||
|
||||
self.load_weights_safe(weights_to_load)#, [ [self.src_dst_opt, 'src_dst_opt'], [self.src_dst_mask_opt, 'src_dst_mask_opt']])
|
||||
else:
|
||||
self.load_weights_safe(weights_to_load)
|
||||
if self.options['learn_mask']:
|
||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
|
||||
else:
|
||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ])
|
||||
|
||||
if self.is_training_mode:
|
||||
|
||||
|
||||
if self.is_training_mode:
|
||||
self.src_sample_losses = []
|
||||
self.dst_sample_losses = []
|
||||
|
||||
|
@ -353,6 +352,7 @@ class SAEModel(ModelBase):
|
|||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
|
||||
output_sample_types=output_sample_types )
|
||||
])
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
opt_ar = [ [self.src_dst_opt, 'src_dst_opt'],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue