diff --git a/models/ModelBase.py b/models/ModelBase.py index bf401f9..5c0c129 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -1,6 +1,7 @@ import colorsys import inspect import json +import multiprocessing import operator import os import pickle @@ -12,12 +13,11 @@ from pathlib import Path import cv2 import numpy as np -from core import imagelib +from core import imagelib, pathex +from core.cv2ex import * from core.interact import interact as io from core.leras import nn from samplelib import SampleGeneratorBase -from core import pathex -from core.cv2ex import * class ModelBase(object): @@ -157,7 +157,7 @@ class ModelBase(object): else: self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \ if not cpu_only else nn.DeviceConfig.CPU() - + nn.initialize(self.device_config) #### @@ -188,6 +188,7 @@ class ModelBase(object): self.on_initialize() self.options['batch_size'] = self.batch_size + self.preview_history_writer = None if self.is_training: self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' ) self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' ) @@ -298,12 +299,12 @@ class ModelBase(object): def ask_batch_size(self, suggest_batch_size=None, range=None): default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size) - + batch_size = max(0, io.input_int("Batch_size", default_batch_size, valid_range=range, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually.")) - + if range is not None: batch_size = np.clip(batch_size, range[0], range[1]) - + self.options['batch_size'] = self.batch_size = batch_size @@ -368,6 +369,11 @@ class ModelBase(object): def get_static_previews(self): return self.onGetPreview (self.sample_for_preview) + def get_preview_history_writer(self): + if self.preview_history_writer is None: + self.preview_history_writer = PreviewHistoryWriter() + return self.preview_history_writer + def save(self): Path( self.get_summary_path() ).write_text( self.get_summary_text() ) @@ -423,10 +429,8 @@ class ModelBase(object): name, bgr = previews[i] plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ] - for preview, filepath in plist: - preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) - img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) - cv2_imwrite (filepath, img ) + if len(plist) != 0: + self.get_preview_history_writer().post(plist, self.loss_history, self.iter) def debug_one_iter(self): images = [] @@ -447,6 +451,10 @@ class ModelBase(object): self.last_sample = sample return sample + #overridable + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0) + def train_one_iter(self): iter_time = time.time() @@ -455,8 +463,7 @@ class ModelBase(object): self.loss_history.append ( [float(loss[1]) for loss in losses] ) - if (not io.is_colab() and self.iter % 10 == 0) or \ - (io.is_colab() and self.iter % 100 == 0): + if self.should_save_preview_history(): plist = [] if io.is_colab(): @@ -470,15 +477,12 @@ class ModelBase(object): for i in range(len(previews)): name, bgr = previews[i] path = self.preview_history_path / name - path.mkdir(parents=True, exist_ok=True) plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ] if not io.is_colab(): plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ] - - for preview, filepath in plist: - preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) - img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) - cv2_imwrite (filepath, img ) + + if len(plist) != 0: + self.get_preview_history_writer().post(plist, self.loss_history, self.iter) self.iter += 1 @@ -625,3 +629,41 @@ class ModelBase(object): lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c ) return lh_img + +class PreviewHistoryWriter(): + def __init__(self): + self.sq = multiprocessing.Queue() + self.p = multiprocessing.Process(target=self.process, args=( self.sq, )) + self.p.daemon = True + self.p.start() + + def process(self, sq): + while True: + while not sq.empty(): + plist, loss_history, iter = sq.get() + + preview_lh_cache = {} + for preview, filepath in plist: + filepath = Path(filepath) + i = (preview.shape[1], preview.shape[2]) + + preview_lh = preview_lh_cache.get(i, None) + if preview_lh is None: + preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2]) + preview_lh_cache[i] = preview_lh + + img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) + + filepath.parent.mkdir(parents=True, exist_ok=True) + cv2_imwrite (filepath, img ) + + time.sleep(0.01) + + def post(self, plist, loss_history, iter): + self.sq.put ( (plist, loss_history, iter) ) + + # disable pickling + def __getstate__(self): + return dict() + def __setstate__(self, d): + self.__dict__.update(d) diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index cba32d7..9164b57 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -609,6 +609,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): model.save_weights ( self.get_strpath_storage_for_file(filename) ) + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) #override def onTrainOneIter(self):