mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-10 23:33:30 -07:00
SAEHD: write_preview_history now works faster
The frequency at which the preview is saved now depends on the resolution. For example 64x64 – every 10 iters. 448x448 – every 70 iters.
This commit is contained in:
parent
58670722dc
commit
ce95af9068
2 changed files with 65 additions and 19 deletions
|
@ -1,6 +1,7 @@
|
||||||
import colorsys
|
import colorsys
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import multiprocessing
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
@ -12,12 +13,11 @@ from pathlib import Path
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from core import imagelib
|
from core import imagelib, pathex
|
||||||
|
from core.cv2ex import *
|
||||||
from core.interact import interact as io
|
from core.interact import interact as io
|
||||||
from core.leras import nn
|
from core.leras import nn
|
||||||
from samplelib import SampleGeneratorBase
|
from samplelib import SampleGeneratorBase
|
||||||
from core import pathex
|
|
||||||
from core.cv2ex import *
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(object):
|
class ModelBase(object):
|
||||||
|
@ -188,6 +188,7 @@ class ModelBase(object):
|
||||||
self.on_initialize()
|
self.on_initialize()
|
||||||
self.options['batch_size'] = self.batch_size
|
self.options['batch_size'] = self.batch_size
|
||||||
|
|
||||||
|
self.preview_history_writer = None
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
|
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
|
||||||
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
|
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
|
||||||
|
@ -368,6 +369,11 @@ class ModelBase(object):
|
||||||
def get_static_previews(self):
|
def get_static_previews(self):
|
||||||
return self.onGetPreview (self.sample_for_preview)
|
return self.onGetPreview (self.sample_for_preview)
|
||||||
|
|
||||||
|
def get_preview_history_writer(self):
|
||||||
|
if self.preview_history_writer is None:
|
||||||
|
self.preview_history_writer = PreviewHistoryWriter()
|
||||||
|
return self.preview_history_writer
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
Path( self.get_summary_path() ).write_text( self.get_summary_text() )
|
Path( self.get_summary_path() ).write_text( self.get_summary_text() )
|
||||||
|
|
||||||
|
@ -423,10 +429,8 @@ class ModelBase(object):
|
||||||
name, bgr = previews[i]
|
name, bgr = previews[i]
|
||||||
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
|
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
|
||||||
|
|
||||||
for preview, filepath in plist:
|
if len(plist) != 0:
|
||||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
|
||||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
|
||||||
cv2_imwrite (filepath, img )
|
|
||||||
|
|
||||||
def debug_one_iter(self):
|
def debug_one_iter(self):
|
||||||
images = []
|
images = []
|
||||||
|
@ -447,6 +451,10 @@ class ModelBase(object):
|
||||||
self.last_sample = sample
|
self.last_sample = sample
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
#overridable
|
||||||
|
def should_save_preview_history(self):
|
||||||
|
return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0)
|
||||||
|
|
||||||
def train_one_iter(self):
|
def train_one_iter(self):
|
||||||
|
|
||||||
iter_time = time.time()
|
iter_time = time.time()
|
||||||
|
@ -455,8 +463,7 @@ class ModelBase(object):
|
||||||
|
|
||||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
||||||
|
|
||||||
if (not io.is_colab() and self.iter % 10 == 0) or \
|
if self.should_save_preview_history():
|
||||||
(io.is_colab() and self.iter % 100 == 0):
|
|
||||||
plist = []
|
plist = []
|
||||||
|
|
||||||
if io.is_colab():
|
if io.is_colab():
|
||||||
|
@ -470,15 +477,12 @@ class ModelBase(object):
|
||||||
for i in range(len(previews)):
|
for i in range(len(previews)):
|
||||||
name, bgr = previews[i]
|
name, bgr = previews[i]
|
||||||
path = self.preview_history_path / name
|
path = self.preview_history_path / name
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
|
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
|
||||||
if not io.is_colab():
|
if not io.is_colab():
|
||||||
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
|
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
|
||||||
|
|
||||||
for preview, filepath in plist:
|
if len(plist) != 0:
|
||||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
|
||||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
|
||||||
cv2_imwrite (filepath, img )
|
|
||||||
|
|
||||||
self.iter += 1
|
self.iter += 1
|
||||||
|
|
||||||
|
@ -625,3 +629,41 @@ class ModelBase(object):
|
||||||
|
|
||||||
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
|
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
|
||||||
return lh_img
|
return lh_img
|
||||||
|
|
||||||
|
class PreviewHistoryWriter():
|
||||||
|
def __init__(self):
|
||||||
|
self.sq = multiprocessing.Queue()
|
||||||
|
self.p = multiprocessing.Process(target=self.process, args=( self.sq, ))
|
||||||
|
self.p.daemon = True
|
||||||
|
self.p.start()
|
||||||
|
|
||||||
|
def process(self, sq):
|
||||||
|
while True:
|
||||||
|
while not sq.empty():
|
||||||
|
plist, loss_history, iter = sq.get()
|
||||||
|
|
||||||
|
preview_lh_cache = {}
|
||||||
|
for preview, filepath in plist:
|
||||||
|
filepath = Path(filepath)
|
||||||
|
i = (preview.shape[1], preview.shape[2])
|
||||||
|
|
||||||
|
preview_lh = preview_lh_cache.get(i, None)
|
||||||
|
if preview_lh is None:
|
||||||
|
preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2])
|
||||||
|
preview_lh_cache[i] = preview_lh
|
||||||
|
|
||||||
|
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
cv2_imwrite (filepath, img )
|
||||||
|
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def post(self, plist, loss_history, iter):
|
||||||
|
self.sq.put ( (plist, loss_history, iter) )
|
||||||
|
|
||||||
|
# disable pickling
|
||||||
|
def __getstate__(self):
|
||||||
|
return dict()
|
||||||
|
def __setstate__(self, d):
|
||||||
|
self.__dict__.update(d)
|
||||||
|
|
|
@ -609,6 +609,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
|
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
|
||||||
model.save_weights ( self.get_strpath_storage_for_file(filename) )
|
model.save_weights ( self.get_strpath_storage_for_file(filename) )
|
||||||
|
|
||||||
|
#override
|
||||||
|
def should_save_preview_history(self):
|
||||||
|
return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \
|
||||||
|
(io.is_colab() and self.iter % 100 == 0)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self):
|
def onTrainOneIter(self):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue