SAEHD: write_preview_history now works faster

The frequency at which the preview is saved now depends on the resolution.
For example 64x64 – every 10 iters. 448x448 – every 70 iters.
This commit is contained in:
Colombo 2020-07-18 11:03:13 +04:00
parent 58670722dc
commit ce95af9068
2 changed files with 65 additions and 19 deletions

View file

@ -1,6 +1,7 @@
import colorsys
import inspect
import json
import multiprocessing
import operator
import os
import pickle
@ -12,12 +13,11 @@ from pathlib import Path
import cv2
import numpy as np
from core import imagelib
from core import imagelib, pathex
from core.cv2ex import *
from core.interact import interact as io
from core.leras import nn
from samplelib import SampleGeneratorBase
from core import pathex
from core.cv2ex import *
class ModelBase(object):
@ -157,7 +157,7 @@ class ModelBase(object):
else:
self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \
if not cpu_only else nn.DeviceConfig.CPU()
nn.initialize(self.device_config)
####
@ -188,6 +188,7 @@ class ModelBase(object):
self.on_initialize()
self.options['batch_size'] = self.batch_size
self.preview_history_writer = None
if self.is_training:
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
@ -298,12 +299,12 @@ class ModelBase(object):
def ask_batch_size(self, suggest_batch_size=None, range=None):
default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size)
batch_size = max(0, io.input_int("Batch_size", default_batch_size, valid_range=range, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
if range is not None:
batch_size = np.clip(batch_size, range[0], range[1])
self.options['batch_size'] = self.batch_size = batch_size
@ -368,6 +369,11 @@ class ModelBase(object):
def get_static_previews(self):
return self.onGetPreview (self.sample_for_preview)
def get_preview_history_writer(self):
if self.preview_history_writer is None:
self.preview_history_writer = PreviewHistoryWriter()
return self.preview_history_writer
def save(self):
Path( self.get_summary_path() ).write_text( self.get_summary_text() )
@ -423,10 +429,8 @@ class ModelBase(object):
name, bgr = previews[i]
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
for preview, filepath in plist:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
if len(plist) != 0:
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
def debug_one_iter(self):
images = []
@ -447,6 +451,10 @@ class ModelBase(object):
self.last_sample = sample
return sample
#overridable
def should_save_preview_history(self):
return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0)
def train_one_iter(self):
iter_time = time.time()
@ -455,8 +463,7 @@ class ModelBase(object):
self.loss_history.append ( [float(loss[1]) for loss in losses] )
if (not io.is_colab() and self.iter % 10 == 0) or \
(io.is_colab() and self.iter % 100 == 0):
if self.should_save_preview_history():
plist = []
if io.is_colab():
@ -470,15 +477,12 @@ class ModelBase(object):
for i in range(len(previews)):
name, bgr = previews[i]
path = self.preview_history_path / name
path.mkdir(parents=True, exist_ok=True)
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
if not io.is_colab():
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
for preview, filepath in plist:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
if len(plist) != 0:
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
self.iter += 1
@ -625,3 +629,41 @@ class ModelBase(object):
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
return lh_img
class PreviewHistoryWriter():
def __init__(self):
self.sq = multiprocessing.Queue()
self.p = multiprocessing.Process(target=self.process, args=( self.sq, ))
self.p.daemon = True
self.p.start()
def process(self, sq):
while True:
while not sq.empty():
plist, loss_history, iter = sq.get()
preview_lh_cache = {}
for preview, filepath in plist:
filepath = Path(filepath)
i = (preview.shape[1], preview.shape[2])
preview_lh = preview_lh_cache.get(i, None)
if preview_lh is None:
preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2])
preview_lh_cache[i] = preview_lh
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
filepath.parent.mkdir(parents=True, exist_ok=True)
cv2_imwrite (filepath, img )
time.sleep(0.01)
def post(self, plist, loss_history, iter):
self.sq.put ( (plist, loss_history, iter) )
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)