SAEHD,Quick96:

improved model generalization, overall accuracy and sharpness
by using new 'Learning rate dropout' technique from paper https://arxiv.org/abs/1912.00144
An example of a loss histogram where this function is enabled after the red arrow:
https://i.imgur.com/3olskOd.jpg
This commit is contained in:
Colombo 2019-12-15 15:53:06 +04:00
parent c866448645
commit 71ebf06c89
2 changed files with 5 additions and 5 deletions

View file

@ -143,8 +143,8 @@ class Quick96Model(ModelBase):
self.CA_conv_weights_list += [layer.weights[0]] #- is Conv2D kernel_weights
if self.is_training_mode:
self.src_dst_opt = RMSprop(lr=2e-4)
self.src_dst_mask_opt = RMSprop(lr=2e-4)
self.src_dst_opt = RMSprop(lr=2e-4, lr_dropout=0.3)
self.src_dst_mask_opt = RMSprop(lr=2e-4, lr_dropout=0.3)
target_src_masked = self.model.target_src*self.model.target_srcm
target_dst_masked = self.model.target_dst*self.model.target_dstm