mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
SAEHD: write_preview_history now works faster
The frequency at which the preview is saved now depends on the resolution. For example 64x64 – every 10 iters. 448x448 – every 70 iters.
This commit is contained in:
parent
58670722dc
commit
ce95af9068
2 changed files with 65 additions and 19 deletions
|
@ -1,6 +1,7 @@
|
|||
import colorsys
|
||||
import inspect
|
||||
import json
|
||||
import multiprocessing
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
|
@ -12,12 +13,11 @@ from pathlib import Path
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from core import imagelib
|
||||
from core import imagelib, pathex
|
||||
from core.cv2ex import *
|
||||
from core.interact import interact as io
|
||||
from core.leras import nn
|
||||
from samplelib import SampleGeneratorBase
|
||||
from core import pathex
|
||||
from core.cv2ex import *
|
||||
|
||||
|
||||
class ModelBase(object):
|
||||
|
@ -157,7 +157,7 @@ class ModelBase(object):
|
|||
else:
|
||||
self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \
|
||||
if not cpu_only else nn.DeviceConfig.CPU()
|
||||
|
||||
|
||||
nn.initialize(self.device_config)
|
||||
|
||||
####
|
||||
|
@ -188,6 +188,7 @@ class ModelBase(object):
|
|||
self.on_initialize()
|
||||
self.options['batch_size'] = self.batch_size
|
||||
|
||||
self.preview_history_writer = None
|
||||
if self.is_training:
|
||||
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
|
||||
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
|
||||
|
@ -298,12 +299,12 @@ class ModelBase(object):
|
|||
|
||||
def ask_batch_size(self, suggest_batch_size=None, range=None):
|
||||
default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size)
|
||||
|
||||
|
||||
batch_size = max(0, io.input_int("Batch_size", default_batch_size, valid_range=range, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
||||
|
||||
|
||||
if range is not None:
|
||||
batch_size = np.clip(batch_size, range[0], range[1])
|
||||
|
||||
|
||||
self.options['batch_size'] = self.batch_size = batch_size
|
||||
|
||||
|
||||
|
@ -368,6 +369,11 @@ class ModelBase(object):
|
|||
def get_static_previews(self):
|
||||
return self.onGetPreview (self.sample_for_preview)
|
||||
|
||||
def get_preview_history_writer(self):
|
||||
if self.preview_history_writer is None:
|
||||
self.preview_history_writer = PreviewHistoryWriter()
|
||||
return self.preview_history_writer
|
||||
|
||||
def save(self):
|
||||
Path( self.get_summary_path() ).write_text( self.get_summary_text() )
|
||||
|
||||
|
@ -423,10 +429,8 @@ class ModelBase(object):
|
|||
name, bgr = previews[i]
|
||||
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
|
||||
|
||||
for preview, filepath in plist:
|
||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
cv2_imwrite (filepath, img )
|
||||
if len(plist) != 0:
|
||||
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
|
||||
|
||||
def debug_one_iter(self):
|
||||
images = []
|
||||
|
@ -447,6 +451,10 @@ class ModelBase(object):
|
|||
self.last_sample = sample
|
||||
return sample
|
||||
|
||||
#overridable
|
||||
def should_save_preview_history(self):
|
||||
return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0)
|
||||
|
||||
def train_one_iter(self):
|
||||
|
||||
iter_time = time.time()
|
||||
|
@ -455,8 +463,7 @@ class ModelBase(object):
|
|||
|
||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
||||
|
||||
if (not io.is_colab() and self.iter % 10 == 0) or \
|
||||
(io.is_colab() and self.iter % 100 == 0):
|
||||
if self.should_save_preview_history():
|
||||
plist = []
|
||||
|
||||
if io.is_colab():
|
||||
|
@ -470,15 +477,12 @@ class ModelBase(object):
|
|||
for i in range(len(previews)):
|
||||
name, bgr = previews[i]
|
||||
path = self.preview_history_path / name
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
|
||||
if not io.is_colab():
|
||||
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
|
||||
|
||||
for preview, filepath in plist:
|
||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
cv2_imwrite (filepath, img )
|
||||
|
||||
if len(plist) != 0:
|
||||
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
|
||||
|
||||
self.iter += 1
|
||||
|
||||
|
@ -625,3 +629,41 @@ class ModelBase(object):
|
|||
|
||||
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
|
||||
return lh_img
|
||||
|
||||
class PreviewHistoryWriter():
|
||||
def __init__(self):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.p = multiprocessing.Process(target=self.process, args=( self.sq, ))
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
def process(self, sq):
|
||||
while True:
|
||||
while not sq.empty():
|
||||
plist, loss_history, iter = sq.get()
|
||||
|
||||
preview_lh_cache = {}
|
||||
for preview, filepath in plist:
|
||||
filepath = Path(filepath)
|
||||
i = (preview.shape[1], preview.shape[2])
|
||||
|
||||
preview_lh = preview_lh_cache.get(i, None)
|
||||
if preview_lh is None:
|
||||
preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2])
|
||||
preview_lh_cache[i] = preview_lh
|
||||
|
||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2_imwrite (filepath, img )
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def post(self, plist, loss_history, iter):
|
||||
self.sq.put ( (plist, loss_history, iter) )
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
return dict()
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue