mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
new optimized training:
for every batch_size*16 samples, model collects the samples with the highest error and learns them again therefore hard samples will be trained more often
This commit is contained in:
parent
a5783df546
commit
c7ab9653c5
1 changed files with 26 additions and 4 deletions
|
@ -1,4 +1,5 @@
|
|||
import multiprocessing
|
||||
import operator
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
@ -644,8 +645,6 @@ class SAEHDModel(ModelBase):
|
|||
self.target_dst :target_dst,
|
||||
self.target_dstm_all:target_dstm_all,
|
||||
})
|
||||
s = np.mean(s)
|
||||
d = np.mean(d)
|
||||
return s, d
|
||||
self.src_dst_train = src_dst_train
|
||||
|
||||
|
@ -765,6 +764,9 @@ class SAEHDModel(ModelBase):
|
|||
generators_count=dst_generators_count )
|
||||
])
|
||||
|
||||
self.last_src_samples_loss = []
|
||||
self.last_dst_samples_loss = []
|
||||
|
||||
if self.pretrain_just_disabled:
|
||||
self.update_sample_for_preview(force_new=True)
|
||||
|
||||
|
@ -780,18 +782,38 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
#override
|
||||
def onTrainOneIter(self):
|
||||
bs = self.get_batch_size()
|
||||
|
||||
( (warped_src, target_src, target_srcm_all), \
|
||||
(warped_dst, target_dst, target_dstm_all) ) = self.generate_next_samples()
|
||||
|
||||
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
|
||||
|
||||
for i in range(bs):
|
||||
self.last_src_samples_loss.append ( (target_src[i], target_srcm_all[i], src_loss[i] ) )
|
||||
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm_all[i], dst_loss[i] ) )
|
||||
|
||||
if len(self.last_src_samples_loss) >= bs*16:
|
||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(2), reverse=True)
|
||||
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(2), reverse=True)
|
||||
|
||||
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
|
||||
target_srcm_all = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
||||
|
||||
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm_all = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||
|
||||
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm_all, target_dst, target_dst, target_dstm_all)
|
||||
self.last_src_samples_loss = []
|
||||
self.last_dst_samples_loss = []
|
||||
|
||||
if self.options['true_face_power'] != 0 and not self.pretrain:
|
||||
self.D_train (warped_src, warped_dst)
|
||||
|
||||
if self.gan_power != 0:
|
||||
self.D_src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
|
||||
|
||||
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
|
||||
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, samples):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue