fix for plaidml

This commit is contained in:
Colombo 2019-10-08 16:55:21 +04:00
parent ac7725163d
commit 627df082d7
2 changed files with 7 additions and 5 deletions

View file

@ -483,7 +483,8 @@ class SAEv2Model(ModelBase):
dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) ) dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) )
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) ) dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
opt_D_loss = [] G_loss = src_loss+dst_loss
if self.true_face_training: if self.true_face_training:
def DLoss(labels,logits): def DLoss(labels,logits):
return K.mean(K.binary_crossentropy(labels,logits)) return K.mean(K.binary_crossentropy(labels,logits))
@ -493,7 +494,7 @@ class SAEv2Model(ModelBase):
src_code_d_zeros = K.zeros_like(src_code_d) src_code_d_zeros = K.zeros_like(src_code_d)
dst_code_d = self.dis( self.model.dst_code ) dst_code_d = self.dis( self.model.dst_code )
dst_code_d_ones = K.ones_like(dst_code_d) dst_code_d_ones = K.ones_like(dst_code_d)
opt_D_loss = [ 0.01*DLoss(src_code_d_ones, src_code_d) ] G_loss += 0.01*DLoss(src_code_d_ones, src_code_d)
loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \ loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \
DLoss(src_code_d_zeros, src_code_d) ) * 0.5 DLoss(src_code_d_zeros, src_code_d) ) * 0.5
@ -502,7 +503,7 @@ class SAEv2Model(ModelBase):
self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm], self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm],
[src_loss,dst_loss], [src_loss,dst_loss],
self.src_dst_opt.get_updates( [src_loss+dst_loss]+opt_D_loss, self.model.src_dst_trainable_weights) self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
) )
if self.options['learn_mask']: if self.options['learn_mask']:

View file

@ -232,6 +232,7 @@ class device:
plaidML_build = os.environ.get("DFL_PLAIDML_BUILD", "0") == "1" plaidML_build = os.environ.get("DFL_PLAIDML_BUILD", "0") == "1"
plaidML_devices = None plaidML_devices = None
plaidML_devices_count = 0
cuda_devices = None cuda_devices = None
if plaidML_build: if plaidML_build:
@ -253,8 +254,8 @@ if plaidML_build:
ctx.shutdown() ctx.shutdown()
except: except:
pass pass
plaidML_devices_count = len(plaidML_devices)
if len(plaidML_devices) != 0: if plaidML_devices_count != 0:
device.backend = "plaidML" device.backend = "plaidML"
else: else:
if cuda_devices is None: if cuda_devices is None: