mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Formatting
This commit is contained in:
parent
fcbc8b125c
commit
ab043da2d9
1 changed files with 341 additions and 231 deletions
|
@ -7,9 +7,9 @@ from facelib import FaceType
|
||||||
from samplelib import *
|
from samplelib import *
|
||||||
from interact import interact as io
|
from interact import interact as io
|
||||||
|
|
||||||
#SAE - Styled AutoEncoder
|
|
||||||
class SAEModel(ModelBase):
|
|
||||||
|
|
||||||
|
# SAE - Styled AutoEncoder
|
||||||
|
class SAEModel(ModelBase):
|
||||||
encoderH5 = 'encoder.h5'
|
encoderH5 = 'encoder.h5'
|
||||||
inter_BH5 = 'inter_B.h5'
|
inter_BH5 = 'inter_B.h5'
|
||||||
inter_ABH5 = 'inter_AB.h5'
|
inter_ABH5 = 'inter_AB.h5'
|
||||||
|
@ -21,37 +21,44 @@ class SAEModel(ModelBase):
|
||||||
decoder_dstH5 = 'decoder_dst.h5'
|
decoder_dstH5 = 'decoder_dst.h5'
|
||||||
decoder_dstmH5 = 'decoder_dstm.h5'
|
decoder_dstmH5 = 'decoder_dstm.h5'
|
||||||
|
|
||||||
#override
|
# override
|
||||||
def onInitializeOptions(self, is_first_run, ask_override):
|
def onInitializeOptions(self, is_first_run, ask_override):
|
||||||
yn_str = {True:'y',False:'n'}
|
yn_str = {True: 'y', False: 'n'}
|
||||||
|
|
||||||
default_resolution = 128
|
default_resolution = 128
|
||||||
default_archi = 'df'
|
default_archi = 'df'
|
||||||
default_face_type = 'f'
|
default_face_type = 'f'
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
resolution = io.input_int("Resolution ( 64-256 ?:help skip:128) : ", default_resolution, help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
|
resolution = io.input_int("Resolution ( 64-256 ?:help skip:128) : ", default_resolution,
|
||||||
resolution = np.clip (resolution, 64, 256)
|
help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
|
||||||
|
resolution = np.clip(resolution, 64, 256)
|
||||||
while np.modf(resolution / 16)[0] != 0.0:
|
while np.modf(resolution / 16)[0] != 0.0:
|
||||||
resolution -= 1
|
resolution -= 1
|
||||||
self.options['resolution'] = resolution
|
self.options['resolution'] = resolution
|
||||||
|
|
||||||
self.options['face_type'] = io.input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
self.options['face_type'] = io.input_str("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type,
|
||||||
self.options['learn_mask'] = io.input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.")
|
['h', 'f'],
|
||||||
|
help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
||||||
|
self.options['learn_mask'] = io.input_bool("Learn mask? (y/n, ?:help skip:y) : ", True,
|
||||||
|
help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.")
|
||||||
else:
|
else:
|
||||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||||
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||||
self.options['learn_mask'] = self.options.get('learn_mask', True)
|
self.options['learn_mask'] = self.options.get('learn_mask', True)
|
||||||
|
|
||||||
|
|
||||||
if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend:
|
if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend:
|
||||||
def_optimizer_mode = self.options.get('optimizer_mode', 1)
|
def_optimizer_mode = self.options.get('optimizer_mode', 1)
|
||||||
self.options['optimizer_mode'] = io.input_int ("Optimizer mode? ( 1,2,3 ?:help skip:%d) : " % (def_optimizer_mode), def_optimizer_mode, help_message="1 - no changes. 2 - allows you to train x2 bigger network consuming RAM. 3 - allows you to train x3 bigger network consuming huge amount of RAM and slower, depends on CPU power.")
|
self.options['optimizer_mode'] = io.input_int(
|
||||||
|
"Optimizer mode? ( 1,2,3 ?:help skip:%d) : " % (def_optimizer_mode), def_optimizer_mode,
|
||||||
|
help_message="1 - no changes. 2 - allows you to train x2 bigger network consuming RAM. 3 - allows you to train x3 bigger network consuming huge amount of RAM and slower, depends on CPU power.")
|
||||||
else:
|
else:
|
||||||
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)
|
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['archi'] = io.input_str ("AE architecture (df, liae ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes.").lower() #-s version is slower, but has decreased change to collapse.
|
self.options['archi'] = io.input_str("AE architecture (df, liae ?:help skip:%s) : " % (default_archi),
|
||||||
|
default_archi, ['df', 'liae'],
|
||||||
|
help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes.").lower() # -s version is slower, but has decreased change to collapse.
|
||||||
else:
|
else:
|
||||||
self.options['archi'] = self.options.get('archi', default_archi)
|
self.options['archi'] = self.options.get('archi', default_archi)
|
||||||
|
|
||||||
|
@ -61,12 +68,26 @@ class SAEModel(ModelBase):
|
||||||
def_ca_weights = False
|
def_ca_weights = False
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
self.options['ae_dims'] = np.clip(
|
||||||
self.options['e_ch_dims'] = np.clip ( io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims) , default_e_ch_dims, help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 )
|
io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims), default_ae_dims,
|
||||||
|
help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU."),
|
||||||
|
32, 1024)
|
||||||
|
self.options['e_ch_dims'] = np.clip(
|
||||||
|
io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims),
|
||||||
|
default_e_ch_dims,
|
||||||
|
help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU."),
|
||||||
|
21, 85)
|
||||||
default_d_ch_dims = self.options['e_ch_dims'] // 2
|
default_d_ch_dims = self.options['e_ch_dims'] // 2
|
||||||
self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 )
|
self.options['d_ch_dims'] = np.clip(
|
||||||
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
|
io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims),
|
||||||
self.options['ca_weights'] = io.input_bool ("Use CA weights? (y/n, ?:help skip: %s ) : " % (yn_str[def_ca_weights]), def_ca_weights, help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes a time at first run.")
|
default_d_ch_dims,
|
||||||
|
help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU."),
|
||||||
|
10, 85)
|
||||||
|
self.options['multiscale_decoder'] = io.input_bool("Use multiscale decoder? (y/n, ?:help skip:n) : ", False,
|
||||||
|
help_message="Multiscale decoder helps to get better details.")
|
||||||
|
self.options['ca_weights'] = io.input_bool(
|
||||||
|
"Use CA weights? (y/n, ?:help skip: %s ) : " % (yn_str[def_ca_weights]), def_ca_weights,
|
||||||
|
help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes a time at first run.")
|
||||||
else:
|
else:
|
||||||
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
||||||
self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims)
|
self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims)
|
||||||
|
@ -78,42 +99,58 @@ class SAEModel(ModelBase):
|
||||||
default_bg_style_power = 0.0
|
default_bg_style_power = 0.0
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
def_pixel_loss = self.options.get('pixel_loss', False)
|
def_pixel_loss = self.options.get('pixel_loss', False)
|
||||||
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: %s ) : " % (yn_str[def_pixel_loss]), def_pixel_loss, help_message="Pixel loss may help to enhance fine details and stabilize face color. Use it only if quality does not improve over time. Enabling this option too early increases the chance of model collapse.")
|
self.options['pixel_loss'] = io.input_bool(
|
||||||
|
"Use pixel loss? (y/n, ?:help skip: %s ) : " % (yn_str[def_pixel_loss]), def_pixel_loss,
|
||||||
|
help_message="Pixel loss may help to enhance fine details and stabilize face color. Use it only if quality does not improve over time. Enabling this option too early increases the chance of model collapse.")
|
||||||
|
|
||||||
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
default_face_style_power = default_face_style_power if is_first_run else self.options.get(
|
||||||
self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power,
|
'face_style_power', default_face_style_power)
|
||||||
help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
|
self.options['face_style_power'] = np.clip(
|
||||||
|
io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power),
|
||||||
|
default_face_style_power,
|
||||||
|
help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes. Enabling this option increases the chance of model collapse."),
|
||||||
|
0.0, 100.0)
|
||||||
|
|
||||||
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power',
|
||||||
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power,
|
default_bg_style_power)
|
||||||
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
|
self.options['bg_style_power'] = np.clip(
|
||||||
|
io.input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power),
|
||||||
|
default_bg_style_power,
|
||||||
|
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."),
|
||||||
|
0.0, 100.0)
|
||||||
|
|
||||||
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
||||||
self.options['apply_random_ct'] = io.input_bool ("Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (yn_str[default_apply_random_ct]), default_apply_random_ct, help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
|
self.options['apply_random_ct'] = io.input_bool(
|
||||||
|
"Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (
|
||||||
if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301
|
yn_str[default_apply_random_ct]), default_apply_random_ct,
|
||||||
|
help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
|
||||||
|
|
||||||
|
if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301
|
||||||
default_clipgrad = False if is_first_run else self.options.get('clipgrad', False)
|
default_clipgrad = False if is_first_run else self.options.get('clipgrad', False)
|
||||||
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping? (y/n, ?:help skip:%s) : " % (yn_str[default_clipgrad]), default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
self.options['clipgrad'] = io.input_bool(
|
||||||
|
"Enable gradient clipping? (y/n, ?:help skip:%s) : " % (yn_str[default_clipgrad]), default_clipgrad,
|
||||||
|
help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||||
else:
|
else:
|
||||||
self.options['clipgrad'] = False
|
self.options['clipgrad'] = False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
|
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
|
||||||
self.options['clipgrad'] = self.options.get('clipgrad', False)
|
self.options['clipgrad'] = self.options.get('clipgrad', False)
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['pretrain'] = io.input_bool ("Pretrain the model? (y/n, ?:help skip:n) : ", False, help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: encoder.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.")
|
self.options['pretrain'] = io.input_bool("Pretrain the model? (y/n, ?:help skip:n) : ", False,
|
||||||
|
help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: encoder.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.")
|
||||||
else:
|
else:
|
||||||
self.options['pretrain'] = False
|
self.options['pretrain'] = False
|
||||||
|
|
||||||
#override
|
# override
|
||||||
def onInitialize(self):
|
def onInitialize(self):
|
||||||
exec(nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
SAEModel.initialize_nn_functions()
|
SAEModel.initialize_nn_functions()
|
||||||
self.set_vram_batch_requirements({1.5:4})
|
self.set_vram_batch_requirements({1.5: 4})
|
||||||
|
|
||||||
resolution = self.options['resolution']
|
resolution = self.options['resolution']
|
||||||
ae_dims = self.options['ae_dims']
|
ae_dims = self.options['ae_dims']
|
||||||
|
@ -140,52 +177,60 @@ class SAEModel(ModelBase):
|
||||||
target_dst = Input(bgr_shape)
|
target_dst = Input(bgr_shape)
|
||||||
target_dstm = Input(mask_shape)
|
target_dstm = Input(mask_shape)
|
||||||
|
|
||||||
target_src_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
target_src_ar = [Input((bgr_shape[0] // (2 ** i),) * 2 + (bgr_shape[-1],)) for i in range(ms_count - 1, -1, -1)]
|
||||||
target_srcm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
target_srcm_ar = [Input((mask_shape[0] // (2 ** i),) * 2 + (mask_shape[-1],)) for i in
|
||||||
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
range(ms_count - 1, -1, -1)]
|
||||||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
target_dst_ar = [Input((bgr_shape[0] // (2 ** i),) * 2 + (bgr_shape[-1],)) for i in range(ms_count - 1, -1, -1)]
|
||||||
|
target_dstm_ar = [Input((mask_shape[0] // (2 ** i),) * 2 + (mask_shape[-1],)) for i in
|
||||||
|
range(ms_count - 1, -1, -1)]
|
||||||
|
|
||||||
common_flow_kwargs = { 'padding': 'zero',
|
common_flow_kwargs = {'padding': 'zero',
|
||||||
'norm': '',
|
'norm': '',
|
||||||
'act':'' }
|
'act': ''}
|
||||||
models_list = []
|
models_list = []
|
||||||
weights_to_load = []
|
weights_to_load = []
|
||||||
if 'liae' in self.options['archi']:
|
if 'liae' in self.options['archi']:
|
||||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs))(
|
||||||
|
Input(bgr_shape))
|
||||||
|
|
||||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
enc_output_Inputs = [Input(K.int_shape(x)[1:]) for x in self.encoder.outputs]
|
||||||
|
|
||||||
self.inter_B = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, **common_flow_kwargs)) (enc_output_Inputs)
|
self.inter_B = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, **common_flow_kwargs))(
|
||||||
self.inter_AB = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, **common_flow_kwargs)) (enc_output_Inputs)
|
enc_output_Inputs)
|
||||||
|
self.inter_AB = modelify(SAEModel.LIAEInterFlow(resolution, ae_dims=ae_dims, **common_flow_kwargs))(
|
||||||
|
enc_output_Inputs)
|
||||||
|
|
||||||
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
|
inter_output_Inputs = [Input(np.array(K.int_shape(x)[1:]) * (1, 1, 2)) for x in self.inter_B.outputs]
|
||||||
|
|
||||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
self.decoder = modelify(
|
||||||
|
SAEModel.LIAEDecFlow(bgr_shape[2], ch_dims=d_ch_dims, multiscale_count=self.ms_count,
|
||||||
|
add_residual_blocks=d_residual_blocks, **common_flow_kwargs))(inter_output_Inputs)
|
||||||
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
self.decoderm = modelify(SAEModel.LIAEDecFlow(mask_shape[2], ch_dims=d_ch_dims, **common_flow_kwargs))(
|
||||||
|
inter_output_Inputs)
|
||||||
models_list += [self.decoderm]
|
models_list += [self.decoderm]
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
weights_to_load += [[self.encoder, 'encoder.h5'],
|
||||||
[self.inter_B , 'inter_B.h5'],
|
[self.inter_B, 'inter_B.h5'],
|
||||||
[self.inter_AB, 'inter_AB.h5'],
|
[self.inter_AB, 'inter_AB.h5'],
|
||||||
[self.decoder , 'decoder.h5'],
|
[self.decoder, 'decoder.h5'],
|
||||||
]
|
]
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
weights_to_load += [ [self.decoderm, 'decoderm.h5'] ]
|
weights_to_load += [[self.decoderm, 'decoderm.h5']]
|
||||||
|
|
||||||
warped_src_code = self.encoder (warped_src)
|
warped_src_code = self.encoder(warped_src)
|
||||||
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
warped_src_inter_AB_code = self.inter_AB(warped_src_code)
|
||||||
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code, warped_src_inter_AB_code])
|
||||||
|
|
||||||
warped_dst_code = self.encoder (warped_dst)
|
warped_dst_code = self.encoder(warped_dst)
|
||||||
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
warped_dst_inter_B_code = self.inter_B(warped_dst_code)
|
||||||
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
|
warped_dst_inter_AB_code = self.inter_AB(warped_dst_code)
|
||||||
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code, warped_dst_inter_AB_code])
|
||||||
|
|
||||||
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code, warped_dst_inter_AB_code])
|
||||||
|
|
||||||
pred_src_src = self.decoder(warped_src_inter_code)
|
pred_src_src = self.decoder(warped_src_inter_code)
|
||||||
pred_dst_dst = self.decoder(warped_dst_inter_code)
|
pred_dst_dst = self.decoder(warped_dst_inter_code)
|
||||||
|
@ -197,31 +242,39 @@ class SAEModel(ModelBase):
|
||||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||||
|
|
||||||
elif 'df' in self.options['archi']:
|
elif 'df' in self.options['archi']:
|
||||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
self.encoder = modelify(
|
||||||
|
SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs))(
|
||||||
|
Input(bgr_shape))
|
||||||
|
|
||||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
dec_Inputs = [Input(K.int_shape(x)[1:]) for x in self.encoder.outputs]
|
||||||
|
|
||||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_src = modelify(
|
||||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
SAEModel.DFDecFlow(bgr_shape[2], ch_dims=d_ch_dims, multiscale_count=self.ms_count,
|
||||||
|
add_residual_blocks=d_residual_blocks, **common_flow_kwargs))(dec_Inputs)
|
||||||
|
self.decoder_dst = modelify(
|
||||||
|
SAEModel.DFDecFlow(bgr_shape[2], ch_dims=d_ch_dims, multiscale_count=self.ms_count,
|
||||||
|
add_residual_blocks=d_residual_blocks, **common_flow_kwargs))(dec_Inputs)
|
||||||
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_srcm = modelify(
|
||||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
SAEModel.DFDecFlow(mask_shape[2], ch_dims=d_ch_dims, **common_flow_kwargs))(dec_Inputs)
|
||||||
|
self.decoder_dstm = modelify(
|
||||||
|
SAEModel.DFDecFlow(mask_shape[2], ch_dims=d_ch_dims, **common_flow_kwargs))(dec_Inputs)
|
||||||
models_list += [self.decoder_srcm, self.decoder_dstm]
|
models_list += [self.decoder_srcm, self.decoder_dstm]
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
weights_to_load += [[self.encoder, 'encoder.h5'],
|
||||||
[self.decoder_src, 'decoder_src.h5'],
|
[self.decoder_src, 'decoder_src.h5'],
|
||||||
[self.decoder_dst, 'decoder_dst.h5']
|
[self.decoder_dst, 'decoder_dst.h5']
|
||||||
]
|
]
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
weights_to_load += [[self.decoder_srcm, 'decoder_srcm.h5'],
|
||||||
[self.decoder_dstm, 'decoder_dstm.h5'],
|
[self.decoder_dstm, 'decoder_dstm.h5'],
|
||||||
]
|
]
|
||||||
|
|
||||||
warped_src_code = self.encoder (warped_src)
|
warped_src_code = self.encoder(warped_src)
|
||||||
warped_dst_code = self.encoder (warped_dst)
|
warped_dst_code = self.encoder(warped_dst)
|
||||||
pred_src_src = self.decoder_src(warped_src_code)
|
pred_src_src = self.decoder_src(warped_src_code)
|
||||||
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||||
pred_src_dst = self.decoder_src(warped_dst_code)
|
pred_src_dst = self.decoder_src(warped_dst_code)
|
||||||
|
@ -232,42 +285,47 @@ class SAEModel(ModelBase):
|
||||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||||
|
|
||||||
if self.is_first_run():
|
if self.is_first_run():
|
||||||
if self.options.get('ca_weights',False):
|
if self.options.get('ca_weights', False):
|
||||||
conv_weights_list = []
|
conv_weights_list = []
|
||||||
for model in models_list:
|
for model in models_list:
|
||||||
for layer in model.layers:
|
for layer in model.layers:
|
||||||
if type(layer) == keras.layers.Conv2D:
|
if type(layer) == keras.layers.Conv2D:
|
||||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights
|
||||||
CAInitializerMP ( conv_weights_list )
|
CAInitializerMP(conv_weights_list)
|
||||||
else:
|
else:
|
||||||
self.load_weights_safe(weights_to_load)
|
self.load_weights_safe(weights_to_load)
|
||||||
|
|
||||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
pred_src_src, pred_dst_dst, pred_src_dst, = [[x] if type(x) != list else x for x in
|
||||||
|
[pred_src_src, pred_dst_dst, pred_src_dst, ]]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ]
|
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [[x] if type(x) != list else x for x in
|
||||||
|
[pred_src_srcm, pred_dst_dstm, pred_src_dstm]]
|
||||||
|
|
||||||
target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar]
|
target_srcm_blurred_ar = [gaussian_blur(max(1, K.int_shape(x)[1] // 32))(x) for x in target_srcm_ar]
|
||||||
target_srcm_sigm_ar = target_srcm_blurred_ar #[ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
|
target_srcm_sigm_ar = target_srcm_blurred_ar # [ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
|
||||||
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
|
target_srcm_anti_sigm_ar = [1.0 - x for x in target_srcm_sigm_ar]
|
||||||
|
|
||||||
target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar]
|
target_dstm_blurred_ar = [gaussian_blur(max(1, K.int_shape(x)[1] // 32))(x) for x in target_dstm_ar]
|
||||||
target_dstm_sigm_ar = target_dstm_blurred_ar#[ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
|
target_dstm_sigm_ar = target_dstm_blurred_ar # [ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
|
||||||
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
|
target_dstm_anti_sigm_ar = [1.0 - x for x in target_dstm_sigm_ar]
|
||||||
|
|
||||||
target_src_sigm_ar = target_src_ar#[ x + 1 for x in target_src_ar]
|
target_src_sigm_ar = target_src_ar # [ x + 1 for x in target_src_ar]
|
||||||
target_dst_sigm_ar = target_dst_ar#[ x + 1 for x in target_dst_ar]
|
target_dst_sigm_ar = target_dst_ar # [ x + 1 for x in target_dst_ar]
|
||||||
|
|
||||||
pred_src_src_sigm_ar = pred_src_src#[ x + 1 for x in pred_src_src]
|
pred_src_src_sigm_ar = pred_src_src # [ x + 1 for x in pred_src_src]
|
||||||
pred_dst_dst_sigm_ar = pred_dst_dst#[ x + 1 for x in pred_dst_dst]
|
pred_dst_dst_sigm_ar = pred_dst_dst # [ x + 1 for x in pred_dst_dst]
|
||||||
pred_src_dst_sigm_ar = pred_src_dst#[ x + 1 for x in pred_src_dst]
|
pred_src_dst_sigm_ar = pred_src_dst # [ x + 1 for x in pred_src_dst]
|
||||||
|
|
||||||
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
target_src_masked_ar = [target_src_sigm_ar[i] * target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
||||||
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
target_dst_masked_ar = [target_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||||
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
target_dst_anti_masked_ar = [target_dst_sigm_ar[i] * target_dstm_anti_sigm_ar[i] for i in
|
||||||
|
range(len(target_dst_sigm_ar))]
|
||||||
|
|
||||||
pred_src_src_masked_ar = [ pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] for i in range(len(pred_src_src_sigm_ar))]
|
pred_src_src_masked_ar = [pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] for i in
|
||||||
pred_dst_dst_masked_ar = [ pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in range(len(pred_dst_dst_sigm_ar))]
|
range(len(pred_src_src_sigm_ar))]
|
||||||
|
pred_dst_dst_masked_ar = [pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in
|
||||||
|
range(len(pred_dst_dst_sigm_ar))]
|
||||||
|
|
||||||
target_src_masked_ar_opt = target_src_masked_ar if masked_training else target_src_sigm_ar
|
target_src_masked_ar_opt = target_src_masked_ar if masked_training else target_src_sigm_ar
|
||||||
target_dst_masked_ar_opt = target_dst_masked_ar if masked_training else target_dst_sigm_ar
|
target_dst_masked_ar_opt = target_dst_masked_ar if masked_training else target_dst_sigm_ar
|
||||||
|
@ -275,12 +333,18 @@ class SAEModel(ModelBase):
|
||||||
pred_src_src_masked_ar_opt = pred_src_src_masked_ar if masked_training else pred_src_src_sigm_ar
|
pred_src_src_masked_ar_opt = pred_src_src_masked_ar if masked_training else pred_src_src_sigm_ar
|
||||||
pred_dst_dst_masked_ar_opt = pred_dst_dst_masked_ar if masked_training else pred_dst_dst_sigm_ar
|
pred_dst_dst_masked_ar_opt = pred_dst_dst_masked_ar if masked_training else pred_dst_dst_sigm_ar
|
||||||
|
|
||||||
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
psd_target_dst_masked_ar = [pred_src_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in
|
||||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
range(len(pred_src_dst_sigm_ar))]
|
||||||
|
psd_target_dst_anti_masked_ar = [pred_src_dst_sigm_ar[i] * target_dstm_anti_sigm_ar[i] for i in
|
||||||
|
range(len(pred_src_dst_sigm_ar))]
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999,
|
||||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
clipnorm=1.0 if self.options['clipgrad'] else 0.0,
|
||||||
|
tf_cpu_mode=self.options['optimizer_mode'] - 1)
|
||||||
|
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999,
|
||||||
|
clipnorm=1.0 if self.options['clipgrad'] else 0.0,
|
||||||
|
tf_cpu_mode=self.options['optimizer_mode'] - 1)
|
||||||
|
|
||||||
if 'liae' in self.options['archi']:
|
if 'liae' in self.options['archi']:
|
||||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
|
@ -292,29 +356,40 @@ class SAEModel(ModelBase):
|
||||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
src_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i]) for i in range(len(target_src_masked_ar_opt)) ])
|
src_loss_batch = sum([10 * dssim(kernel_size=int(resolution / 11.6), max_value=1.0)(
|
||||||
|
target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i]) for i in
|
||||||
|
range(len(target_src_masked_ar_opt))])
|
||||||
else:
|
else:
|
||||||
src_loss_batch = sum([ K.mean ( 50*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
|
src_loss_batch = sum(
|
||||||
|
[K.mean(50 * K.square(target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i]), axis=[1, 2, 3])
|
||||||
|
for i in range(len(target_src_masked_ar_opt))])
|
||||||
|
|
||||||
src_loss = K.mean(src_loss_batch)
|
src_loss = K.mean(src_loss_batch)
|
||||||
|
|
||||||
face_style_power = self.options['face_style_power'] / 100.0
|
face_style_power = self.options['face_style_power'] / 100.0
|
||||||
|
|
||||||
if face_style_power != 0:
|
if face_style_power != 0:
|
||||||
src_loss += style_loss(gaussian_blur_radius=resolution//16, loss_weight=face_style_power, wnd_size=0)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
|
src_loss += style_loss(gaussian_blur_radius=resolution // 16, loss_weight=face_style_power, wnd_size=0)(
|
||||||
|
psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1])
|
||||||
|
|
||||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||||
if bg_style_power != 0:
|
if bg_style_power != 0:
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
bg_loss = K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))
|
bg_loss = K.mean((10 * bg_style_power) * dssim(kernel_size=int(resolution / 11.6), max_value=1.0)(
|
||||||
|
psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1]))
|
||||||
else:
|
else:
|
||||||
bg_loss = K.mean( (50*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
bg_loss = K.mean((50 * bg_style_power) * K.square(
|
||||||
|
psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1]))
|
||||||
src_loss += bg_loss
|
src_loss += bg_loss
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
dst_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i]) for i in range(len(target_dst_masked_ar_opt)) ])
|
dst_loss_batch = sum([10 * dssim(kernel_size=int(resolution / 11.6), max_value=1.0)(
|
||||||
|
target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i]) for i in
|
||||||
|
range(len(target_dst_masked_ar_opt))])
|
||||||
else:
|
else:
|
||||||
dst_loss_batch = sum([ K.mean ( 50*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
|
dst_loss_batch = sum(
|
||||||
|
[K.mean(50 * K.square(target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i]), axis=[1, 2, 3])
|
||||||
|
for i in range(len(target_dst_masked_ar_opt))])
|
||||||
|
|
||||||
dst_loss = K.mean(dst_loss_batch)
|
dst_loss = K.mean(dst_loss_batch)
|
||||||
|
|
||||||
|
@ -324,30 +399,38 @@ class SAEModel(ModelBase):
|
||||||
feed += target_dst_ar[::-1]
|
feed += target_dst_ar[::-1]
|
||||||
feed += target_dstm_ar[::-1]
|
feed += target_dstm_ar[::-1]
|
||||||
|
|
||||||
self.src_dst_train = K.function (feed,[src_loss,dst_loss], self.src_dst_opt.get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
|
self.src_dst_train = K.function(feed, [src_loss, dst_loss],
|
||||||
|
self.src_dst_opt.get_updates(src_loss + dst_loss,
|
||||||
|
src_dst_loss_train_weights))
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[-1]-pred_src_srcm[-1])) for i in range(len(target_srcm_ar)) ])
|
src_mask_loss = sum(
|
||||||
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[-1]-pred_dst_dstm[-1])) for i in range(len(target_dstm_ar)) ])
|
[K.mean(K.square(target_srcm_ar[-1] - pred_src_srcm[-1])) for i in range(len(target_srcm_ar))])
|
||||||
|
dst_mask_loss = sum(
|
||||||
|
[K.mean(K.square(target_dstm_ar[-1] - pred_dst_dstm[-1])) for i in range(len(target_dstm_ar))])
|
||||||
|
|
||||||
feed = [ warped_src, warped_dst]
|
feed = [warped_src, warped_dst]
|
||||||
feed += target_srcm_ar[::-1]
|
feed += target_srcm_ar[::-1]
|
||||||
feed += target_dstm_ar[::-1]
|
feed += target_dstm_ar[::-1]
|
||||||
|
|
||||||
self.src_dst_mask_train = K.function (feed,[src_mask_loss, dst_mask_loss], self.src_dst_mask_opt.get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
|
self.src_dst_mask_train = K.function(feed, [src_mask_loss, dst_mask_loss],
|
||||||
|
self.src_dst_mask_opt.get_updates(src_mask_loss + dst_mask_loss,
|
||||||
|
src_dst_mask_loss_train_weights))
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]])
|
self.AE_view = K.function([warped_src, warped_dst],
|
||||||
|
[pred_src_src[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1],
|
||||||
|
pred_src_dstm[-1]])
|
||||||
else:
|
else:
|
||||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
|
self.AE_view = K.function([warped_src, warped_dst],
|
||||||
|
[pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1]])
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1] ])
|
self.AE_convert = K.function([warped_dst], [pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1]])
|
||||||
else:
|
else:
|
||||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ])
|
self.AE_convert = K.function([warped_dst], [pred_src_dst[-1]])
|
||||||
|
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
self.src_sample_losses = []
|
self.src_sample_losses = []
|
||||||
|
@ -367,125 +450,138 @@ class SAEModel(ModelBase):
|
||||||
training_data_dst_path = self.pretraining_data_path
|
training_data_dst_path = self.pretraining_data_path
|
||||||
sort_by_yaw = False
|
sort_by_yaw = False
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators([
|
||||||
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
SampleGeneratorFace(training_data_src_path,
|
||||||
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'apply_ct': apply_random_ct} ] + \
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i), 'apply_ct': apply_random_ct } for i in range(ms_count)] + \
|
scale_range=np.array([-0.05,
|
||||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)]
|
0.05]) + self.src_scale_mod / 100.0),
|
||||||
),
|
output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution // (2 ** i),
|
||||||
|
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
|
range(ms_count)]
|
||||||
|
),
|
||||||
|
|
||||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution} ] + \
|
output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i)} for i in range(ms_count)] + \
|
'resolution': resolution}] + \
|
||||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)])
|
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
])
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
|
range(ms_count)] + \
|
||||||
#override
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
|
range(ms_count)])
|
||||||
|
])
|
||||||
|
|
||||||
|
# override
|
||||||
def get_model_filename_list(self):
|
def get_model_filename_list(self):
|
||||||
ar = []
|
ar = []
|
||||||
if 'liae' in self.options['archi']:
|
if 'liae' in self.options['archi']:
|
||||||
ar += [[self.encoder, 'encoder.h5'],
|
ar += [[self.encoder, 'encoder.h5'],
|
||||||
[self.inter_B, 'inter_B.h5'],
|
[self.inter_B, 'inter_B.h5'],
|
||||||
[self.decoder, 'decoder.h5']
|
[self.decoder, 'decoder.h5']
|
||||||
]
|
]
|
||||||
|
|
||||||
if not self.pretrain or self.iter == 0:
|
if not self.pretrain or self.iter == 0:
|
||||||
ar += [ [self.inter_AB, 'inter_AB.h5'],
|
ar += [[self.inter_AB, 'inter_AB.h5'],
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
ar += [[self.decoderm, 'decoderm.h5']]
|
||||||
|
|
||||||
elif 'df' in self.options['archi']:
|
elif 'df' in self.options['archi']:
|
||||||
if not self.pretrain or self.iter == 0:
|
if not self.pretrain or self.iter == 0:
|
||||||
ar += [ [self.encoder, 'encoder.h5'],
|
ar += [[self.encoder, 'encoder.h5'],
|
||||||
]
|
]
|
||||||
|
|
||||||
ar += [ [self.decoder_src, 'decoder_src.h5'],
|
ar += [[self.decoder_src, 'decoder_src.h5'],
|
||||||
[self.decoder_dst, 'decoder_dst.h5']
|
[self.decoder_dst, 'decoder_dst.h5']
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
ar += [[self.decoder_srcm, 'decoder_srcm.h5'],
|
||||||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
[self.decoder_dstm, 'decoder_dstm.h5']]
|
||||||
return ar
|
return ar
|
||||||
|
|
||||||
#override
|
# override
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
self.save_weights_safe( self.get_model_filename_list() )
|
self.save_weights_safe(self.get_model_filename_list())
|
||||||
|
|
||||||
#override
|
# override
|
||||||
def onTrainOneIter(self, generators_samples, generators_list):
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
src_samples = generators_samples[0]
|
src_samples = generators_samples[0]
|
||||||
dst_samples = generators_samples[1]
|
dst_samples = generators_samples[1]
|
||||||
|
|
||||||
feed = [src_samples[0], dst_samples[0] ] + \
|
feed = [src_samples[0], dst_samples[0]] + \
|
||||||
src_samples[1:1+self.ms_count*2] + \
|
src_samples[1:1 + self.ms_count * 2] + \
|
||||||
dst_samples[1:1+self.ms_count*2]
|
dst_samples[1:1 + self.ms_count * 2]
|
||||||
|
|
||||||
src_loss, dst_loss, = self.src_dst_train (feed)
|
src_loss, dst_loss, = self.src_dst_train(feed)
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
feed = [ src_samples[0], dst_samples[0] ] + \
|
feed = [src_samples[0], dst_samples[0]] + \
|
||||||
src_samples[1+self.ms_count:1+self.ms_count*2] + \
|
src_samples[1 + self.ms_count:1 + self.ms_count * 2] + \
|
||||||
dst_samples[1+self.ms_count:1+self.ms_count*2]
|
dst_samples[1 + self.ms_count:1 + self.ms_count * 2]
|
||||||
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train (feed)
|
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train(feed)
|
||||||
|
|
||||||
return ( ('src_loss', src_loss), ('dst_loss', dst_loss) )
|
return (('src_loss', src_loss), ('dst_loss', dst_loss))
|
||||||
|
|
||||||
|
# override
|
||||||
#override
|
|
||||||
def onGetPreview(self, sample):
|
def onGetPreview(self, sample):
|
||||||
test_S = sample[0][1][0:4] #first 4 samples
|
test_S = sample[0][1][0:4] # first 4 samples
|
||||||
test_S_m = sample[0][1+self.ms_count][0:4] #first 4 samples
|
test_S_m = sample[0][1 + self.ms_count][0:4] # first 4 samples
|
||||||
test_D = sample[1][1][0:4]
|
test_D = sample[1][1][0:4]
|
||||||
test_D_m = sample[1][1+self.ms_count][0:4]
|
test_D_m = sample[1][1 + self.ms_count][0:4]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
S, D, SS, DD, DDM, SD, SDM = [ np.clip(x, 0.0, 1.0) for x in ([test_S,test_D] + self.AE_view ([test_S, test_D]) ) ]
|
S, D, SS, DD, DDM, SD, SDM = [np.clip(x, 0.0, 1.0) for x in
|
||||||
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
|
([test_S, test_D] + self.AE_view([test_S, test_D]))]
|
||||||
|
DDM, SDM, = [np.repeat(x, (3,), -1) for x in [DDM, SDM]]
|
||||||
else:
|
else:
|
||||||
S, D, SS, DD, SD, = [ np.clip(x, 0.0, 1.0) for x in ([test_S,test_D] + self.AE_view ([test_S, test_D]) ) ]
|
S, D, SS, DD, SD, = [np.clip(x, 0.0, 1.0) for x in ([test_S, test_D] + self.AE_view([test_S, test_D]))]
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
st = []
|
st = []
|
||||||
for i in range(0, len(test_S)):
|
for i in range(0, len(test_S)):
|
||||||
ar = S[i], SS[i], D[i], DD[i], SD[i]
|
ar = S[i], SS[i], D[i], DD[i], SD[i]
|
||||||
st.append ( np.concatenate ( ar, axis=1) )
|
st.append(np.concatenate(ar, axis=1))
|
||||||
|
|
||||||
result += [ ('SAE', np.concatenate (st, axis=0 )), ]
|
result += [('SAE', np.concatenate(st, axis=0)), ]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
st_m = []
|
st_m = []
|
||||||
for i in range(0, len(test_S)):
|
for i in range(0, len(test_S)):
|
||||||
ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
|
ar = S[i] * test_S_m[i], SS[i], D[i] * test_D_m[i], DD[i] * DDM[i], SD[i] * (DDM[i] * SDM[i])
|
||||||
st_m.append ( np.concatenate ( ar, axis=1) )
|
st_m.append(np.concatenate(ar, axis=1))
|
||||||
|
|
||||||
result += [ ('SAE masked', np.concatenate (st_m, axis=0 )), ]
|
result += [('SAE masked', np.concatenate(st_m, axis=0)), ]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def predictor_func (self, face):
|
def predictor_func(self, face):
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
bgr, mask_dst_dstm, mask_src_dstm = self.AE_convert ([face[np.newaxis,...]])
|
bgr, mask_dst_dstm, mask_src_dstm = self.AE_convert([face[np.newaxis, ...]])
|
||||||
mask = mask_dst_dstm[0] * mask_src_dstm[0]
|
mask = mask_dst_dstm[0] * mask_src_dstm[0]
|
||||||
return bgr[0], mask[...,0]
|
return bgr[0], mask[..., 0]
|
||||||
else:
|
else:
|
||||||
bgr, = self.AE_convert ([face[np.newaxis,...]])
|
bgr, = self.AE_convert([face[np.newaxis, ...]])
|
||||||
return bgr[0]
|
return bgr[0]
|
||||||
|
|
||||||
#override
|
# override
|
||||||
def get_converter(self):
|
def get_converter(self):
|
||||||
base_erode_mask_modifier = 30 if self.options['face_type'] == 'f' else 100
|
base_erode_mask_modifier = 30 if self.options['face_type'] == 'f' else 100
|
||||||
base_blur_mask_modifier = 0 if self.options['face_type'] == 'f' else 100
|
base_blur_mask_modifier = 0 if self.options['face_type'] == 'f' else 100
|
||||||
|
|
||||||
default_erode_mask_modifier = 0
|
default_erode_mask_modifier = 0
|
||||||
default_blur_mask_modifier = 100 if (self.options['face_style_power'] or self.options['bg_style_power']) and \
|
default_blur_mask_modifier = 100 if (self.options['face_style_power'] or self.options['bg_style_power']) and \
|
||||||
self.options['face_type'] == 'f' else 0
|
self.options['face_type'] == 'f' else 0
|
||||||
|
|
||||||
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
||||||
|
|
||||||
|
@ -494,7 +590,8 @@ class SAEModel(ModelBase):
|
||||||
predictor_input_size=self.options['resolution'],
|
predictor_input_size=self.options['resolution'],
|
||||||
predictor_masked=self.options['learn_mask'],
|
predictor_masked=self.options['learn_mask'],
|
||||||
face_type=face_type,
|
face_type=face_type,
|
||||||
default_mode = 1 if self.options['apply_random_ct'] or self.options['face_style_power'] or self.options['bg_style_power'] else 4,
|
default_mode=1 if self.options['apply_random_ct'] or self.options['face_style_power'] or
|
||||||
|
self.options['bg_style_power'] else 4,
|
||||||
base_erode_mask_modifier=base_erode_mask_modifier,
|
base_erode_mask_modifier=base_erode_mask_modifier,
|
||||||
base_blur_mask_modifier=base_blur_mask_modifier,
|
base_blur_mask_modifier=base_blur_mask_modifier,
|
||||||
default_erode_mask_modifier=default_erode_mask_modifier,
|
default_erode_mask_modifier=default_erode_mask_modifier,
|
||||||
|
@ -503,7 +600,7 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize_nn_functions():
|
def initialize_nn_functions():
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
|
|
||||||
def NormPass(x):
|
def NormPass(x):
|
||||||
return x
|
return x
|
||||||
|
@ -538,63 +635,73 @@ class SAEModel(ModelBase):
|
||||||
x = Act(self.act, lrelu_alpha=0.2)(x)
|
x = Act(self.act, lrelu_alpha=0.2)(x)
|
||||||
x = Norm(self.norm)(x)
|
x = Norm(self.norm)(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
SAEModel.ResidualBlock = ResidualBlock
|
SAEModel.ResidualBlock = ResidualBlock
|
||||||
|
|
||||||
def downscale (dim, padding='zero', norm='', act='', **kwargs):
|
def downscale(dim, padding='zero', norm='', act='', **kwargs):
|
||||||
def func(x):
|
def func(x):
|
||||||
return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) )
|
return Norm(norm)(Act(act)(Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)))
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
SAEModel.downscale = downscale
|
SAEModel.downscale = downscale
|
||||||
|
|
||||||
def upscale (dim, padding='zero', norm='', act='', **kwargs):
|
def upscale(dim, padding='zero', norm='', act='', **kwargs):
|
||||||
def func(x):
|
def func(x):
|
||||||
return SubpixelUpscaler()(Norm(norm)(Act(act)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x))))
|
return SubpixelUpscaler()(
|
||||||
|
Norm(norm)(Act(act)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x))))
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
SAEModel.upscale = upscale
|
SAEModel.upscale = upscale
|
||||||
|
|
||||||
def to_bgr (output_nc, padding='zero', **kwargs):
|
def to_bgr(output_nc, padding='zero', **kwargs):
|
||||||
def func(x):
|
def func(x):
|
||||||
return Conv2D(output_nc, kernel_size=5, padding=padding, activation='sigmoid')(x)
|
return Conv2D(output_nc, kernel_size=5, padding=padding, activation='sigmoid')(x)
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
SAEModel.to_bgr = to_bgr
|
SAEModel.to_bgr = to_bgr
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEEncFlow(resolution, ch_dims, **kwargs):
|
def LIAEEncFlow(resolution, ch_dims, **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
upscale = partial(SAEModel.upscale, **kwargs)
|
||||||
downscale = partial(SAEModel.downscale, **kwargs)
|
downscale = partial(SAEModel.downscale, **kwargs)
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
dims = K.int_shape(input)[-1]*ch_dims
|
dims = K.int_shape(input)[-1] * ch_dims
|
||||||
|
|
||||||
x = input
|
x = input
|
||||||
x = downscale(dims)(x)
|
x = downscale(dims)(x)
|
||||||
x = downscale(dims*2)(x)
|
x = downscale(dims * 2)(x)
|
||||||
x = downscale(dims*4)(x)
|
x = downscale(dims * 4)(x)
|
||||||
x = downscale(dims*8)(x)
|
x = downscale(dims * 8)(x)
|
||||||
|
|
||||||
x = Flatten()(x)
|
x = Flatten()(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEInterFlow(resolution, ae_dims=256, **kwargs):
|
def LIAEInterFlow(resolution, ae_dims=256, **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
upscale = partial(SAEModel.upscale, **kwargs)
|
||||||
lowest_dense_res=resolution // 16
|
lowest_dense_res = resolution // 16
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
x = Dense(ae_dims)(x)
|
x = Dense(ae_dims)(x)
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims * 2)(x)
|
||||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
|
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims * 2))(x)
|
||||||
x = upscale(ae_dims*2)(x)
|
x = upscale(ae_dims * 2)(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEDecFlow(output_nc,ch_dims, multiscale_count=1, add_residual_blocks=False, **kwargs):
|
def LIAEDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
upscale = partial(SAEModel.upscale, **kwargs)
|
||||||
to_bgr = partial(SAEModel.to_bgr, **kwargs)
|
to_bgr = partial(SAEModel.to_bgr, **kwargs)
|
||||||
dims = output_nc * ch_dims
|
dims = output_nc * ch_dims
|
||||||
|
@ -604,61 +711,63 @@ class SAEModel(ModelBase):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
x1 = upscale(dims*8)( x )
|
x1 = upscale(dims * 8)(x)
|
||||||
|
|
||||||
if add_residual_blocks:
|
if add_residual_blocks:
|
||||||
x1 = ResidualBlock(dims*8)(x1)
|
x1 = ResidualBlock(dims * 8)(x1)
|
||||||
x1 = ResidualBlock(dims*8)(x1)
|
x1 = ResidualBlock(dims * 8)(x1)
|
||||||
|
|
||||||
if multiscale_count >= 3:
|
if multiscale_count >= 3:
|
||||||
outputs += [ to_bgr(output_nc) ( x1 ) ]
|
outputs += [to_bgr(output_nc)(x1)]
|
||||||
|
|
||||||
x2 = upscale(dims*4)( x1 )
|
x2 = upscale(dims * 4)(x1)
|
||||||
|
|
||||||
if add_residual_blocks:
|
if add_residual_blocks:
|
||||||
x2 = ResidualBlock(dims*4)(x2)
|
x2 = ResidualBlock(dims * 4)(x2)
|
||||||
x2 = ResidualBlock(dims*4)(x2)
|
x2 = ResidualBlock(dims * 4)(x2)
|
||||||
|
|
||||||
if multiscale_count >= 2:
|
if multiscale_count >= 2:
|
||||||
outputs += [ to_bgr(output_nc) ( x2 ) ]
|
outputs += [to_bgr(output_nc)(x2)]
|
||||||
|
|
||||||
x3 = upscale(dims*2)( x2 )
|
x3 = upscale(dims * 2)(x2)
|
||||||
|
|
||||||
if add_residual_blocks:
|
if add_residual_blocks:
|
||||||
x3 = ResidualBlock( dims*2)(x3)
|
x3 = ResidualBlock(dims * 2)(x3)
|
||||||
x3 = ResidualBlock( dims*2)(x3)
|
x3 = ResidualBlock(dims * 2)(x3)
|
||||||
|
|
||||||
outputs += [ to_bgr(output_nc) ( x3 ) ]
|
outputs += [to_bgr(output_nc)(x3)]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFEncFlow(resolution, ae_dims, ch_dims, **kwargs):
|
def DFEncFlow(resolution, ae_dims, ch_dims, **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
upscale = partial(SAEModel.upscale, **kwargs)
|
||||||
downscale = partial(SAEModel.downscale, **kwargs)#, kernel_regularizer=keras.regularizers.l2(0.0),
|
downscale = partial(SAEModel.downscale, **kwargs) # , kernel_regularizer=keras.regularizers.l2(0.0),
|
||||||
lowest_dense_res = resolution // 16
|
lowest_dense_res = resolution // 16
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input
|
x = input
|
||||||
|
|
||||||
dims = K.int_shape(input)[-1]*ch_dims
|
dims = K.int_shape(input)[-1] * ch_dims
|
||||||
x = downscale(dims)(x)
|
x = downscale(dims)(x)
|
||||||
x = downscale(dims*2)(x)
|
x = downscale(dims * 2)(x)
|
||||||
x = downscale(dims*4)(x)
|
x = downscale(dims * 4)(x)
|
||||||
x = downscale(dims*8)(x)
|
x = downscale(dims * 8)(x)
|
||||||
|
|
||||||
x = Dense(ae_dims)(Flatten()(x))
|
x = Dense(ae_dims)(Flatten()(x))
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
||||||
x = upscale(ae_dims)(x)
|
x = upscale(ae_dims)(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, **kwargs):
|
def DFDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
upscale = partial(SAEModel.upscale, **kwargs)
|
||||||
to_bgr = partial(SAEModel.to_bgr, **kwargs)
|
to_bgr = partial(SAEModel.to_bgr, **kwargs)
|
||||||
dims = output_nc * ch_dims
|
dims = output_nc * ch_dims
|
||||||
|
@ -668,34 +777,35 @@ class SAEModel(ModelBase):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
x1 = upscale(dims*8)( x )
|
x1 = upscale(dims * 8)(x)
|
||||||
|
|
||||||
if add_residual_blocks:
|
if add_residual_blocks:
|
||||||
x1 = ResidualBlock( dims*8 )(x1)
|
x1 = ResidualBlock(dims * 8)(x1)
|
||||||
x1 = ResidualBlock( dims*8 )(x1)
|
x1 = ResidualBlock(dims * 8)(x1)
|
||||||
|
|
||||||
if multiscale_count >= 3:
|
if multiscale_count >= 3:
|
||||||
outputs += [ to_bgr(output_nc) ( x1 ) ]
|
outputs += [to_bgr(output_nc)(x1)]
|
||||||
|
|
||||||
x2 = upscale(dims*4)( x1 )
|
x2 = upscale(dims * 4)(x1)
|
||||||
|
|
||||||
if add_residual_blocks:
|
if add_residual_blocks:
|
||||||
x2 = ResidualBlock( dims*4)(x2)
|
x2 = ResidualBlock(dims * 4)(x2)
|
||||||
x2 = ResidualBlock( dims*4)(x2)
|
x2 = ResidualBlock(dims * 4)(x2)
|
||||||
|
|
||||||
if multiscale_count >= 2:
|
if multiscale_count >= 2:
|
||||||
outputs += [ to_bgr(output_nc) ( x2 ) ]
|
outputs += [to_bgr(output_nc)(x2)]
|
||||||
|
|
||||||
x3 = upscale(dims*2)( x2 )
|
x3 = upscale(dims * 2)(x2)
|
||||||
|
|
||||||
if add_residual_blocks:
|
if add_residual_blocks:
|
||||||
x3 = ResidualBlock( dims*2)(x3)
|
x3 = ResidualBlock(dims * 2)(x3)
|
||||||
x3 = ResidualBlock( dims*2)(x3)
|
x3 = ResidualBlock(dims * 2)(x3)
|
||||||
|
|
||||||
outputs += [ to_bgr(output_nc) ( x3 ) ]
|
outputs += [to_bgr(output_nc)(x3)]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
Model = SAEModel
|
Model = SAEModel
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue