mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
SAE: finally collapses are fixed
This commit is contained in:
parent
3c3cfa4500
commit
29fbb9f19f
1 changed files with 7 additions and 9 deletions
|
@ -248,9 +248,7 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||||
|
|
||||||
alpha_rec = 10
|
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
|
@ -265,9 +263,9 @@ class SAEModel(ModelBase):
|
||||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
src_loss_batch = sum([ alpha_rec*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i]) for i in range(len(target_src_masked_ar_opt)) ])
|
src_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i]) for i in range(len(target_src_masked_ar_opt)) ])
|
||||||
else:
|
else:
|
||||||
src_loss_batch = sum([ K.mean ( alpha_rec*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
|
src_loss_batch = sum([ K.mean ( 50*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
|
||||||
|
|
||||||
src_loss = K.mean(src_loss_batch)
|
src_loss = K.mean(src_loss_batch)
|
||||||
|
|
||||||
|
@ -279,15 +277,15 @@ class SAEModel(ModelBase):
|
||||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||||
if bg_style_power != 0:
|
if bg_style_power != 0:
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
bg_loss = K.mean( (alpha_rec*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))
|
bg_loss = K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))
|
||||||
else:
|
else:
|
||||||
bg_loss = K.mean( (alpha_rec*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
bg_loss = K.mean( (50*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
||||||
src_loss += bg_loss
|
src_loss += bg_loss
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
dst_loss_batch = sum([ alpha_rec*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i]) for i in range(len(target_dst_masked_ar_opt)) ])
|
dst_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i]) for i in range(len(target_dst_masked_ar_opt)) ])
|
||||||
else:
|
else:
|
||||||
dst_loss_batch = sum([ K.mean ( alpha_rec*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
|
dst_loss_batch = sum([ K.mean ( 50*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
|
||||||
|
|
||||||
dst_loss = K.mean(dst_loss_batch)
|
dst_loss = K.mean(dst_loss_batch)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue