SAE: df-s liae-s archis

This commit is contained in:
iperov 2019-04-23 17:03:30 +04:00
parent 8969e80275
commit 47f9bad42b

View file

@ -51,11 +51,11 @@ class SAEModel(ModelBase):
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1) self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)
if is_first_run: if is_first_run:
self.options['archi'] = io.input_str ("AE architecture (df, liae ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes.").lower() self.options['archi'] = io.input_str ("AE architecture (df, liae, df-s, liae-s ?:help skip:%s) : " % (default_archi) , default_archi, ['df','df-s','liae','liae-s'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. -s version is slower, but has decreased change to collapse.").lower()
else: else:
self.options['archi'] = self.options.get('archi', default_archi) self.options['archi'] = self.options.get('archi', default_archi)
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 default_ae_dims = 256 if 'liae' in self.options['archi'] else 512
default_e_ch_dims = 42 default_e_ch_dims = 42
default_d_ch_dims = default_e_ch_dims // 2 default_d_ch_dims = default_e_ch_dims // 2
@ -125,13 +125,18 @@ class SAEModel(ModelBase):
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
padding = 'reflect' if self.options['remove_gray_border'] else 'zero' padding = 'zero'
norm = ''
if '-s' in self.options['archi']:
norm = 'bn'
common_flow_kwargs = { 'padding': padding, common_flow_kwargs = { 'padding': padding,
'norm': 'bn', 'norm': norm,
'act':'' } 'act':'' }
weights_to_load = [] weights_to_load = []
if self.options['archi'] == 'liae': if 'liae' in self.options['archi']:
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
@ -175,7 +180,7 @@ class SAEModel(ModelBase):
pred_dst_dstm = self.decoderm(warped_dst_inter_code) pred_dst_dstm = self.decoderm(warped_dst_inter_code)
pred_src_dstm = self.decoderm(warped_src_dst_inter_code) pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
elif self.options['archi'] == 'df': elif 'df' in self.options['archi']:
self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
@ -248,7 +253,7 @@ class SAEModel(ModelBase):
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
if self.options['archi'] == 'liae': if 'liae' in self.options['archi']:
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
if self.options['learn_mask']: if self.options['learn_mask']:
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
@ -344,7 +349,7 @@ class SAEModel(ModelBase):
[self.src_dst_mask_opt, 'src_dst_mask_opt'] [self.src_dst_mask_opt, 'src_dst_mask_opt']
] ]
ar = [] ar = []
if self.options['archi'] == 'liae': if 'liae' in self.options['archi']:
ar += [[self.encoder, 'encoder.h5'], ar += [[self.encoder, 'encoder.h5'],
[self.inter_B, 'inter_B.h5'], [self.inter_B, 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5'], [self.inter_AB, 'inter_AB.h5'],
@ -352,7 +357,7 @@ class SAEModel(ModelBase):
] ]
if self.options['learn_mask']: if self.options['learn_mask']:
ar += [ [self.decoderm, 'decoderm.h5'] ] ar += [ [self.decoderm, 'decoderm.h5'] ]
elif self.options['archi'] == 'df': elif 'df' in self.options['archi']:
ar += [[self.encoder, 'encoder.h5'], ar += [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'], [self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5'] [self.decoder_dst, 'decoder_dst.h5']