diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 02c0581..9bd3789 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -1,4 +1,5 @@ import multiprocessing +import operator from functools import partial import numpy as np @@ -644,8 +645,6 @@ class SAEHDModel(ModelBase): self.target_dst :target_dst, self.target_dstm_all:target_dstm_all, }) - s = np.mean(s) - d = np.mean(d) return s, d self.src_dst_train = src_dst_train @@ -764,7 +763,10 @@ class SAEHDModel(ModelBase): ], generators_count=dst_generators_count ) ]) - + + self.last_src_samples_loss = [] + self.last_dst_samples_loss = [] + if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True) @@ -780,10 +782,30 @@ class SAEHDModel(ModelBase): #override def onTrainOneIter(self): + bs = self.get_batch_size() + ( (warped_src, target_src, target_srcm_all), \ (warped_dst, target_dst, target_dstm_all) ) = self.generate_next_samples() src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all) + + for i in range(bs): + self.last_src_samples_loss.append ( (target_src[i], target_srcm_all[i], src_loss[i] ) ) + self.last_dst_samples_loss.append ( (target_dst[i], target_dstm_all[i], dst_loss[i] ) ) + + if len(self.last_src_samples_loss) >= bs*16: + src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(2), reverse=True) + dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(2), reverse=True) + + target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] ) + target_srcm_all = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) + + target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] ) + target_dstm_all = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) + + src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm_all, target_dst, target_dst, target_dstm_all) + self.last_src_samples_loss = [] + self.last_dst_samples_loss = [] if self.options['true_face_power'] != 0 and not self.pretrain: self.D_train (warped_src, warped_dst) @@ -791,7 +813,7 @@ class SAEHDModel(ModelBase): if self.gan_power != 0: self.D_src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all) - return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) + return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) #override def onGetPreview(self, samples):