diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index efa5a0f..aa854de 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -302,32 +302,34 @@ class SAEModel(ModelBase): warped_src, target_src, target_src_mask, src_sample_idxs = generators_samples[0] warped_dst, target_dst, target_dst_mask, dst_sample_idxs = generators_samples[1] - dssim_pixel_alpha = np.clip ( self.epoch / 15000.0, 0.0, 1.0 ) #smooth transition between DSSIM and MSE in 15k epochs + dssim_pixel_alpha = np.clip ( (self.epoch - 5000) / 15000.0, 0.0, 1.0 ) #smooth transition between DSSIM and MSE in 5-20k epochs dssim_pixel_alpha = np.repeat( dssim_pixel_alpha, (self.batch_size,) ) dssim_pixel_alpha = np.expand_dims(dssim_pixel_alpha,-1) src_loss, dst_loss, src_sample_losses, dst_sample_losses = self.src_dst_train ([dssim_pixel_alpha, warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask]) - #gathering array of sample_losses - self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ] - self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ] - - if len(self.src_sample_losses) >= 48: #array is big enough - #fetching idxs which losses are bigger than average - x = np.array (self.src_sample_losses) - self.src_sample_losses = [] - b = x[:,1] - idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint) - generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs - - - if len(self.dst_sample_losses) >= 48: #array is big enough - #fetching idxs which losses are bigger than average - x = np.array (self.dst_sample_losses) - self.dst_sample_losses = [] - b = x[:,1] - idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint) - generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs + # 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly. + # + ##gathering array of sample_losses + #self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ] + #self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ] + # + #if len(self.src_sample_losses) >= 48: #array is big enough + # #fetching idxs which losses are bigger than average + # x = np.array (self.src_sample_losses) + # self.src_sample_losses = [] + # b = x[:,1] + # idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint) + # generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs + # + # + #if len(self.dst_sample_losses) >= 48: #array is big enough + # #fetching idxs which losses are bigger than average + # x = np.array (self.dst_sample_losses) + # self.dst_sample_losses = [] + # b = x[:,1] + # idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint) + # generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs if self.options['learn_mask']: src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask]) diff --git a/samples/SampleGeneratorFace.py b/samples/SampleGeneratorFace.py index 92788d8..13358fa 100644 --- a/samples/SampleGeneratorFace.py +++ b/samples/SampleGeneratorFace.py @@ -54,6 +54,7 @@ class SampleGeneratorFace(SampleGeneratorBase): generator = self.generators[self.generator_counter % len(self.generators) ] return next(generator) + #forces to repeat these sample idxs as fast as possible def repeat_sample_idxs(self, idxs): # [ idx, ... ] #send idxs list to all sub generators. for gen_sq in self.generators_sq: