fix for plaidml

This commit is contained in:
Colombo 2019-10-08 16:55:21 +04:00
parent ac7725163d
commit 627df082d7
2 changed files with 7 additions and 5 deletions

View file

@ -483,7 +483,8 @@ class SAEv2Model(ModelBase):
dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) )
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
opt_D_loss = []
G_loss = src_loss+dst_loss
if self.true_face_training:
def DLoss(labels,logits):
return K.mean(K.binary_crossentropy(labels,logits))
@ -493,7 +494,7 @@ class SAEv2Model(ModelBase):
src_code_d_zeros = K.zeros_like(src_code_d)
dst_code_d = self.dis( self.model.dst_code )
dst_code_d_ones = K.ones_like(dst_code_d)
opt_D_loss = [ 0.01*DLoss(src_code_d_ones, src_code_d) ]
G_loss += 0.01*DLoss(src_code_d_ones, src_code_d)
loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \
DLoss(src_code_d_zeros, src_code_d) ) * 0.5
@ -502,7 +503,7 @@ class SAEv2Model(ModelBase):
self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm],
[src_loss,dst_loss],
self.src_dst_opt.get_updates( [src_loss+dst_loss]+opt_D_loss, self.model.src_dst_trainable_weights)
self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
)
if self.options['learn_mask']: