refactoring. Added RecycleGAN for testing.

This commit is contained in:
iperov 2018-12-28 19:38:52 +04:00
parent 8686309417
commit f8824f9601
24 changed files with 1661 additions and 1505 deletions

View file

@ -1,11 +1,9 @@
from models import ModelBase
import numpy as np
from samples import *
from nnlib import DSSIMMaskLossClass
from nnlib import conv
from nnlib import upscale
from nnlib import nnlib
from models import ModelBase
from facelib import FaceType
from samples import *
class Model(ModelBase):
@ -15,9 +13,7 @@ class Model(ModelBase):
#override
def onInitialize(self, **in_options):
tf = self.tf
keras = self.keras
K = keras.backend
exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {1.5:2,2:2,3:8,4:16,5:24,6:32,7:40,8:48} )
bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.created_vram_gb)
@ -26,21 +22,21 @@ class Model(ModelBase):
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
input_src_bgr = self.keras.layers.Input(bgr_shape)
input_src_mask = self.keras.layers.Input(mask_shape)
input_dst_bgr = self.keras.layers.Input(bgr_shape)
input_dst_mask = self.keras.layers.Input(mask_shape)
input_src_bgr = Input(bgr_shape)
input_src_mask = Input(mask_shape)
input_dst_bgr = Input(bgr_shape)
input_dst_mask = Input(mask_shape)
rec_src_bgr, rec_src_mask = self.decoder_src( self.encoder(input_src_bgr) )
rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) )
self.ae = self.keras.models.Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
if self.is_training_mode:
self.ae, = self.to_multi_gpu_model_if_possible ( [self.ae,] )
self.ae.compile(optimizer=self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
loss=[ DSSIMMaskLossClass(self.tf)([input_src_mask]), 'mae', DSSIMMaskLossClass(self.tf)([input_dst_mask]), 'mae' ] )
self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] )
self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask])
self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask])
@ -130,58 +126,69 @@ class Model(ModelBase):
return ConverterMasked(self.predictor_func, predictor_input_size=64, output_size=64, face_type=FaceType.HALF, **in_options)
def Build(self, created_vram_gb):
exec(nnlib.code_import_all, locals(), globals())
bgr_shape = (64, 64, 3)
mask_shape = (64, 64, 1)
def downscale (dim):
def func(x):
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
return func
def upscale (dim):
def func(x):
return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func
def Encoder(input_shape):
input_layer = self.keras.layers.Input(input_shape)
input_layer = Input(input_shape)
x = input_layer
if created_vram_gb >= 4:
x = conv(self.keras, x, 128)
x = conv(self.keras, x, 256)
x = conv(self.keras, x, 512)
x = conv(self.keras, x, 1024)
x = self.keras.layers.Dense(1024)(self.keras.layers.Flatten()(x))
x = self.keras.layers.Dense(4 * 4 * 1024)(x)
x = self.keras.layers.Reshape((4, 4, 1024))(x)
x = upscale(self.keras, x, 512)
x = downscale(128)(x)
x = downscale(256)(x)
x = downscale(512)(x)
x = downscale(1024)(x)
x = Dense(1024)(Flatten()(x))
x = Dense(4 * 4 * 1024)(x)
x = Reshape((4, 4, 1024))(x)
x = upscale(512)(x)
else:
x = conv(self.keras, x, 128 )
x = conv(self.keras, x, 256 )
x = conv(self.keras, x, 512 )
x = conv(self.keras, x, 768 )
x = self.keras.layers.Dense(512)(self.keras.layers.Flatten()(x))
x = self.keras.layers.Dense(4 * 4 * 512)(x)
x = self.keras.layers.Reshape((4, 4, 512))(x)
x = upscale(self.keras, x, 256)
return self.keras.models.Model(input_layer, x)
x = downscale(128)(x)
x = downscale(256)(x)
x = downscale(512)(x)
x = downscale(768)(x)
x = Dense(512)(Flatten()(x))
x = Dense(4 * 4 * 512)(x)
x = Reshape((4, 4, 512))(x)
x = upscale(256)(x)
return Model(input_layer, x)
def Decoder():
if created_vram_gb >= 4:
input_ = self.keras.layers.Input(shape=(8, 8, 512))
input_ = Input(shape=(8, 8, 512))
x = input_
x = upscale(self.keras, x, 512)
x = upscale(self.keras, x, 256)
x = upscale(self.keras, x, 128)
else:
input_ = self.keras.layers.Input(shape=(8, 8, 256))
x = upscale(512)(x)
x = upscale(256)(x)
x = upscale(128)(x)
x = input_
x = upscale(self.keras, x, 256)
x = upscale(self.keras, x, 128)
x = upscale(self.keras, x, 64)
else:
input_ = Input(shape=(8, 8, 256))
x = input_
x = upscale(256)(x)
x = upscale(128)(x)
x = upscale(64)(x)
y = input_ #mask decoder
y = upscale(self.keras, y, 256)
y = upscale(self.keras, y, 128)
y = upscale(self.keras, y, 64)
y = upscale(256)(y)
y = upscale(128)(y)
y = upscale(64)(y)
x = self.keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = self.keras.layers.convolutional.Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
return self.keras.models.Model(input_, [x,y])
return Model(input_, [x,y])
return bgr_shape, mask_shape, Encoder(bgr_shape), Decoder(), Decoder()