This commit is contained in:
Colombo 2019-09-19 11:16:35 +04:00
parent dc11ec32be
commit c06d073936

View file

@ -47,8 +47,8 @@ class FANSegmentator(object):
self.model.get_layer (s).set_weights ( d[s] )
except:
io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy")
conv_weights_list = []
conv_weights_list = []
for layer in self.model.layers:
if 'CA.' in layer.name:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
@ -58,8 +58,8 @@ class FANSegmentator(object):
inp_t = Input ( (resolution, resolution, 3) )
real_t = Input ( (resolution, resolution, 1) )
out_t = self.model(inp_t)
#loss = K.mean(10*K.square(out_t-real_t))
#loss = K.mean(10*K.square(out_t-real_t))
loss = K.mean(10*K.binary_crossentropy(real_t,out_t) )
out_t_diff1 = out_t[:, 1:, :, :] - out_t[:, :-1, :, :]
@ -68,10 +68,10 @@ class FANSegmentator(object):
total_var_loss = K.mean( 0.1*K.abs(out_t_diff1), axis=[1, 2, 3] ) + K.mean( 0.1*K.abs(out_t_diff2), axis=[1, 2, 3] )
opt = Adam(lr=0.0001, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2)
self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( [loss,total_var_loss], self.model.trainable_weights) )
def __enter__(self):
return self
@ -118,19 +118,19 @@ class FANSegmentator(object):
x = BlurPool(filt_size=3)(x) #x = MaxPooling2D()(x)
x1 = x = Conv2D(ngf*2, kernel_size=3, strides=1, padding='same', activation='relu', name='features.3')(x)
x = BlurPool(filt_size=3)(x)
x = BlurPool(filt_size=3)(x)
x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.6')(x)
x2 = x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.8')(x)
x = BlurPool(filt_size=3)(x)
x = BlurPool(filt_size=3)(x)
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.11')(x)
x3 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.13')(x)
x = BlurPool(filt_size=3)(x)
x = BlurPool(filt_size=3)(x)
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.16')(x)
x4 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.18')(x)
x = BlurPool(filt_size=3)(x)
x = BlurPool(filt_size=3)(x)
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', name='CA.1')(x)