remove lr_dropout for plaidml backend

This commit is contained in:
Colombo 2020-01-08 11:11:33 +04:00
parent d3e6b435aa
commit b8182ae42b
2 changed files with 9 additions and 5 deletions

View file

@ -141,8 +141,9 @@ class Quick96Model(ModelBase):
self.CA_conv_weights_list += [layer.weights[0]] #- is Conv2D kernel_weights self.CA_conv_weights_list += [layer.weights[0]] #- is Conv2D kernel_weights
if self.is_training_mode: if self.is_training_mode:
self.src_dst_opt = RMSprop(lr=2e-4, lr_dropout=0.3) lr_dropout = 0.3 if nnlib.device.backend != 'plaidML' else 0.0
self.src_dst_mask_opt = RMSprop(lr=2e-4, lr_dropout=0.3) self.src_dst_opt = RMSprop(lr=2e-4, lr_dropout=lr_dropout)
self.src_dst_mask_opt = RMSprop(lr=2e-4, lr_dropout=lr_dropout)
target_src_masked = self.model.target_src*self.model.target_srcm target_src_masked = self.model.target_src*self.model.target_srcm
target_dst_masked = self.model.target_dst*self.model.target_dstm target_dst_masked = self.model.target_dst*self.model.target_dstm

View file

@ -65,9 +65,12 @@ class SAEHDModel(ModelBase):
default_bg_style_power = self.options.get('bg_style_power', 0.0) default_bg_style_power = self.options.get('bg_style_power', 0.0)
if is_first_run or ask_override: if is_first_run or ask_override:
default_lr_dropout = self.options.get('lr_dropout', False) if nnlib.device.backend != 'plaidML':
self.options['lr_dropout'] = io.input_bool ( f"Use learning rate dropout? (y/n, ?:help skip:{yn_str[default_lr_dropout]} ) : ", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.") default_lr_dropout = self.options.get('lr_dropout', False)
self.options['lr_dropout'] = io.input_bool ( f"Use learning rate dropout? (y/n, ?:help skip:{yn_str[default_lr_dropout]} ) : ", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.")
else:
self.options['lr_dropout'] = False
default_random_warp = self.options.get('random_warp', True) default_random_warp = self.options.get('random_warp', True)
self.options['random_warp'] = io.input_bool (f"Enable random warp of samples? ( y/n, ?:help skip:{yn_str[default_random_warp]}) : ", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.") self.options['random_warp'] = io.input_bool (f"Enable random warp of samples? ( y/n, ?:help skip:{yn_str[default_random_warp]}) : ", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.")