diff --git a/facelib/FANSegmentator.py b/facelib/FANSegmentator.py index 1bdc2f3..21b4f07 100644 --- a/facelib/FANSegmentator.py +++ b/facelib/FANSegmentator.py @@ -47,8 +47,8 @@ class FANSegmentator(object): self.model.get_layer (s).set_weights ( d[s] ) except: io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy") - - conv_weights_list = [] + + conv_weights_list = [] for layer in self.model.layers: if 'CA.' in layer.name: conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights @@ -58,8 +58,8 @@ class FANSegmentator(object): inp_t = Input ( (resolution, resolution, 3) ) real_t = Input ( (resolution, resolution, 1) ) out_t = self.model(inp_t) - - #loss = K.mean(10*K.square(out_t-real_t)) + + #loss = K.mean(10*K.square(out_t-real_t)) loss = K.mean(10*K.binary_crossentropy(real_t,out_t) ) out_t_diff1 = out_t[:, 1:, :, :] - out_t[:, :-1, :, :] @@ -68,10 +68,10 @@ class FANSegmentator(object): total_var_loss = K.mean( 0.1*K.abs(out_t_diff1), axis=[1, 2, 3] ) + K.mean( 0.1*K.abs(out_t_diff2), axis=[1, 2, 3] ) opt = Adam(lr=0.0001, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2) - + self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( [loss,total_var_loss], self.model.trainable_weights) ) - + def __enter__(self): return self @@ -118,19 +118,19 @@ class FANSegmentator(object): x = BlurPool(filt_size=3)(x) #x = MaxPooling2D()(x) x1 = x = Conv2D(ngf*2, kernel_size=3, strides=1, padding='same', activation='relu', name='features.3')(x) - x = BlurPool(filt_size=3)(x) + x = BlurPool(filt_size=3)(x) x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.6')(x) x2 = x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.8')(x) - x = BlurPool(filt_size=3)(x) + x = BlurPool(filt_size=3)(x) x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.11')(x) x3 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.13')(x) - x = BlurPool(filt_size=3)(x) + x = BlurPool(filt_size=3)(x) x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.16')(x) x4 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.18')(x) - x = BlurPool(filt_size=3)(x) + x = BlurPool(filt_size=3)(x) x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', name='CA.1')(x)