mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
converter: now writes a filename of current frame config,
SAE: removed multiscale decoder, because it's not effective
This commit is contained in:
parent
bef4e5d33c
commit
b6b92bded0
6 changed files with 325 additions and 450 deletions
|
@ -74,14 +74,14 @@ class ConverterConfig(object):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def __str__(self):
|
def to_string(self, filename):
|
||||||
r = ""
|
r = ""
|
||||||
r += f"sharpen_mode : {self.sharpen_dict[self.sharpen_mode]}\n"
|
r += f"sharpen_mode : {self.sharpen_dict[self.sharpen_mode]}\n"
|
||||||
if self.sharpen_mode != 0:
|
if self.sharpen_mode != 0:
|
||||||
r += f"sharpen_amount : {self.sharpen_amount}\n"
|
r += f"sharpen_amount : {self.sharpen_amount}\n"
|
||||||
r += f"super_resolution_mode : {self.super_res_dict[self.super_resolution_mode]}\n"
|
r += f"super_resolution_mode : {self.super_res_dict[self.super_resolution_mode]}\n"
|
||||||
return r
|
return r
|
||||||
|
|
||||||
mode_dict = {0:'original',
|
mode_dict = {0:'original',
|
||||||
1:'overlay',
|
1:'overlay',
|
||||||
2:'hist-match',
|
2:'hist-match',
|
||||||
|
@ -249,9 +249,9 @@ class ConverterConfigMasked(ConverterConfig):
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __str__(self):
|
def to_string(self, filename):
|
||||||
r = (
|
r = (
|
||||||
"""ConverterConfig:\n"""
|
f"""ConverterConfig {filename}:\n"""
|
||||||
f"""Mode: {self.mode}\n"""
|
f"""Mode: {self.mode}\n"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -276,7 +276,7 @@ class ConverterConfigMasked(ConverterConfig):
|
||||||
if 'raw' not in self.mode:
|
if 'raw' not in self.mode:
|
||||||
r += f"""color_transfer_mode: { ctm_dict[self.color_transfer_mode]}\n"""
|
r += f"""color_transfer_mode: { ctm_dict[self.color_transfer_mode]}\n"""
|
||||||
|
|
||||||
r += super().__str__()
|
r += super().to_string(filename)
|
||||||
|
|
||||||
if 'raw' not in self.mode:
|
if 'raw' not in self.mode:
|
||||||
r += (f"""color_degrade_power: {self.color_degrade_power}\n"""
|
r += (f"""color_degrade_power: {self.color_degrade_power}\n"""
|
||||||
|
@ -318,8 +318,8 @@ class ConverterConfigFaceAvatar(ConverterConfig):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def __str__(self):
|
def to_string(self, filename):
|
||||||
return ("ConverterConfig: \n"
|
return (f"ConverterConfig {filename}:\n"
|
||||||
f"add_source_image : {self.add_source_image}\n") + \
|
f"add_source_image : {self.add_source_image}\n") + \
|
||||||
super().__str__() + "================"
|
super().to_string(filename) + "================"
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
class FrameInfo(object):
|
class FrameInfo(object):
|
||||||
def __init__(self, filename=None, landmarks_list=None):
|
def __init__(self, filename=None, landmarks_list=None):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
self.filename_short = str(Path(filename).name)
|
||||||
self.landmarks_list = landmarks_list or []
|
self.landmarks_list = landmarks_list or []
|
||||||
self.motion_deg = 0
|
self.motion_deg = 0
|
||||||
self.motion_power = 0
|
self.motion_power = 0
|
|
@ -369,7 +369,7 @@ class ConvertSubprocessor(Subprocessor):
|
||||||
if not cur_frame.is_shown:
|
if not cur_frame.is_shown:
|
||||||
if cur_frame.is_done:
|
if cur_frame.is_done:
|
||||||
cur_frame.is_shown = True
|
cur_frame.is_shown = True
|
||||||
io.log_info (cur_frame.cfg)
|
io.log_info (cur_frame.cfg.to_string( cur_frame.frame_info.filename_short) )
|
||||||
|
|
||||||
if cur_frame.image is None:
|
if cur_frame.image is None:
|
||||||
cur_frame.image = cv2_imread ( cur_frame.output_filename)
|
cur_frame.image = cv2_imread ( cur_frame.output_filename)
|
||||||
|
@ -464,7 +464,7 @@ class ConvertSubprocessor(Subprocessor):
|
||||||
cfg.toggle_sharpen_mode()
|
cfg.toggle_sharpen_mode()
|
||||||
|
|
||||||
if prev_cfg != cfg:
|
if prev_cfg != cfg:
|
||||||
io.log_info (cfg)
|
io.log_info ( cfg.to_string(cur_frame.frame_info.filename_short) )
|
||||||
cur_frame.is_done = False
|
cur_frame.is_done = False
|
||||||
cur_frame.is_shown = False
|
cur_frame.is_shown = False
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -204,6 +204,7 @@ class ModelBase(object):
|
||||||
|
|
||||||
if self.sample_for_preview is None or choose_preview_history:
|
if self.sample_for_preview is None or choose_preview_history:
|
||||||
if choose_preview_history and io.is_support_windows():
|
if choose_preview_history and io.is_support_windows():
|
||||||
|
io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.")
|
||||||
wnd_name = "[p] - next. [enter] - confirm."
|
wnd_name = "[p] - next. [enter] - confirm."
|
||||||
io.named_window(wnd_name)
|
io.named_window(wnd_name)
|
||||||
io.capture_keys(wnd_name)
|
io.capture_keys(wnd_name)
|
||||||
|
@ -411,11 +412,17 @@ class ModelBase(object):
|
||||||
cv2_imwrite (filepath, img )
|
cv2_imwrite (filepath, img )
|
||||||
|
|
||||||
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
||||||
for model, filename in model_filename_list:
|
loaded = []
|
||||||
|
not_loaded = []
|
||||||
|
for mf in model_filename_list:
|
||||||
|
model, filename = mf
|
||||||
filename = self.get_strpath_storage_for_file(filename)
|
filename = self.get_strpath_storage_for_file(filename)
|
||||||
if Path(filename).exists():
|
if Path(filename).exists():
|
||||||
|
loaded += [ mf ]
|
||||||
model.load_weights(filename)
|
model.load_weights(filename)
|
||||||
|
else:
|
||||||
|
not_loaded += [ mf ]
|
||||||
|
|
||||||
if len(optimizer_filename_list) != 0:
|
if len(optimizer_filename_list) != 0:
|
||||||
opt_filename = self.get_strpath_storage_for_file('opt.h5')
|
opt_filename = self.get_strpath_storage_for_file('opt.h5')
|
||||||
if Path(opt_filename).exists():
|
if Path(opt_filename).exists():
|
||||||
|
@ -432,7 +439,8 @@ class ModelBase(object):
|
||||||
print("set ok")
|
print("set ok")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print ("Unable to load ", opt_filename)
|
print ("Unable to load ", opt_filename)
|
||||||
|
|
||||||
|
return loaded, not_loaded
|
||||||
|
|
||||||
def save_weights_safe(self, model_filename_list):
|
def save_weights_safe(self, model_filename_list):
|
||||||
for model, filename in model_filename_list:
|
for model, filename in model_filename_list:
|
||||||
|
|
|
@ -1,26 +1,18 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from nnlib import nnlib
|
import mathlib
|
||||||
from models import ModelBase
|
|
||||||
from facelib import FaceType
|
from facelib import FaceType
|
||||||
from samplelib import *
|
|
||||||
from interact import interact as io
|
from interact import interact as io
|
||||||
|
from models import ModelBase
|
||||||
|
from nnlib import nnlib
|
||||||
|
from samplelib import *
|
||||||
|
|
||||||
|
|
||||||
#SAE - Styled AutoEncoder
|
#SAE - Styled AutoEncoder
|
||||||
class SAEModel(ModelBase):
|
class SAEModel(ModelBase):
|
||||||
|
|
||||||
encoderH5 = 'encoder.h5'
|
|
||||||
inter_BH5 = 'inter_B.h5'
|
|
||||||
inter_ABH5 = 'inter_AB.h5'
|
|
||||||
decoderH5 = 'decoder.h5'
|
|
||||||
decodermH5 = 'decoderm.h5'
|
|
||||||
|
|
||||||
decoder_srcH5 = 'decoder_src.h5'
|
|
||||||
decoder_srcmH5 = 'decoder_srcm.h5'
|
|
||||||
decoder_dstH5 = 'decoder_dst.h5'
|
|
||||||
decoder_dstmH5 = 'decoder_dstm.h5'
|
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onInitializeOptions(self, is_first_run, ask_override):
|
def onInitializeOptions(self, is_first_run, ask_override):
|
||||||
yn_str = {True:'y',False:'n'}
|
yn_str = {True:'y',False:'n'}
|
||||||
|
@ -28,6 +20,7 @@ class SAEModel(ModelBase):
|
||||||
default_resolution = 128
|
default_resolution = 128
|
||||||
default_archi = 'df'
|
default_archi = 'df'
|
||||||
default_face_type = 'f'
|
default_face_type = 'f'
|
||||||
|
default_learn_mask = True
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
resolution = io.input_int("Resolution ( 64-256 ?:help skip:128) : ", default_resolution, help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
|
resolution = io.input_int("Resolution ( 64-256 ?:help skip:128) : ", default_resolution, help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
|
||||||
|
@ -37,12 +30,11 @@ class SAEModel(ModelBase):
|
||||||
self.options['resolution'] = resolution
|
self.options['resolution'] = resolution
|
||||||
|
|
||||||
self.options['face_type'] = io.input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
self.options['face_type'] = io.input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
||||||
self.options['learn_mask'] = io.input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.")
|
self.options['learn_mask'] = io.input_bool ( f"Learn mask? (y/n, ?:help skip:{yn_str[default_learn_mask]} ) : " , default_learn_mask, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.")
|
||||||
else:
|
else:
|
||||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||||
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||||
self.options['learn_mask'] = self.options.get('learn_mask', True)
|
self.options['learn_mask'] = self.options.get('learn_mask', default_learn_mask)
|
||||||
|
|
||||||
|
|
||||||
if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend:
|
if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend:
|
||||||
def_optimizer_mode = self.options.get('optimizer_mode', 1)
|
def_optimizer_mode = self.options.get('optimizer_mode', 1)
|
||||||
|
@ -59,26 +51,24 @@ class SAEModel(ModelBase):
|
||||||
default_e_ch_dims = 42
|
default_e_ch_dims = 42
|
||||||
default_d_ch_dims = default_e_ch_dims // 2
|
default_d_ch_dims = default_e_ch_dims // 2
|
||||||
def_ca_weights = False
|
def_ca_weights = False
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
||||||
self.options['e_ch_dims'] = np.clip ( io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims) , default_e_ch_dims, help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 )
|
self.options['e_ch_dims'] = np.clip ( io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims) , default_e_ch_dims, help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 )
|
||||||
default_d_ch_dims = self.options['e_ch_dims'] // 2
|
default_d_ch_dims = self.options['e_ch_dims'] // 2
|
||||||
self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 )
|
self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 )
|
||||||
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
|
self.options['ca_weights'] = io.input_bool (f"Use CA weights? (y/n, ?:help skip:{yn_str[def_ca_weights]} ) : ", def_ca_weights, help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes a time at first run.")
|
||||||
self.options['ca_weights'] = io.input_bool ("Use CA weights? (y/n, ?:help skip: %s ) : " % (yn_str[def_ca_weights]), def_ca_weights, help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes a time at first run.")
|
|
||||||
else:
|
else:
|
||||||
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
||||||
self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims)
|
self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims)
|
||||||
self.options['d_ch_dims'] = self.options.get('d_ch_dims', default_d_ch_dims)
|
self.options['d_ch_dims'] = self.options.get('d_ch_dims', default_d_ch_dims)
|
||||||
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
|
|
||||||
self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights)
|
self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights)
|
||||||
|
|
||||||
default_face_style_power = 0.0
|
default_face_style_power = 0.0
|
||||||
default_bg_style_power = 0.0
|
default_bg_style_power = 0.0
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
def_pixel_loss = self.options.get('pixel_loss', False)
|
def_pixel_loss = self.options.get('pixel_loss', False)
|
||||||
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: %s ) : " % (yn_str[def_pixel_loss]), def_pixel_loss, help_message="Pixel loss may help to enhance fine details and stabilize face color. Use it only if quality does not improve over time. Enabling this option too early increases the chance of model collapse.")
|
self.options['pixel_loss'] = io.input_bool (f"Use pixel loss? (y/n, ?:help skip:{yn_str[def_pixel_loss]} ) : ", def_pixel_loss, help_message="Pixel loss may help to enhance fine details and stabilize face color. Use it only if quality does not improve over time. Enabling this option too early increases the chance of model collapse.")
|
||||||
|
|
||||||
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
||||||
self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power,
|
self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power,
|
||||||
|
@ -89,11 +79,11 @@ class SAEModel(ModelBase):
|
||||||
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
|
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
|
||||||
|
|
||||||
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
||||||
self.options['apply_random_ct'] = io.input_bool ("Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (yn_str[default_apply_random_ct]), default_apply_random_ct, help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
|
self.options['apply_random_ct'] = io.input_bool (f"Apply random color transfer to src faceset? (y/n, ?:help skip:{yn_str[default_apply_random_ct]}) : ", default_apply_random_ct, help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
|
||||||
|
|
||||||
if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301
|
if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301
|
||||||
default_clipgrad = False if is_first_run else self.options.get('clipgrad', False)
|
default_clipgrad = False if is_first_run else self.options.get('clipgrad', False)
|
||||||
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping? (y/n, ?:help skip:%s) : " % (yn_str[default_clipgrad]), default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
self.options['clipgrad'] = io.input_bool (f"Enable gradient clipping? (y/n, ?:help skip:{yn_str[default_clipgrad]}) : ", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||||
else:
|
else:
|
||||||
self.options['clipgrad'] = False
|
self.options['clipgrad'] = False
|
||||||
|
|
||||||
|
@ -112,10 +102,11 @@ class SAEModel(ModelBase):
|
||||||
#override
|
#override
|
||||||
def onInitialize(self):
|
def onInitialize(self):
|
||||||
exec(nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
SAEModel.initialize_nn_functions()
|
|
||||||
self.set_vram_batch_requirements({1.5:4})
|
self.set_vram_batch_requirements({1.5:4})
|
||||||
|
|
||||||
resolution = self.options['resolution']
|
resolution = self.options['resolution']
|
||||||
|
learn_mask = self.options['learn_mask']
|
||||||
|
|
||||||
ae_dims = self.options['ae_dims']
|
ae_dims = self.options['ae_dims']
|
||||||
e_ch_dims = self.options['e_ch_dims']
|
e_ch_dims = self.options['e_ch_dims']
|
||||||
d_ch_dims = self.options['d_ch_dims']
|
d_ch_dims = self.options['d_ch_dims']
|
||||||
|
@ -127,232 +118,333 @@ class SAEModel(ModelBase):
|
||||||
bgr_shape = (resolution, resolution, 3)
|
bgr_shape = (resolution, resolution, 3)
|
||||||
mask_shape = (resolution, resolution, 1)
|
mask_shape = (resolution, resolution, 1)
|
||||||
|
|
||||||
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
|
||||||
|
|
||||||
apply_random_ct = self.options.get('apply_random_ct', False)
|
apply_random_ct = self.options.get('apply_random_ct', False)
|
||||||
masked_training = True
|
masked_training = True
|
||||||
|
|
||||||
warped_src = Input(bgr_shape)
|
class SAEDFModel(object):
|
||||||
target_src = Input(bgr_shape)
|
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
|
||||||
target_srcm = Input(mask_shape)
|
super().__init__()
|
||||||
|
self.learn_mask = learn_mask
|
||||||
|
|
||||||
warped_dst = Input(bgr_shape)
|
output_nc = 3
|
||||||
target_dst = Input(bgr_shape)
|
bgr_shape = (resolution, resolution, output_nc)
|
||||||
target_dstm = Input(mask_shape)
|
mask_shape = (resolution, resolution, 1)
|
||||||
|
lowest_dense_res = resolution // 16
|
||||||
|
e_dims = output_nc*e_ch_dims
|
||||||
|
|
||||||
target_src_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
def upscale (dim):
|
||||||
target_srcm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
def func(x):
|
||||||
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x)))
|
||||||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
return func
|
||||||
|
|
||||||
common_flow_kwargs = { 'padding': 'zero',
|
def enc_flow(e_dims, ae_dims, lowest_dense_res):
|
||||||
'norm': '',
|
def func(x):
|
||||||
'act':'' }
|
x = LeakyReLU(0.1)(Conv2D(e_dims, kernel_size=5, strides=2, padding='same')(x))
|
||||||
models_list = []
|
x = LeakyReLU(0.1)(Conv2D(e_dims*2, kernel_size=5, strides=2, padding='same')(x))
|
||||||
weights_to_load = []
|
x = LeakyReLU(0.1)(Conv2D(e_dims*4, kernel_size=5, strides=2, padding='same')(x))
|
||||||
if 'liae' in self.options['archi']:
|
x = LeakyReLU(0.1)(Conv2D(e_dims*8, kernel_size=5, strides=2, padding='same')(x))
|
||||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
|
||||||
|
|
||||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
x = Dense(ae_dims)(Flatten()(x))
|
||||||
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||||
|
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
||||||
|
x = upscale(ae_dims)(x)
|
||||||
|
return x
|
||||||
|
return func
|
||||||
|
|
||||||
self.inter_B = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, **common_flow_kwargs)) (enc_output_Inputs)
|
def dec_flow(output_nc, d_ch_dims):
|
||||||
self.inter_AB = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, **common_flow_kwargs)) (enc_output_Inputs)
|
def ResidualBlock(dim):
|
||||||
|
def func(inp):
|
||||||
|
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
|
||||||
|
x = LeakyReLU(0.2)(x)
|
||||||
|
x = Conv2D(dim, kernel_size=3, padding='same')(x)
|
||||||
|
x = Add()([x, inp])
|
||||||
|
x = LeakyReLU(0.2)(x)
|
||||||
|
return x
|
||||||
|
return func
|
||||||
|
|
||||||
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
|
def func(x):
|
||||||
|
dims = output_nc * d_ch_dims
|
||||||
|
x = upscale(dims*8)(x)
|
||||||
|
x = ResidualBlock(dims*8)(x)
|
||||||
|
x = ResidualBlock(dims*8)(x)
|
||||||
|
|
||||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
x = upscale(dims*4)(x)
|
||||||
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
x = ResidualBlock(dims*4)(x)
|
||||||
|
x = ResidualBlock(dims*4)(x)
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
x = upscale(dims*2)(x)
|
||||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
x = ResidualBlock(dims*2)(x)
|
||||||
models_list += [self.decoderm]
|
x = ResidualBlock(dims*2)(x)
|
||||||
|
|
||||||
if not self.is_first_run():
|
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
return func
|
||||||
[self.inter_B , 'inter_B.h5'],
|
|
||||||
[self.inter_AB, 'inter_AB.h5'],
|
|
||||||
[self.decoder , 'decoder.h5'],
|
|
||||||
]
|
|
||||||
if self.options['learn_mask']:
|
|
||||||
weights_to_load += [ [self.decoderm, 'decoderm.h5'] ]
|
|
||||||
|
|
||||||
warped_src_code = self.encoder (warped_src)
|
self.encoder = modelify(enc_flow(e_dims, ae_dims, lowest_dense_res)) ( Input(bgr_shape) )
|
||||||
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
|
||||||
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
|
||||||
|
|
||||||
warped_dst_code = self.encoder (warped_dst)
|
sh = K.int_shape( self.encoder.outputs[0] )[1:]
|
||||||
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
self.decoder_src = modelify(dec_flow(output_nc, d_ch_dims)) ( Input(sh) )
|
||||||
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
|
self.decoder_dst = modelify(dec_flow(output_nc, d_ch_dims)) ( Input(sh) )
|
||||||
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
|
||||||
|
|
||||||
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
if learn_mask:
|
||||||
|
self.decoder_srcm = modelify(dec_flow(1, d_ch_dims)) ( Input(sh) )
|
||||||
|
self.decoder_dstm = modelify(dec_flow(1, d_ch_dims)) ( Input(sh) )
|
||||||
|
|
||||||
pred_src_src = self.decoder(warped_src_inter_code)
|
self.src_dst_trainable_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
|
||||||
pred_dst_dst = self.decoder(warped_dst_inter_code)
|
|
||||||
pred_src_dst = self.decoder(warped_src_dst_inter_code)
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if learn_mask:
|
||||||
pred_src_srcm = self.decoderm(warped_src_inter_code)
|
self.src_dst_mask_trainable_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
||||||
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
|
||||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
|
||||||
|
|
||||||
elif 'df' in self.options['archi']:
|
self.warped_src, self.warped_dst = Input(bgr_shape), Input(bgr_shape)
|
||||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
src_code, dst_code = self.encoder(self.warped_src), self.encoder(self.warped_dst)
|
||||||
|
|
||||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
self.pred_src_src = self.decoder_src(src_code)
|
||||||
|
self.pred_dst_dst = self.decoder_dst(dst_code)
|
||||||
|
self.pred_src_dst = self.decoder_src(dst_code)
|
||||||
|
|
||||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
if learn_mask:
|
||||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
self.pred_src_srcm = self.decoder_srcm(src_code)
|
||||||
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
self.pred_dst_dstm = self.decoder_dstm(dst_code)
|
||||||
|
self.pred_src_dstm = self.decoder_srcm(dst_code)
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
def get_model_filename_list(self, exclude_for_pretrain=False):
|
||||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
ar = []
|
||||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
if not exclude_for_pretrain:
|
||||||
models_list += [self.decoder_srcm, self.decoder_dstm]
|
ar += [ [self.encoder, 'encoder.h5'] ]
|
||||||
|
ar += [ [self.decoder_src, 'decoder_src.h5'],
|
||||||
|
[self.decoder_dst, 'decoder_dst.h5'] ]
|
||||||
|
if self.learn_mask:
|
||||||
|
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||||
|
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
||||||
|
return ar
|
||||||
|
|
||||||
if not self.is_first_run():
|
class SAELIAEModel(object):
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
|
||||||
[self.decoder_src, 'decoder_src.h5'],
|
super().__init__()
|
||||||
[self.decoder_dst, 'decoder_dst.h5']
|
self.learn_mask = learn_mask
|
||||||
]
|
|
||||||
if self.options['learn_mask']:
|
|
||||||
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
|
||||||
[self.decoder_dstm, 'decoder_dstm.h5'],
|
|
||||||
]
|
|
||||||
|
|
||||||
warped_src_code = self.encoder (warped_src)
|
output_nc = 3
|
||||||
warped_dst_code = self.encoder (warped_dst)
|
bgr_shape = (resolution, resolution, output_nc)
|
||||||
pred_src_src = self.decoder_src(warped_src_code)
|
mask_shape = (resolution, resolution, 1)
|
||||||
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
|
||||||
pred_src_dst = self.decoder_src(warped_dst_code)
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
e_dims = output_nc*e_ch_dims
|
||||||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
d_dims = output_nc*d_ch_dims
|
||||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
lowest_dense_res = resolution // 16
|
||||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
|
||||||
|
|
||||||
if self.is_first_run():
|
def upscale (dim):
|
||||||
if self.options.get('ca_weights',False):
|
def func(x):
|
||||||
conv_weights_list = []
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x)))
|
||||||
for model in models_list:
|
return func
|
||||||
for layer in model.layers:
|
|
||||||
if type(layer) == keras.layers.Conv2D:
|
|
||||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
|
||||||
CAInitializerMP ( conv_weights_list )
|
|
||||||
else:
|
|
||||||
self.load_weights_safe(weights_to_load)
|
|
||||||
|
|
||||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
def enc_flow(e_dims):
|
||||||
|
def func(x):
|
||||||
|
x = LeakyReLU(0.1)(Conv2D(e_dims, kernel_size=5, strides=2, padding='same')(x))
|
||||||
|
x = LeakyReLU(0.1)(Conv2D(e_dims*2, kernel_size=5, strides=2, padding='same')(x))
|
||||||
|
x = LeakyReLU(0.1)(Conv2D(e_dims*4, kernel_size=5, strides=2, padding='same')(x))
|
||||||
|
x = LeakyReLU(0.1)(Conv2D(e_dims*8, kernel_size=5, strides=2, padding='same')(x))
|
||||||
|
x = Flatten()(x)
|
||||||
|
return x
|
||||||
|
return func
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
def inter_flow(lowest_dense_res, ae_dims):
|
||||||
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ]
|
def func(x):
|
||||||
|
x = Dense(ae_dims)(x)
|
||||||
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
|
||||||
|
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
|
||||||
|
x = upscale(ae_dims*2)(x)
|
||||||
|
return x
|
||||||
|
return func
|
||||||
|
|
||||||
target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar]
|
def dec_flow(output_nc, d_dims):
|
||||||
target_srcm_sigm_ar = target_srcm_blurred_ar #[ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
|
def ResidualBlock(dim):
|
||||||
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
|
def func(inp):
|
||||||
|
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
|
||||||
|
x = LeakyReLU(0.2)(x)
|
||||||
|
x = Conv2D(dim, kernel_size=3, padding='same')(x)
|
||||||
|
x = Add()([x, inp])
|
||||||
|
x = LeakyReLU(0.2)(x)
|
||||||
|
return x
|
||||||
|
return func
|
||||||
|
|
||||||
target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar]
|
def func(x):
|
||||||
target_dstm_sigm_ar = target_dstm_blurred_ar#[ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
|
x = upscale(d_dims*8)(x)
|
||||||
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
|
x = ResidualBlock(d_dims*8)(x)
|
||||||
|
x = ResidualBlock(d_dims*8)(x)
|
||||||
|
|
||||||
target_src_sigm_ar = target_src_ar#[ x + 1 for x in target_src_ar]
|
x = upscale(d_dims*4)(x)
|
||||||
target_dst_sigm_ar = target_dst_ar#[ x + 1 for x in target_dst_ar]
|
x = ResidualBlock(d_dims*4)(x)
|
||||||
|
x = ResidualBlock(d_dims*4)(x)
|
||||||
|
|
||||||
pred_src_src_sigm_ar = pred_src_src#[ x + 1 for x in pred_src_src]
|
x = upscale(d_dims*2)(x)
|
||||||
pred_dst_dst_sigm_ar = pred_dst_dst#[ x + 1 for x in pred_dst_dst]
|
x = ResidualBlock(d_dims*2)(x)
|
||||||
pred_src_dst_sigm_ar = pred_src_dst#[ x + 1 for x in pred_src_dst]
|
x = ResidualBlock(d_dims*2)(x)
|
||||||
|
|
||||||
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||||
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
return func
|
||||||
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
|
||||||
|
|
||||||
pred_src_src_masked_ar = [ pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] for i in range(len(pred_src_src_sigm_ar))]
|
self.encoder = modelify(enc_flow(e_dims)) ( Input(bgr_shape) )
|
||||||
pred_dst_dst_masked_ar = [ pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in range(len(pred_dst_dst_sigm_ar))]
|
|
||||||
|
|
||||||
target_src_masked_ar_opt = target_src_masked_ar if masked_training else target_src_sigm_ar
|
sh = K.int_shape( self.encoder.outputs[0] )[1:]
|
||||||
target_dst_masked_ar_opt = target_dst_masked_ar if masked_training else target_dst_sigm_ar
|
self.inter_B = modelify(inter_flow(lowest_dense_res, ae_dims)) ( Input(sh) )
|
||||||
|
self.inter_AB = modelify(inter_flow(lowest_dense_res, ae_dims)) ( Input(sh) )
|
||||||
|
|
||||||
pred_src_src_masked_ar_opt = pred_src_src_masked_ar if masked_training else pred_src_src_sigm_ar
|
sh = np.array(K.int_shape( self.inter_B.outputs[0] )[1:])*(1,1,2)
|
||||||
pred_dst_dst_masked_ar_opt = pred_dst_dst_masked_ar if masked_training else pred_dst_dst_sigm_ar
|
self.decoder = modelify(dec_flow(output_nc, d_dims)) ( Input(sh) )
|
||||||
|
|
||||||
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
if learn_mask:
|
||||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
self.decoderm = modelify(dec_flow(1, d_dims)) ( Input(sh) )
|
||||||
|
|
||||||
|
self.src_dst_trainable_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
|
|
||||||
|
if learn_mask:
|
||||||
|
self.src_dst_mask_trainable_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||||
|
|
||||||
|
self.warped_src, self.warped_dst = Input(bgr_shape), Input(bgr_shape)
|
||||||
|
|
||||||
|
warped_src_code = self.encoder (self.warped_src)
|
||||||
|
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||||
|
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||||
|
|
||||||
|
warped_dst_code = self.encoder (self.warped_dst)
|
||||||
|
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
||||||
|
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
|
||||||
|
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
||||||
|
|
||||||
|
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
||||||
|
|
||||||
|
self.pred_src_src = self.decoder(warped_src_inter_code)
|
||||||
|
self.pred_dst_dst = self.decoder(warped_dst_inter_code)
|
||||||
|
self.pred_src_dst = self.decoder(warped_src_dst_inter_code)
|
||||||
|
|
||||||
|
if learn_mask:
|
||||||
|
self.pred_src_srcm = self.decoderm(warped_src_inter_code)
|
||||||
|
self.pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
||||||
|
self.pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||||
|
|
||||||
|
def get_model_filename_list(self, exclude_for_pretrain=False):
|
||||||
|
ar = [ [self.encoder, 'encoder.h5'],
|
||||||
|
[self.inter_B, 'inter_B.h5'] ]
|
||||||
|
|
||||||
|
if not exclude_for_pretrain:
|
||||||
|
ar += [ [self.inter_AB, 'inter_AB.h5'] ]
|
||||||
|
|
||||||
|
ar += [ [self.decoder, 'decoder.h5'] ]
|
||||||
|
|
||||||
|
if self.learn_mask:
|
||||||
|
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
||||||
|
|
||||||
|
return ar
|
||||||
|
|
||||||
|
if 'df' in self.options['archi']:
|
||||||
|
self.model = SAEDFModel (resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask)
|
||||||
|
elif 'liae' in self.options['archi']:
|
||||||
|
self.model = SAELIAEModel (resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask)
|
||||||
|
|
||||||
|
loaded, not_loaded = [], self.model.get_model_filename_list()
|
||||||
|
if not self.is_first_run():
|
||||||
|
loaded, not_loaded = self.load_weights_safe(not_loaded)
|
||||||
|
|
||||||
|
CA_models = []
|
||||||
|
if self.options.get('ca_weights', False):
|
||||||
|
CA_models += [ model for model, _ in not_loaded ]
|
||||||
|
|
||||||
|
CA_conv_weights_list = []
|
||||||
|
for model in CA_models:
|
||||||
|
for layer in model.layers:
|
||||||
|
if type(layer) == keras.layers.Conv2D:
|
||||||
|
CA_conv_weights_list += [layer.weights[0]] #- is Conv2D kernel_weights
|
||||||
|
|
||||||
|
if len(CA_conv_weights_list) != 0:
|
||||||
|
CAInitializerMP ( CA_conv_weights_list )
|
||||||
|
|
||||||
|
warped_src = self.model.warped_src
|
||||||
|
target_src = Input ( (resolution, resolution, 3) )
|
||||||
|
target_srcm = Input ( (resolution, resolution, 1) )
|
||||||
|
|
||||||
|
warped_dst = self.model.warped_dst
|
||||||
|
target_dst = Input ( (resolution, resolution, 3) )
|
||||||
|
target_dstm = Input ( (resolution, resolution, 1) )
|
||||||
|
|
||||||
|
target_src_sigm = target_src
|
||||||
|
target_dst_sigm = target_dst
|
||||||
|
|
||||||
|
target_srcm_sigm = gaussian_blur( max(1, K.int_shape(target_srcm)[1] // 32) )(target_srcm)
|
||||||
|
target_dstm_sigm = gaussian_blur( max(1, K.int_shape(target_dstm)[1] // 32) )(target_dstm)
|
||||||
|
target_dstm_anti_sigm = 1.0 - target_dstm_sigm
|
||||||
|
|
||||||
|
target_src_masked = target_src_sigm*target_srcm_sigm
|
||||||
|
target_dst_masked = target_dst_sigm*target_dstm_sigm
|
||||||
|
target_dst_anti_masked = target_dst_sigm*target_dstm_anti_sigm
|
||||||
|
|
||||||
|
target_src_masked_opt = target_src_masked if masked_training else target_src_sigm
|
||||||
|
target_dst_masked_opt = target_dst_masked if masked_training else target_dst_sigm
|
||||||
|
|
||||||
|
pred_src_src = self.model.pred_src_src
|
||||||
|
pred_dst_dst = self.model.pred_dst_dst
|
||||||
|
pred_src_dst = self.model.pred_src_dst
|
||||||
|
if learn_mask:
|
||||||
|
pred_src_srcm = self.model.pred_src_srcm
|
||||||
|
pred_dst_dstm = self.model.pred_dst_dstm
|
||||||
|
pred_src_dstm = self.model.pred_src_dstm
|
||||||
|
|
||||||
|
pred_src_src_sigm = self.model.pred_src_src
|
||||||
|
pred_dst_dst_sigm = self.model.pred_dst_dst
|
||||||
|
pred_src_dst_sigm = self.model.pred_src_dst
|
||||||
|
|
||||||
|
pred_src_src_masked = pred_src_src_sigm*target_srcm_sigm
|
||||||
|
pred_dst_dst_masked = pred_dst_dst_sigm*target_dstm_sigm
|
||||||
|
|
||||||
|
pred_src_src_masked_opt = pred_src_src_masked if masked_training else pred_src_src_sigm
|
||||||
|
pred_dst_dst_masked_opt = pred_dst_dst_masked if masked_training else pred_dst_dst_sigm
|
||||||
|
|
||||||
|
psd_target_dst_masked = pred_src_dst_sigm*target_dstm_sigm
|
||||||
|
psd_target_dst_anti_masked = pred_src_dst_sigm*target_dstm_anti_sigm
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
|
self.sr_opt = Adam(lr=5e-5, beta_1=0.9, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
if 'liae' in self.options['archi']:
|
|
||||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
|
||||||
if self.options['learn_mask']:
|
|
||||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
|
||||||
else:
|
|
||||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
|
|
||||||
if self.options['learn_mask']:
|
|
||||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
src_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i]) for i in range(len(target_src_masked_ar_opt)) ])
|
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
|
||||||
else:
|
else:
|
||||||
src_loss_batch = sum([ K.mean ( 50*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
|
src_loss = K.mean ( 50*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
|
||||||
|
|
||||||
src_loss = K.mean(src_loss_batch)
|
|
||||||
|
|
||||||
face_style_power = self.options['face_style_power'] / 100.0
|
|
||||||
|
|
||||||
|
face_style_power = self.options['face_style_power'] / 100.0
|
||||||
if face_style_power != 0:
|
if face_style_power != 0:
|
||||||
src_loss += style_loss(gaussian_blur_radius=resolution//16, loss_weight=face_style_power, wnd_size=0)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
|
src_loss += style_loss(gaussian_blur_radius=resolution//16, loss_weight=face_style_power, wnd_size=0)( psd_target_dst_masked, target_dst_masked )
|
||||||
|
|
||||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||||
if bg_style_power != 0:
|
if bg_style_power != 0:
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
bg_loss = K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))
|
src_loss += K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))
|
||||||
else:
|
else:
|
||||||
bg_loss = K.mean( (50*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
src_loss += K.mean( (50*bg_style_power)*K.square( psd_target_dst_anti_masked - target_dst_anti_masked ))
|
||||||
src_loss += bg_loss
|
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
dst_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i]) for i in range(len(target_dst_masked_ar_opt)) ])
|
dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) )
|
||||||
else:
|
else:
|
||||||
dst_loss_batch = sum([ K.mean ( 50*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
|
dst_loss = K.mean( 50*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
|
||||||
|
|
||||||
dst_loss = K.mean(dst_loss_batch)
|
self.src_dst_train = K.function ([warped_src, warped_dst, target_src, target_srcm, target_dst, target_dstm],[src_loss,dst_loss], self.src_dst_opt.get_updates(src_loss+dst_loss, self.model.src_dst_trainable_weights) )
|
||||||
|
|
||||||
feed = [warped_src, warped_dst]
|
|
||||||
feed += target_src_ar[::-1]
|
|
||||||
feed += target_srcm_ar[::-1]
|
|
||||||
feed += target_dst_ar[::-1]
|
|
||||||
feed += target_dstm_ar[::-1]
|
|
||||||
|
|
||||||
self.src_dst_train = K.function (feed,[src_loss,dst_loss], self.src_dst_opt.get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[-1]-pred_src_srcm[-1])) for i in range(len(target_srcm_ar)) ])
|
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
|
||||||
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[-1]-pred_dst_dstm[-1])) for i in range(len(target_dstm_ar)) ])
|
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
|
||||||
|
self.src_dst_mask_train = K.function ([warped_src, warped_dst, target_srcm, target_dstm],[src_mask_loss, dst_mask_loss], self.src_dst_mask_opt.get_updates(src_mask_loss+dst_mask_loss, self.model.src_dst_mask_trainable_weights ) )
|
||||||
feed = [ warped_src, warped_dst]
|
|
||||||
feed += target_srcm_ar[::-1]
|
|
||||||
feed += target_dstm_ar[::-1]
|
|
||||||
|
|
||||||
self.src_dst_mask_train = K.function (feed,[src_mask_loss, dst_mask_loss], self.src_dst_mask_opt.get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]])
|
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
|
||||||
else:
|
else:
|
||||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
|
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src, pred_dst_dst, pred_src_dst ])
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1] ])
|
self.AE_convert = K.function ([warped_dst],[ pred_src_dst, pred_dst_dstm, pred_src_dstm ])
|
||||||
else:
|
else:
|
||||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ])
|
self.AE_convert = K.function ([warped_dst],[ pred_src_dst ])
|
||||||
|
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
self.src_sample_losses = []
|
|
||||||
self.dst_sample_losses = []
|
|
||||||
|
|
||||||
t = SampleProcessor.Types
|
t = SampleProcessor.Types
|
||||||
face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF
|
face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF
|
||||||
|
|
||||||
|
@ -372,46 +464,21 @@ class SAEModel(ModelBase):
|
||||||
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'apply_ct': apply_random_ct} ] + \
|
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'apply_ct': apply_random_ct},
|
||||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i), 'apply_ct': apply_random_ct } for i in range(ms_count)] + \
|
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution, 'apply_ct': apply_random_ct },
|
||||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)]
|
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ]
|
||||||
),
|
),
|
||||||
|
|
||||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution} ] + \
|
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution},
|
||||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i)} for i in range(ms_count)] + \
|
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution},
|
||||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)])
|
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution} ])
|
||||||
])
|
])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def get_model_filename_list(self):
|
def get_model_filename_list(self):
|
||||||
ar = []
|
ar = self.model.get_model_filename_list ( exclude_for_pretrain=(self.pretrain and self.iter != 0) )
|
||||||
if 'liae' in self.options['archi']:
|
|
||||||
ar += [[self.encoder, 'encoder.h5'],
|
|
||||||
[self.inter_B, 'inter_B.h5'],
|
|
||||||
[self.decoder, 'decoder.h5']
|
|
||||||
]
|
|
||||||
|
|
||||||
if not self.pretrain or self.iter == 0:
|
|
||||||
ar += [ [self.inter_AB, 'inter_AB.h5'],
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
|
||||||
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
|
||||||
|
|
||||||
elif 'df' in self.options['archi']:
|
|
||||||
if not self.pretrain or self.iter == 0:
|
|
||||||
ar += [ [self.encoder, 'encoder.h5'],
|
|
||||||
]
|
|
||||||
|
|
||||||
ar += [ [self.decoder_src, 'decoder_src.h5'],
|
|
||||||
[self.decoder_dst, 'decoder_dst.h5']
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
|
||||||
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
|
||||||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
|
||||||
return ar
|
return ar
|
||||||
|
|
||||||
#override
|
#override
|
||||||
|
@ -420,30 +487,25 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self, generators_samples, generators_list):
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
src_samples = generators_samples[0]
|
warped_src, target_src, target_srcm = generators_samples[0]
|
||||||
dst_samples = generators_samples[1]
|
warped_dst, target_dst, target_dstm = generators_samples[1]
|
||||||
|
|
||||||
feed = [src_samples[0], dst_samples[0] ] + \
|
feed = [warped_src, warped_dst, target_src, target_srcm, target_dst, target_dstm]
|
||||||
src_samples[1:1+self.ms_count*2] + \
|
|
||||||
dst_samples[1:1+self.ms_count*2]
|
|
||||||
|
|
||||||
src_loss, dst_loss, = self.src_dst_train (feed)
|
src_loss, dst_loss, = self.src_dst_train (feed)
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
feed = [ src_samples[0], dst_samples[0] ] + \
|
feed = [ warped_src, warped_dst, target_srcm, target_dstm ]
|
||||||
src_samples[1+self.ms_count:1+self.ms_count*2] + \
|
|
||||||
dst_samples[1+self.ms_count:1+self.ms_count*2]
|
|
||||||
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train (feed)
|
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train (feed)
|
||||||
|
|
||||||
return ( ('src_loss', src_loss), ('dst_loss', dst_loss) )
|
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
|
||||||
|
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, sample):
|
def onGetPreview(self, sample):
|
||||||
test_S = sample[0][1][0:4] #first 4 samples
|
test_S = sample[0][1][0:4] #first 4 samples
|
||||||
test_S_m = sample[0][1+self.ms_count][0:4] #first 4 samples
|
test_S_m = sample[0][2][0:4] #first 4 samples
|
||||||
test_D = sample[1][1][0:4]
|
test_D = sample[1][1][0:4]
|
||||||
test_D_m = sample[1][1+self.ms_count][0:4]
|
test_D_m = sample[1][2][0:4]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
S, D, SS, DD, DDM, SD, SDM = [ np.clip(x, 0.0, 1.0) for x in ([test_S,test_D] + self.AE_view ([test_S, test_D]) ) ]
|
S, D, SS, DD, DDM, SD, SDM = [ np.clip(x, 0.0, 1.0) for x in ([test_S,test_D] + self.AE_view ([test_S, test_D]) ) ]
|
||||||
|
@ -453,15 +515,16 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
st = []
|
st = []
|
||||||
for i in range(0, len(test_S)):
|
for i in range(len(test_S)):
|
||||||
ar = S[i], SS[i], D[i], DD[i], SD[i]
|
ar = S[i], SS[i], D[i], DD[i], SD[i]
|
||||||
|
|
||||||
st.append ( np.concatenate ( ar, axis=1) )
|
st.append ( np.concatenate ( ar, axis=1) )
|
||||||
|
|
||||||
result += [ ('SAE', np.concatenate (st, axis=0 )), ]
|
result += [ ('SAE', np.concatenate (st, axis=0 )), ]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
st_m = []
|
st_m = []
|
||||||
for i in range(0, len(test_S)):
|
for i in range(len(test_S)):
|
||||||
ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
|
ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
|
||||||
st_m.append ( np.concatenate ( ar, axis=1) )
|
st_m.append ( np.concatenate ( ar, axis=1) )
|
||||||
|
|
||||||
|
@ -491,207 +554,4 @@ class SAEModel(ModelBase):
|
||||||
clip_hborder_mask_per=0.0625 if (self.options['face_type'] == 'f') else 0,
|
clip_hborder_mask_per=0.0625 if (self.options['face_type'] == 'f') else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
Model = SAEModel
|
||||||
def initialize_nn_functions():
|
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
|
||||||
|
|
||||||
def NormPass(x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
def Norm(norm=''):
|
|
||||||
if norm == 'bn':
|
|
||||||
return BatchNormalization(axis=-1)
|
|
||||||
else:
|
|
||||||
return NormPass
|
|
||||||
|
|
||||||
def Act(act='', lrelu_alpha=0.1):
|
|
||||||
if act == 'prelu':
|
|
||||||
return PReLU()
|
|
||||||
else:
|
|
||||||
return LeakyReLU(alpha=lrelu_alpha)
|
|
||||||
|
|
||||||
class ResidualBlock(object):
|
|
||||||
def __init__(self, filters, kernel_size=3, padding='zero', norm='', act='', **kwargs):
|
|
||||||
self.filters = filters
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.padding = padding
|
|
||||||
self.norm = norm
|
|
||||||
self.act = act
|
|
||||||
|
|
||||||
def __call__(self, inp):
|
|
||||||
x = inp
|
|
||||||
x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding)(x)
|
|
||||||
x = Act(self.act, lrelu_alpha=0.2)(x)
|
|
||||||
x = Norm(self.norm)(x)
|
|
||||||
x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding)(x)
|
|
||||||
x = Add()([x, inp])
|
|
||||||
x = Act(self.act, lrelu_alpha=0.2)(x)
|
|
||||||
x = Norm(self.norm)(x)
|
|
||||||
return x
|
|
||||||
SAEModel.ResidualBlock = ResidualBlock
|
|
||||||
|
|
||||||
def downscale (dim, padding='zero', norm='', act='', **kwargs):
|
|
||||||
def func(x):
|
|
||||||
return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) )
|
|
||||||
return func
|
|
||||||
SAEModel.downscale = downscale
|
|
||||||
|
|
||||||
#def downscale (dim, padding='zero', norm='', act='', **kwargs):
|
|
||||||
# def func(x):
|
|
||||||
# return BlurPool()( Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=1, padding=padding)(x)) ) )
|
|
||||||
# return func
|
|
||||||
#SAEModel.downscale = downscale
|
|
||||||
|
|
||||||
def upscale (dim, padding='zero', norm='', act='', **kwargs):
|
|
||||||
def func(x):
|
|
||||||
return SubpixelUpscaler()(Norm(norm)(Act(act)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x))))
|
|
||||||
return func
|
|
||||||
SAEModel.upscale = upscale
|
|
||||||
|
|
||||||
def to_bgr (output_nc, padding='zero', **kwargs):
|
|
||||||
def func(x):
|
|
||||||
return Conv2D(output_nc, kernel_size=5, padding=padding, activation='sigmoid')(x)
|
|
||||||
return func
|
|
||||||
SAEModel.to_bgr = to_bgr
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def LIAEEncFlow(resolution, ch_dims, **kwargs):
|
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
|
||||||
downscale = partial(SAEModel.downscale, **kwargs)
|
|
||||||
|
|
||||||
def func(input):
|
|
||||||
dims = K.int_shape(input)[-1]*ch_dims
|
|
||||||
|
|
||||||
x = input
|
|
||||||
x = downscale(dims)(x)
|
|
||||||
x = downscale(dims*2)(x)
|
|
||||||
x = downscale(dims*4)(x)
|
|
||||||
x = downscale(dims*8)(x)
|
|
||||||
|
|
||||||
x = Flatten()(x)
|
|
||||||
return x
|
|
||||||
return func
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def LIAEInterFlow(resolution, ae_dims=256, **kwargs):
|
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
|
||||||
lowest_dense_res=resolution // 16
|
|
||||||
|
|
||||||
def func(input):
|
|
||||||
x = input[0]
|
|
||||||
x = Dense(ae_dims)(x)
|
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
|
|
||||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
|
|
||||||
x = upscale(ae_dims*2)(x)
|
|
||||||
return x
|
|
||||||
return func
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def LIAEDecFlow(output_nc,ch_dims, multiscale_count=1, add_residual_blocks=False, **kwargs):
|
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
|
||||||
to_bgr = partial(SAEModel.to_bgr, **kwargs)
|
|
||||||
dims = output_nc * ch_dims
|
|
||||||
ResidualBlock = partial(SAEModel.ResidualBlock, **kwargs)
|
|
||||||
|
|
||||||
def func(input):
|
|
||||||
x = input[0]
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
x1 = upscale(dims*8)( x )
|
|
||||||
|
|
||||||
if add_residual_blocks:
|
|
||||||
x1 = ResidualBlock(dims*8)(x1)
|
|
||||||
x1 = ResidualBlock(dims*8)(x1)
|
|
||||||
|
|
||||||
if multiscale_count >= 3:
|
|
||||||
outputs += [ to_bgr(output_nc) ( x1 ) ]
|
|
||||||
|
|
||||||
x2 = upscale(dims*4)( x1 )
|
|
||||||
|
|
||||||
if add_residual_blocks:
|
|
||||||
x2 = ResidualBlock(dims*4)(x2)
|
|
||||||
x2 = ResidualBlock(dims*4)(x2)
|
|
||||||
|
|
||||||
if multiscale_count >= 2:
|
|
||||||
outputs += [ to_bgr(output_nc) ( x2 ) ]
|
|
||||||
|
|
||||||
x3 = upscale(dims*2)( x2 )
|
|
||||||
|
|
||||||
if add_residual_blocks:
|
|
||||||
x3 = ResidualBlock( dims*2)(x3)
|
|
||||||
x3 = ResidualBlock( dims*2)(x3)
|
|
||||||
|
|
||||||
outputs += [ to_bgr(output_nc) ( x3 ) ]
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
return func
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def DFEncFlow(resolution, ae_dims, ch_dims, **kwargs):
|
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
|
||||||
downscale = partial(SAEModel.downscale, **kwargs)#, kernel_regularizer=keras.regularizers.l2(0.0),
|
|
||||||
lowest_dense_res = resolution // 16
|
|
||||||
|
|
||||||
def func(input):
|
|
||||||
x = input
|
|
||||||
|
|
||||||
dims = K.int_shape(input)[-1]*ch_dims
|
|
||||||
x = downscale(dims)(x)
|
|
||||||
x = downscale(dims*2)(x)
|
|
||||||
x = downscale(dims*4)(x)
|
|
||||||
x = downscale(dims*8)(x)
|
|
||||||
|
|
||||||
x = Dense(ae_dims)(Flatten()(x))
|
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
|
||||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
|
||||||
x = upscale(ae_dims)(x)
|
|
||||||
return x
|
|
||||||
return func
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def DFDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, **kwargs):
|
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
|
||||||
to_bgr = partial(SAEModel.to_bgr, **kwargs)
|
|
||||||
dims = output_nc * ch_dims
|
|
||||||
ResidualBlock = partial(SAEModel.ResidualBlock, **kwargs)
|
|
||||||
|
|
||||||
def func(input):
|
|
||||||
x = input[0]
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
x1 = upscale(dims*8)( x )
|
|
||||||
|
|
||||||
if add_residual_blocks:
|
|
||||||
x1 = ResidualBlock( dims*8 )(x1)
|
|
||||||
x1 = ResidualBlock( dims*8 )(x1)
|
|
||||||
|
|
||||||
if multiscale_count >= 3:
|
|
||||||
outputs += [ to_bgr(output_nc) ( x1 ) ]
|
|
||||||
|
|
||||||
x2 = upscale(dims*4)( x1 )
|
|
||||||
|
|
||||||
if add_residual_blocks:
|
|
||||||
x2 = ResidualBlock( dims*4)(x2)
|
|
||||||
x2 = ResidualBlock( dims*4)(x2)
|
|
||||||
|
|
||||||
if multiscale_count >= 2:
|
|
||||||
outputs += [ to_bgr(output_nc) ( x2 ) ]
|
|
||||||
|
|
||||||
x3 = upscale(dims*2)( x2 )
|
|
||||||
|
|
||||||
if add_residual_blocks:
|
|
||||||
x3 = ResidualBlock( dims*2)(x3)
|
|
||||||
x3 = ResidualBlock( dims*2)(x3)
|
|
||||||
|
|
||||||
outputs += [ to_bgr(output_nc) ( x3 ) ]
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
Model = SAEModel
|
|
||||||
|
|
|
@ -56,11 +56,14 @@ Conv2DTranspose = nnlib.Conv2DTranspose
|
||||||
EqualConv2D = nnlib.EqualConv2D
|
EqualConv2D = nnlib.EqualConv2D
|
||||||
SeparableConv2D = KL.SeparableConv2D
|
SeparableConv2D = KL.SeparableConv2D
|
||||||
MaxPooling2D = KL.MaxPooling2D
|
MaxPooling2D = KL.MaxPooling2D
|
||||||
|
AveragePooling2D = KL.AveragePooling2D
|
||||||
|
GlobalAveragePooling2D = KL.GlobalAveragePooling2D
|
||||||
UpSampling2D = KL.UpSampling2D
|
UpSampling2D = KL.UpSampling2D
|
||||||
BatchNormalization = KL.BatchNormalization
|
BatchNormalization = KL.BatchNormalization
|
||||||
PixelNormalization = nnlib.PixelNormalization
|
PixelNormalization = nnlib.PixelNormalization
|
||||||
|
|
||||||
LeakyReLU = KL.LeakyReLU
|
LeakyReLU = KL.LeakyReLU
|
||||||
|
ELU = KL.ELU
|
||||||
ReLU = KL.ReLU
|
ReLU = KL.ReLU
|
||||||
PReLU = KL.PReLU
|
PReLU = KL.PReLU
|
||||||
tanh = KL.Activation('tanh')
|
tanh = KL.Activation('tanh')
|
||||||
|
@ -70,6 +73,7 @@ Softmax = KL.Softmax
|
||||||
|
|
||||||
Lambda = KL.Lambda
|
Lambda = KL.Lambda
|
||||||
Add = KL.Add
|
Add = KL.Add
|
||||||
|
Multiply = KL.Multiply
|
||||||
Concatenate = KL.Concatenate
|
Concatenate = KL.Concatenate
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue