mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
SAE: forgot to remove normalizing from tanh
This commit is contained in:
parent
4ff88865ce
commit
d71a310fd7
1 changed files with 10 additions and 10 deletions
|
@ -267,19 +267,19 @@ class SAEModel(ModelBase):
|
||||||
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ]
|
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ]
|
||||||
|
|
||||||
target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar]
|
target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar]
|
||||||
target_srcm_sigm_ar = [ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
|
target_srcm_sigm_ar = target_srcm_blurred_ar #[ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
|
||||||
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
|
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
|
||||||
|
|
||||||
target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar]
|
target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar]
|
||||||
target_dstm_sigm_ar = [ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
|
target_dstm_sigm_ar = target_dstm_blurred_ar#[ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
|
||||||
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
|
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
|
||||||
|
|
||||||
target_src_sigm_ar = [ x + 1 for x in target_src_ar]
|
target_src_sigm_ar = target_src_ar#[ x + 1 for x in target_src_ar]
|
||||||
target_dst_sigm_ar = [ x + 1 for x in target_dst_ar]
|
target_dst_sigm_ar = target_dst_ar#[ x + 1 for x in target_dst_ar]
|
||||||
|
|
||||||
pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src]
|
pred_src_src_sigm_ar = pred_src_src#[ x + 1 for x in pred_src_src]
|
||||||
pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst]
|
pred_dst_dst_sigm_ar = pred_dst_dst#[ x + 1 for x in pred_dst_dst]
|
||||||
pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst]
|
pred_src_dst_sigm_ar = pred_src_dst#[ x + 1 for x in pred_src_dst]
|
||||||
|
|
||||||
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
||||||
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||||
|
@ -311,7 +311,7 @@ class SAEModel(ModelBase):
|
||||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=2.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ])
|
src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ])
|
||||||
else:
|
else:
|
||||||
src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
|
src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
|
||||||
|
|
||||||
|
@ -325,13 +325,13 @@ class SAEModel(ModelBase):
|
||||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||||
if bg_style_power != 0:
|
if bg_style_power != 0:
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
bg_loss = K.mean( (100*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
bg_loss = K.mean( (100*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||||
else:
|
else:
|
||||||
bg_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
bg_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
||||||
src_loss += bg_loss
|
src_loss += bg_loss
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=2.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ])
|
dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ])
|
||||||
else:
|
else:
|
||||||
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
|
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue