Upgraded to TF version 1.13.2

Removed the wait at first launch for most graphics cards.

Increased speed of training by 10-20%, but you have to retrain all models from scratch.

SAEHD:

added option 'use float16'
	Experimental option. Reduces the model size by half.
	Increases the speed of training.
	Decreases the accuracy of the model.
	The model may collapse or not train.
	Model may not learn the mask in large resolutions.

true_face_training option is replaced by
"True face power". 0.0000 .. 1.0
Experimental option. Discriminates the result face to be more like the src face. Higher value - stronger discrimination.
Comparison - https://i.imgur.com/czScS9q.png
This commit is contained in:
Colombo 2020-01-25 21:58:19 +04:00
parent a3dfcb91b9
commit 76ca79216e
49 changed files with 1320 additions and 1297 deletions

View file

@ -15,25 +15,17 @@ class SAEHDModel(ModelBase):
#override
def on_initialize_options(self):
device_config = nn.getCurrentDeviceConfig()
lowest_vram = 2
if len(device_config.devices) != 0:
lowest_vram = device_config.devices.get_worst_device().total_mem_gb
if lowest_vram >= 4:
suggest_batch_size = 8
else:
suggest_batch_size = 4
yn_str = {True:'y',False:'n'}
ask_override = self.ask_override()
if self.is_first_run() or ask_override:
self.ask_enable_autobackup()
self.ask_write_preview_history()
self.ask_target_iter()
self.ask_random_flip()
self.ask_batch_size(suggest_batch_size)
yn_str = {True:'y',False:'n'}
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
@ -42,52 +34,63 @@ class SAEHDModel(ModelBase):
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64)
default_d_mask_dims = default_d_dims // 3
default_d_mask_dims += default_d_mask_dims % 2
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
default_use_float16 = self.options['use_float16'] = self.load_or_def_option('use_float16', False)
default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True)
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
default_true_face_training = self.options['true_face_training'] = self.load_or_def_option('true_face_training', False)
default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0)
default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0)
default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0)
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
ask_override = self.ask_override()
if self.is_first_run() or ask_override:
self.ask_enable_autobackup()
self.ask_write_preview_history()
self.ask_target_iter()
self.ask_random_flip()
self.ask_batch_size(suggest_batch_size)
if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-256", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
resolution = np.clip ( (resolution // 16) * 16, 64, 256)
self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f'], help_message="Half / mid face / full face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face.").lower()
if (self.is_first_run() or ask_override) and len(device_config.devices) == 1:
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
if self.is_first_run():
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() #-s version is slower, but has decreased change to collapse.
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['e_dims'] = e_dims + e_dims % 2
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['d_dims'] = d_dims + d_dims % 2
d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
if self.is_first_run() or ask_override:
self.options['learn_mask'] = io.input_bool ("Learn mask", default_learn_mask, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case merger forced to use 'not predicted mask' that is not smooth as predicted.")
if self.is_first_run() or ask_override:
if len(device_config.devices) == 1:
self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.")
self.options['use_float16'] = io.input_bool ("Use float16", default_use_float16, help_message="Experimental option. Reduces the model size by half. Increases the speed of training. Decreases the accuracy of the model. The model may collapse. Model does not study the mask in large resolutions.")
self.options['lr_dropout'] = io.input_bool ("Use learning rate dropout", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.")
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.")
if 'df' in self.options['archi']:
self.options['true_face_training'] = io.input_bool ("Enable 'true face' training", default_true_face_training, help_message="The result face will be more like src and will get extra sharpness. Enable it for last 10-20k iterations before conversion.")
self.options['true_face_power'] = np.clip ( io.input_number (" 'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 )
else:
self.options['true_face_training'] = False
self.options['true_face_power'] = 0.0
self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn to transfer background around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
@ -96,20 +99,24 @@ class SAEHDModel(ModelBase):
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")
if self.options['pretrain'] and self.get_pretraining_data_path() is None:
raise Exception("pretraining_data_path is not defined")
raise Exception("pretraining_data_path is not defined")
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
if self.pretrain_just_disabled:
self.set_iter(1)
#override
def on_initialize(self):
nn.initialize()
device_config = nn.getCurrentDeviceConfig()
self.model_data_format = "NCHW" if len(device_config.devices) != 0 else "NHWC"
nn.initialize(floatx="float16" if self.options['use_float16'] else "float32",
data_format=self.model_data_format)
tf = nn.tf
conv_kernel_initializer = nn.initializers.ca
conv_kernel_initializer = nn.initializers.ca()
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
self.in_ch = in_ch
@ -120,19 +127,19 @@ class SAEHDModel(ModelBase):
self.use_activator = use_activator
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer )
padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
def forward(self, x):
x = self.conv1(x)
if self.subpixel:
x = tf.nn.space_to_depth(x, 2)
x = nn.tf_space_to_depth(x, 2)
if self.use_activator:
x = tf.nn.leaky_relu(x, 0.1)
return x
@ -143,19 +150,19 @@ class SAEHDModel(ModelBase):
class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
self.downs = []
last_ch = in_ch
for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
last_ch = self.downs[-1].get_out_ch()
def forward(self, inp):
x = inp
for down in self.downs:
x = down(x)
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
@ -163,7 +170,7 @@ class SAEHDModel(ModelBase):
def forward(self, x):
x = self.conv1(x)
x = tf.nn.leaky_relu(x, 0.1)
x = tf.nn.depth_to_space(x, 2)
x = nn.tf_depth_to_space(x, 2)
return x
class ResidualBlock(nn.ModelBase):
@ -192,9 +199,9 @@ class SAEHDModel(ModelBase):
x = tf.nn.leaky_relu(x, 0.2)
return x, upx
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, is_hd):
self.is_hd=is_hd
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, is_hd):
self.is_hd=is_hd
if self.is_hd:
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
@ -202,7 +209,7 @@ class SAEHDModel(ModelBase):
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
else:
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
def forward(self, inp):
if self.is_hd:
x = tf.concat([ nn.tf_flatten(self.down1(inp)),
@ -211,85 +218,84 @@ class SAEHDModel(ModelBase):
nn.tf_flatten(self.down4(inp)) ], -1 )
else:
x = nn.tf_flatten(self.down1(inp))
return x
class Inter(nn.ModelBase):
def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, **kwargs):
self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch
super().__init__(**kwargs)
def on_build(self):
in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal )
self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
def forward(self, inp):
x = self.dense1(inp)
x = self.dense2(x)
x = tf.reshape (x, (-1, lowest_dense_res, lowest_dense_res, self.ae_out_ch))
x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x)
return x
def get_out_ch(self):
return self.ae_out_ch
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
self.is_hd = is_hd
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
if is_hd:
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
else:
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
def get_weights_ex(self, include_mask):
# Call internal get_weights in order to initialize inner logic
self.get_weights()
self.get_weights()
weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \
+ self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights()
if include_mask:
weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \
+ self.out_convm.get_weights()
+ self.out_convm.get_weights()
return weights
def forward(self, inp):
z = inp
if self.is_hd:
x, upx = self.res0(z)
x, upx = self.res0(z)
x = self.upscale0(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res1(x)
x = self.upscale1(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res2(x)
x = self.upscale2(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res3(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res3(x)
else:
x = self.upscale0(z)
x = self.res0(x)
@ -301,13 +307,13 @@ class SAEHDModel(ModelBase):
m = self.upscalem0(z)
m = self.upscalem1(m)
m = self.upscalem2(m)
return tf.nn.sigmoid(self.out_conv(x)), \
tf.nn.sigmoid(self.out_convm(m))
class CodeDiscriminator(nn.ModelBase):
def on_build(self, in_ch, code_res, ch=256):
n_downscales = 2 + code_res // 8
n_downscales = 1 + code_res // 8
self.convs = []
prev_ch = in_ch
@ -329,12 +335,12 @@ class SAEHDModel(ModelBase):
resolution = self.options['resolution']
learn_mask = self.options['learn_mask']
archi = self.options['archi']
ae_dims = self.options['ae_dims']
ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims']
d_dims = self.options['d_dims']
d_mask_dims = self.options['d_mask_dims']
d_mask_dims = self.options['d_mask_dims']
self.pretrain = self.options['pretrain']
masked_training = True
models_opt_on_gpu = False if len(devices) != 1 else self.options['models_opt_on_gpu']
@ -343,8 +349,8 @@ class SAEHDModel(ModelBase):
input_nc = 3
output_nc = 3
bgr_shape = (resolution, resolution, output_nc)
mask_shape = (resolution, resolution, 1)
bgr_shape = nn.get4Dshape(resolution,resolution,input_nc)
mask_shape = nn.get4Dshape(resolution,resolution,1)
lowest_dense_res = resolution // 16
self.model_filename_list = []
@ -352,24 +358,24 @@ class SAEHDModel(ModelBase):
with tf.device ('/CPU:0'):
#Place holders on CPU
self.warped_src = tf.placeholder (tf.float32, (None,)+bgr_shape)
self.warped_dst = tf.placeholder (tf.float32, (None,)+bgr_shape)
self.warped_src = tf.placeholder (nn.tf_floatx, bgr_shape)
self.warped_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_src = tf.placeholder (tf.float32, (None,)+bgr_shape)
self.target_dst = tf.placeholder (tf.float32, (None,)+bgr_shape)
self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_srcm = tf.placeholder (tf.float32, (None,)+mask_shape)
self.target_dstm = tf.placeholder (tf.float32, (None,)+mask_shape)
self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape)
self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape)
# Initializing model classes
with tf.device (models_opt_device):
if 'df' in archi:
self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, is_hd='hd' in archi, name='encoder')
encoder_out_ch = self.encoder.compute_output_shape ( (tf.float32, (None,resolution,resolution,input_nc)))[-1]
self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, is_hd='hd' in archi, name='encoder')
encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))
self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
inter_out_ch = self.inter.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1]
inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_src')
self.decoder_dst = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder_dst')
@ -379,23 +385,22 @@ class SAEHDModel(ModelBase):
[self.decoder_dst, 'decoder_dst.npy'] ]
if self.is_training:
if self.options['true_face_training']:
if self.options['true_face_power'] != 0:
self.dis = CodeDiscriminator(ae_dims, code_res=lowest_dense_res*2, name='dis' )
self.model_filename_list += [ [self.dis, 'dis.npy'] ]
elif 'liae' in archi:
self.encoder = Encoder(in_ch=input_nc, e_ch=e_dims, is_hd='hd' in archi, name='encoder')
encoder_out_ch = self.encoder.compute_output_shape ( (tf.float32, (None,resolution,resolution,input_nc)))[-1]
encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape))
self.inter_AB = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
self.inter_B = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')
inter_AB_out_ch = self.inter_AB.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1]
inter_B_out_ch = self.inter_B.compute_output_shape ( (tf.float32, (None,encoder_out_ch)))[-1]
inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))
inters_out_ch = inter_AB_out_ch+inter_B_out_ch
self.decoder = Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd='hd' in archi, name='decoder')
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
[self.inter_AB, 'inter_AB.npy'],
[self.inter_B , 'inter_B.npy'],
@ -417,8 +422,8 @@ class SAEHDModel(ModelBase):
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights_ex(learn_mask)
self.src_dst_opt.initialize_variables (self.src_dst_all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu)
if self.options['true_face_training']:
if self.options['true_face_power'] != 0:
self.D_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_opt')
self.D_opt.initialize_variables ( self.dis.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
self.model_filename_list += [ (self.D_opt, 'D_opt.npy') ]
@ -429,7 +434,7 @@ class SAEHDModel(ModelBase):
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu)
# Compute losses per GPU
gpu_pred_src_src_list = []
gpu_pred_dst_dst_list = []
@ -462,29 +467,29 @@ class SAEHDModel(ModelBase):
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
elif 'liae' in archi:
gpu_src_code = self.encoder (gpu_warped_src)
gpu_src_inter_AB_code = self.inter_AB (gpu_src_code)
gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code],-1)
gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis )
gpu_dst_code = self.encoder (gpu_warped_dst)
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code],-1)
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code],-1)
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis )
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
gpu_pred_src_src_list.append(gpu_pred_src_src)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
gpu_pred_src_dst_list.append(gpu_pred_src_dst)
gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
@ -503,28 +508,28 @@ class SAEHDModel(ModelBase):
gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_srcmasked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_srcmasked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
if learn_mask:
gpu_src_loss += tf.reduce_mean ( tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
face_style_power = self.options['face_style_power'] / 100.0
if face_style_power != 0 and not self.pretrain:
gpu_src_loss += nn.tf_style_loss(gpu_psd_target_dst_masked, gpu_target_dst_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power)
bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0 and not self.pretrain:
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.tf_dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.tf_dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] )
gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
if learn_mask:
gpu_dst_loss += tf.reduce_mean ( tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
gpu_src_losses += [gpu_src_loss]
gpu_dst_losses += [gpu_dst_loss]
gpu_src_dst_loss = gpu_src_loss + gpu_dst_loss
if self.options['true_face_training']:
if self.options['true_face_power'] != 0:
def DLoss(labels,logits):
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])
@ -533,8 +538,8 @@ class SAEHDModel(ModelBase):
gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d)
gpu_dst_code_d = self.dis( gpu_dst_code )
gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d)
gpu_src_dst_loss += 0.01*DLoss(gpu_src_code_d_ones, gpu_src_code_d)
gpu_src_dst_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d)
gpu_D_loss = (DLoss(gpu_src_code_d_ones , gpu_dst_code_d) + \
DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5
@ -546,35 +551,20 @@ class SAEHDModel(ModelBase):
# Average losses and gradients, and create optimizer update ops
with tf.device (models_opt_device):
if gpu_count == 1:
pred_src_src = gpu_pred_src_src_list[0]
pred_dst_dst = gpu_pred_dst_dst_list[0]
pred_src_dst = gpu_pred_src_dst_list[0]
pred_src_srcm = gpu_pred_src_srcm_list[0]
pred_dst_dstm = gpu_pred_dst_dstm_list[0]
pred_src_dstm = gpu_pred_src_dstm_list[0]
src_loss = gpu_src_losses[0]
dst_loss = gpu_dst_losses[0]
src_dst_loss_gv = gpu_src_dst_loss_gvs[0]
else:
pred_src_src = tf.concat(gpu_pred_src_src_list, 0)
pred_dst_dst = tf.concat(gpu_pred_dst_dst_list, 0)
pred_src_dst = tf.concat(gpu_pred_src_dst_list, 0)
pred_src_srcm = tf.concat(gpu_pred_src_srcm_list, 0)
pred_dst_dstm = tf.concat(gpu_pred_dst_dstm_list, 0)
pred_src_dstm = tf.concat(gpu_pred_src_dstm_list, 0)
src_loss = nn.tf_average_tensor_list(gpu_src_losses)
dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
pred_src_src = nn.tf_concat(gpu_pred_src_src_list, 0)
pred_dst_dst = nn.tf_concat(gpu_pred_dst_dst_list, 0)
pred_src_dst = nn.tf_concat(gpu_pred_src_dst_list, 0)
pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0)
pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0)
pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0)
if self.options['true_face_training']:
D_loss_gv = nn.tf_average_gv_list(gpu_D_loss_gvs)
src_loss = nn.tf_average_tensor_list(gpu_src_losses)
dst_loss = nn.tf_average_tensor_list(gpu_dst_losses)
src_dst_loss_gv = nn.tf_average_gv_list (gpu_src_dst_loss_gvs)
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv )
if self.options['true_face_training']:
if self.options['true_face_power'] != 0:
D_loss_gv = nn.tf_average_gv_list(gpu_D_loss_gvs)
D_loss_gv_op = self.D_opt.get_update_op (D_loss_gv )
@ -594,7 +584,7 @@ class SAEHDModel(ModelBase):
return s, d
self.src_dst_train = src_dst_train
if self.options['true_face_training']:
if self.options['true_face_power'] != 0:
def D_train(warped_src, warped_dst):
nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst})
self.D_train = D_train
@ -611,23 +601,23 @@ class SAEHDModel(ModelBase):
self.warped_dst:warped_dst})
self.AE_view = AE_view
else:
# Initializing merge function
# Initializing merge function
with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
if 'df' in archi:
if 'df' in archi:
gpu_dst_code = self.inter(self.encoder(self.warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
elif 'liae' in archi:
gpu_dst_code = self.encoder (self.warped_dst)
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code],-1)
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code],-1)
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
if learn_mask:
def AE_merge( warped_dst):
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst})
@ -640,7 +630,7 @@ class SAEHDModel(ModelBase):
# Loading/initializing all models/optimizers weights
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
do_init = self.is_first_run()
if self.pretrain_just_disabled:
if 'df' in archi:
if model == self.inter:
@ -648,15 +638,15 @@ class SAEHDModel(ModelBase):
elif 'liae' in archi:
if model == self.inter_AB:
do_init = True
if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
if do_init:
model.init_weights()
# initializing sample generators
if self.is_training:
t = SampleProcessor.Types
if self.options['face_type'] == 'h':
@ -670,29 +660,29 @@ class SAEHDModel(ModelBase):
training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None
t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
cpu_count = multiprocessing.cpu_count()
src_generators_count = cpu_count // 2
if self.options['ct_mode'] != 'none':
src_generators_count = int(src_generators_count * 1.5)
src_generators_count = int(src_generators_count * 1.5)
dst_generators_count = cpu_count - src_generators_count
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'resolution':resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ],
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution } ],
generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'resolution':resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution} ],
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'data_format':nn.data_format, 'resolution': resolution} ],
generators_count=dst_generators_count )
])
@ -710,10 +700,10 @@ class SAEHDModel(ModelBase):
def onTrainOneIter(self):
( (warped_src, target_src, target_srcm), \
(warped_dst, target_dst, target_dstm) ) = self.generate_next_samples()
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm)
if self.options['true_face_training'] and not self.pretrain:
if self.options['true_face_power'] != 0 and not self.pretrain:
self.D_train (warped_src, warped_dst)
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
@ -728,10 +718,12 @@ class SAEHDModel(ModelBase):
for sample_list in samples ]
if self.options['learn_mask']:
S, D, SS, DD, DDM, SD, SDM = [ np.clip(x, 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
else:
S, D, SS, DD, SD, = [ np.clip(x, 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
S, D, SS, DD, SD, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format) , 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
result = []
st = []
@ -753,12 +745,16 @@ class SAEHDModel(ModelBase):
return result
def predictor_func (self, face=None):
face = face[None,...]
face = nn.to_data_format(face, self.model_data_format, "NHWC")
if self.options['learn_mask']:
bgr, mask_dst_dstm, mask_src_dstm = self.AE_merge (face[np.newaxis,...])
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ]
mask = mask_dst_dstm[0] * mask_src_dstm[0]
return bgr[0], mask[...,0]
else:
bgr, = self.AE_merge (face[np.newaxis,...])
bgr, = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ]
return bgr[0]
#override