H64, H128, DF, LIAEF128: added pixel loss option.

This commit is contained in:
iperov 2019-02-11 12:05:54 +04:00
parent af3dd59f67
commit f8e63970d2
5 changed files with 52 additions and 34 deletions

View file

@ -21,7 +21,12 @@ class Model(ModelBase):
if 'created_vram_gb' in self.options.keys():
self.options.pop ('created_vram_gb')
self.options['lighter_ae'] = self.options.get('lighter_ae', default_lighter_ae)
if is_first_run or ask_override:
self.options['pixel_loss'] = self.options['pixel_loss'] = input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", False, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 30-40k epochs to enhance fine details and remove face jitter.")
else:
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
#override
def onInitialize(self, **in_options):
exec(nnlib.import_all(), locals(), globals())
@ -44,9 +49,8 @@ class Model(ModelBase):
rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) )
self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] )
self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[ DSSIMMSEMaskLoss(input_src_mask, is_mse=self.options['pixel_loss']), 'mae', DSSIMMSEMaskLoss(input_dst_mask, is_mse=self.options['pixel_loss']), 'mae' ] )
self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask])
self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask])