SAEHD: write_preview_history now works faster

The frequency at which the preview is saved now depends on the resolution.
For example 64x64 – every 10 iters. 448x448 – every 70 iters.
This commit is contained in:
Colombo 2020-07-18 11:03:13 +04:00
parent 58670722dc
commit ce95af9068
2 changed files with 65 additions and 19 deletions

View file

@ -1,6 +1,7 @@
import colorsys import colorsys
import inspect import inspect
import json import json
import multiprocessing
import operator import operator
import os import os
import pickle import pickle
@ -12,12 +13,11 @@ from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
from core import imagelib from core import imagelib, pathex
from core.cv2ex import *
from core.interact import interact as io from core.interact import interact as io
from core.leras import nn from core.leras import nn
from samplelib import SampleGeneratorBase from samplelib import SampleGeneratorBase
from core import pathex
from core.cv2ex import *
class ModelBase(object): class ModelBase(object):
@ -188,6 +188,7 @@ class ModelBase(object):
self.on_initialize() self.on_initialize()
self.options['batch_size'] = self.batch_size self.options['batch_size'] = self.batch_size
self.preview_history_writer = None
if self.is_training: if self.is_training:
self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' ) self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' )
self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' ) self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' )
@ -368,6 +369,11 @@ class ModelBase(object):
def get_static_previews(self): def get_static_previews(self):
return self.onGetPreview (self.sample_for_preview) return self.onGetPreview (self.sample_for_preview)
def get_preview_history_writer(self):
if self.preview_history_writer is None:
self.preview_history_writer = PreviewHistoryWriter()
return self.preview_history_writer
def save(self): def save(self):
Path( self.get_summary_path() ).write_text( self.get_summary_text() ) Path( self.get_summary_path() ).write_text( self.get_summary_text() )
@ -423,10 +429,8 @@ class ModelBase(object):
name, bgr = previews[i] name, bgr = previews[i]
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ] plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
for preview, filepath in plist: if len(plist) != 0:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
def debug_one_iter(self): def debug_one_iter(self):
images = [] images = []
@ -447,6 +451,10 @@ class ModelBase(object):
self.last_sample = sample self.last_sample = sample
return sample return sample
#overridable
def should_save_preview_history(self):
return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0)
def train_one_iter(self): def train_one_iter(self):
iter_time = time.time() iter_time = time.time()
@ -455,8 +463,7 @@ class ModelBase(object):
self.loss_history.append ( [float(loss[1]) for loss in losses] ) self.loss_history.append ( [float(loss[1]) for loss in losses] )
if (not io.is_colab() and self.iter % 10 == 0) or \ if self.should_save_preview_history():
(io.is_colab() and self.iter % 100 == 0):
plist = [] plist = []
if io.is_colab(): if io.is_colab():
@ -470,15 +477,12 @@ class ModelBase(object):
for i in range(len(previews)): for i in range(len(previews)):
name, bgr = previews[i] name, bgr = previews[i]
path = self.preview_history_path / name path = self.preview_history_path / name
path.mkdir(parents=True, exist_ok=True)
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ] plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
if not io.is_colab(): if not io.is_colab():
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ] plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
for preview, filepath in plist: if len(plist) != 0:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) self.get_preview_history_writer().post(plist, self.loss_history, self.iter)
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
self.iter += 1 self.iter += 1
@ -625,3 +629,41 @@ class ModelBase(object):
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c ) lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
return lh_img return lh_img
class PreviewHistoryWriter():
def __init__(self):
self.sq = multiprocessing.Queue()
self.p = multiprocessing.Process(target=self.process, args=( self.sq, ))
self.p.daemon = True
self.p.start()
def process(self, sq):
while True:
while not sq.empty():
plist, loss_history, iter = sq.get()
preview_lh_cache = {}
for preview, filepath in plist:
filepath = Path(filepath)
i = (preview.shape[1], preview.shape[2])
preview_lh = preview_lh_cache.get(i, None)
if preview_lh is None:
preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2])
preview_lh_cache[i] = preview_lh
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
filepath.parent.mkdir(parents=True, exist_ok=True)
cv2_imwrite (filepath, img )
time.sleep(0.01)
def post(self, plist, loss_history, iter):
self.sq.put ( (plist, loss_history, iter) )
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)

View file

@ -609,6 +609,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False):
model.save_weights ( self.get_strpath_storage_for_file(filename) ) model.save_weights ( self.get_strpath_storage_for_file(filename) )
#override
def should_save_preview_history(self):
return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \
(io.is_colab() and self.iter % 100 == 0)
#override #override
def onTrainOneIter(self): def onTrainOneIter(self):