converter: now writes a filename of current frame config,

SAE: removed multiscale decoder, because it's not effective
This commit is contained in:
Colombo 2019-09-13 08:59:00 +04:00
parent bef4e5d33c
commit b6b92bded0
6 changed files with 325 additions and 450 deletions

View file

@ -74,14 +74,14 @@ class ConverterConfig(object):
return False return False
#overridable #overridable
def __str__(self): def to_string(self, filename):
r = "" r = ""
r += f"sharpen_mode : {self.sharpen_dict[self.sharpen_mode]}\n" r += f"sharpen_mode : {self.sharpen_dict[self.sharpen_mode]}\n"
if self.sharpen_mode != 0: if self.sharpen_mode != 0:
r += f"sharpen_amount : {self.sharpen_amount}\n" r += f"sharpen_amount : {self.sharpen_amount}\n"
r += f"super_resolution_mode : {self.super_res_dict[self.super_resolution_mode]}\n" r += f"super_resolution_mode : {self.super_res_dict[self.super_resolution_mode]}\n"
return r return r
mode_dict = {0:'original', mode_dict = {0:'original',
1:'overlay', 1:'overlay',
2:'hist-match', 2:'hist-match',
@ -249,9 +249,9 @@ class ConverterConfigMasked(ConverterConfig):
return False return False
def __str__(self): def to_string(self, filename):
r = ( r = (
"""ConverterConfig:\n""" f"""ConverterConfig {filename}:\n"""
f"""Mode: {self.mode}\n""" f"""Mode: {self.mode}\n"""
) )
@ -276,7 +276,7 @@ class ConverterConfigMasked(ConverterConfig):
if 'raw' not in self.mode: if 'raw' not in self.mode:
r += f"""color_transfer_mode: { ctm_dict[self.color_transfer_mode]}\n""" r += f"""color_transfer_mode: { ctm_dict[self.color_transfer_mode]}\n"""
r += super().__str__() r += super().to_string(filename)
if 'raw' not in self.mode: if 'raw' not in self.mode:
r += (f"""color_degrade_power: {self.color_degrade_power}\n""" r += (f"""color_degrade_power: {self.color_degrade_power}\n"""
@ -318,8 +318,8 @@ class ConverterConfigFaceAvatar(ConverterConfig):
return False return False
#override #override
def __str__(self): def to_string(self, filename):
return ("ConverterConfig: \n" return (f"ConverterConfig {filename}:\n"
f"add_source_image : {self.add_source_image}\n") + \ f"add_source_image : {self.add_source_image}\n") + \
super().__str__() + "================" super().to_string(filename) + "================"

View file

@ -1,6 +1,9 @@
from pathlib import Path
class FrameInfo(object): class FrameInfo(object):
def __init__(self, filename=None, landmarks_list=None): def __init__(self, filename=None, landmarks_list=None):
self.filename = filename self.filename = filename
self.filename_short = str(Path(filename).name)
self.landmarks_list = landmarks_list or [] self.landmarks_list = landmarks_list or []
self.motion_deg = 0 self.motion_deg = 0
self.motion_power = 0 self.motion_power = 0

View file

@ -369,7 +369,7 @@ class ConvertSubprocessor(Subprocessor):
if not cur_frame.is_shown: if not cur_frame.is_shown:
if cur_frame.is_done: if cur_frame.is_done:
cur_frame.is_shown = True cur_frame.is_shown = True
io.log_info (cur_frame.cfg) io.log_info (cur_frame.cfg.to_string( cur_frame.frame_info.filename_short) )
if cur_frame.image is None: if cur_frame.image is None:
cur_frame.image = cv2_imread ( cur_frame.output_filename) cur_frame.image = cv2_imread ( cur_frame.output_filename)
@ -464,7 +464,7 @@ class ConvertSubprocessor(Subprocessor):
cfg.toggle_sharpen_mode() cfg.toggle_sharpen_mode()
if prev_cfg != cfg: if prev_cfg != cfg:
io.log_info (cfg) io.log_info ( cfg.to_string(cur_frame.frame_info.filename_short) )
cur_frame.is_done = False cur_frame.is_done = False
cur_frame.is_shown = False cur_frame.is_shown = False
else: else:

View file

@ -204,6 +204,7 @@ class ModelBase(object):
if self.sample_for_preview is None or choose_preview_history: if self.sample_for_preview is None or choose_preview_history:
if choose_preview_history and io.is_support_windows(): if choose_preview_history and io.is_support_windows():
io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.")
wnd_name = "[p] - next. [enter] - confirm." wnd_name = "[p] - next. [enter] - confirm."
io.named_window(wnd_name) io.named_window(wnd_name)
io.capture_keys(wnd_name) io.capture_keys(wnd_name)
@ -411,11 +412,17 @@ class ModelBase(object):
cv2_imwrite (filepath, img ) cv2_imwrite (filepath, img )
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]): def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
for model, filename in model_filename_list: loaded = []
not_loaded = []
for mf in model_filename_list:
model, filename = mf
filename = self.get_strpath_storage_for_file(filename) filename = self.get_strpath_storage_for_file(filename)
if Path(filename).exists(): if Path(filename).exists():
loaded += [ mf ]
model.load_weights(filename) model.load_weights(filename)
else:
not_loaded += [ mf ]
if len(optimizer_filename_list) != 0: if len(optimizer_filename_list) != 0:
opt_filename = self.get_strpath_storage_for_file('opt.h5') opt_filename = self.get_strpath_storage_for_file('opt.h5')
if Path(opt_filename).exists(): if Path(opt_filename).exists():
@ -432,7 +439,8 @@ class ModelBase(object):
print("set ok") print("set ok")
except Exception as e: except Exception as e:
print ("Unable to load ", opt_filename) print ("Unable to load ", opt_filename)
return loaded, not_loaded
def save_weights_safe(self, model_filename_list): def save_weights_safe(self, model_filename_list):
for model, filename in model_filename_list: for model, filename in model_filename_list:

View file

@ -1,26 +1,18 @@
from functools import partial from functools import partial
import numpy as np import numpy as np
from nnlib import nnlib import mathlib
from models import ModelBase
from facelib import FaceType from facelib import FaceType
from samplelib import *
from interact import interact as io from interact import interact as io
from models import ModelBase
from nnlib import nnlib
from samplelib import *
#SAE - Styled AutoEncoder #SAE - Styled AutoEncoder
class SAEModel(ModelBase): class SAEModel(ModelBase):
encoderH5 = 'encoder.h5'
inter_BH5 = 'inter_B.h5'
inter_ABH5 = 'inter_AB.h5'
decoderH5 = 'decoder.h5'
decodermH5 = 'decoderm.h5'
decoder_srcH5 = 'decoder_src.h5'
decoder_srcmH5 = 'decoder_srcm.h5'
decoder_dstH5 = 'decoder_dst.h5'
decoder_dstmH5 = 'decoder_dstm.h5'
#override #override
def onInitializeOptions(self, is_first_run, ask_override): def onInitializeOptions(self, is_first_run, ask_override):
yn_str = {True:'y',False:'n'} yn_str = {True:'y',False:'n'}
@ -28,6 +20,7 @@ class SAEModel(ModelBase):
default_resolution = 128 default_resolution = 128
default_archi = 'df' default_archi = 'df'
default_face_type = 'f' default_face_type = 'f'
default_learn_mask = True
if is_first_run: if is_first_run:
resolution = io.input_int("Resolution ( 64-256 ?:help skip:128) : ", default_resolution, help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.") resolution = io.input_int("Resolution ( 64-256 ?:help skip:128) : ", default_resolution, help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
@ -37,12 +30,11 @@ class SAEModel(ModelBase):
self.options['resolution'] = resolution self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower() self.options['face_type'] = io.input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
self.options['learn_mask'] = io.input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.") self.options['learn_mask'] = io.input_bool ( f"Learn mask? (y/n, ?:help skip:{yn_str[default_learn_mask]} ) : " , default_learn_mask, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.")
else: else:
self.options['resolution'] = self.options.get('resolution', default_resolution) self.options['resolution'] = self.options.get('resolution', default_resolution)
self.options['face_type'] = self.options.get('face_type', default_face_type) self.options['face_type'] = self.options.get('face_type', default_face_type)
self.options['learn_mask'] = self.options.get('learn_mask', True) self.options['learn_mask'] = self.options.get('learn_mask', default_learn_mask)
if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend: if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend:
def_optimizer_mode = self.options.get('optimizer_mode', 1) def_optimizer_mode = self.options.get('optimizer_mode', 1)
@ -59,26 +51,24 @@ class SAEModel(ModelBase):
default_e_ch_dims = 42 default_e_ch_dims = 42
default_d_ch_dims = default_e_ch_dims // 2 default_d_ch_dims = default_e_ch_dims // 2
def_ca_weights = False def_ca_weights = False
if is_first_run: if is_first_run:
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
self.options['e_ch_dims'] = np.clip ( io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims) , default_e_ch_dims, help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 ) self.options['e_ch_dims'] = np.clip ( io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims) , default_e_ch_dims, help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 )
default_d_ch_dims = self.options['e_ch_dims'] // 2 default_d_ch_dims = self.options['e_ch_dims'] // 2
self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 ) self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 )
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.") self.options['ca_weights'] = io.input_bool (f"Use CA weights? (y/n, ?:help skip:{yn_str[def_ca_weights]} ) : ", def_ca_weights, help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes a time at first run.")
self.options['ca_weights'] = io.input_bool ("Use CA weights? (y/n, ?:help skip: %s ) : " % (yn_str[def_ca_weights]), def_ca_weights, help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes a time at first run.")
else: else:
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims) self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims) self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims)
self.options['d_ch_dims'] = self.options.get('d_ch_dims', default_d_ch_dims) self.options['d_ch_dims'] = self.options.get('d_ch_dims', default_d_ch_dims)
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights) self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights)
default_face_style_power = 0.0 default_face_style_power = 0.0
default_bg_style_power = 0.0 default_bg_style_power = 0.0
if is_first_run or ask_override: if is_first_run or ask_override:
def_pixel_loss = self.options.get('pixel_loss', False) def_pixel_loss = self.options.get('pixel_loss', False)
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: %s ) : " % (yn_str[def_pixel_loss]), def_pixel_loss, help_message="Pixel loss may help to enhance fine details and stabilize face color. Use it only if quality does not improve over time. Enabling this option too early increases the chance of model collapse.") self.options['pixel_loss'] = io.input_bool (f"Use pixel loss? (y/n, ?:help skip:{yn_str[def_pixel_loss]} ) : ", def_pixel_loss, help_message="Pixel loss may help to enhance fine details and stabilize face color. Use it only if quality does not improve over time. Enabling this option too early increases the chance of model collapse.")
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power) default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power, self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power,
@ -89,11 +79,11 @@ class SAEModel(ModelBase):
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 ) help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False) default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
self.options['apply_random_ct'] = io.input_bool ("Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (yn_str[default_apply_random_ct]), default_apply_random_ct, help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.") self.options['apply_random_ct'] = io.input_bool (f"Apply random color transfer to src faceset? (y/n, ?:help skip:{yn_str[default_apply_random_ct]}) : ", default_apply_random_ct, help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301 if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301
default_clipgrad = False if is_first_run else self.options.get('clipgrad', False) default_clipgrad = False if is_first_run else self.options.get('clipgrad', False)
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping? (y/n, ?:help skip:%s) : " % (yn_str[default_clipgrad]), default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") self.options['clipgrad'] = io.input_bool (f"Enable gradient clipping? (y/n, ?:help skip:{yn_str[default_clipgrad]}) : ", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
else: else:
self.options['clipgrad'] = False self.options['clipgrad'] = False
@ -112,10 +102,11 @@ class SAEModel(ModelBase):
#override #override
def onInitialize(self): def onInitialize(self):
exec(nnlib.import_all(), locals(), globals()) exec(nnlib.import_all(), locals(), globals())
SAEModel.initialize_nn_functions()
self.set_vram_batch_requirements({1.5:4}) self.set_vram_batch_requirements({1.5:4})
resolution = self.options['resolution'] resolution = self.options['resolution']
learn_mask = self.options['learn_mask']
ae_dims = self.options['ae_dims'] ae_dims = self.options['ae_dims']
e_ch_dims = self.options['e_ch_dims'] e_ch_dims = self.options['e_ch_dims']
d_ch_dims = self.options['d_ch_dims'] d_ch_dims = self.options['d_ch_dims']
@ -127,232 +118,333 @@ class SAEModel(ModelBase):
bgr_shape = (resolution, resolution, 3) bgr_shape = (resolution, resolution, 3)
mask_shape = (resolution, resolution, 1) mask_shape = (resolution, resolution, 1)
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
apply_random_ct = self.options.get('apply_random_ct', False) apply_random_ct = self.options.get('apply_random_ct', False)
masked_training = True masked_training = True
warped_src = Input(bgr_shape) class SAEDFModel(object):
target_src = Input(bgr_shape) def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
target_srcm = Input(mask_shape) super().__init__()
self.learn_mask = learn_mask
warped_dst = Input(bgr_shape) output_nc = 3
target_dst = Input(bgr_shape) bgr_shape = (resolution, resolution, output_nc)
target_dstm = Input(mask_shape) mask_shape = (resolution, resolution, 1)
lowest_dense_res = resolution // 16
e_dims = output_nc*e_ch_dims
target_src_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)] def upscale (dim):
target_srcm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)] def func(x):
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)] return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x)))
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)] return func
common_flow_kwargs = { 'padding': 'zero', def enc_flow(e_dims, ae_dims, lowest_dense_res):
'norm': '', def func(x):
'act':'' } x = LeakyReLU(0.1)(Conv2D(e_dims, kernel_size=5, strides=2, padding='same')(x))
models_list = [] x = LeakyReLU(0.1)(Conv2D(e_dims*2, kernel_size=5, strides=2, padding='same')(x))
weights_to_load = [] x = LeakyReLU(0.1)(Conv2D(e_dims*4, kernel_size=5, strides=2, padding='same')(x))
if 'liae' in self.options['archi']: x = LeakyReLU(0.1)(Conv2D(e_dims*8, kernel_size=5, strides=2, padding='same')(x))
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] x = Dense(ae_dims)(Flatten()(x))
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
x = upscale(ae_dims)(x)
return x
return func
self.inter_B = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, **common_flow_kwargs)) (enc_output_Inputs) def dec_flow(output_nc, d_ch_dims):
self.inter_AB = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, **common_flow_kwargs)) (enc_output_Inputs) def ResidualBlock(dim):
def func(inp):
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
x = LeakyReLU(0.2)(x)
x = Conv2D(dim, kernel_size=3, padding='same')(x)
x = Add()([x, inp])
x = LeakyReLU(0.2)(x)
return x
return func
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ] def func(x):
dims = output_nc * d_ch_dims
x = upscale(dims*8)(x)
x = ResidualBlock(dims*8)(x)
x = ResidualBlock(dims*8)(x)
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs) x = upscale(dims*4)(x)
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder] x = ResidualBlock(dims*4)(x)
x = ResidualBlock(dims*4)(x)
if self.options['learn_mask']: x = upscale(dims*2)(x)
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs) x = ResidualBlock(dims*2)(x)
models_list += [self.decoderm] x = ResidualBlock(dims*2)(x)
if not self.is_first_run(): return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
weights_to_load += [ [self.encoder , 'encoder.h5'], return func
[self.inter_B , 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5'],
[self.decoder , 'decoder.h5'],
]
if self.options['learn_mask']:
weights_to_load += [ [self.decoderm, 'decoderm.h5'] ]
warped_src_code = self.encoder (warped_src) self.encoder = modelify(enc_flow(e_dims, ae_dims, lowest_dense_res)) ( Input(bgr_shape) )
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
warped_dst_code = self.encoder (warped_dst) sh = K.int_shape( self.encoder.outputs[0] )[1:]
warped_dst_inter_B_code = self.inter_B (warped_dst_code) self.decoder_src = modelify(dec_flow(output_nc, d_ch_dims)) ( Input(sh) )
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code) self.decoder_dst = modelify(dec_flow(output_nc, d_ch_dims)) ( Input(sh) )
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code]) if learn_mask:
self.decoder_srcm = modelify(dec_flow(1, d_ch_dims)) ( Input(sh) )
self.decoder_dstm = modelify(dec_flow(1, d_ch_dims)) ( Input(sh) )
pred_src_src = self.decoder(warped_src_inter_code) self.src_dst_trainable_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
pred_dst_dst = self.decoder(warped_dst_inter_code)
pred_src_dst = self.decoder(warped_src_dst_inter_code)
if self.options['learn_mask']: if learn_mask:
pred_src_srcm = self.decoderm(warped_src_inter_code) self.src_dst_mask_trainable_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
elif 'df' in self.options['archi']: self.warped_src, self.warped_dst = Input(bgr_shape), Input(bgr_shape)
self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) src_code, dst_code = self.encoder(self.warped_src), self.encoder(self.warped_dst)
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] self.pred_src_src = self.decoder_src(src_code)
self.pred_dst_dst = self.decoder_dst(dst_code)
self.pred_src_dst = self.decoder_src(dst_code)
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) if learn_mask:
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) self.pred_src_srcm = self.decoder_srcm(src_code)
models_list += [self.encoder, self.decoder_src, self.decoder_dst] self.pred_dst_dstm = self.decoder_dstm(dst_code)
self.pred_src_dstm = self.decoder_srcm(dst_code)
if self.options['learn_mask']: def get_model_filename_list(self, exclude_for_pretrain=False):
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) ar = []
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) if not exclude_for_pretrain:
models_list += [self.decoder_srcm, self.decoder_dstm] ar += [ [self.encoder, 'encoder.h5'] ]
ar += [ [self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5'] ]
if self.learn_mask:
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'] ]
return ar
if not self.is_first_run(): class SAELIAEModel(object):
weights_to_load += [ [self.encoder , 'encoder.h5'], def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
[self.decoder_src, 'decoder_src.h5'], super().__init__()
[self.decoder_dst, 'decoder_dst.h5'] self.learn_mask = learn_mask
]
if self.options['learn_mask']:
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'],
]
warped_src_code = self.encoder (warped_src) output_nc = 3
warped_dst_code = self.encoder (warped_dst) bgr_shape = (resolution, resolution, output_nc)
pred_src_src = self.decoder_src(warped_src_code) mask_shape = (resolution, resolution, 1)
pred_dst_dst = self.decoder_dst(warped_dst_code)
pred_src_dst = self.decoder_src(warped_dst_code)
if self.options['learn_mask']: e_dims = output_nc*e_ch_dims
pred_src_srcm = self.decoder_srcm(warped_src_code) d_dims = output_nc*d_ch_dims
pred_dst_dstm = self.decoder_dstm(warped_dst_code) lowest_dense_res = resolution // 16
pred_src_dstm = self.decoder_srcm(warped_dst_code)
if self.is_first_run(): def upscale (dim):
if self.options.get('ca_weights',False): def func(x):
conv_weights_list = [] return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x)))
for model in models_list: return func
for layer in model.layers:
if type(layer) == keras.layers.Conv2D:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
CAInitializerMP ( conv_weights_list )
else:
self.load_weights_safe(weights_to_load)
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ] def enc_flow(e_dims):
def func(x):
x = LeakyReLU(0.1)(Conv2D(e_dims, kernel_size=5, strides=2, padding='same')(x))
x = LeakyReLU(0.1)(Conv2D(e_dims*2, kernel_size=5, strides=2, padding='same')(x))
x = LeakyReLU(0.1)(Conv2D(e_dims*4, kernel_size=5, strides=2, padding='same')(x))
x = LeakyReLU(0.1)(Conv2D(e_dims*8, kernel_size=5, strides=2, padding='same')(x))
x = Flatten()(x)
return x
return func
if self.options['learn_mask']: def inter_flow(lowest_dense_res, ae_dims):
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ] def func(x):
x = Dense(ae_dims)(x)
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
x = upscale(ae_dims*2)(x)
return x
return func
target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar] def dec_flow(output_nc, d_dims):
target_srcm_sigm_ar = target_srcm_blurred_ar #[ x / 2.0 + 0.5 for x in target_srcm_blurred_ar] def ResidualBlock(dim):
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar] def func(inp):
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
x = LeakyReLU(0.2)(x)
x = Conv2D(dim, kernel_size=3, padding='same')(x)
x = Add()([x, inp])
x = LeakyReLU(0.2)(x)
return x
return func
target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar] def func(x):
target_dstm_sigm_ar = target_dstm_blurred_ar#[ x / 2.0 + 0.5 for x in target_dstm_blurred_ar] x = upscale(d_dims*8)(x)
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar] x = ResidualBlock(d_dims*8)(x)
x = ResidualBlock(d_dims*8)(x)
target_src_sigm_ar = target_src_ar#[ x + 1 for x in target_src_ar] x = upscale(d_dims*4)(x)
target_dst_sigm_ar = target_dst_ar#[ x + 1 for x in target_dst_ar] x = ResidualBlock(d_dims*4)(x)
x = ResidualBlock(d_dims*4)(x)
pred_src_src_sigm_ar = pred_src_src#[ x + 1 for x in pred_src_src] x = upscale(d_dims*2)(x)
pred_dst_dst_sigm_ar = pred_dst_dst#[ x + 1 for x in pred_dst_dst] x = ResidualBlock(d_dims*2)(x)
pred_src_dst_sigm_ar = pred_src_dst#[ x + 1 for x in pred_src_dst] x = ResidualBlock(d_dims*2)(x)
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))] return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))] return func
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
pred_src_src_masked_ar = [ pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] for i in range(len(pred_src_src_sigm_ar))] self.encoder = modelify(enc_flow(e_dims)) ( Input(bgr_shape) )
pred_dst_dst_masked_ar = [ pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in range(len(pred_dst_dst_sigm_ar))]
target_src_masked_ar_opt = target_src_masked_ar if masked_training else target_src_sigm_ar sh = K.int_shape( self.encoder.outputs[0] )[1:]
target_dst_masked_ar_opt = target_dst_masked_ar if masked_training else target_dst_sigm_ar self.inter_B = modelify(inter_flow(lowest_dense_res, ae_dims)) ( Input(sh) )
self.inter_AB = modelify(inter_flow(lowest_dense_res, ae_dims)) ( Input(sh) )
pred_src_src_masked_ar_opt = pred_src_src_masked_ar if masked_training else pred_src_src_sigm_ar sh = np.array(K.int_shape( self.inter_B.outputs[0] )[1:])*(1,1,2)
pred_dst_dst_masked_ar_opt = pred_dst_dst_masked_ar if masked_training else pred_dst_dst_sigm_ar self.decoder = modelify(dec_flow(output_nc, d_dims)) ( Input(sh) )
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] if learn_mask:
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] self.decoderm = modelify(dec_flow(1, d_dims)) ( Input(sh) )
self.src_dst_trainable_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
if learn_mask:
self.src_dst_mask_trainable_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
self.warped_src, self.warped_dst = Input(bgr_shape), Input(bgr_shape)
warped_src_code = self.encoder (self.warped_src)
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
warped_dst_code = self.encoder (self.warped_dst)
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
self.pred_src_src = self.decoder(warped_src_inter_code)
self.pred_dst_dst = self.decoder(warped_dst_inter_code)
self.pred_src_dst = self.decoder(warped_src_dst_inter_code)
if learn_mask:
self.pred_src_srcm = self.decoderm(warped_src_inter_code)
self.pred_dst_dstm = self.decoderm(warped_dst_inter_code)
self.pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
def get_model_filename_list(self, exclude_for_pretrain=False):
ar = [ [self.encoder, 'encoder.h5'],
[self.inter_B, 'inter_B.h5'] ]
if not exclude_for_pretrain:
ar += [ [self.inter_AB, 'inter_AB.h5'] ]
ar += [ [self.decoder, 'decoder.h5'] ]
if self.learn_mask:
ar += [ [self.decoderm, 'decoderm.h5'] ]
return ar
if 'df' in self.options['archi']:
self.model = SAEDFModel (resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask)
elif 'liae' in self.options['archi']:
self.model = SAELIAEModel (resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask)
loaded, not_loaded = [], self.model.get_model_filename_list()
if not self.is_first_run():
loaded, not_loaded = self.load_weights_safe(not_loaded)
CA_models = []
if self.options.get('ca_weights', False):
CA_models += [ model for model, _ in not_loaded ]
CA_conv_weights_list = []
for model in CA_models:
for layer in model.layers:
if type(layer) == keras.layers.Conv2D:
CA_conv_weights_list += [layer.weights[0]] #- is Conv2D kernel_weights
if len(CA_conv_weights_list) != 0:
CAInitializerMP ( CA_conv_weights_list )
warped_src = self.model.warped_src
target_src = Input ( (resolution, resolution, 3) )
target_srcm = Input ( (resolution, resolution, 1) )
warped_dst = self.model.warped_dst
target_dst = Input ( (resolution, resolution, 3) )
target_dstm = Input ( (resolution, resolution, 1) )
target_src_sigm = target_src
target_dst_sigm = target_dst
target_srcm_sigm = gaussian_blur( max(1, K.int_shape(target_srcm)[1] // 32) )(target_srcm)
target_dstm_sigm = gaussian_blur( max(1, K.int_shape(target_dstm)[1] // 32) )(target_dstm)
target_dstm_anti_sigm = 1.0 - target_dstm_sigm
target_src_masked = target_src_sigm*target_srcm_sigm
target_dst_masked = target_dst_sigm*target_dstm_sigm
target_dst_anti_masked = target_dst_sigm*target_dstm_anti_sigm
target_src_masked_opt = target_src_masked if masked_training else target_src_sigm
target_dst_masked_opt = target_dst_masked if masked_training else target_dst_sigm
pred_src_src = self.model.pred_src_src
pred_dst_dst = self.model.pred_dst_dst
pred_src_dst = self.model.pred_src_dst
if learn_mask:
pred_src_srcm = self.model.pred_src_srcm
pred_dst_dstm = self.model.pred_dst_dstm
pred_src_dstm = self.model.pred_src_dstm
pred_src_src_sigm = self.model.pred_src_src
pred_dst_dst_sigm = self.model.pred_dst_dst
pred_src_dst_sigm = self.model.pred_src_dst
pred_src_src_masked = pred_src_src_sigm*target_srcm_sigm
pred_dst_dst_masked = pred_dst_dst_sigm*target_dstm_sigm
pred_src_src_masked_opt = pred_src_src_masked if masked_training else pred_src_src_sigm
pred_dst_dst_masked_opt = pred_dst_dst_masked if masked_training else pred_dst_dst_sigm
psd_target_dst_masked = pred_src_dst_sigm*target_dstm_sigm
psd_target_dst_anti_masked = pred_src_dst_sigm*target_dstm_anti_sigm
if self.is_training_mode: if self.is_training_mode:
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
self.sr_opt = Adam(lr=5e-5, beta_1=0.9, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
if 'liae' in self.options['archi']:
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
if self.options['learn_mask']:
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
else:
src_dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
if self.options['learn_mask']:
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
src_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i]) for i in range(len(target_src_masked_ar_opt)) ]) src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
else: else:
src_loss_batch = sum([ K.mean ( 50*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ]) src_loss = K.mean ( 50*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
src_loss = K.mean(src_loss_batch)
face_style_power = self.options['face_style_power'] / 100.0
face_style_power = self.options['face_style_power'] / 100.0
if face_style_power != 0: if face_style_power != 0:
src_loss += style_loss(gaussian_blur_radius=resolution//16, loss_weight=face_style_power, wnd_size=0)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] ) src_loss += style_loss(gaussian_blur_radius=resolution//16, loss_weight=face_style_power, wnd_size=0)( psd_target_dst_masked, target_dst_masked )
bg_style_power = self.options['bg_style_power'] / 100.0 bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0: if bg_style_power != 0:
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
bg_loss = K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )) src_loss += K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))
else: else:
bg_loss = K.mean( (50*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] )) src_loss += K.mean( (50*bg_style_power)*K.square( psd_target_dst_anti_masked - target_dst_anti_masked ))
src_loss += bg_loss
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
dst_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i]) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) )
else: else:
dst_loss_batch = sum([ K.mean ( 50*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss = K.mean( 50*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
dst_loss = K.mean(dst_loss_batch) self.src_dst_train = K.function ([warped_src, warped_dst, target_src, target_srcm, target_dst, target_dstm],[src_loss,dst_loss], self.src_dst_opt.get_updates(src_loss+dst_loss, self.model.src_dst_trainable_weights) )
feed = [warped_src, warped_dst]
feed += target_src_ar[::-1]
feed += target_srcm_ar[::-1]
feed += target_dst_ar[::-1]
feed += target_dstm_ar[::-1]
self.src_dst_train = K.function (feed,[src_loss,dst_loss], self.src_dst_opt.get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
if self.options['learn_mask']: if self.options['learn_mask']:
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[-1]-pred_src_srcm[-1])) for i in range(len(target_srcm_ar)) ]) src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[-1]-pred_dst_dstm[-1])) for i in range(len(target_dstm_ar)) ]) dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
self.src_dst_mask_train = K.function ([warped_src, warped_dst, target_srcm, target_dstm],[src_mask_loss, dst_mask_loss], self.src_dst_mask_opt.get_updates(src_mask_loss+dst_mask_loss, self.model.src_dst_mask_trainable_weights ) )
feed = [ warped_src, warped_dst]
feed += target_srcm_ar[::-1]
feed += target_dstm_ar[::-1]
self.src_dst_mask_train = K.function (feed,[src_mask_loss, dst_mask_loss], self.src_dst_mask_opt.get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
if self.options['learn_mask']: if self.options['learn_mask']:
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]]) self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
else: else:
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] ) self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src, pred_dst_dst, pred_src_dst ])
else: else:
if self.options['learn_mask']: if self.options['learn_mask']:
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1] ]) self.AE_convert = K.function ([warped_dst],[ pred_src_dst, pred_dst_dstm, pred_src_dstm ])
else: else:
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ]) self.AE_convert = K.function ([warped_dst],[ pred_src_dst ])
if self.is_training_mode: if self.is_training_mode:
self.src_sample_losses = []
self.dst_sample_losses = []
t = SampleProcessor.Types t = SampleProcessor.Types
face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF
@ -372,46 +464,21 @@ class SAEModel(ModelBase):
random_ct_samples_path=training_data_dst_path if apply_random_ct else None, random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
debug=self.is_debug(), batch_size=self.batch_size, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'apply_ct': apply_random_ct} ] + \ output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'apply_ct': apply_random_ct},
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i), 'apply_ct': apply_random_ct } for i in range(ms_count)] + \ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution, 'apply_ct': apply_random_ct },
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)] {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ]
), ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution} ] + \ output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution},
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i)} for i in range(ms_count)] + \ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution},
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)]) {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution} ])
]) ])
#override #override
def get_model_filename_list(self): def get_model_filename_list(self):
ar = [] ar = self.model.get_model_filename_list ( exclude_for_pretrain=(self.pretrain and self.iter != 0) )
if 'liae' in self.options['archi']:
ar += [[self.encoder, 'encoder.h5'],
[self.inter_B, 'inter_B.h5'],
[self.decoder, 'decoder.h5']
]
if not self.pretrain or self.iter == 0:
ar += [ [self.inter_AB, 'inter_AB.h5'],
]
if self.options['learn_mask']:
ar += [ [self.decoderm, 'decoderm.h5'] ]
elif 'df' in self.options['archi']:
if not self.pretrain or self.iter == 0:
ar += [ [self.encoder, 'encoder.h5'],
]
ar += [ [self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']
]
if self.options['learn_mask']:
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'] ]
return ar return ar
#override #override
@ -420,30 +487,25 @@ class SAEModel(ModelBase):
#override #override
def onTrainOneIter(self, generators_samples, generators_list): def onTrainOneIter(self, generators_samples, generators_list):
src_samples = generators_samples[0] warped_src, target_src, target_srcm = generators_samples[0]
dst_samples = generators_samples[1] warped_dst, target_dst, target_dstm = generators_samples[1]
feed = [src_samples[0], dst_samples[0] ] + \ feed = [warped_src, warped_dst, target_src, target_srcm, target_dst, target_dstm]
src_samples[1:1+self.ms_count*2] + \
dst_samples[1:1+self.ms_count*2]
src_loss, dst_loss, = self.src_dst_train (feed) src_loss, dst_loss, = self.src_dst_train (feed)
if self.options['learn_mask']: if self.options['learn_mask']:
feed = [ src_samples[0], dst_samples[0] ] + \ feed = [ warped_src, warped_dst, target_srcm, target_dstm ]
src_samples[1+self.ms_count:1+self.ms_count*2] + \
dst_samples[1+self.ms_count:1+self.ms_count*2]
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train (feed) src_mask_loss, dst_mask_loss, = self.src_dst_mask_train (feed)
return ( ('src_loss', src_loss), ('dst_loss', dst_loss) ) return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
#override #override
def onGetPreview(self, sample): def onGetPreview(self, sample):
test_S = sample[0][1][0:4] #first 4 samples test_S = sample[0][1][0:4] #first 4 samples
test_S_m = sample[0][1+self.ms_count][0:4] #first 4 samples test_S_m = sample[0][2][0:4] #first 4 samples
test_D = sample[1][1][0:4] test_D = sample[1][1][0:4]
test_D_m = sample[1][1+self.ms_count][0:4] test_D_m = sample[1][2][0:4]
if self.options['learn_mask']: if self.options['learn_mask']:
S, D, SS, DD, DDM, SD, SDM = [ np.clip(x, 0.0, 1.0) for x in ([test_S,test_D] + self.AE_view ([test_S, test_D]) ) ] S, D, SS, DD, DDM, SD, SDM = [ np.clip(x, 0.0, 1.0) for x in ([test_S,test_D] + self.AE_view ([test_S, test_D]) ) ]
@ -453,15 +515,16 @@ class SAEModel(ModelBase):
result = [] result = []
st = [] st = []
for i in range(0, len(test_S)): for i in range(len(test_S)):
ar = S[i], SS[i], D[i], DD[i], SD[i] ar = S[i], SS[i], D[i], DD[i], SD[i]
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAE', np.concatenate (st, axis=0 )), ] result += [ ('SAE', np.concatenate (st, axis=0 )), ]
if self.options['learn_mask']: if self.options['learn_mask']:
st_m = [] st_m = []
for i in range(0, len(test_S)): for i in range(len(test_S)):
ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
st_m.append ( np.concatenate ( ar, axis=1) ) st_m.append ( np.concatenate ( ar, axis=1) )
@ -491,207 +554,4 @@ class SAEModel(ModelBase):
clip_hborder_mask_per=0.0625 if (self.options['face_type'] == 'f') else 0, clip_hborder_mask_per=0.0625 if (self.options['face_type'] == 'f') else 0,
) )
@staticmethod Model = SAEModel
def initialize_nn_functions():
exec (nnlib.import_all(), locals(), globals())
def NormPass(x):
return x
def Norm(norm=''):
if norm == 'bn':
return BatchNormalization(axis=-1)
else:
return NormPass
def Act(act='', lrelu_alpha=0.1):
if act == 'prelu':
return PReLU()
else:
return LeakyReLU(alpha=lrelu_alpha)
class ResidualBlock(object):
def __init__(self, filters, kernel_size=3, padding='zero', norm='', act='', **kwargs):
self.filters = filters
self.kernel_size = kernel_size
self.padding = padding
self.norm = norm
self.act = act
def __call__(self, inp):
x = inp
x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding)(x)
x = Act(self.act, lrelu_alpha=0.2)(x)
x = Norm(self.norm)(x)
x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding)(x)
x = Add()([x, inp])
x = Act(self.act, lrelu_alpha=0.2)(x)
x = Norm(self.norm)(x)
return x
SAEModel.ResidualBlock = ResidualBlock
def downscale (dim, padding='zero', norm='', act='', **kwargs):
def func(x):
return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) )
return func
SAEModel.downscale = downscale
#def downscale (dim, padding='zero', norm='', act='', **kwargs):
# def func(x):
# return BlurPool()( Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=1, padding=padding)(x)) ) )
# return func
#SAEModel.downscale = downscale
def upscale (dim, padding='zero', norm='', act='', **kwargs):
def func(x):
return SubpixelUpscaler()(Norm(norm)(Act(act)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x))))
return func
SAEModel.upscale = upscale
def to_bgr (output_nc, padding='zero', **kwargs):
def func(x):
return Conv2D(output_nc, kernel_size=5, padding=padding, activation='sigmoid')(x)
return func
SAEModel.to_bgr = to_bgr
@staticmethod
def LIAEEncFlow(resolution, ch_dims, **kwargs):
exec (nnlib.import_all(), locals(), globals())
upscale = partial(SAEModel.upscale, **kwargs)
downscale = partial(SAEModel.downscale, **kwargs)
def func(input):
dims = K.int_shape(input)[-1]*ch_dims
x = input
x = downscale(dims)(x)
x = downscale(dims*2)(x)
x = downscale(dims*4)(x)
x = downscale(dims*8)(x)
x = Flatten()(x)
return x
return func
@staticmethod
def LIAEInterFlow(resolution, ae_dims=256, **kwargs):
exec (nnlib.import_all(), locals(), globals())
upscale = partial(SAEModel.upscale, **kwargs)
lowest_dense_res=resolution // 16
def func(input):
x = input[0]
x = Dense(ae_dims)(x)
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
x = upscale(ae_dims*2)(x)
return x
return func
@staticmethod
def LIAEDecFlow(output_nc,ch_dims, multiscale_count=1, add_residual_blocks=False, **kwargs):
exec (nnlib.import_all(), locals(), globals())
upscale = partial(SAEModel.upscale, **kwargs)
to_bgr = partial(SAEModel.to_bgr, **kwargs)
dims = output_nc * ch_dims
ResidualBlock = partial(SAEModel.ResidualBlock, **kwargs)
def func(input):
x = input[0]
outputs = []
x1 = upscale(dims*8)( x )
if add_residual_blocks:
x1 = ResidualBlock(dims*8)(x1)
x1 = ResidualBlock(dims*8)(x1)
if multiscale_count >= 3:
outputs += [ to_bgr(output_nc) ( x1 ) ]
x2 = upscale(dims*4)( x1 )
if add_residual_blocks:
x2 = ResidualBlock(dims*4)(x2)
x2 = ResidualBlock(dims*4)(x2)
if multiscale_count >= 2:
outputs += [ to_bgr(output_nc) ( x2 ) ]
x3 = upscale(dims*2)( x2 )
if add_residual_blocks:
x3 = ResidualBlock( dims*2)(x3)
x3 = ResidualBlock( dims*2)(x3)
outputs += [ to_bgr(output_nc) ( x3 ) ]
return outputs
return func
@staticmethod
def DFEncFlow(resolution, ae_dims, ch_dims, **kwargs):
exec (nnlib.import_all(), locals(), globals())
upscale = partial(SAEModel.upscale, **kwargs)
downscale = partial(SAEModel.downscale, **kwargs)#, kernel_regularizer=keras.regularizers.l2(0.0),
lowest_dense_res = resolution // 16
def func(input):
x = input
dims = K.int_shape(input)[-1]*ch_dims
x = downscale(dims)(x)
x = downscale(dims*2)(x)
x = downscale(dims*4)(x)
x = downscale(dims*8)(x)
x = Dense(ae_dims)(Flatten()(x))
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
x = upscale(ae_dims)(x)
return x
return func
@staticmethod
def DFDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, **kwargs):
exec (nnlib.import_all(), locals(), globals())
upscale = partial(SAEModel.upscale, **kwargs)
to_bgr = partial(SAEModel.to_bgr, **kwargs)
dims = output_nc * ch_dims
ResidualBlock = partial(SAEModel.ResidualBlock, **kwargs)
def func(input):
x = input[0]
outputs = []
x1 = upscale(dims*8)( x )
if add_residual_blocks:
x1 = ResidualBlock( dims*8 )(x1)
x1 = ResidualBlock( dims*8 )(x1)
if multiscale_count >= 3:
outputs += [ to_bgr(output_nc) ( x1 ) ]
x2 = upscale(dims*4)( x1 )
if add_residual_blocks:
x2 = ResidualBlock( dims*4)(x2)
x2 = ResidualBlock( dims*4)(x2)
if multiscale_count >= 2:
outputs += [ to_bgr(output_nc) ( x2 ) ]
x3 = upscale(dims*2)( x2 )
if add_residual_blocks:
x3 = ResidualBlock( dims*2)(x3)
x3 = ResidualBlock( dims*2)(x3)
outputs += [ to_bgr(output_nc) ( x3 ) ]
return outputs
return func
Model = SAEModel

View file

@ -56,11 +56,14 @@ Conv2DTranspose = nnlib.Conv2DTranspose
EqualConv2D = nnlib.EqualConv2D EqualConv2D = nnlib.EqualConv2D
SeparableConv2D = KL.SeparableConv2D SeparableConv2D = KL.SeparableConv2D
MaxPooling2D = KL.MaxPooling2D MaxPooling2D = KL.MaxPooling2D
AveragePooling2D = KL.AveragePooling2D
GlobalAveragePooling2D = KL.GlobalAveragePooling2D
UpSampling2D = KL.UpSampling2D UpSampling2D = KL.UpSampling2D
BatchNormalization = KL.BatchNormalization BatchNormalization = KL.BatchNormalization
PixelNormalization = nnlib.PixelNormalization PixelNormalization = nnlib.PixelNormalization
LeakyReLU = KL.LeakyReLU LeakyReLU = KL.LeakyReLU
ELU = KL.ELU
ReLU = KL.ReLU ReLU = KL.ReLU
PReLU = KL.PReLU PReLU = KL.PReLU
tanh = KL.Activation('tanh') tanh = KL.Activation('tanh')
@ -70,6 +73,7 @@ Softmax = KL.Softmax
Lambda = KL.Lambda Lambda = KL.Lambda
Add = KL.Add Add = KL.Add
Multiply = KL.Multiply
Concatenate = KL.Concatenate Concatenate = KL.Concatenate