mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
new optimized training:
for every batch_size*16 samples, model collects the samples with the highest error and learns them again therefore hard samples will be trained more often
This commit is contained in:
parent
a5783df546
commit
c7ab9653c5
1 changed files with 26 additions and 4 deletions
|
@ -1,4 +1,5 @@
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import operator
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -644,8 +645,6 @@ class SAEHDModel(ModelBase):
|
||||||
self.target_dst :target_dst,
|
self.target_dst :target_dst,
|
||||||
self.target_dstm_all:target_dstm_all,
|
self.target_dstm_all:target_dstm_all,
|
||||||
})
|
})
|
||||||
s = np.mean(s)
|
|
||||||
d = np.mean(d)
|
|
||||||
return s, d
|
return s, d
|
||||||
self.src_dst_train = src_dst_train
|
self.src_dst_train = src_dst_train
|
||||||
|
|
||||||
|
@ -765,6 +764,9 @@ class SAEHDModel(ModelBase):
|
||||||
generators_count=dst_generators_count )
|
generators_count=dst_generators_count )
|
||||||
])
|
])
|
||||||
|
|
||||||
|
self.last_src_samples_loss = []
|
||||||
|
self.last_dst_samples_loss = []
|
||||||
|
|
||||||
if self.pretrain_just_disabled:
|
if self.pretrain_just_disabled:
|
||||||
self.update_sample_for_preview(force_new=True)
|
self.update_sample_for_preview(force_new=True)
|
||||||
|
|
||||||
|
@ -780,18 +782,38 @@ class SAEHDModel(ModelBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self):
|
def onTrainOneIter(self):
|
||||||
|
bs = self.get_batch_size()
|
||||||
|
|
||||||
( (warped_src, target_src, target_srcm_all), \
|
( (warped_src, target_src, target_srcm_all), \
|
||||||
(warped_dst, target_dst, target_dstm_all) ) = self.generate_next_samples()
|
(warped_dst, target_dst, target_dstm_all) ) = self.generate_next_samples()
|
||||||
|
|
||||||
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
|
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
|
||||||
|
|
||||||
|
for i in range(bs):
|
||||||
|
self.last_src_samples_loss.append ( (target_src[i], target_srcm_all[i], src_loss[i] ) )
|
||||||
|
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm_all[i], dst_loss[i] ) )
|
||||||
|
|
||||||
|
if len(self.last_src_samples_loss) >= bs*16:
|
||||||
|
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(2), reverse=True)
|
||||||
|
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(2), reverse=True)
|
||||||
|
|
||||||
|
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
|
||||||
|
target_srcm_all = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
||||||
|
|
||||||
|
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
|
||||||
|
target_dstm_all = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||||
|
|
||||||
|
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm_all, target_dst, target_dst, target_dstm_all)
|
||||||
|
self.last_src_samples_loss = []
|
||||||
|
self.last_dst_samples_loss = []
|
||||||
|
|
||||||
if self.options['true_face_power'] != 0 and not self.pretrain:
|
if self.options['true_face_power'] != 0 and not self.pretrain:
|
||||||
self.D_train (warped_src, warped_dst)
|
self.D_train (warped_src, warped_dst)
|
||||||
|
|
||||||
if self.gan_power != 0:
|
if self.gan_power != 0:
|
||||||
self.D_src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
|
self.D_src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
|
||||||
|
|
||||||
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
|
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, samples):
|
def onGetPreview(self, samples):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue