diff --git a/flaskr/__init__.py b/flaskr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flaskr/app.py b/flaskr/app.py new file mode 100644 index 0000000..32a4f30 --- /dev/null +++ b/flaskr/app.py @@ -0,0 +1,102 @@ +from pathlib import Path + +from flask import Flask, send_file, Response, render_template, render_template_string, request, g +from flask_socketio import SocketIO, emit +import logging + + +def create_flask_app(s2c, c2s, s2flask, kwargs): + app = Flask(__name__, template_folder="templates", static_folder="static") + log = logging.getLogger('werkzeug') + log.disabled = True + model_path = Path(kwargs.get('saved_models_path', '')) + filename = 'preview.jpg' + preview_file = str(model_path / filename) + + def gen(): + frame = open(preview_file, 'rb').read() + while True: + try: + frame = open(preview_file, 'rb').read() + except: + pass + yield b'--frame\r\nContent-Type: image/jpeg\r\n\r\n' + yield frame + yield b'\r\n\r\n' + + def send(queue, op): + queue.put({'op': op}) + + def send_and_wait(queue, op): + while not s2flask.empty(): + s2flask.get() + queue.put({'op': op}) + while s2flask.empty(): + pass + s2flask.get() + + @app.route('/save', methods=['POST']) + def save(): + send(s2c, 'save') + return '', 204 + + @app.route('/exit', methods=['POST']) + def exit(): + send(c2s, 'close') + request.environ.get('werkzeug.server.shutdown')() + return '', 204 + + @app.route('/update', methods=['POST']) + def update(): + send(c2s, 'update') + return '', 204 + + @app.route('/next_preview', methods=['POST']) + def next_preview(): + send(c2s, 'next_preview') + return '', 204 + + @app.route('/change_history_range', methods=['POST']) + def change_history_range(): + send(c2s, 'change_history_range') + return '', 204 + + @app.route('/zoom_prev', methods=['POST']) + def zoom_prev(): + send(c2s, 'zoom_prev') + return '', 204 + + @app.route('/zoom_next', methods=['POST']) + def zoom_next(): + send(c2s, 'zoom_next') + return '', 204 + + @app.route('/') + def index(): + return render_template('index.html') + + # @app.route('/preview_image') + # def preview_image(): + # return Response(gen(), mimetype='multipart/x-mixed-replace;boundary=frame') + + @app.route('/preview_image') + def preview_image(): + return send_file(preview_file, mimetype='image/jpeg', cache_timeout=-1) + + socketio = SocketIO(app) + + @socketio.on('connect', namespace='/') + def test_connect(): + emit('my response', {'data': 'Connected'}) + + @socketio.on('disconnect', namespace='/test') + def test_disconnect(): + print('Client disconnected') + + return socketio, app + + + + + + diff --git a/flaskr/static/favicon.ico b/flaskr/static/favicon.ico new file mode 100644 index 0000000..46aec07 Binary files /dev/null and b/flaskr/static/favicon.ico differ diff --git a/flaskr/templates/index.html b/flaskr/templates/index.html new file mode 100644 index 0000000..d97326f --- /dev/null +++ b/flaskr/templates/index.html @@ -0,0 +1,94 @@ + + + + + + + Training Preview + + + + + +
Training Preview
+
+
+ + + + + + + +
+ + + diff --git a/main.py b/main.py index 2ba085f..c372e14 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,7 @@ if __name__ == "__main__": setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values))) exit_code = 0 - + parser = argparse.ArgumentParser() subparsers = parser.add_subparsers() @@ -52,9 +52,9 @@ if __name__ == "__main__": p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to _debug\ directory.") p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to _debug\ directory.") p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None) - p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.") + p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.") p.add_argument('--image-size', type=int, dest="image_size", default=None, help="Output image size.") - p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.") + p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.") p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.") p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.") p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.") @@ -127,6 +127,7 @@ if __name__ == "__main__": 'silent_start' : arguments.silent_start, 'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ], 'debug' : arguments.debug, + 'flask_preview' : arguments.flask_preview, } from mainscripts import Trainer Trainer.main(**kwargs) @@ -144,8 +145,9 @@ if __name__ == "__main__": p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.") - - + p.add_argument('--flask-preview', action="store_true", dest="flask_preview", default=False, + help="Launches a flask server to view the previews in a web browser") + p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') p.set_defaults (func=process_train) @@ -252,7 +254,7 @@ if __name__ == "__main__": p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.set_defaults(func=process_faceset_enhancer) - + def process_dev_test(arguments): osex.set_process_lowest_prio() from mainscripts import dev_misc @@ -261,10 +263,10 @@ if __name__ == "__main__": p = subparsers.add_parser( "dev_test", help="") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.set_defaults (func=process_dev_test) - + # ========== XSeg xseg_parser = subparsers.add_parser( "xseg", help="XSeg tools.").add_subparsers() - + p = xseg_parser.add_parser( "editor", help="XSeg editor.") def process_xsegeditor(arguments): @@ -272,11 +274,11 @@ if __name__ == "__main__": from XSegEditor import XSegEditor global exit_code exit_code = XSegEditor.start (Path(arguments.input_dir)) - + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.set_defaults (func=process_xsegeditor) - + p = xseg_parser.add_parser( "apply", help="Apply trained XSeg model to the extracted faces.") def process_xsegapply(arguments): @@ -286,8 +288,8 @@ if __name__ == "__main__": p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir") p.set_defaults (func=process_xsegapply) - - + + p = xseg_parser.add_parser( "remove", help="Remove applied XSeg masks from the extracted faces.") def process_xsegremove(arguments): osex.set_process_lowest_prio() @@ -295,8 +297,8 @@ if __name__ == "__main__": XSegUtil.remove_xseg (Path(arguments.input_dir) ) p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.set_defaults (func=process_xsegremove) - - + + p = xseg_parser.add_parser( "remove_labels", help="Remove XSeg labels from the extracted faces.") def process_xsegremovelabels(arguments): osex.set_process_lowest_prio() @@ -304,8 +306,8 @@ if __name__ == "__main__": XSegUtil.remove_xseg_labels (Path(arguments.input_dir) ) p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.set_defaults (func=process_xsegremovelabels) - - + + p = xseg_parser.add_parser( "fetch", help="Copies faces containing XSeg polygons in _xseg dir.") def process_xsegfetch(arguments): @@ -314,7 +316,7 @@ if __name__ == "__main__": XSegUtil.fetch_xseg (Path(arguments.input_dir) ) p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.set_defaults (func=process_xsegfetch) - + def bad_args(arguments): parser.print_help() exit(0) @@ -325,9 +327,9 @@ if __name__ == "__main__": if exit_code == 0: print ("Done.") - + exit(exit_code) - + ''' import code code.interact(local=dict(globals(), **locals())) diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 7d73e2f..507bc9c 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -4,6 +4,8 @@ import traceback import queue import threading import time +from enum import Enum + import numpy as np import itertools from pathlib import Path @@ -13,21 +15,23 @@ import cv2 import models from core.interact import interact as io -def trainerThread (s2c, c2s, e, - model_class_name = None, - saved_models_path = None, - training_data_src_path = None, - training_data_dst_path = None, - pretraining_data_path = None, - pretrained_model_path = None, - no_preview=False, - force_model_name=None, - force_gpu_idxs=None, - cpu_only=None, - silent_start=False, - execute_programs = None, - debug=False, - **kwargs): + +def trainerThread(s2c, c2s, e, + socketio=None, + model_class_name=None, + saved_models_path=None, + training_data_src_path=None, + training_data_dst_path=None, + pretraining_data_path=None, + pretrained_model_path=None, + no_preview=False, + force_model_name=None, + force_gpu_idxs=None, + cpu_only=None, + silent_start=False, + execute_programs=None, + debug=False, + **kwargs): while True: try: start_time = time.time() @@ -44,67 +48,70 @@ def trainerThread (s2c, c2s, e, saved_models_path.mkdir(exist_ok=True, parents=True) model = models.import_model(model_class_name)( - is_training=True, - saved_models_path=saved_models_path, - training_data_src_path=training_data_src_path, - training_data_dst_path=training_data_dst_path, - pretraining_data_path=pretraining_data_path, - pretrained_model_path=pretrained_model_path, - no_preview=no_preview, - force_model_name=force_model_name, - force_gpu_idxs=force_gpu_idxs, - cpu_only=cpu_only, - silent_start=silent_start, - debug=debug, - ) + is_training=True, + saved_models_path=saved_models_path, + training_data_src_path=training_data_src_path, + training_data_dst_path=training_data_dst_path, + pretraining_data_path=pretraining_data_path, + pretrained_model_path=pretrained_model_path, + no_preview=no_preview, + force_model_name=force_model_name, + force_gpu_idxs=force_gpu_idxs, + cpu_only=cpu_only, + silent_start=silent_start, + debug=debug, + ) is_reached_goal = model.is_reached_iter_goal() - shared_state = { 'after_save' : False } + shared_state = {'after_save': False} loss_string = "" - save_iter = model.get_iter() + save_iter = model.get_iter() + def model_save(): if not debug and not is_reached_goal: - io.log_info ("Saving....", end='\r') + io.log_info("Saving....", end='\r') model.save() shared_state['after_save'] = True - + def model_backup(): if not debug and not is_reached_goal: - model.create_backup() + model.create_backup() def send_preview(): if not debug: previews = model.get_previews() - c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) + c2s.put({'op': 'show', 'previews': previews, 'iter': model.get_iter(), + 'loss_history': model.get_loss_history().copy()}) else: - previews = [( 'debug, press update for new', model.debug_one_iter())] - c2s.put ( {'op':'show', 'previews': previews} ) - e.set() #Set the GUI Thread as Ready + previews = [('debug, press update for new', model.debug_one_iter())] + c2s.put({'op': 'show', 'previews': previews}) + e.set() # Set the GUI Thread as Ready if model.get_target_iter() != 0: if is_reached_goal: io.log_info('Model already trained to target iteration. You can use preview.') else: - io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) ) + io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( + model.get_target_iter())) else: io.log_info('Starting. Press "Enter" to stop training and save model.') last_save_time = time.time() - execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ] + execute_programs = [[x[0], x[1], time.time()] for x in execute_programs] - for i in itertools.count(0,1): + for i in itertools.count(0, 1): if not debug: cur_time = time.time() for x in execute_programs: prog_time, prog, last_time = x exec_prog = False - if prog_time > 0 and (cur_time - start_time) >= prog_time: + if 0 < prog_time <= (cur_time - start_time): x[0] = 0 exec_prog = True - elif prog_time < 0 and (cur_time - last_time) >= -prog_time: + elif prog_time < 0 and (cur_time - last_time) >= -prog_time: x[2] = cur_time exec_prog = True @@ -112,18 +119,20 @@ def trainerThread (s2c, c2s, e, try: exec(prog) except Exception as e: - print("Unable to execute program: %s" % (prog) ) + print("Unable to execute program: %s" % prog) if not is_reached_goal: if model.get_iter() == 0: io.log_info("") - io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.") + io.log_info( + "Trying to do the first iteration. If an error occurs, reduce the model parameters.") io.log_info("") - + if sys.platform[0:3] == 'win': io.log_info("!!!") - io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.") + io.log_info( + "Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.") io.log_info("https://i.imgur.com/B7cmDCB.jpg") io.log_info("!!!") @@ -132,19 +141,19 @@ def trainerThread (s2c, c2s, e, loss_history = model.get_loss_history() time_str = time.strftime("[%H:%M:%S]") if iter_time >= 10: - loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) ) + loss_string = "{0}[#{1:06d}][{2:.5s}s]".format(time_str, iter, '{:0.4f}'.format(iter_time)) else: - loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) ) + loss_string = "{0}[#{1:06d}][{2:04d}ms]".format(time_str, iter, int(iter_time * 1000)) if shared_state['after_save']: shared_state['after_save'] = False - - mean_loss = np.mean ( loss_history[save_iter:iter], axis=0) + + mean_loss = np.mean(loss_history[save_iter:iter], axis=0) for loss_value in mean_loss: loss_string += "[%.4f]" % (loss_value) - io.log_info (loss_string) + io.log_info(loss_string) save_iter = iter else: @@ -152,25 +161,28 @@ def trainerThread (s2c, c2s, e, loss_string += "[%.4f]" % (loss_value) if io.is_colab(): - io.log_info ('\r' + loss_string, end='') + io.log_info('\r' + loss_string, end='') else: - io.log_info (loss_string, end='\r') + io.log_info(loss_string, end='\r') + + if socketio is not None: + socketio.emit('loss', loss_string) if model.get_iter() == 1: model_save() if model.get_target_iter() != 0 and model.is_reached_iter_goal(): - io.log_info ('Reached target iteration.') + io.log_info('Reached target iteration.') model_save() is_reached_goal = True - io.log_info ('You can use preview now.') + io.log_info('You can use preview now.') - if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60: - last_save_time += save_interval_min*60 + if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min * 60: + last_save_time += save_interval_min * 60 model_save() send_preview() - if i==0: + if i == 0: if is_reached_goal: model.pass_one_iter() send_preview() @@ -179,8 +191,8 @@ def trainerThread (s2c, c2s, e, time.sleep(0.005) while not s2c.empty(): - input = s2c.get() - op = input['op'] + item = s2c.get() + op = item['op'] if op == 'save': model_save() elif op == 'backup': @@ -197,43 +209,227 @@ def trainerThread (s2c, c2s, e, if i == -1: break - - model.finalize() except Exception as e: - print ('Error: %s' % (str(e))) + print('Error: %s' % (str(e))) traceback.print_exc() break - c2s.put ( {'op':'close'} ) + c2s.put({'op': 'close'}) +class Zoom(Enum): + ZOOM_25 = (1 / 4, '25%') + ZOOM_33 = (1 / 3, '33%') + ZOOM_50 = (1 / 2, '50%') + ZOOM_67 = (2 / 3, '67%') + ZOOM_75 = (3 / 4, '75%') + ZOOM_80 = (4 / 5, '80%') + ZOOM_90 = (9 / 10, '90%') + ZOOM_100 = (1, '100%') + ZOOM_110 = (11 / 10, '110%') + ZOOM_125 = (5 / 4, '125%') + ZOOM_150 = (3 / 2, '150%') + ZOOM_175 = (7 / 4, '175%') + ZOOM_200 = (2, '200%') + ZOOM_250 = (5 / 2, '250%') + ZOOM_300 = (3, '300%') + ZOOM_400 = (4, '400%') + ZOOM_500 = (5, '500%') + + def __init__(self, scale, label): + self.scale = scale + self.label = label + + def prev(self): + cls = self.__class__ + members = list(cls) + index = members.index(self) - 1 + if index < 0: + return self + return members[index] + + def next(self): + cls = self.__class__ + members = list(cls) + index = members.index(self) + 1 + if index >= len(members): + return self + return members[index] + + +def scale_previews(previews, zoom=Zoom.ZOOM_100): + scaled = [] + for preview in previews: + preview_name, preview_rgb = preview + scale_factor = zoom.scale + if scale_factor < 1: + scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0), + fx=scale_factor, + fy=scale_factor, + interpolation=cv2.INTER_AREA))) + elif scale_factor > 1: + scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0), + fx=scale_factor, + fy=scale_factor, + interpolation=cv2.INTER_LANCZOS4))) + else: + scaled.append((preview_name, preview_rgb)) + return scaled + + +def create_preview_pane_image(previews, selected_preview, loss_history, + show_last_history_iters_count, iteration, batch_size, zoom=Zoom.ZOOM_100): + scaled_previews = scale_previews(previews, zoom) + selected_preview_name = scaled_previews[selected_preview][0] + selected_preview_rgb = scaled_previews[selected_preview][1] + h, w, c = selected_preview_rgb.shape + + # HEAD + head_lines = [ + '[s]:save [enter]:exit [-/+]:zoom: %s' % zoom.label, + '[p]:update [space]:next preview [l]:change history range', + 'Preview: "%s" [%d/%d]' % (selected_preview_name, selected_preview + 1, len(previews)) + ] + head_line_height = int(15 * zoom.scale) + head_height = len(head_lines) * head_line_height + head = np.ones((head_height, w, c)) * 0.1 + + for i in range(0, len(head_lines)): + t = i * head_line_height + b = (i + 1) * head_line_height + head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8] * c) + + final = head + + if loss_history is not None: + if show_last_history_iters_count == 0: + loss_history_to_show = loss_history + else: + loss_history_to_show = loss_history[-show_last_history_iters_count:] + lh_height = int(100 * zoom.scale) + lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iteration, w, c, lh_height) + final = np.concatenate([final, lh_img], axis=0) + + final = np.concatenate([final, selected_preview_rgb], axis=0) + final = np.clip(final, 0, 1) + return (final * 255).astype(np.uint8) + def main(**kwargs): - io.log_info ("Running trainer.\r\n") + io.log_info("Running trainer.\r\n") no_preview = kwargs.get('no_preview', False) + flask_preview = kwargs.get('flask_preview', False) s2c = queue.Queue() c2s = queue.Queue() e = threading.Event() - thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e), kwargs=kwargs ) - thread.start() - e.wait() #Wait for inital load to occur. + previews = None + loss_history = None + selected_preview = 0 + update_preview = False + is_waiting_preview = False + show_last_history_iters_count = 0 + iteration = 0 + batch_size = 1 + zoom = Zoom.ZOOM_100 + + if flask_preview: + from flaskr.app import create_flask_app + s2flask = queue.Queue() + socketio, flask_app = create_flask_app(s2c, c2s, s2flask, kwargs) + + thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, socketio), kwargs=kwargs) + thread.start() + + e.wait() # Wait for inital load to occur. + + flask_t = threading.Thread(target=socketio.run, args=(flask_app,), + kwargs={'debug': True, 'use_reloader': False}) + flask_t.start() + + while True: + if not c2s.empty(): + item = c2s.get() + op = item['op'] + if op == 'show': + is_waiting_preview = False + loss_history = item['loss_history'] if 'loss_history' in item.keys() else None + previews = item['previews'] if 'previews' in item.keys() else None + iteration = item['iter'] if 'iter' in item.keys() else 0 + # batch_size = input['batch_size'] if 'iter' in input.keys() else 1 + if previews is not None: + update_preview = True + elif op == 'update': + if not is_waiting_preview: + is_waiting_preview = True + s2c.put({'op': 'preview'}) + elif op == 'next_preview': + selected_preview = (selected_preview + 1) % len(previews) + update_preview = True + elif op == 'change_history_range': + if show_last_history_iters_count == 0: + show_last_history_iters_count = 5000 + elif show_last_history_iters_count == 5000: + show_last_history_iters_count = 10000 + elif show_last_history_iters_count == 10000: + show_last_history_iters_count = 50000 + elif show_last_history_iters_count == 50000: + show_last_history_iters_count = 100000 + elif show_last_history_iters_count == 100000: + show_last_history_iters_count = 0 + update_preview = True + elif op == 'close': + s2c.put({'op': 'close'}) + break + elif op == 'zoom_prev': + zoom = zoom.prev() + update_preview = True + elif op == 'zoom_next': + zoom = zoom.next() + update_preview = True + + if update_preview: + update_preview = False + selected_preview = selected_preview % len(previews) + preview_pane_image = create_preview_pane_image(previews, + selected_preview, + loss_history, + show_last_history_iters_count, + iteration, + batch_size, + zoom) + # io.show_image(wnd_name, preview_pane_image) + model_path = Path(kwargs.get('saved_models_path', '')) + filename = 'preview.jpg' + preview_file = str(model_path / filename) + cv2.imwrite(preview_file, preview_pane_image) + s2flask.put({'op': 'show'}) + socketio.emit('preview', {'iter': iteration, 'loss': loss_history[-1]}) + try: + io.process_messages(0.01) + except KeyboardInterrupt: + s2c.put({'op': 'close'}) + else: + thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e), kwargs=kwargs) + thread.start() + + e.wait() # Wait for inital load to occur. if no_preview: while True: if not c2s.empty(): - input = c2s.get() - op = input.get('op','') + item = c2s.get() + op = item.get('op', '') if op == 'close': break try: io.process_messages(0.1) except KeyboardInterrupt: - s2c.put ( {'op': 'close'} ) + s2c.put({'op': 'close'}) else: wnd_name = "Training preview" io.named_window(wnd_name) @@ -249,33 +445,33 @@ def main(**kwargs): iter = 0 while True: if not c2s.empty(): - input = c2s.get() + item = c2s.get() op = input['op'] if op == 'show': is_waiting_preview = False - loss_history = input['loss_history'] if 'loss_history' in input.keys() else None - previews = input['previews'] if 'previews' in input.keys() else None - iter = input['iter'] if 'iter' in input.keys() else 0 + loss_history = item['loss_history'] if 'loss_history' in item.keys() else None + previews = item['previews'] if 'previews' in item.keys() else None + iter = item['iter'] if 'iter' in item.keys() else 0 if previews is not None: max_w = 0 max_h = 0 for (preview_name, preview_rgb) in previews: (h, w, c) = preview_rgb.shape - max_h = max (max_h, h) - max_w = max (max_w, w) + max_h = max(max_h, h) + max_w = max(max_w, w) max_size = 800 if max_h > max_size: - max_w = int( max_w / (max_h / max_size) ) + max_w = int(max_w / (max_h / max_size)) max_h = max_size - #make all previews size equal + # make all previews size equal for preview in previews[:]: (preview_name, preview_rgb) = preview (h, w, c) = preview_rgb.shape if h != max_h or w != max_w: previews.remove(preview) - previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) ) + previews.append((preview_name, cv2.resize(preview_rgb, (max_w, max_h)))) selected_preview = selected_preview % len(previews) update_preview = True elif op == 'close': @@ -286,22 +482,22 @@ def main(**kwargs): selected_preview_name = previews[selected_preview][0] selected_preview_rgb = previews[selected_preview][1] - (h,w,c) = selected_preview_rgb.shape + (h, w, c) = selected_preview_rgb.shape # HEAD head_lines = [ '[s]:save [b]:backup [enter]:exit', '[p]:update [space]:next preview [l]:change history range', - 'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) ) - ] + 'Preview: "%s" [%d/%d]' % (selected_preview_name, selected_preview + 1, len(previews)) + ] head_line_height = 15 head_height = len(head_lines) * head_line_height - head = np.ones ( (head_height,w,c) ) * 0.1 + head = np.ones((head_height, w, c)) * 0.1 for i in range(0, len(head_lines)): - t = i*head_line_height - b = (i+1)*head_line_height - head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c ) + t = i * head_line_height + b = (i + 1) * head_line_height + head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8] * c) final = head @@ -312,27 +508,28 @@ def main(**kwargs): loss_history_to_show = loss_history[-show_last_history_iters_count:] lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c) - final = np.concatenate ( [final, lh_img], axis=0 ) + final = np.concatenate([final, lh_img], axis=0) - final = np.concatenate ( [final, selected_preview_rgb], axis=0 ) + final = np.concatenate([final, selected_preview_rgb], axis=0) final = np.clip(final, 0, 1) - io.show_image( wnd_name, (final*255).astype(np.uint8) ) + io.show_image(wnd_name, (final * 255).astype(np.uint8)) is_showing = True key_events = io.get_key_events(wnd_name) - key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else ( + 0, 0, False, False, False) if key == ord('\n') or key == ord('\r'): - s2c.put ( {'op': 'close'} ) + s2c.put({'op': 'close'}) elif key == ord('s'): - s2c.put ( {'op': 'save'} ) + s2c.put({'op': 'save'}) elif key == ord('b'): - s2c.put ( {'op': 'backup'} ) + s2c.put({'op': 'backup'}) elif key == ord('p'): if not is_waiting_preview: is_waiting_preview = True - s2c.put ( {'op': 'preview'} ) + s2c.put({'op': 'preview'}) elif key == ord('l'): if show_last_history_iters_count == 0: show_last_history_iters_count = 5000 @@ -352,6 +549,6 @@ def main(**kwargs): try: io.process_messages(0.1) except KeyboardInterrupt: - s2c.put ( {'op': 'close'} ) + s2c.put({'op': 'close'}) - io.destroy_all_windows() \ No newline at end of file + io.destroy_all_windows() diff --git a/models/ModelBase.py b/models/ModelBase.py index 3cb88c5..29ca6b8 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -535,7 +535,7 @@ class ModelBase(object): def get_summary_text(self): visible_options = self.options.copy() visible_options.update(self.options_show_override) - + ###Generate text summary of model hyperparameters #Find the longest key name and value string. Used as column widths. width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" @@ -574,10 +574,9 @@ class ModelBase(object): return summary_text @staticmethod - def get_loss_history_preview(loss_history, iter, w, c): + def get_loss_history_preview(loss_history, iter, w, c, lh_height=100): loss_history = np.array (loss_history.copy()) - lh_height = 100 lh_img = np.ones ( (lh_height,w,c) ) * 0.1 if len(loss_history) != 0: diff --git a/requirements-cuda.txt b/requirements-cuda.txt index ca18994..360ac11 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,4 +7,6 @@ scikit-image==0.14.2 scipy==1.4.1 colorama tensorflow-gpu==2.3.1 -pyqt5 \ No newline at end of file +pyqt5 +Flask==1.1.1 +flask-socketio==4.2.1