SAEHD, AMP: removed the implicit function of periodically retraining last 16 “high-loss” samples

This commit is contained in:
iperov 2021-09-29 16:48:54 +04:00
parent 33ff0be722
commit 9e0079c6a0
3 changed files with 13 additions and 54 deletions

View file

@ -581,9 +581,6 @@ class AMPModel(ModelBase):
generators_count=dst_generators_count )
])
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
def export_dfm (self):
output_path=self.get_strpath_storage_for_file('model.dfm')
@ -653,26 +650,6 @@ class AMPModel(ModelBase):
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
for i in range(bs):
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) )
if len(self.last_src_samples_loss) >= bs*16:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.gan_power != 0:
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)