From d8c5e42a5507ad478d0823f665b4d9af01bccdbc Mon Sep 17 00:00:00 2001 From: Jeremy Hummel Date: Wed, 11 Sep 2019 22:04:15 -0700 Subject: [PATCH] Scale previews --- mainscripts/Trainer.py | 108 ++++++++++++++++++++++++++++------------- models/ModelBase.py | 4 +- 2 files changed, 76 insertions(+), 36 deletions(-) diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index e7f9382..00752e1 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -3,6 +3,8 @@ import traceback import queue import threading import time +from enum import Enum + import numpy as np import itertools from pathlib import Path @@ -181,48 +183,88 @@ def trainerThread (s2c, c2s, e, args, device_args): c2s.put ( {'op':'close'} ) -def scale_previews(previews, preview_min_height, preview_max_height): - preview_height = max((h for h, w, c in (im.shape for name, im in previews))) +class Zoom(Enum): + ZOOM_25 = (1/4, '25%') + ZOOM_33 = (1/3, '33%') + ZOOM_50 = (1/2, '50%') + ZOOM_67 = (2/3, '67%') + ZOOM_75 = (3/4, '75%') + ZOOM_80 = (4/5, '80%') + ZOOM_90 = (9/10, '90%') + ZOOM_100 = (1, '100%') + ZOOM_110 = (11/10, '110%') + ZOOM_125 = (5/4, '125%') + ZOOM_150 = (3/2, '150%') + ZOOM_175 = (7/4, '175%') + ZOOM_200 = (2, '200%') + ZOOM_250 = (5/2, '250%') + ZOOM_300 = (3, '300%') + ZOOM_400 = (4, '400%') + ZOOM_500 = (5, '500%') - if preview_height > preview_max_height: - preview_height = preview_max_height - elif preview_height < preview_min_height: - preview_height = preview_min_height + def __init__(self, scale, label): + self.scale = scale + self.label = label - # make all previews size equal + def prev(self): + cls = self.__class__ + members = list(cls) + index = members.index(self) - 1 + if index < 0: + return self + return members[index] + + def next(self): + cls = self.__class__ + members = list(cls) + index = members.index(self) + 1 + if index >= len(members): + return self + return members[index] + + +def scale_previews(previews, zoom=Zoom.ZOOM_100): + # Zoom previews for preview in previews[:]: - (preview_name, preview_rgb) = preview - (h, w, c) = preview_rgb.shape - if h != preview_height: - scale_factor = preview_height / float(h) + preview_name, preview_rgb = preview + h, w, c = preview_rgb.shape + scale_factor = zoom.scale * float(h) + if scale_factor < 1: previews.remove(preview) previews.append((preview_name, cv2.resize(preview_rgb, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_AREA))) + elif scale_factor > 1: + previews.remove(preview) + previews.append((preview_name, cv2.resize(preview_rgb, (0, 0), + fx=scale_factor, + fy=scale_factor, + interpolation=cv2.INTER_LANCZOS4))) return previews def create_preview_pane_image(previews, selected_preview, loss_history, - show_last_history_iters_count, batch_size): + show_last_history_iters_count, iteration, batch_size, zoom=Zoom.ZOOM_100): + previews = scale_previews(previews, zoom) selected_preview_name = previews[selected_preview][0] selected_preview_rgb = previews[selected_preview][1] - (h,w,c) = selected_preview_rgb.shape + h, w, c = selected_preview_rgb.shape # HEAD head_lines = [ - '[s]:save [enter]:exit', + '[s]:save [enter]:exit [-/+]:zoom: %s' % zoom.label, '[p]:update [space]:next preview [l]:change history range', 'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) ) ] - head_line_height = 15 + head_line_height = int(15 * zoom.scale) head_height = len(head_lines) * head_line_height - head = np.ones ( (head_height,w,c) ) * 0.1 + head = np.ones((head_height, w, c)) * 0.1 for i in range(0, len(head_lines)): - t = i*head_line_height - b = (i+1)*head_line_height - head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c ) + t = i * head_line_height + b = (i+1) * head_line_height + head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8]*c) final = head @@ -231,11 +273,11 @@ def create_preview_pane_image(previews, selected_preview, loss_history, loss_history_to_show = loss_history else: loss_history_to_show = loss_history[-show_last_history_iters_count:] - - lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, batch_size, w, c) + lh_height = int(100 * zoom.scale) + lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iteration, batch_size, w, c, lh_height) final = np.concatenate ( [final, lh_img], axis=0 ) - final = np.concatenate ( [final, selected_preview_rgb], axis=0 ) + final = np.concatenate([final, selected_preview_rgb], axis=0) final = np.clip(final, 0, 1) return (final*255).astype(np.uint8) @@ -278,10 +320,9 @@ def main(args, device_args): is_showing = False is_waiting_preview = False show_last_history_iters_count = 0 - iter = 0 + iteration = 0 batch_size = 1 - preview_min_height = 512 - preview_max_height = 1024 + zoom = Zoom.ZOOM_100 while True: if not c2s.empty(): @@ -291,23 +332,24 @@ def main(args, device_args): is_waiting_preview = False loss_history = input['loss_history'] if 'loss_history' in input.keys() else None previews = input['previews'] if 'previews' in input.keys() else None - iter = input['iter'] if 'iter' in input.keys() else 0 + iteration = input['iter'] if 'iter' in input.keys() else 0 #batch_size = input['batch_size'] if 'iter' in input.keys() else 1 if previews is not None: - previews = scale_previews(previews, preview_min_height, preview_max_height) - selected_preview = selected_preview % len(previews) update_preview = True elif op == 'close': break if update_preview: update_preview = False + selected_preview = selected_preview % len(previews) preview_pane_image = create_preview_pane_image(previews, selected_preview, loss_history, show_last_history_iters_count, - batch_size) - io.show_image( wnd_name, preview_pane_image) + iteration, + batch_size, + zoom) + io.show_image(wnd_name, preview_pane_image) is_showing = True key_events = io.get_key_events(wnd_name) @@ -337,12 +379,10 @@ def main(args, device_args): selected_preview = (selected_preview + 1) % len(previews) update_preview = True elif key == ord('-'): - # Decrease zoom + zoom = zoom.prev() pass elif key == ord('+'): - # Increase zoom - pass - + zoom = zoom.next() try: io.process_messages(0.1) except KeyboardInterrupt: diff --git a/models/ModelBase.py b/models/ModelBase.py index 208d408..2281d0c 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -641,11 +641,11 @@ class ModelBase(object): self.batch_size = d[ keys[-1] ] @staticmethod - def get_loss_history_preview(loss_history, iter,batch_size, w, c): + def get_loss_history_preview(loss_history, iter, batch_size, w, c, lh_height=100): loss_history = np.array (loss_history.copy()) lh_height = 100 - lh_img = np.ones ( (lh_height,w,c) ) * 0.1 + lh_img = np.ones((lh_height, w, c)) * 0.1 if len(loss_history) != 0: loss_count = len(loss_history[0])