mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
fix for plaidml
This commit is contained in:
parent
ac7725163d
commit
627df082d7
2 changed files with 7 additions and 5 deletions
|
@ -483,7 +483,8 @@ class SAEv2Model(ModelBase):
|
|||
dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) )
|
||||
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
|
||||
|
||||
opt_D_loss = []
|
||||
G_loss = src_loss+dst_loss
|
||||
|
||||
if self.true_face_training:
|
||||
def DLoss(labels,logits):
|
||||
return K.mean(K.binary_crossentropy(labels,logits))
|
||||
|
@ -493,7 +494,7 @@ class SAEv2Model(ModelBase):
|
|||
src_code_d_zeros = K.zeros_like(src_code_d)
|
||||
dst_code_d = self.dis( self.model.dst_code )
|
||||
dst_code_d_ones = K.ones_like(dst_code_d)
|
||||
opt_D_loss = [ 0.01*DLoss(src_code_d_ones, src_code_d) ]
|
||||
G_loss += 0.01*DLoss(src_code_d_ones, src_code_d)
|
||||
|
||||
loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \
|
||||
DLoss(src_code_d_zeros, src_code_d) ) * 0.5
|
||||
|
@ -502,7 +503,7 @@ class SAEv2Model(ModelBase):
|
|||
|
||||
self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm],
|
||||
[src_loss,dst_loss],
|
||||
self.src_dst_opt.get_updates( [src_loss+dst_loss]+opt_D_loss, self.model.src_dst_trainable_weights)
|
||||
self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
|
||||
)
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue