mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
refactoring. Added RecycleGAN for testing.
This commit is contained in:
parent
8686309417
commit
f8824f9601
24 changed files with 1661 additions and 1505 deletions
|
@ -1,10 +1,7 @@
|
|||
from models import ModelBase
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from nnlib import DSSIMMaskLossClass
|
||||
from nnlib import conv
|
||||
from nnlib import upscale
|
||||
from nnlib import nnlib
|
||||
from models import ModelBase
|
||||
from facelib import FaceType
|
||||
from samples import *
|
||||
|
||||
|
@ -17,16 +14,14 @@ class Model(ModelBase):
|
|||
|
||||
#override
|
||||
def onInitialize(self, **in_options):
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
self.set_vram_batch_requirements( {4.5:4,5:4,6:8,7:12,8:16,9:20,10:24,11:24,12:32,13:48} )
|
||||
|
||||
ae_input_layer = self.keras.layers.Input(shape=(128, 128, 3))
|
||||
mask_layer = self.keras.layers.Input(shape=(128, 128, 1)) #same as output
|
||||
ae_input_layer = Input(shape=(128, 128, 3))
|
||||
mask_layer = Input(shape=(128, 128, 1)) #same as output
|
||||
|
||||
self.encoder = self.Encoder(ae_input_layer)
|
||||
self.decoder = self.Decoder()
|
||||
self.inter_B = self.Intermediate ()
|
||||
self.inter_AB = self.Intermediate ()
|
||||
|
||||
self.encoder, self.decoder, self.inter_B, self.inter_AB = self.Build(ae_input_layer)
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
|
||||
|
@ -36,16 +31,14 @@ class Model(ModelBase):
|
|||
code = self.encoder(ae_input_layer)
|
||||
AB = self.inter_AB(code)
|
||||
B = self.inter_B(code)
|
||||
self.autoencoder_src = self.keras.models.Model([ae_input_layer,mask_layer], self.decoder(self.keras.layers.Concatenate()([AB, AB])) )
|
||||
self.autoencoder_dst = self.keras.models.Model([ae_input_layer,mask_layer], self.decoder(self.keras.layers.Concatenate()([B, AB])) )
|
||||
self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([AB, AB])) )
|
||||
self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([B, AB])) )
|
||||
|
||||
if self.is_training_mode:
|
||||
self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] )
|
||||
|
||||
optimizer = self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
dssimloss = DSSIMMaskLossClass(self.tf)([mask_layer])
|
||||
self.autoencoder_src.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
|
||||
self.autoencoder_dst.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
|
||||
|
||||
self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
||||
self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
|
||||
|
||||
if self.is_training_mode:
|
||||
f = SampleProcessor.TypeFlags
|
||||
|
@ -131,37 +124,52 @@ class Model(ModelBase):
|
|||
in_options['blur_mask_modifier'] = 0
|
||||
|
||||
return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.FULL, clip_border_mask_per=0.046875, **in_options)
|
||||
|
||||
def Build(self, input_layer):
|
||||
exec(nnlib.code_import_all, locals(), globals())
|
||||
|
||||
def Encoder(self, input_layer,):
|
||||
x = input_layer
|
||||
x = conv(self.keras, x, 128)
|
||||
x = conv(self.keras, x, 256)
|
||||
x = conv(self.keras, x, 512)
|
||||
x = conv(self.keras, x, 1024)
|
||||
x = self.keras.layers.Flatten()(x)
|
||||
return self.keras.models.Model(input_layer, x)
|
||||
def downscale (dim):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
|
||||
return func
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def Encoder():
|
||||
x = input_layer
|
||||
x = downscale(128)(x)
|
||||
x = downscale(256)(x)
|
||||
x = downscale(512)(x)
|
||||
x = downscale(1024)(x)
|
||||
x = Flatten()(x)
|
||||
return Model(input_layer, x)
|
||||
|
||||
def Intermediate(self):
|
||||
input_layer = self.keras.layers.Input(shape=(None, 8 * 8 * 1024))
|
||||
x = input_layer
|
||||
x = self.keras.layers.Dense(256)(x)
|
||||
x = self.keras.layers.Dense(8 * 8 * 512)(x)
|
||||
x = self.keras.layers.Reshape((8, 8, 512))(x)
|
||||
x = upscale(self.keras, x, 512)
|
||||
return self.keras.models.Model(input_layer, x)
|
||||
def Intermediate():
|
||||
input_layer = Input(shape=(None, 8 * 8 * 1024))
|
||||
x = input_layer
|
||||
x = Dense(256)(x)
|
||||
x = Dense(8 * 8 * 512)(x)
|
||||
x = Reshape((8, 8, 512))(x)
|
||||
x = upscale(512)(x)
|
||||
return Model(input_layer, x)
|
||||
|
||||
def Decoder(self):
|
||||
input_ = self.keras.layers.Input(shape=(16, 16, 1024))
|
||||
x = input_
|
||||
x = upscale(self.keras, x, 512)
|
||||
x = upscale(self.keras, x, 256)
|
||||
x = upscale(self.keras, x, 128)
|
||||
x = self.keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||
|
||||
y = input_ #mask decoder
|
||||
y = upscale(self.keras, y, 512)
|
||||
y = upscale(self.keras, y, 256)
|
||||
y = upscale(self.keras, y, 128)
|
||||
y = self.keras.layers.convolutional.Conv2D(1, kernel_size=5, padding='same', activation='sigmoid' )(y)
|
||||
|
||||
return self.keras.models.Model(input_, [x,y])
|
||||
def Decoder():
|
||||
input_ = Input(shape=(16, 16, 1024))
|
||||
x = input_
|
||||
x = upscale(512)(x)
|
||||
x = upscale(256)(x)
|
||||
x = upscale(128)(x)
|
||||
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||
|
||||
y = input_ #mask decoder
|
||||
y = upscale(512)(y)
|
||||
y = upscale(256)(y)
|
||||
y = upscale(128)(y)
|
||||
y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid' )(y)
|
||||
|
||||
return Model(input_, [x,y])
|
||||
|
||||
return Encoder(), Decoder(), Intermediate(), Intermediate()
|
Loading…
Add table
Add a link
Reference in a new issue