new optimized training:

for every batch_size*16 samples,
model collects the samples with the highest error and learns them again
therefore hard samples will be trained more often
This commit is contained in:
Colombo 2020-02-27 11:19:53 +04:00
parent a5783df546
commit c7ab9653c5

View file

@ -1,4 +1,5 @@
import multiprocessing
import operator
from functools import partial
import numpy as np
@ -644,8 +645,6 @@ class SAEHDModel(ModelBase):
self.target_dst :target_dst,
self.target_dstm_all:target_dstm_all,
})
s = np.mean(s)
d = np.mean(d)
return s, d
self.src_dst_train = src_dst_train
@ -764,7 +763,10 @@ class SAEHDModel(ModelBase):
],
generators_count=dst_generators_count )
])
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True)
@ -780,10 +782,30 @@ class SAEHDModel(ModelBase):
#override
def onTrainOneIter(self):
bs = self.get_batch_size()
( (warped_src, target_src, target_srcm_all), \
(warped_dst, target_dst, target_dstm_all) ) = self.generate_next_samples()
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
for i in range(bs):
self.last_src_samples_loss.append ( (target_src[i], target_srcm_all[i], src_loss[i] ) )
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm_all[i], dst_loss[i] ) )
if len(self.last_src_samples_loss) >= bs*16:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(2), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(2), reverse=True)
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
target_srcm_all = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
target_dstm_all = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm_all, target_dst, target_dst, target_dstm_all)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.options['true_face_power'] != 0 and not self.pretrain:
self.D_train (warped_src, warped_dst)
@ -791,7 +813,7 @@ class SAEHDModel(ModelBase):
if self.gan_power != 0:
self.D_src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
#override
def onGetPreview(self, samples):