mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
SAEHD, AMP: removed the implicit function of periodically retraining last 16 “high-loss” samples
This commit is contained in:
parent
33ff0be722
commit
9e0079c6a0
3 changed files with 13 additions and 54 deletions
|
@ -104,7 +104,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
if archi_opts is not None:
|
||||
if len(archi_opts) == 0:
|
||||
continue
|
||||
if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0:
|
||||
if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0:
|
||||
continue
|
||||
|
||||
if 'd' in archi_opts:
|
||||
|
@ -688,9 +688,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
generators_count=dst_generators_count )
|
||||
])
|
||||
|
||||
self.last_src_samples_loss = []
|
||||
self.last_dst_samples_loss = []
|
||||
|
||||
if self.pretrain_just_disabled:
|
||||
self.update_sample_for_preview(force_new=True)
|
||||
|
||||
|
@ -765,33 +762,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
|
||||
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
|
||||
|
||||
bs = self.get_batch_size()
|
||||
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
||||
|
||||
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||
|
||||
for i in range(bs):
|
||||
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) )
|
||||
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i],) )
|
||||
|
||||
if len(self.last_src_samples_loss) >= bs*16:
|
||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
|
||||
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
|
||||
|
||||
target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
||||
target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
|
||||
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
||||
|
||||
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
||||
|
||||
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
|
||||
self.last_src_samples_loss = []
|
||||
self.last_dst_samples_loss = []
|
||||
|
||||
if self.options['true_face_power'] != 0 and not self.pretrain:
|
||||
self.D_train (warped_src, warped_dst)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue