From 637bf69eb7e7905084f2afd7e1af18e1be28ae69 Mon Sep 17 00:00:00 2001 From: jh Date: Fri, 13 Sep 2019 16:23:07 -0700 Subject: [PATCH] better flask previews --- mainscripts/FlaskTrainer.py | 458 ++++++++++++++++++++---------------- 1 file changed, 255 insertions(+), 203 deletions(-) diff --git a/mainscripts/FlaskTrainer.py b/mainscripts/FlaskTrainer.py index c4fa732..7df8904 100644 --- a/mainscripts/FlaskTrainer.py +++ b/mainscripts/FlaskTrainer.py @@ -3,8 +3,9 @@ import traceback import queue import threading import time -from io import BytesIO -import base64 +from enum import Enum +from os.path import getmtime + import numpy as np import itertools from pathlib import Path @@ -14,20 +15,20 @@ import cv2 import models from interact import interact as io from flask import Flask, send_file, Response, render_template, render_template_string, request, g -from flask_caching import Cache +# from flask_socketio import SocketIO -def trainerThread(s2c, c2s, e, args, device_args): +def trainerThread (s2c, c2s, e, args, device_args): while True: try: start_time = time.time() - training_data_src_path = Path(args.get('training_data_src_dir', '')) - training_data_dst_path = Path(args.get('training_data_dst_dir', '')) + training_data_src_path = Path( args.get('training_data_src_dir', '') ) + training_data_dst_path = Path( args.get('training_data_dst_dir', '') ) pretraining_data_path = args.get('pretraining_data_dir', '') pretraining_data_path = Path(pretraining_data_path) if pretraining_data_path is not None else None - model_path = Path(args.get('model_path', '')) + model_path = Path( args.get('model_path', '') ) model_name = args.get('model_name', '') save_interval_min = 15 debug = args.get('debug', '') @@ -54,25 +55,24 @@ def trainerThread(s2c, c2s, e, args, device_args): is_reached_goal = model.is_reached_iter_goal() - shared_state = {'after_save': False} + shared_state = { 'after_save' : False } loss_string = "" - save_iter = model.get_iter() - + save_iter = model.get_iter() def model_save(): if not debug and not is_reached_goal: - io.log_info("Saving....", end='\r') + io.log_info ("Saving....", end='\r') model.save() shared_state['after_save'] = True def send_preview(): if not debug: previews = model.get_previews() - c2s.put({'op': 'show', 'previews': previews, 'iter': model.get_iter(), - 'loss_history': model.get_loss_history().copy()}) + c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) else: - previews = [('debug, press update for new', model.debug_one_iter())] - c2s.put({'op': 'show', 'previews': previews}) - e.set() # Set the GUI Thread as Ready + previews = [( 'debug, press update for new', model.debug_one_iter())] + c2s.put ( {'op':'show', 'previews': previews} ) + e.set() #Set the GUI Thread as Ready + if model.is_first_run(): model_save() @@ -81,16 +81,15 @@ def trainerThread(s2c, c2s, e, args, device_args): if is_reached_goal: io.log_info('Model already trained to target iteration. You can use preview.') else: - io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( - model.get_target_iter())) + io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) ) else: io.log_info('Starting. Press "Enter" to stop training and save model.') last_save_time = time.time() - execute_programs = [[x[0], x[1], time.time()] for x in execute_programs] + execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ] - for i in itertools.count(0, 1): + for i in itertools.count(0,1): if not debug: cur_time = time.time() @@ -100,7 +99,7 @@ def trainerThread(s2c, c2s, e, args, device_args): if prog_time > 0 and (cur_time - start_time) >= prog_time: x[0] = 0 exec_prog = True - elif prog_time < 0 and (cur_time - last_time) >= -prog_time: + elif prog_time < 0 and (cur_time - last_time) >= -prog_time: x[2] = cur_time exec_prog = True @@ -108,7 +107,7 @@ def trainerThread(s2c, c2s, e, args, device_args): try: exec(prog) except Exception as e: - print("Unable to execute program: %s" % (prog)) + print("Unable to execute program: %s" % (prog) ) if not is_reached_goal: iter, iter_time, batch_size = model.train_one_iter() @@ -116,23 +115,20 @@ def trainerThread(s2c, c2s, e, args, device_args): loss_history = model.get_loss_history() time_str = time.strftime("[%H:%M:%S]") if iter_time >= 10: - loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format(time_str, iter, - '{:0.4f}'.format(iter_time), - batch_size) + loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format ( time_str, iter, '{:0.4f}'.format(iter_time), batch_size ) else: - loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format(time_str, iter, - int(iter_time * 1000), batch_size) + loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format ( time_str, iter, int(iter_time*1000), batch_size) if shared_state['after_save']: shared_state['after_save'] = False - last_save_time = time.time() # upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274 + last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274 - mean_loss = np.mean([np.array(loss_history[i]) for i in range(save_iter, iter)], axis=0) + mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0) for loss_value in mean_loss: loss_string += "[%.4f]" % (loss_value) - io.log_info(loss_string) + io.log_info (loss_string) save_iter = iter else: @@ -140,21 +136,21 @@ def trainerThread(s2c, c2s, e, args, device_args): loss_string += "[%.4f]" % (loss_value) if io.is_colab(): - io.log_info('\r' + loss_string, end='') + io.log_info ('\r' + loss_string, end='') else: - io.log_info(loss_string, end='\r') + io.log_info (loss_string, end='\r') if model.get_target_iter() != 0 and model.is_reached_iter_goal(): - io.log_info('Reached target iteration.') + io.log_info ('Reached target iteration.') model_save() is_reached_goal = True - io.log_info('You can use preview now.') + io.log_info ('You can use preview now.') - if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min * 60: + if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60: model_save() send_preview() - if i == 0: + if i==0: if is_reached_goal: model.pass_one_iter() send_preview() @@ -179,169 +175,127 @@ def trainerThread(s2c, c2s, e, args, device_args): if i == -1: break + + model.finalize() except Exception as e: - print('Error: %s' % (str(e))) + print ('Error: %s' % (str(e))) traceback.print_exc() break - c2s.put({'op': 'close'}) + c2s.put ( {'op':'close'} ) -class Preview: - def __init__(self, c2s, s2c, preview_queue): - self.c2s = c2s - self.s2c = s2c - self.preview_queue = preview_queue - # self.wnd_name = "Training preview" - # io.named_window(wnd_name) - # io.capture_keys(wnd_name) +class Zoom(Enum): + ZOOM_25 = (1/4, '25%') + ZOOM_33 = (1/3, '33%') + ZOOM_50 = (1/2, '50%') + ZOOM_67 = (2/3, '67%') + ZOOM_75 = (3/4, '75%') + ZOOM_80 = (4/5, '80%') + ZOOM_90 = (9/10, '90%') + ZOOM_100 = (1, '100%') + ZOOM_110 = (11/10, '110%') + ZOOM_125 = (5/4, '125%') + ZOOM_150 = (3/2, '150%') + ZOOM_175 = (7/4, '175%') + ZOOM_200 = (2, '200%') + ZOOM_250 = (5/2, '250%') + ZOOM_300 = (3, '300%') + ZOOM_400 = (4, '400%') + ZOOM_500 = (5, '500%') - self.previews = None - self.loss_history = None - self.selected_preview = 0 - self.update_preview = False - self.is_showing = False - self.is_waiting_preview = False - self.show_last_history_iters_count = 0 - self.iter = 0 - self.batch_size = 1 - self.preview_min_height = 512 - self.preview_max_height = 1024 - self.close = False + def __init__(self, scale, label): + self.scale = scale + self.label = label - def get_preview(self): - while not self.close: - self.process_queue_items() - self.update_preview_frame() + def prev(self): + cls = self.__class__ + members = list(cls) + index = members.index(self) - 1 + if index < 0: + return self + return members[index] - def process_queue_items(self): - if not self.c2s.empty(): - input = self.c2s.get() - op = input['op'] - if op == 'show': - self.is_waiting_preview = False - self.loss_history = input['loss_history'] if 'loss_history' in input.keys() else None - self.previews = input['previews'] if 'previews' in input.keys() else None - self.iter = input['iter'] if 'iter' in input.keys() else 0 - - if self.previews is not None: - self.resize_previews() - self.selected_preview = self.selected_preview % len(self.previews) - self.update_preview = True - elif op == 'close': - self.close = True - elif op == 'update': - self.update() - elif op == 'next_preview': - self.next_preview() - elif op == 'change_history_range': - self.change_history_range() - - def update_preview_frame(self): - if self.update_preview: - self.update_preview = False - - selected_preview_name = self.previews[self.selected_preview][0] - selected_preview_rgb = self.previews[self.selected_preview][1] - (h, w, c) = selected_preview_rgb.shape - - # HEAD - head_lines = [ - '[s]:save [enter]:exit', - '[p]:update [space]:next preview [l]:change history range', - 'Preview: "%s" [%d/%d]' % (selected_preview_name, self.selected_preview + 1, len(self.previews)) - ] - head_line_height = 15 - head_height = len(head_lines) * head_line_height - head = np.ones((head_height, w, c)) * 0.1 - - for i in range(0, len(head_lines)): - t = i * head_line_height - b = (i + 1) * head_line_height - head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8] * c) - - final = head - - if self.loss_history is not None: - if self.show_last_history_iters_count == 0: - loss_history_to_show = self.loss_history - else: - loss_history_to_show = self.loss_history[-self.show_last_history_iters_count:] - - lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, self.iter, self.batch_size, w, - c) - final = np.concatenate([final, lh_img], axis=0) - - final = np.concatenate([final, selected_preview_rgb], axis=0) - final = np.clip(final, 0, 1) - preview_pane = (final * 255).astype(np.uint8) - retval, buffer = cv2.imencode('.jpg', preview_pane) - # jpg_as_text = base64.b64encode(buffer) - jpg_as_text = buffer.tostring() - self.preview_queue.put(jpg_as_text) - - def resize_previews(self): - preview_height = max((h for h, w, c in (im.shape for name, im in self.previews))) - if preview_height > self.preview_max_height: - preview_height = self.preview_max_height - elif preview_height < self.preview_min_height: - preview_height = self.preview_min_height - - # make all previews size equal - for p in self.previews[:]: - (preview_name, preview_rgb) = p - (h, w, c) = preview_rgb.shape - if h != preview_height: - scale_factor = preview_height / float(h) - self.previews.remove(p) - self.previews.append((preview_name, cv2.resize(preview_rgb, (0, 0), - fx=scale_factor, - fy=scale_factor, - interpolation=cv2.INTER_AREA))) - self.selected_preview = self.selected_preview % len(self.previews) - - def save(self): - self.s2c.put({'op': 'save'}) - - def exit(self): - self.s2c.put({'op': 'close'}) - - def update(self): - if not self.is_waiting_preview: - self.is_waiting_preview = True - self.s2c.put({'op': 'preview'}) - - def next_preview(self): - self.selected_preview = (self.selected_preview + 1) % len(self.previews) - self.update_preview = True - - def change_history_range(self): - if self.show_last_history_iters_count == 0: - self.show_last_history_iters_count = 5000 - elif self.show_last_history_iters_count == 5000: - self.show_last_history_iters_count = 10000 - elif self.show_last_history_iters_count == 10000: - self.show_last_history_iters_count = 50000 - elif self.show_last_history_iters_count == 50000: - self.show_last_history_iters_count = 100000 - elif self.show_last_history_iters_count == 100000: - self.show_last_history_iters_count = 0 - self.update_preview = True + def next(self): + cls = self.__class__ + members = list(cls) + index = members.index(self) + 1 + if index >= len(members): + return self + return members[index] -def flask_thread(s2c, c2s, preview_queue): - config = { - "DEBUG": True, # some Flask specific configs - "CACHE_TYPE": "simple", # Flask-Caching related configs - "CACHE_DEFAULT_TIMEOUT": 300 - } +def scale_previews(previews, zoom=Zoom.ZOOM_100): + scaled = [] + for preview in previews: + preview_name, preview_rgb = preview + scale_factor = zoom.scale + if scale_factor < 1: + scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0), + fx=scale_factor, + fy=scale_factor, + interpolation=cv2.INTER_AREA))) + elif scale_factor > 1: + scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0), + fx=scale_factor, + fy=scale_factor, + interpolation=cv2.INTER_LANCZOS4))) + else: + scaled.append((preview_name, preview_rgb)) + return scaled + + +def create_preview_pane_image(previews, selected_preview, loss_history, + show_last_history_iters_count, iteration, batch_size, zoom=Zoom.ZOOM_100): + scaled_previews = scale_previews(previews, zoom) + selected_preview_name = scaled_previews[selected_preview][0] + selected_preview_rgb = scaled_previews[selected_preview][1] + h, w, c = selected_preview_rgb.shape + + # HEAD + head_lines = [ + '[s]:save [enter]:exit [-/+]:zoom: %s' % zoom.label, + '[p]:update [space]:next preview [l]:change history range', + 'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews)) + ] + head_line_height = int(15 * zoom.scale) + head_height = len(head_lines) * head_line_height + head = np.ones((head_height, w, c)) * 0.1 + + for i in range(0, len(head_lines)): + t = i * head_line_height + b = (i+1) * head_line_height + head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8]*c) + + final = head + + if loss_history is not None: + if show_last_history_iters_count == 0: + loss_history_to_show = loss_history + else: + loss_history_to_show = loss_history[-show_last_history_iters_count:] + lh_height = int(100 * zoom.scale) + lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iteration, batch_size, w, c, lh_height) + final = np.concatenate ( [final, lh_img], axis=0 ) + + final = np.concatenate([final, selected_preview_rgb], axis=0) + final = np.clip(final, 0, 1) + return (final*255).astype(np.uint8) + + +def flask_thread(s2c, c2s, s2flask, args): + # config = { + # "DEBUG": True, # some Flask specific configs + # "CACHE_TYPE": "simple", # Flask-Caching related configs + # "CACHE_DEFAULT_TIMEOUT": 300 + # } app = Flask(__name__) - app.config.from_mapping(config) - cache = Cache(app) + # app.config.from_mapping(config) + # cache = Cache(app) template = """ - Video Streaming Demonstration + Flask Server Demonstration

Video Streaming Demonstration

@@ -357,8 +311,17 @@ def flask_thread(s2c, c2s, preview_queue): """ def gen(): - if not preview_queue.empty(): - frame = preview_queue.get() + model_path = Path(args.get('model_path', '')) + print('[MainThread]', 'model_path:', model_path) + filename = 'preview.jpg' + preview_file = str(model_path / filename) + print('[MainThread]', 'preview_file:', preview_file) + frame = open(preview_file, 'rb').read() + while True: + try: + frame = open(preview_file, 'rb').read() + except: + pass yield b'--frame\r\nContent-Type: image/jpeg\r\n\r\n' yield frame yield b'\r\n\r\n' @@ -371,46 +334,135 @@ def flask_thread(s2c, c2s, preview_queue): elif 'exit' in request.form: s2c.put({'op': 'close'}) elif 'update' in request.form: + while not s2flask.empty(): + input = s2flask.get() c2s.put({'op': 'update'}) + while s2flask.empty(): + pass + input = s2flask.get() elif 'next_preview' in request.form: - c2s.put({'op': 'preview'}) + while not s2flask.empty(): + input = s2flask.get() + c2s.put({'op': 'next_preview'}) + while s2flask.empty(): + pass + input = s2flask.get() elif 'change_history_range' in request.form: + while not s2flask.empty(): + input = s2flask.get() c2s.put({'op': 'change_history_range'}) + while s2flask.empty(): + pass + input = s2flask.get() + # return '', 204 return render_template_string(template) - def queue_not_empty(): - return not preview_queue.empty() - # @app.route('/preview_image') - # @cache.cached(timeout=300, unless=queue_not_empty) # def preview_image(): - # yield Response(preview_queue.get(), - # mimetype='multipart/x-mixed-replace;boundary=frame') + # return Response(gen(), mimetype='multipart/x-mixed-replace;boundary=frame') @app.route('/preview_image') - @cache.cached(timeout=300, unless=queue_not_empty) def preview_image(): - return Response(preview_queue.get(), mimetype='image/jpeg') + model_path = Path(args.get('model_path', '')) + filename = 'preview.jpg' + preview_file = str(model_path / filename) + return send_file(preview_file, mimetype='image/jpeg', cache_timeout=-1) - app.run(debug=True, use_reloader=False) + app.run(debug=False, use_reloader=False) def main(args, device_args): - io.log_info("Running trainer.\r\n") + io.log_info ("Running trainer.\r\n") + + no_preview = args.get('no_preview', False) + s2c = queue.Queue() c2s = queue.Queue() - preview_queue = queue.Queue() + s2flask = queue.Queue() e = threading.Event() - thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args)) + thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args) ) thread.start() - e.wait() # Wait for inital load to occur. + e.wait() #Wait for inital load to occur. - flask_t = threading.Thread(target=flask_thread, args=(s2c, c2s, preview_queue)) + flask_t = threading.Thread(target=flask_thread, args=(s2c, c2s, s2flask, args)) flask_t.start() - preview = Preview(c2s, s2c, preview_queue) - preview.get_preview() + wnd_name = "Training preview" + io.named_window(wnd_name) + io.capture_keys(wnd_name) + + previews = None + loss_history = None + selected_preview = 0 + update_preview = False + is_showing = False + is_waiting_preview = False + show_last_history_iters_count = 0 + iteration = 0 + batch_size = 1 + zoom = Zoom.ZOOM_100 + + while True: + if not c2s.empty(): + input = c2s.get() + op = input['op'] + if op == 'show': + is_waiting_preview = False + loss_history = input['loss_history'] if 'loss_history' in input.keys() else None + previews = input['previews'] if 'previews' in input.keys() else None + iteration = input['iter'] if 'iter' in input.keys() else 0 + #batch_size = input['batch_size'] if 'iter' in input.keys() else 1 + if previews is not None: + update_preview = True + elif op == 'update': + if not is_waiting_preview: + is_waiting_preview = True + s2c.put({'op': 'preview'}) + elif op == 'next_preview': + selected_preview = (selected_preview + 1) % len(previews) + update_preview = True + elif op == 'change_history_range': + if show_last_history_iters_count == 0: + show_last_history_iters_count = 5000 + elif show_last_history_iters_count == 5000: + show_last_history_iters_count = 10000 + elif show_last_history_iters_count == 10000: + show_last_history_iters_count = 50000 + elif show_last_history_iters_count == 50000: + show_last_history_iters_count = 100000 + elif show_last_history_iters_count == 100000: + show_last_history_iters_count = 0 + update_preview = True + + if update_preview: + update_preview = False + selected_preview = selected_preview % len(previews) + preview_pane_image = create_preview_pane_image(previews, + selected_preview, + loss_history, + show_last_history_iters_count, + iteration, + batch_size, + zoom) + # io.show_image(wnd_name, preview_pane_image) + model_path = Path(args.get('model_path', '')) + filename = 'preview.jpg' + preview_file = str(model_path / filename) + cv2.imwrite(preview_file, preview_pane_image) + s2flask.put({'op': 'show'}) + # socketio.emit('some event', {'data': 42}) + + # cv2.imshow(wnd_name, preview_pane_image) + is_showing = True + try: + io.process_messages(0.01) + except KeyboardInterrupt: + s2c.put({'op': 'close'}) + + io.destroy_all_windows() + +