mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 21:42:08 -07:00
Converter:
Session is now saved to the model folder. blur and erode ranges are increased to -400+400 hist-match-bw is now replaced with seamless2 mode. Added 'ebs' color transfer mode (works only on Windows). FANSEG model (used in FAN-x mask modes) is retrained with new model configuration and now produces better precision and less jitter
This commit is contained in:
parent
70dada42ea
commit
7ed38a8097
29 changed files with 768 additions and 314 deletions
|
@ -47,10 +47,30 @@ class FANSegmentator(object):
|
|||
self.model.get_layer (s).set_weights ( d[s] )
|
||||
except:
|
||||
io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy")
|
||||
|
||||
conv_weights_list = []
|
||||
for layer in self.model.layers:
|
||||
if 'CA.' in layer.name:
|
||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
CAInitializerMP ( conv_weights_list )
|
||||
|
||||
if training:
|
||||
#self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
||||
self.model.compile(loss='binary_crossentropy', optimizer=Adam(tf_cpu_mode=2) )
|
||||
inp_t = Input ( (resolution, resolution, 3) )
|
||||
real_t = Input ( (resolution, resolution, 1) )
|
||||
out_t = self.model(inp_t)
|
||||
|
||||
#loss = K.mean(10*K.square(out_t-real_t))
|
||||
loss = K.mean(10*K.binary_crossentropy(real_t,out_t) )
|
||||
|
||||
out_t_diff1 = out_t[:, 1:, :, :] - out_t[:, :-1, :, :]
|
||||
out_t_diff2 = out_t[:, :, 1:, :] - out_t[:, :, :-1, :]
|
||||
|
||||
total_var_loss = K.mean( 0.1*K.abs(out_t_diff1), axis=[1, 2, 3] ) + K.mean( 0.1*K.abs(out_t_diff2), axis=[1, 2, 3] )
|
||||
|
||||
opt = Adam(lr=0.0001, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2)
|
||||
|
||||
self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( [loss,total_var_loss], self.model.trainable_weights) )
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
@ -61,8 +81,9 @@ class FANSegmentator(object):
|
|||
def save_weights(self):
|
||||
self.model.save_weights (str(self.weights_path))
|
||||
|
||||
def train_on_batch(self, inp, outp):
|
||||
return self.model.train_on_batch(inp, outp)
|
||||
def train(self, inp, real):
|
||||
loss, = self.train_func ([inp, real])
|
||||
return loss
|
||||
|
||||
def extract (self, input_image, is_input_tanh=False):
|
||||
input_shape_len = len(input_image.shape)
|
||||
|
@ -78,62 +99,62 @@ class FANSegmentator(object):
|
|||
return result
|
||||
|
||||
@staticmethod
|
||||
def BuildModel ( resolution, ngf=64, norm='', act='lrelu'):
|
||||
def BuildModel ( resolution, ngf=64):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
inp = Input ( (resolution,resolution,3) )
|
||||
x = inp
|
||||
x = FANSegmentator.Flow(ngf=ngf, norm=norm, act=act)(x)
|
||||
x = FANSegmentator.Flow(ngf=ngf)(x)
|
||||
model = Model(inp,x)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def Flow(ngf=64, num_downs=4, norm='', act='lrelu'):
|
||||
def Flow(ngf=64):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
||||
def func(input):
|
||||
x = input
|
||||
|
||||
x0 = x = Conv2D(ngf, kernel_size=3, strides=1, padding='same', activation='relu', name='features.0')(x)
|
||||
x = MaxPooling2D()(x)
|
||||
x = BlurPool(filt_size=3)(x) #x = MaxPooling2D()(x)
|
||||
|
||||
x1 = x = Conv2D(ngf*2, kernel_size=3, strides=1, padding='same', activation='relu', name='features.3')(x)
|
||||
x = MaxPooling2D()(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
|
||||
x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.6')(x)
|
||||
x2 = x = Conv2D(ngf*4, kernel_size=3, strides=1, padding='same', activation='relu', name='features.8')(x)
|
||||
x = MaxPooling2D()(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
|
||||
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.11')(x)
|
||||
x3 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.13')(x)
|
||||
x = MaxPooling2D()(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
|
||||
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.16')(x)
|
||||
x4 = x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', activation='relu', name='features.18')(x)
|
||||
x = MaxPooling2D()(x)
|
||||
x = BlurPool(filt_size=3)(x)
|
||||
|
||||
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same')(x)
|
||||
x = Conv2D(ngf*8, kernel_size=3, strides=1, padding='same', name='CA.1')(x)
|
||||
|
||||
x = Conv2DTranspose (ngf*4, 3, strides=2, padding='same', activation='relu') (x)
|
||||
x = Conv2DTranspose (ngf*4, 3, strides=2, padding='same', activation='relu', name='CA.2') (x)
|
||||
x = Concatenate(axis=3)([ x, x4])
|
||||
x = Conv2D (ngf*8, 3, strides=1, padding='same', activation='relu') (x)
|
||||
x = Conv2D (ngf*8, 3, strides=1, padding='same', activation='relu', name='CA.3') (x)
|
||||
|
||||
x = Conv2DTranspose (ngf*4, 3, strides=2, padding='same', activation='relu') (x)
|
||||
x = Conv2DTranspose (ngf*4, 3, strides=2, padding='same', activation='relu', name='CA.4') (x)
|
||||
x = Concatenate(axis=3)([ x, x3])
|
||||
x = Conv2D (ngf*8, 3, strides=1, padding='same', activation='relu') (x)
|
||||
x = Conv2D (ngf*8, 3, strides=1, padding='same', activation='relu', name='CA.5') (x)
|
||||
|
||||
x = Conv2DTranspose (ngf*2, 3, strides=2, padding='same', activation='relu') (x)
|
||||
x = Conv2DTranspose (ngf*2, 3, strides=2, padding='same', activation='relu', name='CA.6') (x)
|
||||
x = Concatenate(axis=3)([ x, x2])
|
||||
x = Conv2D (ngf*4, 3, strides=1, padding='same', activation='relu') (x)
|
||||
x = Conv2D (ngf*4, 3, strides=1, padding='same', activation='relu', name='CA.7') (x)
|
||||
|
||||
x = Conv2DTranspose (ngf, 3, strides=2, padding='same', activation='relu') (x)
|
||||
x = Conv2DTranspose (ngf, 3, strides=2, padding='same', activation='relu', name='CA.8') (x)
|
||||
x = Concatenate(axis=3)([ x, x1])
|
||||
x = Conv2D (ngf*2, 3, strides=1, padding='same', activation='relu') (x)
|
||||
x = Conv2D (ngf*2, 3, strides=1, padding='same', activation='relu', name='CA.9') (x)
|
||||
|
||||
x = Conv2DTranspose (ngf // 2, 3, strides=2, padding='same', activation='relu') (x)
|
||||
x = Conv2DTranspose (ngf // 2, 3, strides=2, padding='same', activation='relu', name='CA.10') (x)
|
||||
x = Concatenate(axis=3)([ x, x0])
|
||||
x = Conv2D (ngf, 3, strides=1, padding='same', activation='relu') (x)
|
||||
x = Conv2D (ngf, 3, strides=1, padding='same', activation='relu', name='CA.11') (x)
|
||||
|
||||
return Conv2D(1, 3, strides=1, padding='same', activation='sigmoid')(x)
|
||||
return Conv2D(1, 3, strides=1, padding='same', activation='sigmoid', name='CA.12')(x)
|
||||
|
||||
|
||||
return func
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue