SAEHD, AMP: removed the implicit function of periodically retraining last 16 “high-loss” samples

This commit is contained in:
iperov 2021-09-29 16:48:54 +04:00
parent 33ff0be722
commit 9e0079c6a0
3 changed files with 13 additions and 54 deletions

View file

@ -104,7 +104,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if archi_opts is not None:
if len(archi_opts) == 0:
continue
if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0:
if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0:
continue
if 'd' in archi_opts:
@ -688,9 +688,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
generators_count=dst_generators_count )
])
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True)
@ -765,33 +762,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
bs = self.get_batch_size()
( (warped_src, target_src, target_srcm, target_srcm_em), \
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
for i in range(bs):
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) )
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i],) )
if len(self.last_src_samples_loss) >= bs*16:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.options['true_face_power'] != 0 and not self.pretrain:
self.D_train (warped_src, warped_dst)