mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
upd
This commit is contained in:
parent
429e7e6aee
commit
535041f7bb
2 changed files with 28 additions and 24 deletions
|
@ -306,29 +306,6 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
src_loss, dst_loss, src_sample_losses, dst_sample_losses = self.src_dst_train ([dssim_pixel_alpha, warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
|
src_loss, dst_loss, src_sample_losses, dst_sample_losses = self.src_dst_train ([dssim_pixel_alpha, warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
|
||||||
|
|
||||||
# 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly.
|
|
||||||
#
|
|
||||||
##gathering array of sample_losses
|
|
||||||
#self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ]
|
|
||||||
#self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ]
|
|
||||||
#
|
|
||||||
#if len(self.src_sample_losses) >= 48: #array is big enough
|
|
||||||
# #fetching idxs which losses are bigger than average
|
|
||||||
# x = np.array (self.src_sample_losses)
|
|
||||||
# self.src_sample_losses = []
|
|
||||||
# b = x[:,1]
|
|
||||||
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
|
|
||||||
# generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#if len(self.dst_sample_losses) >= 48: #array is big enough
|
|
||||||
# #fetching idxs which losses are bigger than average
|
|
||||||
# x = np.array (self.dst_sample_losses)
|
|
||||||
# self.dst_sample_losses = []
|
|
||||||
# b = x[:,1]
|
|
||||||
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
|
|
||||||
# generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask])
|
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask])
|
||||||
|
|
||||||
|
@ -575,3 +552,29 @@ class SAEModel(ModelBase):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
Model = SAEModel
|
Model = SAEModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly.
|
||||||
|
#
|
||||||
|
##gathering array of sample_losses
|
||||||
|
#self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ]
|
||||||
|
#self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ]
|
||||||
|
#
|
||||||
|
#if len(self.src_sample_losses) >= 128: #array is big enough
|
||||||
|
# #fetching idxs which losses are bigger than average
|
||||||
|
# x = np.array (self.src_sample_losses)
|
||||||
|
# self.src_sample_losses = []
|
||||||
|
# b = x[:,1]
|
||||||
|
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
|
||||||
|
# generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
|
||||||
|
# print ("src repeated %d" % (len(idxs)) )
|
||||||
|
#
|
||||||
|
#if len(self.dst_sample_losses) >= 128: #array is big enough
|
||||||
|
# #fetching idxs which losses are bigger than average
|
||||||
|
# x = np.array (self.dst_sample_losses)
|
||||||
|
# self.dst_sample_losses = []
|
||||||
|
# b = x[:,1]
|
||||||
|
# idxs = (x[:,0][ np.argwhere ( b [ b > (np.mean(b)+np.std(b)) ] )[:,0] ]).astype(np.uint)
|
||||||
|
# generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
|
||||||
|
# print ("dst repeated %d" % (len(idxs)) )
|
|
@ -55,6 +55,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
return next(generator)
|
return next(generator)
|
||||||
|
|
||||||
#forces to repeat these sample idxs as fast as possible
|
#forces to repeat these sample idxs as fast as possible
|
||||||
|
#currently unused
|
||||||
def repeat_sample_idxs(self, idxs): # [ idx, ... ]
|
def repeat_sample_idxs(self, idxs): # [ idx, ... ]
|
||||||
#send idxs list to all sub generators.
|
#send idxs list to all sub generators.
|
||||||
for gen_sq in self.generators_sq:
|
for gen_sq in self.generators_sq:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue