From b8182ae42bb0d9db1e20c2e976063f460f7db465 Mon Sep 17 00:00:00 2001 From: Colombo Date: Wed, 8 Jan 2020 11:11:33 +0400 Subject: [PATCH] remove lr_dropout for plaidml backend --- models/Model_Quick96/Model.py | 5 +++-- models/Model_SAEHD/Model.py | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index b9a94c3..6b66041 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -141,8 +141,9 @@ class Quick96Model(ModelBase): self.CA_conv_weights_list += [layer.weights[0]] #- is Conv2D kernel_weights if self.is_training_mode: - self.src_dst_opt = RMSprop(lr=2e-4, lr_dropout=0.3) - self.src_dst_mask_opt = RMSprop(lr=2e-4, lr_dropout=0.3) + lr_dropout = 0.3 if nnlib.device.backend != 'plaidML' else 0.0 + self.src_dst_opt = RMSprop(lr=2e-4, lr_dropout=lr_dropout) + self.src_dst_mask_opt = RMSprop(lr=2e-4, lr_dropout=lr_dropout) target_src_masked = self.model.target_src*self.model.target_srcm target_dst_masked = self.model.target_dst*self.model.target_dstm diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index cc9217c..180b2a1 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -65,9 +65,12 @@ class SAEHDModel(ModelBase): default_bg_style_power = self.options.get('bg_style_power', 0.0) if is_first_run or ask_override: - default_lr_dropout = self.options.get('lr_dropout', False) - self.options['lr_dropout'] = io.input_bool ( f"Use learning rate dropout? (y/n, ?:help skip:{yn_str[default_lr_dropout]} ) : ", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.") - + if nnlib.device.backend != 'plaidML': + default_lr_dropout = self.options.get('lr_dropout', False) + self.options['lr_dropout'] = io.input_bool ( f"Use learning rate dropout? (y/n, ?:help skip:{yn_str[default_lr_dropout]} ) : ", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.") + else: + self.options['lr_dropout'] = False + default_random_warp = self.options.get('random_warp', True) self.options['random_warp'] = io.input_bool (f"Enable random warp of samples? ( y/n, ?:help skip:{yn_str[default_random_warp]}) : ", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.")