diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index bd28693..1a1925f 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -483,7 +483,8 @@ class SAEv2Model(ModelBase): dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) ) dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) ) - opt_D_loss = [] + G_loss = src_loss+dst_loss + if self.true_face_training: def DLoss(labels,logits): return K.mean(K.binary_crossentropy(labels,logits)) @@ -493,7 +494,7 @@ class SAEv2Model(ModelBase): src_code_d_zeros = K.zeros_like(src_code_d) dst_code_d = self.dis( self.model.dst_code ) dst_code_d_ones = K.ones_like(dst_code_d) - opt_D_loss = [ 0.01*DLoss(src_code_d_ones, src_code_d) ] + G_loss += 0.01*DLoss(src_code_d_ones, src_code_d) loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \ DLoss(src_code_d_zeros, src_code_d) ) * 0.5 @@ -502,7 +503,7 @@ class SAEv2Model(ModelBase): self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm], [src_loss,dst_loss], - self.src_dst_opt.get_updates( [src_loss+dst_loss]+opt_D_loss, self.model.src_dst_trainable_weights) + self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights) ) if self.options['learn_mask']: diff --git a/nnlib/device.py b/nnlib/device.py index 82b2883..306c777 100644 --- a/nnlib/device.py +++ b/nnlib/device.py @@ -232,6 +232,7 @@ class device: plaidML_build = os.environ.get("DFL_PLAIDML_BUILD", "0") == "1" plaidML_devices = None +plaidML_devices_count = 0 cuda_devices = None if plaidML_build: @@ -253,8 +254,8 @@ if plaidML_build: ctx.shutdown() except: pass - - if len(plaidML_devices) != 0: + plaidML_devices_count = len(plaidML_devices) + if plaidML_devices_count != 0: device.backend = "plaidML" else: if cuda_devices is None: