SAE: removed simple_optimizer . Added optimizer mode for tensorflow only (NVIDIA cards), allows to train x2-x3 bigger networks with normal Adam optimizer, consuming VRAM and CPU power.

This commit is contained in:
iperov 2019-03-13 11:54:17 +04:00
parent 7d6ca32250
commit 58763756f5
3 changed files with 100 additions and 37 deletions

View file

@ -42,11 +42,12 @@ class SAEModel(ModelBase):
self.options['face_type'] = self.options.get('face_type', default_face_type)
self.options['learn_mask'] = self.options.get('learn_mask', True)
if is_first_run or ask_override:
def_simple_optimizer = self.options.get('simple_optimizer', False)
self.options['simple_optimizer'] = io.input_bool ("Use simple optimizer? (y/n, ?:help skip:%s) : " % ( yn_str[def_simple_optimizer] ), def_simple_optimizer, help_message="Simple optimizer allows you to train bigger network or more batch size, sacrificing training accuracy.")
if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend:
def_optimizer_mode = self.options.get('optimizer_mode', 1)
self.options['optimizer_mode'] = io.input_int ("Optimizer mode? ( 1,2,3 ?:help skip:%d) : " % (def_optimizer_mode), def_optimizer_mode, help_message="1 - no changes. 2 - allows you to train x2 bigger network consuming RAM. 3 - allows you to train x3 bigger network consuming huge amount of RAM and slower, depends on CPU power.")
else:
self.options['simple_optimizer'] = self.options.get('simple_optimizer', False)
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)
if is_first_run:
self.options['archi'] = io.input_str ("AE architecture (df, liae, vg ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae','vg'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'vg' - currently testing.").lower()
@ -269,14 +270,10 @@ class SAEModel(ModelBase):
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
if self.is_training_mode:
if self.options['simple_optimizer']:
self.src_dst_opt = DFLOptimizer(lr=5e-5)
self.src_dst_mask_opt = DFLOptimizer(lr=5e-5)
else:
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
if self.is_training_mode:
self.src_dst_opt = AdamCPU(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
self.src_dst_mask_opt = AdamCPU(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
if self.options['archi'] == 'liae':
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
if self.options['learn_mask']: