diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 14fba1a..73569e6 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -305,29 +305,6 @@ class SAEModel(ModelBase): dssim_pixel_alpha = np.expand_dims(dssim_pixel_alpha,-1) src_loss, dst_loss, src_sample_losses, dst_sample_losses = self.src_dst_train ([dssim_pixel_alpha, warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask]) - - # 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly. - # - ##gathering array of sample_losses - #self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ] - #self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ] - # - #if len(self.src_sample_losses) >= 48: #array is big enough - # #fetching idxs which losses are bigger than average - # x = np.array (self.src_sample_losses) - # self.src_sample_losses = [] - # b = x[:,1] - # idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint) - # generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs - # - # - #if len(self.dst_sample_losses) >= 48: #array is big enough - # #fetching idxs which losses are bigger than average - # x = np.array (self.dst_sample_losses) - # self.dst_sample_losses = [] - # b = x[:,1] - # idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint) - # generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs if self.options['learn_mask']: src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask]) @@ -574,4 +551,30 @@ class SAEModel(ModelBase): return func -Model = SAEModel \ No newline at end of file +Model = SAEModel + + + +# 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly. +# +##gathering array of sample_losses +#self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ] +#self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ] +# +#if len(self.src_sample_losses) >= 128: #array is big enough +# #fetching idxs which losses are bigger than average +# x = np.array (self.src_sample_losses) +# self.src_sample_losses = [] +# b = x[:,1] +# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint) +# generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs +# print ("src repeated %d" % (len(idxs)) ) +# +#if len(self.dst_sample_losses) >= 128: #array is big enough +# #fetching idxs which losses are bigger than average +# x = np.array (self.dst_sample_losses) +# self.dst_sample_losses = [] +# b = x[:,1] +# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint) +# generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs +# print ("dst repeated %d" % (len(idxs)) ) \ No newline at end of file diff --git a/samples/SampleGeneratorFace.py b/samples/SampleGeneratorFace.py index 13358fa..43181bf 100644 --- a/samples/SampleGeneratorFace.py +++ b/samples/SampleGeneratorFace.py @@ -55,6 +55,7 @@ class SampleGeneratorFace(SampleGeneratorBase): return next(generator) #forces to repeat these sample idxs as fast as possible + #currently unused def repeat_sample_idxs(self, idxs): # [ idx, ... ] #send idxs list to all sub generators. for gen_sq in self.generators_sq: