Merge pull request #115 from faceshiftlabs/updates-from-old-master

Adds flask preview pane
This commit is contained in:
Jeremy Hummel 2021-03-11 23:58:09 -08:00 committed by GitHub
commit 424b38d2c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 516 additions and 120 deletions

0
flaskr/__init__.py Normal file
View file

102
flaskr/app.py Normal file
View file

@ -0,0 +1,102 @@
from pathlib import Path
from flask import Flask, send_file, Response, render_template, render_template_string, request, g
from flask_socketio import SocketIO, emit
import logging
def create_flask_app(s2c, c2s, s2flask, kwargs):
app = Flask(__name__, template_folder="templates", static_folder="static")
log = logging.getLogger('werkzeug')
log.disabled = True
model_path = Path(kwargs.get('saved_models_path', ''))
filename = 'preview.jpg'
preview_file = str(model_path / filename)
def gen():
frame = open(preview_file, 'rb').read()
while True:
try:
frame = open(preview_file, 'rb').read()
except:
pass
yield b'--frame\r\nContent-Type: image/jpeg\r\n\r\n'
yield frame
yield b'\r\n\r\n'
def send(queue, op):
queue.put({'op': op})
def send_and_wait(queue, op):
while not s2flask.empty():
s2flask.get()
queue.put({'op': op})
while s2flask.empty():
pass
s2flask.get()
@app.route('/save', methods=['POST'])
def save():
send(s2c, 'save')
return '', 204
@app.route('/exit', methods=['POST'])
def exit():
send(c2s, 'close')
request.environ.get('werkzeug.server.shutdown')()
return '', 204
@app.route('/update', methods=['POST'])
def update():
send(c2s, 'update')
return '', 204
@app.route('/next_preview', methods=['POST'])
def next_preview():
send(c2s, 'next_preview')
return '', 204
@app.route('/change_history_range', methods=['POST'])
def change_history_range():
send(c2s, 'change_history_range')
return '', 204
@app.route('/zoom_prev', methods=['POST'])
def zoom_prev():
send(c2s, 'zoom_prev')
return '', 204
@app.route('/zoom_next', methods=['POST'])
def zoom_next():
send(c2s, 'zoom_next')
return '', 204
@app.route('/')
def index():
return render_template('index.html')
# @app.route('/preview_image')
# def preview_image():
# return Response(gen(), mimetype='multipart/x-mixed-replace;boundary=frame')
@app.route('/preview_image')
def preview_image():
return send_file(preview_file, mimetype='image/jpeg', cache_timeout=-1)
socketio = SocketIO(app)
@socketio.on('connect', namespace='/')
def test_connect():
emit('my response', {'data': 'Connected'})
@socketio.on('disconnect', namespace='/test')
def test_disconnect():
print('Client disconnected')
return socketio, app

BIN
flaskr/static/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 284 KiB

View file

@ -0,0 +1,94 @@
<head>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.4.1/jquery.min.js"
integrity="sha256-CSXorXvZcTkaix6Yvo6HppcZGetbYMGWSFlBw8HfCJo="
crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/2.2.0/socket.io.js"
integrity="sha256-yr4fRk/GU1ehYJPAs8P4JlTgu0Hdsp4ZKrx8bDEDC3I="
crossorigin="anonymous"></script>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/material.indigo-pink.min.css">
<script defer src="https://code.getmdl.io/1.3.0/material.min.js"></script>
<title>Training Preview</title>
<link rel="shortcut icon" href="{{ url_for('static', filename='favicon.ico') }}">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script type="text/javascript">
$(function() {
function save() {
$.post("{{ url_for('save') }}");
}
function exit() {
$.post("{{ url_for('exit') }}");
}
function update() {
$.post("{{ url_for('update') }}");
}
function next_preview() {
$.post("{{ url_for('next_preview') }}");
}
function change_history_range() {
$.post("{{ url_for('change_history_range') }}");
}
function zoom_prev() {
$.post("{{ url_for('zoom_prev') }}");
}
function zoom_next() {
$.post("{{ url_for('zoom_next') }}");
}
$(document).keypress(function (event) {
switch (event.key) {
case "s" : save(); break;
case "Enter" : exit(); break;
case "p" : update(); break;
case " " : next_preview(); break;
case "l" : change_history_range(); break;
case "-" : zoom_prev(); break;
case "=" : zoom_next(); break;
}
// console.log('kp:', event);
});
$('button#save').click(save);
$('button#exit').click(exit);
$('button#update').click(update);
$('button#next_preview').click(next_preview);
$('button#change_history_range').click(change_history_range);
$('button#zoom_prev').click(zoom_prev);
$('button#zoom_next').click(zoom_next);
$('img#preview').click(update);
const socket = io.connect();
socket.on('preview', function(msg) {
console.log(msg);
$('img#preview').attr("src", "{{ url_for('preview_image') }}?q=" + new Date().getTime());
});
socket.on('loss', function(loss_string) {
console.log(loss_string);
$('div#loss').html(loss_string);
});
});
</script>
</head>
<body>
<div class="mdl-typography--headline">Training Preview</div>
<div id="loss"></div>
<div>
<button class='mdl-button mdl-js-button mdl-button--raised mdl-js-ripple-effect' id='save'>Save</button>
<button class='mdl-button mdl-js-button mdl-button--raised mdl-js-ripple-effect' id='exit'>Exit</button>
<button class='mdl-button mdl-js-button mdl-button--raised mdl-js-ripple-effect' id='update'>Update</button>
<button class='mdl-button mdl-js-button mdl-button--raised mdl-js-ripple-effect' id='next_preview'>Next preview</button>
<button class='mdl-button mdl-js-button mdl-button--raised mdl-js-ripple-effect' id='change_history_range'>Change History Range</button>
<button class='mdl-button mdl-js-button mdl-button--raised mdl-js-ripple-effect' id='zoom_prev'>Zoom -</button>
<button class='mdl-button mdl-js-button mdl-button--raised mdl-js-ripple-effect' id='zoom_next'>Zoom +</button>
</div>
<img id='preview' src="{{ url_for('preview_image') }}" style="max-width: 100%">
</body>
</html>

40
main.py
View file

@ -23,7 +23,7 @@ if __name__ == "__main__":
setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values))) setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values)))
exit_code = 0 exit_code = 0
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers() subparsers = parser.add_subparsers()
@ -52,9 +52,9 @@ if __name__ == "__main__":
p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to <output-dir>_debug\ directory.") p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to <output-dir>_debug\ directory.")
p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to <output-dir>_debug\ directory.") p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to <output-dir>_debug\ directory.")
p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None) p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None)
p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.") p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.")
p.add_argument('--image-size', type=int, dest="image_size", default=None, help="Output image size.") p.add_argument('--image-size', type=int, dest="image_size", default=None, help="Output image size.")
p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.") p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.")
p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.") p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.")
p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.") p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.")
p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.") p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.")
@ -127,6 +127,7 @@ if __name__ == "__main__":
'silent_start' : arguments.silent_start, 'silent_start' : arguments.silent_start,
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ], 'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
'debug' : arguments.debug, 'debug' : arguments.debug,
'flask_preview' : arguments.flask_preview,
} }
from mainscripts import Trainer from mainscripts import Trainer
Trainer.main(**kwargs) Trainer.main(**kwargs)
@ -144,8 +145,9 @@ if __name__ == "__main__":
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.") p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
p.add_argument('--flask-preview', action="store_true", dest="flask_preview", default=False,
help="Launches a flask server to view the previews in a web browser")
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
p.set_defaults (func=process_train) p.set_defaults (func=process_train)
@ -252,7 +254,7 @@ if __name__ == "__main__":
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.set_defaults(func=process_faceset_enhancer) p.set_defaults(func=process_faceset_enhancer)
def process_dev_test(arguments): def process_dev_test(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
from mainscripts import dev_misc from mainscripts import dev_misc
@ -261,10 +263,10 @@ if __name__ == "__main__":
p = subparsers.add_parser( "dev_test", help="") p = subparsers.add_parser( "dev_test", help="")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_dev_test) p.set_defaults (func=process_dev_test)
# ========== XSeg # ========== XSeg
xseg_parser = subparsers.add_parser( "xseg", help="XSeg tools.").add_subparsers() xseg_parser = subparsers.add_parser( "xseg", help="XSeg tools.").add_subparsers()
p = xseg_parser.add_parser( "editor", help="XSeg editor.") p = xseg_parser.add_parser( "editor", help="XSeg editor.")
def process_xsegeditor(arguments): def process_xsegeditor(arguments):
@ -272,11 +274,11 @@ if __name__ == "__main__":
from XSegEditor import XSegEditor from XSegEditor import XSegEditor
global exit_code global exit_code
exit_code = XSegEditor.start (Path(arguments.input_dir)) exit_code = XSegEditor.start (Path(arguments.input_dir))
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xsegeditor) p.set_defaults (func=process_xsegeditor)
p = xseg_parser.add_parser( "apply", help="Apply trained XSeg model to the extracted faces.") p = xseg_parser.add_parser( "apply", help="Apply trained XSeg model to the extracted faces.")
def process_xsegapply(arguments): def process_xsegapply(arguments):
@ -286,8 +288,8 @@ if __name__ == "__main__":
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir") p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir")
p.set_defaults (func=process_xsegapply) p.set_defaults (func=process_xsegapply)
p = xseg_parser.add_parser( "remove", help="Remove applied XSeg masks from the extracted faces.") p = xseg_parser.add_parser( "remove", help="Remove applied XSeg masks from the extracted faces.")
def process_xsegremove(arguments): def process_xsegremove(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
@ -295,8 +297,8 @@ if __name__ == "__main__":
XSegUtil.remove_xseg (Path(arguments.input_dir) ) XSegUtil.remove_xseg (Path(arguments.input_dir) )
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xsegremove) p.set_defaults (func=process_xsegremove)
p = xseg_parser.add_parser( "remove_labels", help="Remove XSeg labels from the extracted faces.") p = xseg_parser.add_parser( "remove_labels", help="Remove XSeg labels from the extracted faces.")
def process_xsegremovelabels(arguments): def process_xsegremovelabels(arguments):
osex.set_process_lowest_prio() osex.set_process_lowest_prio()
@ -304,8 +306,8 @@ if __name__ == "__main__":
XSegUtil.remove_xseg_labels (Path(arguments.input_dir) ) XSegUtil.remove_xseg_labels (Path(arguments.input_dir) )
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xsegremovelabels) p.set_defaults (func=process_xsegremovelabels)
p = xseg_parser.add_parser( "fetch", help="Copies faces containing XSeg polygons in <input_dir>_xseg dir.") p = xseg_parser.add_parser( "fetch", help="Copies faces containing XSeg polygons in <input_dir>_xseg dir.")
def process_xsegfetch(arguments): def process_xsegfetch(arguments):
@ -314,7 +316,7 @@ if __name__ == "__main__":
XSegUtil.fetch_xseg (Path(arguments.input_dir) ) XSegUtil.fetch_xseg (Path(arguments.input_dir) )
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir")
p.set_defaults (func=process_xsegfetch) p.set_defaults (func=process_xsegfetch)
def bad_args(arguments): def bad_args(arguments):
parser.print_help() parser.print_help()
exit(0) exit(0)
@ -325,9 +327,9 @@ if __name__ == "__main__":
if exit_code == 0: if exit_code == 0:
print ("Done.") print ("Done.")
exit(exit_code) exit(exit_code)
''' '''
import code import code
code.interact(local=dict(globals(), **locals())) code.interact(local=dict(globals(), **locals()))

View file

@ -4,6 +4,8 @@ import traceback
import queue import queue
import threading import threading
import time import time
from enum import Enum
import numpy as np import numpy as np
import itertools import itertools
from pathlib import Path from pathlib import Path
@ -13,21 +15,23 @@ import cv2
import models import models
from core.interact import interact as io from core.interact import interact as io
def trainerThread (s2c, c2s, e,
model_class_name = None, def trainerThread(s2c, c2s, e,
saved_models_path = None, socketio=None,
training_data_src_path = None, model_class_name=None,
training_data_dst_path = None, saved_models_path=None,
pretraining_data_path = None, training_data_src_path=None,
pretrained_model_path = None, training_data_dst_path=None,
no_preview=False, pretraining_data_path=None,
force_model_name=None, pretrained_model_path=None,
force_gpu_idxs=None, no_preview=False,
cpu_only=None, force_model_name=None,
silent_start=False, force_gpu_idxs=None,
execute_programs = None, cpu_only=None,
debug=False, silent_start=False,
**kwargs): execute_programs=None,
debug=False,
**kwargs):
while True: while True:
try: try:
start_time = time.time() start_time = time.time()
@ -44,67 +48,70 @@ def trainerThread (s2c, c2s, e,
saved_models_path.mkdir(exist_ok=True, parents=True) saved_models_path.mkdir(exist_ok=True, parents=True)
model = models.import_model(model_class_name)( model = models.import_model(model_class_name)(
is_training=True, is_training=True,
saved_models_path=saved_models_path, saved_models_path=saved_models_path,
training_data_src_path=training_data_src_path, training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path, training_data_dst_path=training_data_dst_path,
pretraining_data_path=pretraining_data_path, pretraining_data_path=pretraining_data_path,
pretrained_model_path=pretrained_model_path, pretrained_model_path=pretrained_model_path,
no_preview=no_preview, no_preview=no_preview,
force_model_name=force_model_name, force_model_name=force_model_name,
force_gpu_idxs=force_gpu_idxs, force_gpu_idxs=force_gpu_idxs,
cpu_only=cpu_only, cpu_only=cpu_only,
silent_start=silent_start, silent_start=silent_start,
debug=debug, debug=debug,
) )
is_reached_goal = model.is_reached_iter_goal() is_reached_goal = model.is_reached_iter_goal()
shared_state = { 'after_save' : False } shared_state = {'after_save': False}
loss_string = "" loss_string = ""
save_iter = model.get_iter() save_iter = model.get_iter()
def model_save(): def model_save():
if not debug and not is_reached_goal: if not debug and not is_reached_goal:
io.log_info ("Saving....", end='\r') io.log_info("Saving....", end='\r')
model.save() model.save()
shared_state['after_save'] = True shared_state['after_save'] = True
def model_backup(): def model_backup():
if not debug and not is_reached_goal: if not debug and not is_reached_goal:
model.create_backup() model.create_backup()
def send_preview(): def send_preview():
if not debug: if not debug:
previews = model.get_previews() previews = model.get_previews()
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) c2s.put({'op': 'show', 'previews': previews, 'iter': model.get_iter(),
'loss_history': model.get_loss_history().copy()})
else: else:
previews = [( 'debug, press update for new', model.debug_one_iter())] previews = [('debug, press update for new', model.debug_one_iter())]
c2s.put ( {'op':'show', 'previews': previews} ) c2s.put({'op': 'show', 'previews': previews})
e.set() #Set the GUI Thread as Ready e.set() # Set the GUI Thread as Ready
if model.get_target_iter() != 0: if model.get_target_iter() != 0:
if is_reached_goal: if is_reached_goal:
io.log_info('Model already trained to target iteration. You can use preview.') io.log_info('Model already trained to target iteration. You can use preview.')
else: else:
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) ) io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % (
model.get_target_iter()))
else: else:
io.log_info('Starting. Press "Enter" to stop training and save model.') io.log_info('Starting. Press "Enter" to stop training and save model.')
last_save_time = time.time() last_save_time = time.time()
execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ] execute_programs = [[x[0], x[1], time.time()] for x in execute_programs]
for i in itertools.count(0,1): for i in itertools.count(0, 1):
if not debug: if not debug:
cur_time = time.time() cur_time = time.time()
for x in execute_programs: for x in execute_programs:
prog_time, prog, last_time = x prog_time, prog, last_time = x
exec_prog = False exec_prog = False
if prog_time > 0 and (cur_time - start_time) >= prog_time: if 0 < prog_time <= (cur_time - start_time):
x[0] = 0 x[0] = 0
exec_prog = True exec_prog = True
elif prog_time < 0 and (cur_time - last_time) >= -prog_time: elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
x[2] = cur_time x[2] = cur_time
exec_prog = True exec_prog = True
@ -112,18 +119,20 @@ def trainerThread (s2c, c2s, e,
try: try:
exec(prog) exec(prog)
except Exception as e: except Exception as e:
print("Unable to execute program: %s" % (prog) ) print("Unable to execute program: %s" % prog)
if not is_reached_goal: if not is_reached_goal:
if model.get_iter() == 0: if model.get_iter() == 0:
io.log_info("") io.log_info("")
io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.") io.log_info(
"Trying to do the first iteration. If an error occurs, reduce the model parameters.")
io.log_info("") io.log_info("")
if sys.platform[0:3] == 'win': if sys.platform[0:3] == 'win':
io.log_info("!!!") io.log_info("!!!")
io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.") io.log_info(
"Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.")
io.log_info("https://i.imgur.com/B7cmDCB.jpg") io.log_info("https://i.imgur.com/B7cmDCB.jpg")
io.log_info("!!!") io.log_info("!!!")
@ -132,19 +141,19 @@ def trainerThread (s2c, c2s, e,
loss_history = model.get_loss_history() loss_history = model.get_loss_history()
time_str = time.strftime("[%H:%M:%S]") time_str = time.strftime("[%H:%M:%S]")
if iter_time >= 10: if iter_time >= 10:
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) ) loss_string = "{0}[#{1:06d}][{2:.5s}s]".format(time_str, iter, '{:0.4f}'.format(iter_time))
else: else:
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) ) loss_string = "{0}[#{1:06d}][{2:04d}ms]".format(time_str, iter, int(iter_time * 1000))
if shared_state['after_save']: if shared_state['after_save']:
shared_state['after_save'] = False shared_state['after_save'] = False
mean_loss = np.mean ( loss_history[save_iter:iter], axis=0) mean_loss = np.mean(loss_history[save_iter:iter], axis=0)
for loss_value in mean_loss: for loss_value in mean_loss:
loss_string += "[%.4f]" % (loss_value) loss_string += "[%.4f]" % (loss_value)
io.log_info (loss_string) io.log_info(loss_string)
save_iter = iter save_iter = iter
else: else:
@ -152,25 +161,28 @@ def trainerThread (s2c, c2s, e,
loss_string += "[%.4f]" % (loss_value) loss_string += "[%.4f]" % (loss_value)
if io.is_colab(): if io.is_colab():
io.log_info ('\r' + loss_string, end='') io.log_info('\r' + loss_string, end='')
else: else:
io.log_info (loss_string, end='\r') io.log_info(loss_string, end='\r')
if socketio is not None:
socketio.emit('loss', loss_string)
if model.get_iter() == 1: if model.get_iter() == 1:
model_save() model_save()
if model.get_target_iter() != 0 and model.is_reached_iter_goal(): if model.get_target_iter() != 0 and model.is_reached_iter_goal():
io.log_info ('Reached target iteration.') io.log_info('Reached target iteration.')
model_save() model_save()
is_reached_goal = True is_reached_goal = True
io.log_info ('You can use preview now.') io.log_info('You can use preview now.')
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60: if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min * 60:
last_save_time += save_interval_min*60 last_save_time += save_interval_min * 60
model_save() model_save()
send_preview() send_preview()
if i==0: if i == 0:
if is_reached_goal: if is_reached_goal:
model.pass_one_iter() model.pass_one_iter()
send_preview() send_preview()
@ -179,8 +191,8 @@ def trainerThread (s2c, c2s, e,
time.sleep(0.005) time.sleep(0.005)
while not s2c.empty(): while not s2c.empty():
input = s2c.get() item = s2c.get()
op = input['op'] op = item['op']
if op == 'save': if op == 'save':
model_save() model_save()
elif op == 'backup': elif op == 'backup':
@ -197,43 +209,227 @@ def trainerThread (s2c, c2s, e,
if i == -1: if i == -1:
break break
model.finalize() model.finalize()
except Exception as e: except Exception as e:
print ('Error: %s' % (str(e))) print('Error: %s' % (str(e)))
traceback.print_exc() traceback.print_exc()
break break
c2s.put ( {'op':'close'} ) c2s.put({'op': 'close'})
class Zoom(Enum):
ZOOM_25 = (1 / 4, '25%')
ZOOM_33 = (1 / 3, '33%')
ZOOM_50 = (1 / 2, '50%')
ZOOM_67 = (2 / 3, '67%')
ZOOM_75 = (3 / 4, '75%')
ZOOM_80 = (4 / 5, '80%')
ZOOM_90 = (9 / 10, '90%')
ZOOM_100 = (1, '100%')
ZOOM_110 = (11 / 10, '110%')
ZOOM_125 = (5 / 4, '125%')
ZOOM_150 = (3 / 2, '150%')
ZOOM_175 = (7 / 4, '175%')
ZOOM_200 = (2, '200%')
ZOOM_250 = (5 / 2, '250%')
ZOOM_300 = (3, '300%')
ZOOM_400 = (4, '400%')
ZOOM_500 = (5, '500%')
def __init__(self, scale, label):
self.scale = scale
self.label = label
def prev(self):
cls = self.__class__
members = list(cls)
index = members.index(self) - 1
if index < 0:
return self
return members[index]
def next(self):
cls = self.__class__
members = list(cls)
index = members.index(self) + 1
if index >= len(members):
return self
return members[index]
def scale_previews(previews, zoom=Zoom.ZOOM_100):
scaled = []
for preview in previews:
preview_name, preview_rgb = preview
scale_factor = zoom.scale
if scale_factor < 1:
scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0),
fx=scale_factor,
fy=scale_factor,
interpolation=cv2.INTER_AREA)))
elif scale_factor > 1:
scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0),
fx=scale_factor,
fy=scale_factor,
interpolation=cv2.INTER_LANCZOS4)))
else:
scaled.append((preview_name, preview_rgb))
return scaled
def create_preview_pane_image(previews, selected_preview, loss_history,
show_last_history_iters_count, iteration, batch_size, zoom=Zoom.ZOOM_100):
scaled_previews = scale_previews(previews, zoom)
selected_preview_name = scaled_previews[selected_preview][0]
selected_preview_rgb = scaled_previews[selected_preview][1]
h, w, c = selected_preview_rgb.shape
# HEAD
head_lines = [
'[s]:save [enter]:exit [-/+]:zoom: %s' % zoom.label,
'[p]:update [space]:next preview [l]:change history range',
'Preview: "%s" [%d/%d]' % (selected_preview_name, selected_preview + 1, len(previews))
]
head_line_height = int(15 * zoom.scale)
head_height = len(head_lines) * head_line_height
head = np.ones((head_height, w, c)) * 0.1
for i in range(0, len(head_lines)):
t = i * head_line_height
b = (i + 1) * head_line_height
head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8] * c)
final = head
if loss_history is not None:
if show_last_history_iters_count == 0:
loss_history_to_show = loss_history
else:
loss_history_to_show = loss_history[-show_last_history_iters_count:]
lh_height = int(100 * zoom.scale)
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iteration, w, c, lh_height)
final = np.concatenate([final, lh_img], axis=0)
final = np.concatenate([final, selected_preview_rgb], axis=0)
final = np.clip(final, 0, 1)
return (final * 255).astype(np.uint8)
def main(**kwargs): def main(**kwargs):
io.log_info ("Running trainer.\r\n") io.log_info("Running trainer.\r\n")
no_preview = kwargs.get('no_preview', False) no_preview = kwargs.get('no_preview', False)
flask_preview = kwargs.get('flask_preview', False)
s2c = queue.Queue() s2c = queue.Queue()
c2s = queue.Queue() c2s = queue.Queue()
e = threading.Event() e = threading.Event()
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e), kwargs=kwargs )
thread.start()
e.wait() #Wait for inital load to occur. previews = None
loss_history = None
selected_preview = 0
update_preview = False
is_waiting_preview = False
show_last_history_iters_count = 0
iteration = 0
batch_size = 1
zoom = Zoom.ZOOM_100
if flask_preview:
from flaskr.app import create_flask_app
s2flask = queue.Queue()
socketio, flask_app = create_flask_app(s2c, c2s, s2flask, kwargs)
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, socketio), kwargs=kwargs)
thread.start()
e.wait() # Wait for inital load to occur.
flask_t = threading.Thread(target=socketio.run, args=(flask_app,),
kwargs={'debug': True, 'use_reloader': False})
flask_t.start()
while True:
if not c2s.empty():
item = c2s.get()
op = item['op']
if op == 'show':
is_waiting_preview = False
loss_history = item['loss_history'] if 'loss_history' in item.keys() else None
previews = item['previews'] if 'previews' in item.keys() else None
iteration = item['iter'] if 'iter' in item.keys() else 0
# batch_size = input['batch_size'] if 'iter' in input.keys() else 1
if previews is not None:
update_preview = True
elif op == 'update':
if not is_waiting_preview:
is_waiting_preview = True
s2c.put({'op': 'preview'})
elif op == 'next_preview':
selected_preview = (selected_preview + 1) % len(previews)
update_preview = True
elif op == 'change_history_range':
if show_last_history_iters_count == 0:
show_last_history_iters_count = 5000
elif show_last_history_iters_count == 5000:
show_last_history_iters_count = 10000
elif show_last_history_iters_count == 10000:
show_last_history_iters_count = 50000
elif show_last_history_iters_count == 50000:
show_last_history_iters_count = 100000
elif show_last_history_iters_count == 100000:
show_last_history_iters_count = 0
update_preview = True
elif op == 'close':
s2c.put({'op': 'close'})
break
elif op == 'zoom_prev':
zoom = zoom.prev()
update_preview = True
elif op == 'zoom_next':
zoom = zoom.next()
update_preview = True
if update_preview:
update_preview = False
selected_preview = selected_preview % len(previews)
preview_pane_image = create_preview_pane_image(previews,
selected_preview,
loss_history,
show_last_history_iters_count,
iteration,
batch_size,
zoom)
# io.show_image(wnd_name, preview_pane_image)
model_path = Path(kwargs.get('saved_models_path', ''))
filename = 'preview.jpg'
preview_file = str(model_path / filename)
cv2.imwrite(preview_file, preview_pane_image)
s2flask.put({'op': 'show'})
socketio.emit('preview', {'iter': iteration, 'loss': loss_history[-1]})
try:
io.process_messages(0.01)
except KeyboardInterrupt:
s2c.put({'op': 'close'})
else:
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e), kwargs=kwargs)
thread.start()
e.wait() # Wait for inital load to occur.
if no_preview: if no_preview:
while True: while True:
if not c2s.empty(): if not c2s.empty():
input = c2s.get() item = c2s.get()
op = input.get('op','') op = item.get('op', '')
if op == 'close': if op == 'close':
break break
try: try:
io.process_messages(0.1) io.process_messages(0.1)
except KeyboardInterrupt: except KeyboardInterrupt:
s2c.put ( {'op': 'close'} ) s2c.put({'op': 'close'})
else: else:
wnd_name = "Training preview" wnd_name = "Training preview"
io.named_window(wnd_name) io.named_window(wnd_name)
@ -249,33 +445,33 @@ def main(**kwargs):
iter = 0 iter = 0
while True: while True:
if not c2s.empty(): if not c2s.empty():
input = c2s.get() item = c2s.get()
op = input['op'] op = input['op']
if op == 'show': if op == 'show':
is_waiting_preview = False is_waiting_preview = False
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None loss_history = item['loss_history'] if 'loss_history' in item.keys() else None
previews = input['previews'] if 'previews' in input.keys() else None previews = item['previews'] if 'previews' in item.keys() else None
iter = input['iter'] if 'iter' in input.keys() else 0 iter = item['iter'] if 'iter' in item.keys() else 0
if previews is not None: if previews is not None:
max_w = 0 max_w = 0
max_h = 0 max_h = 0
for (preview_name, preview_rgb) in previews: for (preview_name, preview_rgb) in previews:
(h, w, c) = preview_rgb.shape (h, w, c) = preview_rgb.shape
max_h = max (max_h, h) max_h = max(max_h, h)
max_w = max (max_w, w) max_w = max(max_w, w)
max_size = 800 max_size = 800
if max_h > max_size: if max_h > max_size:
max_w = int( max_w / (max_h / max_size) ) max_w = int(max_w / (max_h / max_size))
max_h = max_size max_h = max_size
#make all previews size equal # make all previews size equal
for preview in previews[:]: for preview in previews[:]:
(preview_name, preview_rgb) = preview (preview_name, preview_rgb) = preview
(h, w, c) = preview_rgb.shape (h, w, c) = preview_rgb.shape
if h != max_h or w != max_w: if h != max_h or w != max_w:
previews.remove(preview) previews.remove(preview)
previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) ) previews.append((preview_name, cv2.resize(preview_rgb, (max_w, max_h))))
selected_preview = selected_preview % len(previews) selected_preview = selected_preview % len(previews)
update_preview = True update_preview = True
elif op == 'close': elif op == 'close':
@ -286,22 +482,22 @@ def main(**kwargs):
selected_preview_name = previews[selected_preview][0] selected_preview_name = previews[selected_preview][0]
selected_preview_rgb = previews[selected_preview][1] selected_preview_rgb = previews[selected_preview][1]
(h,w,c) = selected_preview_rgb.shape (h, w, c) = selected_preview_rgb.shape
# HEAD # HEAD
head_lines = [ head_lines = [
'[s]:save [b]:backup [enter]:exit', '[s]:save [b]:backup [enter]:exit',
'[p]:update [space]:next preview [l]:change history range', '[p]:update [space]:next preview [l]:change history range',
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) ) 'Preview: "%s" [%d/%d]' % (selected_preview_name, selected_preview + 1, len(previews))
] ]
head_line_height = 15 head_line_height = 15
head_height = len(head_lines) * head_line_height head_height = len(head_lines) * head_line_height
head = np.ones ( (head_height,w,c) ) * 0.1 head = np.ones((head_height, w, c)) * 0.1
for i in range(0, len(head_lines)): for i in range(0, len(head_lines)):
t = i*head_line_height t = i * head_line_height
b = (i+1)*head_line_height b = (i + 1) * head_line_height
head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c ) head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8] * c)
final = head final = head
@ -312,27 +508,28 @@ def main(**kwargs):
loss_history_to_show = loss_history[-show_last_history_iters_count:] loss_history_to_show = loss_history[-show_last_history_iters_count:]
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c) lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c)
final = np.concatenate ( [final, lh_img], axis=0 ) final = np.concatenate([final, lh_img], axis=0)
final = np.concatenate ( [final, selected_preview_rgb], axis=0 ) final = np.concatenate([final, selected_preview_rgb], axis=0)
final = np.clip(final, 0, 1) final = np.clip(final, 0, 1)
io.show_image( wnd_name, (final*255).astype(np.uint8) ) io.show_image(wnd_name, (final * 255).astype(np.uint8))
is_showing = True is_showing = True
key_events = io.get_key_events(wnd_name) key_events = io.get_key_events(wnd_name)
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (
0, 0, False, False, False)
if key == ord('\n') or key == ord('\r'): if key == ord('\n') or key == ord('\r'):
s2c.put ( {'op': 'close'} ) s2c.put({'op': 'close'})
elif key == ord('s'): elif key == ord('s'):
s2c.put ( {'op': 'save'} ) s2c.put({'op': 'save'})
elif key == ord('b'): elif key == ord('b'):
s2c.put ( {'op': 'backup'} ) s2c.put({'op': 'backup'})
elif key == ord('p'): elif key == ord('p'):
if not is_waiting_preview: if not is_waiting_preview:
is_waiting_preview = True is_waiting_preview = True
s2c.put ( {'op': 'preview'} ) s2c.put({'op': 'preview'})
elif key == ord('l'): elif key == ord('l'):
if show_last_history_iters_count == 0: if show_last_history_iters_count == 0:
show_last_history_iters_count = 5000 show_last_history_iters_count = 5000
@ -352,6 +549,6 @@ def main(**kwargs):
try: try:
io.process_messages(0.1) io.process_messages(0.1)
except KeyboardInterrupt: except KeyboardInterrupt:
s2c.put ( {'op': 'close'} ) s2c.put({'op': 'close'})
io.destroy_all_windows() io.destroy_all_windows()

View file

@ -535,7 +535,7 @@ class ModelBase(object):
def get_summary_text(self): def get_summary_text(self):
visible_options = self.options.copy() visible_options = self.options.copy()
visible_options.update(self.options_show_override) visible_options.update(self.options_show_override)
###Generate text summary of model hyperparameters ###Generate text summary of model hyperparameters
#Find the longest key name and value string. Used as column widths. #Find the longest key name and value string. Used as column widths.
width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
@ -574,10 +574,9 @@ class ModelBase(object):
return summary_text return summary_text
@staticmethod @staticmethod
def get_loss_history_preview(loss_history, iter, w, c): def get_loss_history_preview(loss_history, iter, w, c, lh_height=100):
loss_history = np.array (loss_history.copy()) loss_history = np.array (loss_history.copy())
lh_height = 100
lh_img = np.ones ( (lh_height,w,c) ) * 0.1 lh_img = np.ones ( (lh_height,w,c) ) * 0.1
if len(loss_history) != 0: if len(loss_history) != 0:

View file

@ -7,4 +7,6 @@ scikit-image==0.14.2
scipy==1.4.1 scipy==1.4.1
colorama colorama
tensorflow-gpu==2.3.1 tensorflow-gpu==2.3.1
pyqt5 pyqt5
Flask==1.1.1
flask-socketio==4.2.1