mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 14:24:40 -07:00
Merge pull request #71 from faceshiftlabs/experiment/flask-preview
Experiment/flask preview
This commit is contained in:
commit
0f28352d57
11 changed files with 444 additions and 143 deletions
0
flaskr/__init__.py
Normal file
0
flaskr/__init__.py
Normal file
99
flaskr/app.py
Normal file
99
flaskr/app.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
from pathlib import Path
|
||||
|
||||
from flask import Flask, send_file, Response, render_template, render_template_string, request, g
|
||||
from flask_socketio import SocketIO, emit
|
||||
|
||||
|
||||
def create_flask_app(s2c, c2s, s2flask, args):
|
||||
app = Flask(__name__, template_folder="templates", static_folder="static")
|
||||
model_path = Path(args.get('model_path', ''))
|
||||
filename = 'preview.jpg'
|
||||
preview_file = str(model_path / filename)
|
||||
|
||||
def gen():
|
||||
frame = open(preview_file, 'rb').read()
|
||||
while True:
|
||||
try:
|
||||
frame = open(preview_file, 'rb').read()
|
||||
except:
|
||||
pass
|
||||
yield b'--frame\r\nContent-Type: image/jpeg\r\n\r\n'
|
||||
yield frame
|
||||
yield b'\r\n\r\n'
|
||||
|
||||
def send(queue, op):
|
||||
queue.put({'op': op})
|
||||
|
||||
def send_and_wait(queue, op):
|
||||
while not s2flask.empty():
|
||||
s2flask.get()
|
||||
queue.put({'op': op})
|
||||
while s2flask.empty():
|
||||
pass
|
||||
s2flask.get()
|
||||
|
||||
@app.route('/save', methods=['POST'])
|
||||
def save():
|
||||
send(s2c, 'save')
|
||||
return '', 204
|
||||
|
||||
@app.route('/exit', methods=['POST'])
|
||||
def exit():
|
||||
send(c2s, 'close')
|
||||
request.environ.get('werkzeug.server.shutdown')()
|
||||
return '', 204
|
||||
|
||||
@app.route('/update', methods=['POST'])
|
||||
def update():
|
||||
send(c2s, 'update')
|
||||
return '', 204
|
||||
|
||||
@app.route('/next_preview', methods=['POST'])
|
||||
def next_preview():
|
||||
send(c2s, 'next_preview')
|
||||
return '', 204
|
||||
|
||||
@app.route('/change_history_range', methods=['POST'])
|
||||
def change_history_range():
|
||||
send(c2s, 'change_history_range')
|
||||
return '', 204
|
||||
|
||||
@app.route('/zoom_prev', methods=['POST'])
|
||||
def zoom_prev():
|
||||
send(c2s, 'zoom_prev')
|
||||
return '', 204
|
||||
|
||||
@app.route('/zoom_next', methods=['POST'])
|
||||
def zoom_next():
|
||||
send(c2s, 'zoom_next')
|
||||
return '', 204
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return render_template('index.html')
|
||||
|
||||
# @app.route('/preview_image')
|
||||
# def preview_image():
|
||||
# return Response(gen(), mimetype='multipart/x-mixed-replace;boundary=frame')
|
||||
|
||||
@app.route('/preview_image')
|
||||
def preview_image():
|
||||
return send_file(preview_file, mimetype='image/jpeg', cache_timeout=-1)
|
||||
|
||||
socketio = SocketIO(app)
|
||||
|
||||
@socketio.on('connect', namespace='/')
|
||||
def test_connect():
|
||||
emit('my response', {'data': 'Connected'})
|
||||
|
||||
@socketio.on('disconnect', namespace='/test')
|
||||
def test_disconnect():
|
||||
print('Client disconnected')
|
||||
|
||||
return socketio, app
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
BIN
flaskr/static/favicon.ico
Normal file
BIN
flaskr/static/favicon.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 284 KiB |
88
flaskr/templates/index.html
Normal file
88
flaskr/templates/index.html
Normal file
|
@ -0,0 +1,88 @@
|
|||
<head>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.4.1/jquery.min.js"
|
||||
integrity="sha256-CSXorXvZcTkaix6Yvo6HppcZGetbYMGWSFlBw8HfCJo="
|
||||
crossorigin="anonymous"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/2.2.0/socket.io.js"
|
||||
integrity="sha256-yr4fRk/GU1ehYJPAs8P4JlTgu0Hdsp4ZKrx8bDEDC3I="
|
||||
crossorigin="anonymous"></script>
|
||||
<title>Training Preview</title>
|
||||
<link rel="shortcut icon" href="{{ url_for('static', filename='favicon.ico') }}">
|
||||
<script type="text/javascript">
|
||||
$(function() {
|
||||
function save() {
|
||||
$.post("{{ url_for('save') }}");
|
||||
}
|
||||
|
||||
function exit() {
|
||||
$.post("{{ url_for('exit') }}");
|
||||
}
|
||||
|
||||
function update() {
|
||||
$.post("{{ url_for('update') }}");
|
||||
}
|
||||
|
||||
function next_preview() {
|
||||
$.post("{{ url_for('next_preview') }}");
|
||||
}
|
||||
|
||||
function change_history_range() {
|
||||
$.post("{{ url_for('change_history_range') }}");
|
||||
}
|
||||
|
||||
function zoom_prev() {
|
||||
$.post("{{ url_for('zoom_prev') }}");
|
||||
}
|
||||
|
||||
function zoom_next() {
|
||||
$.post("{{ url_for('zoom_next') }}");
|
||||
}
|
||||
|
||||
$(document).keypress(function (event) {
|
||||
switch (event.key) {
|
||||
case "s" : save(); break;
|
||||
case "Enter" : exit(); break;
|
||||
case "p" : update(); break;
|
||||
case " " : next_preview(); break;
|
||||
case "l" : change_history_range(); break;
|
||||
case "-" : zoom_prev(); break;
|
||||
case "=" : zoom_next(); break;
|
||||
}
|
||||
// console.log('kp:', event);
|
||||
});
|
||||
|
||||
$('button#save').click(save);
|
||||
$('button#exit').click(exit);
|
||||
$('button#update').click(update);
|
||||
$('button#next_preview').click(next_preview);
|
||||
$('button#change_history_range').click(change_history_range);
|
||||
$('button#zoom_prev').click(zoom_prev);
|
||||
$('button#zoom_next').click(zoom_next);
|
||||
|
||||
const socket = io.connect('http://' + document.domain + ':' + location.port);
|
||||
socket.on('preview', function(msg) {
|
||||
console.log(msg);
|
||||
$('img#preview').attr("src", "{{ url_for('preview_image') }}?q=" + new Date().getTime());
|
||||
});
|
||||
|
||||
socket.on('loss', function(loss_string) {
|
||||
console.log(loss_string);
|
||||
$('h2#loss').html(loss_string);
|
||||
});
|
||||
});
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Training Preview</h1>
|
||||
<h2 id="loss"></h2>
|
||||
<div>
|
||||
<button class='btn btn-default' id='save'>Save</button>
|
||||
<button class='btn btn-default' id='exit'>Exit</button>
|
||||
<button class='btn btn-default' id='update'>Update</button>
|
||||
<button class='btn btn-default' id='next_preview'>Next preview</button>
|
||||
<button class='btn btn-default' id='change_history_range'>Change History Range</button>
|
||||
<button class='btn btn-default' id='zoom_prev'>Zoom -</button>
|
||||
<button class='btn btn-default' id='zoom_next'>Zoom +</button>
|
||||
</div>
|
||||
<img id='preview' src="{{ url_for('preview_image') }}">
|
||||
</body>
|
||||
</html>
|
10
main.py
10
main.py
|
@ -106,10 +106,10 @@ if __name__ == "__main__":
|
|||
|
||||
#if arguments.remove_fanseg:
|
||||
# Util.remove_fanseg_folder (input_path=arguments.input_dir)
|
||||
|
||||
|
||||
if arguments.remove_ie_polys:
|
||||
Util.remove_ie_polys_folder (input_path=arguments.input_dir)
|
||||
|
||||
|
||||
p = subparsers.add_parser( "util", help="Utilities.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
|
||||
|
@ -129,11 +129,13 @@ if __name__ == "__main__":
|
|||
'model_name' : arguments.model_name,
|
||||
'no_preview' : arguments.no_preview,
|
||||
'debug' : arguments.debug,
|
||||
'flask_preview' : arguments.flask_preview,
|
||||
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ]
|
||||
}
|
||||
device_args = {'cpu_only' : arguments.cpu_only,
|
||||
'force_gpu_idx' : arguments.force_gpu_idx,
|
||||
}
|
||||
|
||||
from mainscripts import Trainer
|
||||
Trainer.main(args, device_args)
|
||||
|
||||
|
@ -155,6 +157,8 @@ if __name__ == "__main__":
|
|||
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
||||
p.add_argument('--pingpong', dest="ping_pong", default=False,
|
||||
help="Cycle between a batch size of 1 and the chosen batch size")
|
||||
p.add_argument('--flask-preview', action="store_true", dest="flask_preview", default=False,
|
||||
help="Launches a flask server to view the previews in a web browser")
|
||||
p.set_defaults (func=process_train)
|
||||
|
||||
def process_convert(arguments):
|
||||
|
@ -251,7 +255,7 @@ if __name__ == "__main__":
|
|||
p.add_argument('--confirmed-dir', required=True, action=fixPathAction, dest="confirmed_dir", help="This is where the labeled faces will be stored.")
|
||||
p.add_argument('--skipped-dir', required=True, action=fixPathAction, dest="skipped_dir", help="This is where the labeled faces will be stored.")
|
||||
p.add_argument('--no-default-mask', action="store_true", dest="no_default_mask", default=False, help="Don't use default mask.")
|
||||
|
||||
|
||||
p.set_defaults(func=process_labelingtool_edit_mask)
|
||||
|
||||
def bad_args(arguments):
|
||||
|
|
|
@ -1,31 +1,31 @@
|
|||
import sys
|
||||
import traceback
|
||||
import traceback
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
from pathlib import Path
|
||||
from utils import Path_utils
|
||||
import imagelib
|
||||
import cv2
|
||||
import models
|
||||
from interact import interact as io
|
||||
|
||||
def trainerThread (s2c, c2s, e, args, device_args):
|
||||
|
||||
def trainer_thread(s2c, c2s, e, args, device_args, socketio=None):
|
||||
while True:
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
training_data_src_path = Path( args.get('training_data_src_dir', '') )
|
||||
training_data_dst_path = Path( args.get('training_data_dst_dir', '') )
|
||||
training_data_src_path = Path(args.get('training_data_src_dir', ''))
|
||||
training_data_dst_path = Path(args.get('training_data_dst_dir', ''))
|
||||
|
||||
pretraining_data_path = args.get('pretraining_data_dir', '')
|
||||
pretraining_data_path = Path(pretraining_data_path) if pretraining_data_path is not None else None
|
||||
|
||||
model_path = Path( args.get('model_path', '') )
|
||||
model_path = Path(args.get('model_path', ''))
|
||||
model_name = args.get('model_name', '')
|
||||
save_interval_min = 15
|
||||
debug = args.get('debug', '')
|
||||
|
@ -43,33 +43,34 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
model_path.mkdir(exist_ok=True)
|
||||
|
||||
model = models.import_model(model_name)(
|
||||
model_path,
|
||||
training_data_src_path=training_data_src_path,
|
||||
training_data_dst_path=training_data_dst_path,
|
||||
pretraining_data_path=pretraining_data_path,
|
||||
debug=debug,
|
||||
device_args=device_args)
|
||||
model_path,
|
||||
training_data_src_path=training_data_src_path,
|
||||
training_data_dst_path=training_data_dst_path,
|
||||
pretraining_data_path=pretraining_data_path,
|
||||
debug=debug,
|
||||
device_args=device_args)
|
||||
|
||||
is_reached_goal = model.is_reached_iter_goal()
|
||||
|
||||
shared_state = { 'after_save' : False }
|
||||
shared_state = {'after_save': False}
|
||||
loss_string = ""
|
||||
save_iter = model.get_iter()
|
||||
save_iter = model.get_iter()
|
||||
|
||||
def model_save():
|
||||
if not debug and not is_reached_goal:
|
||||
io.log_info ("Saving....", end='\r')
|
||||
io.log_info("Saving....", end='\r')
|
||||
model.save()
|
||||
shared_state['after_save'] = True
|
||||
|
||||
def send_preview():
|
||||
if not debug:
|
||||
previews = model.get_previews()
|
||||
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
|
||||
c2s.put({'op': 'show', 'previews': previews, 'iter': model.get_iter(),
|
||||
'loss_history': model.get_loss_history().copy()})
|
||||
else:
|
||||
previews = [( 'debug, press update for new', model.debug_one_iter())]
|
||||
c2s.put ( {'op':'show', 'previews': previews} )
|
||||
e.set() #Set the GUI Thread as Ready
|
||||
|
||||
previews = [('debug, press update for new', model.debug_one_iter())]
|
||||
c2s.put({'op': 'show', 'previews': previews})
|
||||
e.set() # Set the GUI Thread as Ready
|
||||
|
||||
if model.is_first_run():
|
||||
model_save()
|
||||
|
@ -78,25 +79,26 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
if is_reached_goal:
|
||||
io.log_info('Model already trained to target iteration. You can use preview.')
|
||||
else:
|
||||
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) )
|
||||
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % (
|
||||
model.get_target_iter()))
|
||||
else:
|
||||
io.log_info('Starting. Press "Enter" to stop training and save model.')
|
||||
|
||||
last_save_time = time.time()
|
||||
|
||||
execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ]
|
||||
execute_programs = [[x[0], x[1], time.time()] for x in execute_programs]
|
||||
|
||||
for i in itertools.count(0,1):
|
||||
for i in itertools.count(0, 1):
|
||||
if not debug:
|
||||
cur_time = time.time()
|
||||
|
||||
for x in execute_programs:
|
||||
prog_time, prog, last_time = x
|
||||
exec_prog = False
|
||||
if prog_time > 0 and (cur_time - start_time) >= prog_time:
|
||||
if 0 < prog_time <= (cur_time - start_time):
|
||||
x[0] = 0
|
||||
exec_prog = True
|
||||
elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
|
||||
elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
|
||||
x[2] = cur_time
|
||||
exec_prog = True
|
||||
|
||||
|
@ -104,7 +106,7 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
try:
|
||||
exec(prog)
|
||||
except Exception as e:
|
||||
print("Unable to execute program: %s" % (prog) )
|
||||
print("Unable to execute program: %s" % (prog))
|
||||
|
||||
if not is_reached_goal:
|
||||
iter, iter_time, batch_size = model.train_one_iter()
|
||||
|
@ -112,42 +114,49 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
loss_history = model.get_loss_history()
|
||||
time_str = time.strftime("[%H:%M:%S]")
|
||||
if iter_time >= 10:
|
||||
loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format ( time_str, iter, '{:0.4f}'.format(iter_time), batch_size )
|
||||
loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format(time_str, iter,
|
||||
'{:0.4f}'.format(iter_time),
|
||||
batch_size)
|
||||
else:
|
||||
loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format ( time_str, iter, int(iter_time*1000), batch_size)
|
||||
loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format(time_str, iter,
|
||||
int(iter_time * 1000), batch_size)
|
||||
|
||||
if shared_state['after_save']:
|
||||
shared_state['after_save'] = False
|
||||
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
|
||||
last_save_time = time.time() # upd last_save_time only after save+one_iter, because
|
||||
# plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
|
||||
|
||||
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
|
||||
mean_loss = np.mean([np.array(loss_history[i]) for i in range(save_iter, iter)], axis=0)
|
||||
|
||||
for loss_value in mean_loss:
|
||||
loss_string += "[%.4f]" % (loss_value)
|
||||
loss_string += "[%.4f]" % loss_value
|
||||
|
||||
io.log_info (loss_string)
|
||||
io.log_info(loss_string)
|
||||
|
||||
save_iter = iter
|
||||
else:
|
||||
for loss_value in loss_history[-1]:
|
||||
loss_string += "[%.4f]" % (loss_value)
|
||||
loss_string += "[%.4f]" % loss_value
|
||||
|
||||
if io.is_colab():
|
||||
io.log_info ('\r' + loss_string, end='')
|
||||
io.log_info('\r' + loss_string, end='')
|
||||
else:
|
||||
io.log_info (loss_string, end='\r')
|
||||
io.log_info(loss_string, end='\r')
|
||||
|
||||
if socketio is not None:
|
||||
socketio.emit('loss', loss_string)
|
||||
|
||||
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
|
||||
io.log_info ('Reached target iteration.')
|
||||
io.log_info('Reached target iteration.')
|
||||
model_save()
|
||||
is_reached_goal = True
|
||||
io.log_info ('You can use preview now.')
|
||||
io.log_info('You can use preview now.')
|
||||
|
||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min * 60:
|
||||
model_save()
|
||||
send_preview()
|
||||
|
||||
if i==0:
|
||||
if i == 0:
|
||||
if is_reached_goal:
|
||||
model.pass_one_iter()
|
||||
send_preview()
|
||||
|
@ -156,8 +165,8 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
time.sleep(0.005)
|
||||
|
||||
while not s2c.empty():
|
||||
input = s2c.get()
|
||||
op = input['op']
|
||||
item = s2c.get()
|
||||
op = item['op']
|
||||
if op == 'save':
|
||||
model_save()
|
||||
elif op == 'preview':
|
||||
|
@ -172,32 +181,30 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
if i == -1:
|
||||
break
|
||||
|
||||
|
||||
|
||||
model.finalize()
|
||||
|
||||
except Exception as e:
|
||||
print ('Error: %s' % (str(e)))
|
||||
print('Error: %s' % (str(e)))
|
||||
traceback.print_exc()
|
||||
break
|
||||
c2s.put ( {'op':'close'} )
|
||||
c2s.put({'op': 'close'})
|
||||
|
||||
|
||||
class Zoom(Enum):
|
||||
ZOOM_25 = (1/4, '25%')
|
||||
ZOOM_33 = (1/3, '33%')
|
||||
ZOOM_50 = (1/2, '50%')
|
||||
ZOOM_67 = (2/3, '67%')
|
||||
ZOOM_75 = (3/4, '75%')
|
||||
ZOOM_80 = (4/5, '80%')
|
||||
ZOOM_90 = (9/10, '90%')
|
||||
ZOOM_25 = (1 / 4, '25%')
|
||||
ZOOM_33 = (1 / 3, '33%')
|
||||
ZOOM_50 = (1 / 2, '50%')
|
||||
ZOOM_67 = (2 / 3, '67%')
|
||||
ZOOM_75 = (3 / 4, '75%')
|
||||
ZOOM_80 = (4 / 5, '80%')
|
||||
ZOOM_90 = (9 / 10, '90%')
|
||||
ZOOM_100 = (1, '100%')
|
||||
ZOOM_110 = (11/10, '110%')
|
||||
ZOOM_125 = (5/4, '125%')
|
||||
ZOOM_150 = (3/2, '150%')
|
||||
ZOOM_175 = (7/4, '175%')
|
||||
ZOOM_110 = (11 / 10, '110%')
|
||||
ZOOM_125 = (5 / 4, '125%')
|
||||
ZOOM_150 = (3 / 2, '150%')
|
||||
ZOOM_175 = (7 / 4, '175%')
|
||||
ZOOM_200 = (2, '200%')
|
||||
ZOOM_250 = (5/2, '250%')
|
||||
ZOOM_250 = (5 / 2, '250%')
|
||||
ZOOM_300 = (3, '300%')
|
||||
ZOOM_400 = (4, '400%')
|
||||
ZOOM_500 = (5, '500%')
|
||||
|
@ -254,7 +261,7 @@ def create_preview_pane_image(previews, selected_preview, loss_history,
|
|||
head_lines = [
|
||||
'[s]:save [enter]:exit [-/+]:zoom: %s' % zoom.label,
|
||||
'[p]:update [space]:next preview [l]:change history range',
|
||||
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews))
|
||||
'Preview: "%s" [%d/%d]' % (selected_preview_name, selected_preview + 1, len(previews))
|
||||
]
|
||||
head_line_height = int(15 * zoom.scale)
|
||||
head_height = len(head_lines) * head_line_height
|
||||
|
@ -262,8 +269,8 @@ def create_preview_pane_image(previews, selected_preview, loss_history,
|
|||
|
||||
for i in range(0, len(head_lines)):
|
||||
t = i * head_line_height
|
||||
b = (i+1) * head_line_height
|
||||
head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8]*c)
|
||||
b = (i + 1) * head_line_height
|
||||
head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8] * c)
|
||||
|
||||
final = head
|
||||
|
||||
|
@ -274,69 +281,88 @@ def create_preview_pane_image(previews, selected_preview, loss_history,
|
|||
loss_history_to_show = loss_history[-show_last_history_iters_count:]
|
||||
lh_height = int(100 * zoom.scale)
|
||||
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iteration, batch_size, w, c, lh_height)
|
||||
final = np.concatenate ( [final, lh_img], axis=0 )
|
||||
final = np.concatenate([final, lh_img], axis=0)
|
||||
|
||||
final = np.concatenate([final, selected_preview_rgb], axis=0)
|
||||
final = np.clip(final, 0, 1)
|
||||
return (final*255).astype(np.uint8)
|
||||
return (final * 255).astype(np.uint8)
|
||||
|
||||
|
||||
def main(args, device_args):
|
||||
io.log_info ("Running trainer.\r\n")
|
||||
io.log_info("Running trainer.\r\n")
|
||||
|
||||
no_preview = args.get('no_preview', False)
|
||||
|
||||
flask_preview = args.get('flask_preview', False)
|
||||
|
||||
s2c = queue.Queue()
|
||||
c2s = queue.Queue()
|
||||
|
||||
e = threading.Event()
|
||||
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args) )
|
||||
thread.start()
|
||||
|
||||
e.wait() #Wait for inital load to occur.
|
||||
previews = None
|
||||
loss_history = None
|
||||
selected_preview = 0
|
||||
update_preview = False
|
||||
is_waiting_preview = False
|
||||
show_last_history_iters_count = 0
|
||||
iteration = 0
|
||||
batch_size = 1
|
||||
zoom = Zoom.ZOOM_100
|
||||
|
||||
if no_preview:
|
||||
while True:
|
||||
if not c2s.empty():
|
||||
input = c2s.get()
|
||||
op = input.get('op','')
|
||||
if op == 'close':
|
||||
break
|
||||
try:
|
||||
io.process_messages(0.1)
|
||||
except KeyboardInterrupt:
|
||||
s2c.put ( {'op': 'close'} )
|
||||
else:
|
||||
wnd_name = "Training preview"
|
||||
io.named_window(wnd_name)
|
||||
io.capture_keys(wnd_name)
|
||||
if flask_preview:
|
||||
from flaskr.app import create_flask_app
|
||||
s2flask = queue.Queue()
|
||||
socketio, flask_app = create_flask_app(s2c, c2s, s2flask, args)
|
||||
|
||||
previews = None
|
||||
loss_history = None
|
||||
selected_preview = 0
|
||||
update_preview = False
|
||||
is_showing = False
|
||||
is_waiting_preview = False
|
||||
show_last_history_iters_count = 0
|
||||
iteration = 0
|
||||
batch_size = 1
|
||||
zoom = Zoom.ZOOM_100
|
||||
e = threading.Event()
|
||||
thread = threading.Thread(target=trainer_thread, args=(s2c, c2s, e, args, device_args, socketio))
|
||||
thread.start()
|
||||
|
||||
e.wait() # Wait for inital load to occur.
|
||||
|
||||
flask_t = threading.Thread(target=socketio.run, args=(flask_app,),
|
||||
kwargs={'debug': True, 'use_reloader': False})
|
||||
flask_t.start()
|
||||
|
||||
while True:
|
||||
if not c2s.empty():
|
||||
input = c2s.get()
|
||||
op = input['op']
|
||||
item = c2s.get()
|
||||
op = item['op']
|
||||
if op == 'show':
|
||||
is_waiting_preview = False
|
||||
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
|
||||
previews = input['previews'] if 'previews' in input.keys() else None
|
||||
iteration = input['iter'] if 'iter' in input.keys() else 0
|
||||
#batch_size = input['batch_size'] if 'iter' in input.keys() else 1
|
||||
loss_history = item['loss_history'] if 'loss_history' in item.keys() else None
|
||||
previews = item['previews'] if 'previews' in item.keys() else None
|
||||
iteration = item['iter'] if 'iter' in item.keys() else 0
|
||||
# batch_size = input['batch_size'] if 'iter' in input.keys() else 1
|
||||
if previews is not None:
|
||||
update_preview = True
|
||||
elif op == 'update':
|
||||
if not is_waiting_preview:
|
||||
is_waiting_preview = True
|
||||
s2c.put({'op': 'preview'})
|
||||
elif op == 'next_preview':
|
||||
selected_preview = (selected_preview + 1) % len(previews)
|
||||
update_preview = True
|
||||
elif op == 'change_history_range':
|
||||
if show_last_history_iters_count == 0:
|
||||
show_last_history_iters_count = 5000
|
||||
elif show_last_history_iters_count == 5000:
|
||||
show_last_history_iters_count = 10000
|
||||
elif show_last_history_iters_count == 10000:
|
||||
show_last_history_iters_count = 50000
|
||||
elif show_last_history_iters_count == 50000:
|
||||
show_last_history_iters_count = 100000
|
||||
elif show_last_history_iters_count == 100000:
|
||||
show_last_history_iters_count = 0
|
||||
update_preview = True
|
||||
elif op == 'close':
|
||||
s2c.put({'op': 'close'})
|
||||
break
|
||||
elif op == 'zoom_prev':
|
||||
zoom = zoom.prev()
|
||||
update_preview = True
|
||||
elif op == 'zoom_next':
|
||||
zoom = zoom.next()
|
||||
update_preview = True
|
||||
|
||||
if update_preview:
|
||||
update_preview = False
|
||||
|
@ -348,44 +374,113 @@ def main(args, device_args):
|
|||
iteration,
|
||||
batch_size,
|
||||
zoom)
|
||||
io.show_image(wnd_name, preview_pane_image)
|
||||
is_showing = True
|
||||
|
||||
key_events = io.get_key_events(wnd_name)
|
||||
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
|
||||
|
||||
if key == ord('\n') or key == ord('\r'):
|
||||
s2c.put ( {'op': 'close'} )
|
||||
elif key == ord('s'):
|
||||
s2c.put ( {'op': 'save'} )
|
||||
elif key == ord('p'):
|
||||
if not is_waiting_preview:
|
||||
is_waiting_preview = True
|
||||
s2c.put ( {'op': 'preview'} )
|
||||
elif key == ord('l'):
|
||||
if show_last_history_iters_count == 0:
|
||||
show_last_history_iters_count = 5000
|
||||
elif show_last_history_iters_count == 5000:
|
||||
show_last_history_iters_count = 10000
|
||||
elif show_last_history_iters_count == 10000:
|
||||
show_last_history_iters_count = 50000
|
||||
elif show_last_history_iters_count == 50000:
|
||||
show_last_history_iters_count = 100000
|
||||
elif show_last_history_iters_count == 100000:
|
||||
show_last_history_iters_count = 0
|
||||
update_preview = True
|
||||
elif key == ord(' '):
|
||||
selected_preview = (selected_preview + 1) % len(previews)
|
||||
update_preview = True
|
||||
elif key == ord('-'):
|
||||
zoom = zoom.prev()
|
||||
update_preview = True
|
||||
elif key == ord('=') or key == ord('+'):
|
||||
zoom = zoom.next()
|
||||
update_preview = True
|
||||
# io.show_image(wnd_name, preview_pane_image)
|
||||
model_path = Path(args.get('model_path', ''))
|
||||
filename = 'preview.jpg'
|
||||
preview_file = str(model_path / filename)
|
||||
cv2.imwrite(preview_file, preview_pane_image)
|
||||
s2flask.put({'op': 'show'})
|
||||
socketio.emit('preview', {'iter': iteration, 'loss': loss_history[-1]})
|
||||
try:
|
||||
io.process_messages(0.1)
|
||||
io.process_messages(0.01)
|
||||
except KeyboardInterrupt:
|
||||
s2c.put ( {'op': 'close'} )
|
||||
s2c.put({'op': 'close'})
|
||||
else:
|
||||
thread = threading.Thread(target=trainer_thread, args=(s2c, c2s, e, args, device_args))
|
||||
thread.start()
|
||||
|
||||
io.destroy_all_windows()
|
||||
e.wait() # Wait for inital load to occur.
|
||||
|
||||
if no_preview:
|
||||
while True:
|
||||
if not c2s.empty():
|
||||
item = c2s.get()
|
||||
op = item.get('op', '')
|
||||
if op == 'close':
|
||||
break
|
||||
try:
|
||||
io.process_messages(0.1)
|
||||
except KeyboardInterrupt:
|
||||
s2c.put({'op': 'close'})
|
||||
else:
|
||||
wnd_name = "Training preview"
|
||||
io.named_window(wnd_name)
|
||||
io.capture_keys(wnd_name)
|
||||
|
||||
previews = None
|
||||
loss_history = None
|
||||
selected_preview = 0
|
||||
update_preview = False
|
||||
is_showing = False
|
||||
is_waiting_preview = False
|
||||
show_last_history_iters_count = 0
|
||||
iteration = 0
|
||||
batch_size = 1
|
||||
zoom = Zoom.ZOOM_100
|
||||
|
||||
while True:
|
||||
if not c2s.empty():
|
||||
item = c2s.get()
|
||||
op = item['op']
|
||||
if op == 'show':
|
||||
is_waiting_preview = False
|
||||
loss_history = item['loss_history'] if 'loss_history' in item.keys() else None
|
||||
previews = item['previews'] if 'previews' in item.keys() else None
|
||||
iteration = item['iter'] if 'iter' in item.keys() else 0
|
||||
# batch_size = input['batch_size'] if 'iter' in input.keys() else 1
|
||||
if previews is not None:
|
||||
update_preview = True
|
||||
elif op == 'close':
|
||||
break
|
||||
|
||||
if update_preview:
|
||||
update_preview = False
|
||||
selected_preview = selected_preview % len(previews)
|
||||
preview_pane_image = create_preview_pane_image(previews,
|
||||
selected_preview,
|
||||
loss_history,
|
||||
show_last_history_iters_count,
|
||||
iteration,
|
||||
batch_size,
|
||||
zoom)
|
||||
io.show_image(wnd_name, preview_pane_image)
|
||||
|
||||
key_events = io.get_key_events(wnd_name)
|
||||
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (
|
||||
0, 0, False, False, False)
|
||||
|
||||
if key == ord('\n') or key == ord('\r'):
|
||||
s2c.put({'op': 'close'})
|
||||
elif key == ord('s'):
|
||||
s2c.put({'op': 'save'})
|
||||
elif key == ord('p'):
|
||||
if not is_waiting_preview:
|
||||
is_waiting_preview = True
|
||||
s2c.put({'op': 'preview'})
|
||||
elif key == ord('l'):
|
||||
if show_last_history_iters_count == 0:
|
||||
show_last_history_iters_count = 5000
|
||||
elif show_last_history_iters_count == 5000:
|
||||
show_last_history_iters_count = 10000
|
||||
elif show_last_history_iters_count == 10000:
|
||||
show_last_history_iters_count = 50000
|
||||
elif show_last_history_iters_count == 50000:
|
||||
show_last_history_iters_count = 100000
|
||||
elif show_last_history_iters_count == 100000:
|
||||
show_last_history_iters_count = 0
|
||||
update_preview = True
|
||||
elif key == ord(' '):
|
||||
selected_preview = (selected_preview + 1) % len(previews)
|
||||
update_preview = True
|
||||
elif key == ord('-'):
|
||||
zoom = zoom.prev()
|
||||
update_preview = True
|
||||
elif key == ord('=') or key == ord('+'):
|
||||
zoom = zoom.next()
|
||||
update_preview = True
|
||||
try:
|
||||
io.process_messages(0.1)
|
||||
except KeyboardInterrupt:
|
||||
s2c.put({'op': 'close'})
|
||||
|
||||
io.destroy_all_windows()
|
||||
|
|
|
@ -7,4 +7,6 @@ plaidml-keras==0.5.0
|
|||
scikit-image
|
||||
tqdm
|
||||
ffmpeg-python==0.1.17
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
Flask==1.1.1
|
||||
flask-socketio==4.2.1
|
||||
|
|
|
@ -7,3 +7,5 @@ scikit-image
|
|||
tqdm
|
||||
ffmpeg-python==0.1.17
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
Flask==1.1.1
|
||||
flask-socketio==4.2.1
|
||||
|
|
|
@ -9,3 +9,5 @@ scikit-image
|
|||
tqdm
|
||||
ffmpeg-python==0.1.17
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
Flask==1.1.1
|
||||
flask-socketio==4.2.1
|
||||
|
|
|
@ -9,3 +9,5 @@ scikit-image
|
|||
tqdm
|
||||
ffmpeg-python==0.1.17
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
Flask==1.1.1
|
||||
flask-socketio==4.2.1
|
||||
|
|
7
update-prebuilt-dependecies.bat
Normal file
7
update-prebuilt-dependecies.bat
Normal file
|
@ -0,0 +1,7 @@
|
|||
@echo off
|
||||
call ..\setenv.bat
|
||||
|
||||
python -m pip install Flask==1.1.1
|
||||
python -m pip install flask-socketio==4.2.1
|
||||
|
||||
pause
|
Loading…
Add table
Add a link
Reference in a new issue