mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 13:09:56 -07:00
Fixed retraining_samples
This commit is contained in:
parent
224af63704
commit
248ab19e4b
1 changed files with 4 additions and 4 deletions
|
@ -915,13 +915,13 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
|
|
||||||
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||||
|
|
||||||
for i in range(bs):
|
|
||||||
self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) )
|
|
||||||
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) )
|
|
||||||
|
|
||||||
if self.options['retraining_samples']:
|
if self.options['retraining_samples']:
|
||||||
bs = self.get_batch_size()
|
bs = self.get_batch_size()
|
||||||
|
|
||||||
|
for i in range(bs):
|
||||||
|
self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) )
|
||||||
|
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) )
|
||||||
|
|
||||||
if len(self.last_src_samples_loss) >= bs*16:
|
if len(self.last_src_samples_loss) >= bs*16:
|
||||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
|
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
|
||||||
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
|
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue