SAEHD: write_preview_history now works faster

The frequency at which the preview is saved now depends on the resolution.
For example 64x64 – every 10 iters. 448x448 – every 70 iters.
This commit is contained in:
Colombo 2020-07-18 11:03:13 +04:00
parent 58670722dc
commit ce95af9068
2 changed files with 65 additions and 19 deletions

View file

@ -1,6 +1,7 @@
import colorsys
import inspect
import json
import multiprocessing
import operator
import os
import pickle
@ -12,12 +13,11 @@ from pathlib import Path
import cv2
import numpy as np
from core import imagelib
from core import imagelib, pathex
from core.cv2ex import *
from core.interact import interact as io
from core.leras import nn
from samplelib import SampleGeneratorBase
from core import pathex
from core.cv2ex import *
class ModelBase(object):
@ -188,6 +188,7 @@ class ModelBase(object):
self.on_initialize()
self.options['batch_size'] = self.batch_size
self.preview_history_writer = None
if self.is_training:
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
@ -368,6 +369,11 @@ class ModelBase(object):
def get_static_previews(self):
return self.onGetPreview (self.sample_for_preview)
def get_preview_history_writer(self):
if self.preview_history_writer is None:
self.preview_history_writer = PreviewHistoryWriter()
return self.preview_history_writer
def save(self):
Path( self.get_summary_path() ).write_text( self.get_summary_text() )
@ -423,10 +429,8 @@ class ModelBase(object):
name, bgr = previews[i]
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
for preview, filepath in plist:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
if len(plist) != 0:
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
def debug_one_iter(self):
images = []
@ -447,6 +451,10 @@ class ModelBase(object):
self.last_sample = sample
return sample
#overridable
def should_save_preview_history(self):
return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0)
def train_one_iter(self):
iter_time = time.time()
@ -455,8 +463,7 @@ class ModelBase(object):
self.loss_history.append ( [float(loss[1]) for loss in losses] )
if (not io.is_colab() and self.iter % 10 == 0) or \
(io.is_colab() and self.iter % 100 == 0):
if self.should_save_preview_history():
plist = []
if io.is_colab():
@ -470,15 +477,12 @@ class ModelBase(object):
for i in range(len(previews)):
name, bgr = previews[i]
path = self.preview_history_path / name
path.mkdir(parents=True, exist_ok=True)
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
if not io.is_colab():
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
for preview, filepath in plist:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
if len(plist) != 0:
self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
self.iter += 1
@ -625,3 +629,41 @@ class ModelBase(object):
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
return lh_img
class PreviewHistoryWriter():
def __init__(self):
self.sq = multiprocessing.Queue()
self.p = multiprocessing.Process(target=self.process, args=( self.sq, ))
self.p.daemon = True
self.p.start()
def process(self, sq):
while True:
while not sq.empty():
plist, loss_history, iter = sq.get()
preview_lh_cache = {}
for preview, filepath in plist:
filepath = Path(filepath)
i = (preview.shape[1], preview.shape[2])
preview_lh = preview_lh_cache.get(i, None)
if preview_lh is None:
preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2])
preview_lh_cache[i] = preview_lh
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
filepath.parent.mkdir(parents=True, exist_ok=True)
cv2_imwrite (filepath, img )
time.sleep(0.01)
def post(self, plist, loss_history, iter):
self.sq.put ( (plist, loss_history, iter) )
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)

View file

@ -609,6 +609,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
model.save_weights ( self.get_strpath_storage_for_file(filename) )
#override
def should_save_preview_history(self):
return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \
(io.is_colab() and self.iter % 100 == 0)
#override
def onTrainOneIter(self):