mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
1
This commit is contained in:
parent
dc11ec32be
commit
c06d073936
1 changed files with 10 additions and 10 deletions
|
@ -47,8 +47,8 @@ class FANSegmentator(object):
|
|||
self.model.get_layer (s).set_weights ( d[s] )
|
||||
except:
|
||||
io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy")
|
||||
|
||||
conv_weights_list = []
|
||||
|
||||
conv_weights_list = []
|
||||
for layer in self.model.layers:
|
||||
if 'CA.' in layer.name:
|
||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
|
@ -58,8 +58,8 @@ class FANSegmentator(object):
|
|||
inp_t = Input ( (resolution, resolution, 3) )
|
||||
real_t = Input ( (resolution, resolution, 1) )
|
||||
out_t = self.model(inp_t)
|
||||
|
||||
#loss = K.mean(10*K.square(out_t-real_t))
|
||||
|
||||
#loss = K.mean(10*K.square(out_t-real_t))
|
||||
loss = K.mean(10*K.binary_crossentropy(real_t,out_t) )
|
||||
|
||||
out_t_diff1 = out_t[:, 1:, :, :] - out_t[:, :-1, :, :]
|
||||
|
@ -68,10 +68,10 @@ class FANSegmentator(object):
|
|||
total_var_loss = K.mean( 0.1*K.abs(out_t_diff1), axis=[1, 2, 3] ) + K.mean( 0.1*K.abs(out_t_diff2), axis=[1, 2, 3] )
|
||||
|
||||
opt = Adam(lr=0.0001, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2)
|
||||
|
||||
|
||||
self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( [loss,total_var_loss], self.model.trainable_weights) )
|
||||
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
@ -118,19 +118,19 @@ class FANSegmentator(object):
|
|||
x = BlurPool(filt_size=3)(x) #x = MaxPooling2D()(x)
|
||||
|
||||
x1 = x = Conv2D(ngf*2, kernel_size=3, strides=1, padding='same', activation='relu', name='features.3')(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
|
||||
x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.6')(x)
|
||||
x2 = x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.8')(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
|
||||
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.11')(x)
|
||||
x3 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.13')(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
|
||||
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.16')(x)
|
||||
x4 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.18')(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
|
||||
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', name='CA.1')(x)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue