This commit is contained in:
iperov 2019-02-12 10:55:41 +04:00
parent 429e7e6aee
commit 535041f7bb
2 changed files with 28 additions and 24 deletions

View file

@ -305,29 +305,6 @@ class SAEModel(ModelBase):
dssim_pixel_alpha = np.expand_dims(dssim_pixel_alpha,-1)
src_loss, dst_loss, src_sample_losses, dst_sample_losses = self.src_dst_train ([dssim_pixel_alpha, warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
# 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly.
#
##gathering array of sample_losses
#self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ]
#self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ]
#
#if len(self.src_sample_losses) >= 48: #array is big enough
# #fetching idxs which losses are bigger than average
# x = np.array (self.src_sample_losses)
# self.src_sample_losses = []
# b = x[:,1]
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
# generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
#
#
#if len(self.dst_sample_losses) >= 48: #array is big enough
# #fetching idxs which losses are bigger than average
# x = np.array (self.dst_sample_losses)
# self.dst_sample_losses = []
# b = x[:,1]
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
# generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
if self.options['learn_mask']:
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask])
@ -574,4 +551,30 @@ class SAEModel(ModelBase):
return func
Model = SAEModel
Model = SAEModel
# 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly.
#
##gathering array of sample_losses
#self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ]
#self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ]
#
#if len(self.src_sample_losses) >= 128: #array is big enough
# #fetching idxs which losses are bigger than average
# x = np.array (self.src_sample_losses)
# self.src_sample_losses = []
# b = x[:,1]
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
# generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
# print ("src repeated %d" % (len(idxs)) )
#
#if len(self.dst_sample_losses) >= 128: #array is big enough
# #fetching idxs which losses are bigger than average
# x = np.array (self.dst_sample_losses)
# self.dst_sample_losses = []
# b = x[:,1]
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
# generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
# print ("dst repeated %d" % (len(idxs)) )