Merge pull request #71 from faceshiftlabs/experiment/flask-preview

Experiment/flask preview
This commit is contained in:
Jeremy Hummel 2019-09-17 10:29:35 -07:00 committed by GitHub
commit 0f28352d57
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 444 additions and 143 deletions

0
flaskr/__init__.py Normal file
View file

99
flaskr/app.py Normal file
View file

@ -0,0 +1,99 @@
from pathlib import Path
from flask import Flask, send_file, Response, render_template, render_template_string, request, g
from flask_socketio import SocketIO, emit
def create_flask_app(s2c, c2s, s2flask, args):
app = Flask(__name__, template_folder="templates", static_folder="static")
model_path = Path(args.get('model_path', ''))
filename = 'preview.jpg'
preview_file = str(model_path / filename)
def gen():
frame = open(preview_file, 'rb').read()
while True:
try:
frame = open(preview_file, 'rb').read()
except:
pass
yield b'--frame\r\nContent-Type: image/jpeg\r\n\r\n'
yield frame
yield b'\r\n\r\n'
def send(queue, op):
queue.put({'op': op})
def send_and_wait(queue, op):
while not s2flask.empty():
s2flask.get()
queue.put({'op': op})
while s2flask.empty():
pass
s2flask.get()
@app.route('/save', methods=['POST'])
def save():
send(s2c, 'save')
return '', 204
@app.route('/exit', methods=['POST'])
def exit():
send(c2s, 'close')
request.environ.get('werkzeug.server.shutdown')()
return '', 204
@app.route('/update', methods=['POST'])
def update():
send(c2s, 'update')
return '', 204
@app.route('/next_preview', methods=['POST'])
def next_preview():
send(c2s, 'next_preview')
return '', 204
@app.route('/change_history_range', methods=['POST'])
def change_history_range():
send(c2s, 'change_history_range')
return '', 204
@app.route('/zoom_prev', methods=['POST'])
def zoom_prev():
send(c2s, 'zoom_prev')
return '', 204
@app.route('/zoom_next', methods=['POST'])
def zoom_next():
send(c2s, 'zoom_next')
return '', 204
@app.route('/')
def index():
return render_template('index.html')
# @app.route('/preview_image')
# def preview_image():
# return Response(gen(), mimetype='multipart/x-mixed-replace;boundary=frame')
@app.route('/preview_image')
def preview_image():
return send_file(preview_file, mimetype='image/jpeg', cache_timeout=-1)
socketio = SocketIO(app)
@socketio.on('connect', namespace='/')
def test_connect():
emit('my response', {'data': 'Connected'})
@socketio.on('disconnect', namespace='/test')
def test_disconnect():
print('Client disconnected')
return socketio, app

BIN
flaskr/static/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 284 KiB

View file

@ -0,0 +1,88 @@
<head>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.4.1/jquery.min.js"
integrity="sha256-CSXorXvZcTkaix6Yvo6HppcZGetbYMGWSFlBw8HfCJo="
crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/2.2.0/socket.io.js"
integrity="sha256-yr4fRk/GU1ehYJPAs8P4JlTgu0Hdsp4ZKrx8bDEDC3I="
crossorigin="anonymous"></script>
<title>Training Preview</title>
<link rel="shortcut icon" href="{{ url_for('static', filename='favicon.ico') }}">
<script type="text/javascript">
$(function() {
function save() {
$.post("{{ url_for('save') }}");
}
function exit() {
$.post("{{ url_for('exit') }}");
}
function update() {
$.post("{{ url_for('update') }}");
}
function next_preview() {
$.post("{{ url_for('next_preview') }}");
}
function change_history_range() {
$.post("{{ url_for('change_history_range') }}");
}
function zoom_prev() {
$.post("{{ url_for('zoom_prev') }}");
}
function zoom_next() {
$.post("{{ url_for('zoom_next') }}");
}
$(document).keypress(function (event) {
switch (event.key) {
case "s" : save(); break;
case "Enter" : exit(); break;
case "p" : update(); break;
case " " : next_preview(); break;
case "l" : change_history_range(); break;
case "-" : zoom_prev(); break;
case "=" : zoom_next(); break;
}
// console.log('kp:', event);
});
$('button#save').click(save);
$('button#exit').click(exit);
$('button#update').click(update);
$('button#next_preview').click(next_preview);
$('button#change_history_range').click(change_history_range);
$('button#zoom_prev').click(zoom_prev);
$('button#zoom_next').click(zoom_next);
const socket = io.connect('http://' + document.domain + ':' + location.port);
socket.on('preview', function(msg) {
console.log(msg);
$('img#preview').attr("src", "{{ url_for('preview_image') }}?q=" + new Date().getTime());
});
socket.on('loss', function(loss_string) {
console.log(loss_string);
$('h2#loss').html(loss_string);
});
});
</script>
</head>
<body>
<h1>Training Preview</h1>
<h2 id="loss"></h2>
<div>
<button class='btn btn-default' id='save'>Save</button>
<button class='btn btn-default' id='exit'>Exit</button>
<button class='btn btn-default' id='update'>Update</button>
<button class='btn btn-default' id='next_preview'>Next preview</button>
<button class='btn btn-default' id='change_history_range'>Change History Range</button>
<button class='btn btn-default' id='zoom_prev'>Zoom -</button>
<button class='btn btn-default' id='zoom_next'>Zoom +</button>
</div>
<img id='preview' src="{{ url_for('preview_image') }}">
</body>
</html>

10
main.py
View file

@ -106,10 +106,10 @@ if __name__ == "__main__":
#if arguments.remove_fanseg:
# Util.remove_fanseg_folder (input_path=arguments.input_dir)
if arguments.remove_ie_polys:
Util.remove_ie_polys_folder (input_path=arguments.input_dir)
p = subparsers.add_parser( "util", help="Utilities.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
@ -129,11 +129,13 @@ if __name__ == "__main__":
'model_name' : arguments.model_name,
'no_preview' : arguments.no_preview,
'debug' : arguments.debug,
'flask_preview' : arguments.flask_preview,
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ]
}
device_args = {'cpu_only' : arguments.cpu_only,
'force_gpu_idx' : arguments.force_gpu_idx,
}
from mainscripts import Trainer
Trainer.main(args, device_args)
@ -155,6 +157,8 @@ if __name__ == "__main__":
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
p.add_argument('--pingpong', dest="ping_pong", default=False,
help="Cycle between a batch size of 1 and the chosen batch size")
p.add_argument('--flask-preview', action="store_true", dest="flask_preview", default=False,
help="Launches a flask server to view the previews in a web browser")
p.set_defaults (func=process_train)
def process_convert(arguments):
@ -251,7 +255,7 @@ if __name__ == "__main__":
p.add_argument('--confirmed-dir', required=True, action=fixPathAction, dest="confirmed_dir", help="This is where the labeled faces will be stored.")
p.add_argument('--skipped-dir', required=True, action=fixPathAction, dest="skipped_dir", help="This is where the labeled faces will be stored.")
p.add_argument('--no-default-mask', action="store_true", dest="no_default_mask", default=False, help="Don't use default mask.")
p.set_defaults(func=process_labelingtool_edit_mask)
def bad_args(arguments):

View file

@ -1,31 +1,31 @@
import sys
import traceback
import traceback
import queue
import threading
import time
from enum import Enum
import itertools
import numpy as np
import itertools
from pathlib import Path
from utils import Path_utils
import imagelib
import cv2
import models
from interact import interact as io
def trainerThread (s2c, c2s, e, args, device_args):
def trainer_thread(s2c, c2s, e, args, device_args, socketio=None):
while True:
try:
start_time = time.time()
training_data_src_path = Path( args.get('training_data_src_dir', '') )
training_data_dst_path = Path( args.get('training_data_dst_dir', '') )
training_data_src_path = Path(args.get('training_data_src_dir', ''))
training_data_dst_path = Path(args.get('training_data_dst_dir', ''))
pretraining_data_path = args.get('pretraining_data_dir', '')
pretraining_data_path = Path(pretraining_data_path) if pretraining_data_path is not None else None
model_path = Path( args.get('model_path', '') )
model_path = Path(args.get('model_path', ''))
model_name = args.get('model_name', '')
save_interval_min = 15
debug = args.get('debug', '')
@ -43,33 +43,34 @@ def trainerThread (s2c, c2s, e, args, device_args):
model_path.mkdir(exist_ok=True)
model = models.import_model(model_name)(
model_path,
training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path,
pretraining_data_path=pretraining_data_path,
debug=debug,
device_args=device_args)
model_path,
training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path,
pretraining_data_path=pretraining_data_path,
debug=debug,
device_args=device_args)
is_reached_goal = model.is_reached_iter_goal()
shared_state = { 'after_save' : False }
shared_state = {'after_save': False}
loss_string = ""
save_iter = model.get_iter()
save_iter = model.get_iter()
def model_save():
if not debug and not is_reached_goal:
io.log_info ("Saving....", end='\r')
io.log_info("Saving....", end='\r')
model.save()
shared_state['after_save'] = True
def send_preview():
if not debug:
previews = model.get_previews()
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
c2s.put({'op': 'show', 'previews': previews, 'iter': model.get_iter(),
'loss_history': model.get_loss_history().copy()})
else:
previews = [( 'debug, press update for new', model.debug_one_iter())]
c2s.put ( {'op':'show', 'previews': previews} )
e.set() #Set the GUI Thread as Ready
previews = [('debug, press update for new', model.debug_one_iter())]
c2s.put({'op': 'show', 'previews': previews})
e.set() # Set the GUI Thread as Ready
if model.is_first_run():
model_save()
@ -78,25 +79,26 @@ def trainerThread (s2c, c2s, e, args, device_args):
if is_reached_goal:
io.log_info('Model already trained to target iteration. You can use preview.')
else:
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) )
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % (
model.get_target_iter()))
else:
io.log_info('Starting. Press "Enter" to stop training and save model.')
last_save_time = time.time()
execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ]
execute_programs = [[x[0], x[1], time.time()] for x in execute_programs]
for i in itertools.count(0,1):
for i in itertools.count(0, 1):
if not debug:
cur_time = time.time()
for x in execute_programs:
prog_time, prog, last_time = x
exec_prog = False
if prog_time > 0 and (cur_time - start_time) >= prog_time:
if 0 < prog_time <= (cur_time - start_time):
x[0] = 0
exec_prog = True
elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
x[2] = cur_time
exec_prog = True
@ -104,7 +106,7 @@ def trainerThread (s2c, c2s, e, args, device_args):
try:
exec(prog)
except Exception as e:
print("Unable to execute program: %s" % (prog) )
print("Unable to execute program: %s" % (prog))
if not is_reached_goal:
iter, iter_time, batch_size = model.train_one_iter()
@ -112,42 +114,49 @@ def trainerThread (s2c, c2s, e, args, device_args):
loss_history = model.get_loss_history()
time_str = time.strftime("[%H:%M:%S]")
if iter_time >= 10:
loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format ( time_str, iter, '{:0.4f}'.format(iter_time), batch_size )
loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format(time_str, iter,
'{:0.4f}'.format(iter_time),
batch_size)
else:
loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format ( time_str, iter, int(iter_time*1000), batch_size)
loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format(time_str, iter,
int(iter_time * 1000), batch_size)
if shared_state['after_save']:
shared_state['after_save'] = False
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
last_save_time = time.time() # upd last_save_time only after save+one_iter, because
# plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
mean_loss = np.mean([np.array(loss_history[i]) for i in range(save_iter, iter)], axis=0)
for loss_value in mean_loss:
loss_string += "[%.4f]" % (loss_value)
loss_string += "[%.4f]" % loss_value
io.log_info (loss_string)
io.log_info(loss_string)
save_iter = iter
else:
for loss_value in loss_history[-1]:
loss_string += "[%.4f]" % (loss_value)
loss_string += "[%.4f]" % loss_value
if io.is_colab():
io.log_info ('\r' + loss_string, end='')
io.log_info('\r' + loss_string, end='')
else:
io.log_info (loss_string, end='\r')
io.log_info(loss_string, end='\r')
if socketio is not None:
socketio.emit('loss', loss_string)
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
io.log_info ('Reached target iteration.')
io.log_info('Reached target iteration.')
model_save()
is_reached_goal = True
io.log_info ('You can use preview now.')
io.log_info('You can use preview now.')
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min * 60:
model_save()
send_preview()
if i==0:
if i == 0:
if is_reached_goal:
model.pass_one_iter()
send_preview()
@ -156,8 +165,8 @@ def trainerThread (s2c, c2s, e, args, device_args):
time.sleep(0.005)
while not s2c.empty():
input = s2c.get()
op = input['op']
item = s2c.get()
op = item['op']
if op == 'save':
model_save()
elif op == 'preview':
@ -172,32 +181,30 @@ def trainerThread (s2c, c2s, e, args, device_args):
if i == -1:
break
model.finalize()
except Exception as e:
print ('Error: %s' % (str(e)))
print('Error: %s' % (str(e)))
traceback.print_exc()
break
c2s.put ( {'op':'close'} )
c2s.put({'op': 'close'})
class Zoom(Enum):
ZOOM_25 = (1/4, '25%')
ZOOM_33 = (1/3, '33%')
ZOOM_50 = (1/2, '50%')
ZOOM_67 = (2/3, '67%')
ZOOM_75 = (3/4, '75%')
ZOOM_80 = (4/5, '80%')
ZOOM_90 = (9/10, '90%')
ZOOM_25 = (1 / 4, '25%')
ZOOM_33 = (1 / 3, '33%')
ZOOM_50 = (1 / 2, '50%')
ZOOM_67 = (2 / 3, '67%')
ZOOM_75 = (3 / 4, '75%')
ZOOM_80 = (4 / 5, '80%')
ZOOM_90 = (9 / 10, '90%')
ZOOM_100 = (1, '100%')
ZOOM_110 = (11/10, '110%')
ZOOM_125 = (5/4, '125%')
ZOOM_150 = (3/2, '150%')
ZOOM_175 = (7/4, '175%')
ZOOM_110 = (11 / 10, '110%')
ZOOM_125 = (5 / 4, '125%')
ZOOM_150 = (3 / 2, '150%')
ZOOM_175 = (7 / 4, '175%')
ZOOM_200 = (2, '200%')
ZOOM_250 = (5/2, '250%')
ZOOM_250 = (5 / 2, '250%')
ZOOM_300 = (3, '300%')
ZOOM_400 = (4, '400%')
ZOOM_500 = (5, '500%')
@ -254,7 +261,7 @@ def create_preview_pane_image(previews, selected_preview, loss_history,
head_lines = [
'[s]:save [enter]:exit [-/+]:zoom: %s' % zoom.label,
'[p]:update [space]:next preview [l]:change history range',
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews))
'Preview: "%s" [%d/%d]' % (selected_preview_name, selected_preview + 1, len(previews))
]
head_line_height = int(15 * zoom.scale)
head_height = len(head_lines) * head_line_height
@ -262,8 +269,8 @@ def create_preview_pane_image(previews, selected_preview, loss_history,
for i in range(0, len(head_lines)):
t = i * head_line_height
b = (i+1) * head_line_height
head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8]*c)
b = (i + 1) * head_line_height
head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8] * c)
final = head
@ -274,69 +281,88 @@ def create_preview_pane_image(previews, selected_preview, loss_history,
loss_history_to_show = loss_history[-show_last_history_iters_count:]
lh_height = int(100 * zoom.scale)
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iteration, batch_size, w, c, lh_height)
final = np.concatenate ( [final, lh_img], axis=0 )
final = np.concatenate([final, lh_img], axis=0)
final = np.concatenate([final, selected_preview_rgb], axis=0)
final = np.clip(final, 0, 1)
return (final*255).astype(np.uint8)
return (final * 255).astype(np.uint8)
def main(args, device_args):
io.log_info ("Running trainer.\r\n")
io.log_info("Running trainer.\r\n")
no_preview = args.get('no_preview', False)
flask_preview = args.get('flask_preview', False)
s2c = queue.Queue()
c2s = queue.Queue()
e = threading.Event()
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args) )
thread.start()
e.wait() #Wait for inital load to occur.
previews = None
loss_history = None
selected_preview = 0
update_preview = False
is_waiting_preview = False
show_last_history_iters_count = 0
iteration = 0
batch_size = 1
zoom = Zoom.ZOOM_100
if no_preview:
while True:
if not c2s.empty():
input = c2s.get()
op = input.get('op','')
if op == 'close':
break
try:
io.process_messages(0.1)
except KeyboardInterrupt:
s2c.put ( {'op': 'close'} )
else:
wnd_name = "Training preview"
io.named_window(wnd_name)
io.capture_keys(wnd_name)
if flask_preview:
from flaskr.app import create_flask_app
s2flask = queue.Queue()
socketio, flask_app = create_flask_app(s2c, c2s, s2flask, args)
previews = None
loss_history = None
selected_preview = 0
update_preview = False
is_showing = False
is_waiting_preview = False
show_last_history_iters_count = 0
iteration = 0
batch_size = 1
zoom = Zoom.ZOOM_100
e = threading.Event()
thread = threading.Thread(target=trainer_thread, args=(s2c, c2s, e, args, device_args, socketio))
thread.start()
e.wait() # Wait for inital load to occur.
flask_t = threading.Thread(target=socketio.run, args=(flask_app,),
kwargs={'debug': True, 'use_reloader': False})
flask_t.start()
while True:
if not c2s.empty():
input = c2s.get()
op = input['op']
item = c2s.get()
op = item['op']
if op == 'show':
is_waiting_preview = False
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
previews = input['previews'] if 'previews' in input.keys() else None
iteration = input['iter'] if 'iter' in input.keys() else 0
#batch_size = input['batch_size'] if 'iter' in input.keys() else 1
loss_history = item['loss_history'] if 'loss_history' in item.keys() else None
previews = item['previews'] if 'previews' in item.keys() else None
iteration = item['iter'] if 'iter' in item.keys() else 0
# batch_size = input['batch_size'] if 'iter' in input.keys() else 1
if previews is not None:
update_preview = True
elif op == 'update':
if not is_waiting_preview:
is_waiting_preview = True
s2c.put({'op': 'preview'})
elif op == 'next_preview':
selected_preview = (selected_preview + 1) % len(previews)
update_preview = True
elif op == 'change_history_range':
if show_last_history_iters_count == 0:
show_last_history_iters_count = 5000
elif show_last_history_iters_count == 5000:
show_last_history_iters_count = 10000
elif show_last_history_iters_count == 10000:
show_last_history_iters_count = 50000
elif show_last_history_iters_count == 50000:
show_last_history_iters_count = 100000
elif show_last_history_iters_count == 100000:
show_last_history_iters_count = 0
update_preview = True
elif op == 'close':
s2c.put({'op': 'close'})
break
elif op == 'zoom_prev':
zoom = zoom.prev()
update_preview = True
elif op == 'zoom_next':
zoom = zoom.next()
update_preview = True
if update_preview:
update_preview = False
@ -348,44 +374,113 @@ def main(args, device_args):
iteration,
batch_size,
zoom)
io.show_image(wnd_name, preview_pane_image)
is_showing = True
key_events = io.get_key_events(wnd_name)
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
if key == ord('\n') or key == ord('\r'):
s2c.put ( {'op': 'close'} )
elif key == ord('s'):
s2c.put ( {'op': 'save'} )
elif key == ord('p'):
if not is_waiting_preview:
is_waiting_preview = True
s2c.put ( {'op': 'preview'} )
elif key == ord('l'):
if show_last_history_iters_count == 0:
show_last_history_iters_count = 5000
elif show_last_history_iters_count == 5000:
show_last_history_iters_count = 10000
elif show_last_history_iters_count == 10000:
show_last_history_iters_count = 50000
elif show_last_history_iters_count == 50000:
show_last_history_iters_count = 100000
elif show_last_history_iters_count == 100000:
show_last_history_iters_count = 0
update_preview = True
elif key == ord(' '):
selected_preview = (selected_preview + 1) % len(previews)
update_preview = True
elif key == ord('-'):
zoom = zoom.prev()
update_preview = True
elif key == ord('=') or key == ord('+'):
zoom = zoom.next()
update_preview = True
# io.show_image(wnd_name, preview_pane_image)
model_path = Path(args.get('model_path', ''))
filename = 'preview.jpg'
preview_file = str(model_path / filename)
cv2.imwrite(preview_file, preview_pane_image)
s2flask.put({'op': 'show'})
socketio.emit('preview', {'iter': iteration, 'loss': loss_history[-1]})
try:
io.process_messages(0.1)
io.process_messages(0.01)
except KeyboardInterrupt:
s2c.put ( {'op': 'close'} )
s2c.put({'op': 'close'})
else:
thread = threading.Thread(target=trainer_thread, args=(s2c, c2s, e, args, device_args))
thread.start()
io.destroy_all_windows()
e.wait() # Wait for inital load to occur.
if no_preview:
while True:
if not c2s.empty():
item = c2s.get()
op = item.get('op', '')
if op == 'close':
break
try:
io.process_messages(0.1)
except KeyboardInterrupt:
s2c.put({'op': 'close'})
else:
wnd_name = "Training preview"
io.named_window(wnd_name)
io.capture_keys(wnd_name)
previews = None
loss_history = None
selected_preview = 0
update_preview = False
is_showing = False
is_waiting_preview = False
show_last_history_iters_count = 0
iteration = 0
batch_size = 1
zoom = Zoom.ZOOM_100
while True:
if not c2s.empty():
item = c2s.get()
op = item['op']
if op == 'show':
is_waiting_preview = False
loss_history = item['loss_history'] if 'loss_history' in item.keys() else None
previews = item['previews'] if 'previews' in item.keys() else None
iteration = item['iter'] if 'iter' in item.keys() else 0
# batch_size = input['batch_size'] if 'iter' in input.keys() else 1
if previews is not None:
update_preview = True
elif op == 'close':
break
if update_preview:
update_preview = False
selected_preview = selected_preview % len(previews)
preview_pane_image = create_preview_pane_image(previews,
selected_preview,
loss_history,
show_last_history_iters_count,
iteration,
batch_size,
zoom)
io.show_image(wnd_name, preview_pane_image)
key_events = io.get_key_events(wnd_name)
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (
0, 0, False, False, False)
if key == ord('\n') or key == ord('\r'):
s2c.put({'op': 'close'})
elif key == ord('s'):
s2c.put({'op': 'save'})
elif key == ord('p'):
if not is_waiting_preview:
is_waiting_preview = True
s2c.put({'op': 'preview'})
elif key == ord('l'):
if show_last_history_iters_count == 0:
show_last_history_iters_count = 5000
elif show_last_history_iters_count == 5000:
show_last_history_iters_count = 10000
elif show_last_history_iters_count == 10000:
show_last_history_iters_count = 50000
elif show_last_history_iters_count == 50000:
show_last_history_iters_count = 100000
elif show_last_history_iters_count == 100000:
show_last_history_iters_count = 0
update_preview = True
elif key == ord(' '):
selected_preview = (selected_preview + 1) % len(previews)
update_preview = True
elif key == ord('-'):
zoom = zoom.prev()
update_preview = True
elif key == ord('=') or key == ord('+'):
zoom = zoom.next()
update_preview = True
try:
io.process_messages(0.1)
except KeyboardInterrupt:
s2c.put({'op': 'close'})
io.destroy_all_windows()

View file

@ -7,4 +7,6 @@ plaidml-keras==0.5.0
scikit-image
tqdm
ffmpeg-python==0.1.17
git+https://www.github.com/keras-team/keras-contrib.git
git+https://www.github.com/keras-team/keras-contrib.git
Flask==1.1.1
flask-socketio==4.2.1

View file

@ -7,3 +7,5 @@ scikit-image
tqdm
ffmpeg-python==0.1.17
git+https://www.github.com/keras-team/keras-contrib.git
Flask==1.1.1
flask-socketio==4.2.1

View file

@ -9,3 +9,5 @@ scikit-image
tqdm
ffmpeg-python==0.1.17
git+https://www.github.com/keras-team/keras-contrib.git
Flask==1.1.1
flask-socketio==4.2.1

View file

@ -9,3 +9,5 @@ scikit-image
tqdm
ffmpeg-python==0.1.17
git+https://www.github.com/keras-team/keras-contrib.git
Flask==1.1.1
flask-socketio==4.2.1

View file

@ -0,0 +1,7 @@
@echo off
call ..\setenv.bat
python -m pip install Flask==1.1.1
python -m pip install flask-socketio==4.2.1
pause