better flask previews

This commit is contained in:
jh 2019-09-13 16:23:07 -07:00
commit 637bf69eb7

View file

@ -3,8 +3,9 @@ import traceback
import queue
import threading
import time
from io import BytesIO
import base64
from enum import Enum
from os.path import getmtime
import numpy as np
import itertools
from pathlib import Path
@ -14,20 +15,20 @@ import cv2
import models
from interact import interact as io
from flask import Flask, send_file, Response, render_template, render_template_string, request, g
from flask_caching import Cache
# from flask_socketio import SocketIO
def trainerThread(s2c, c2s, e, args, device_args):
def trainerThread (s2c, c2s, e, args, device_args):
while True:
try:
start_time = time.time()
training_data_src_path = Path(args.get('training_data_src_dir', ''))
training_data_dst_path = Path(args.get('training_data_dst_dir', ''))
training_data_src_path = Path( args.get('training_data_src_dir', '') )
training_data_dst_path = Path( args.get('training_data_dst_dir', '') )
pretraining_data_path = args.get('pretraining_data_dir', '')
pretraining_data_path = Path(pretraining_data_path) if pretraining_data_path is not None else None
model_path = Path(args.get('model_path', ''))
model_path = Path( args.get('model_path', '') )
model_name = args.get('model_name', '')
save_interval_min = 15
debug = args.get('debug', '')
@ -54,25 +55,24 @@ def trainerThread(s2c, c2s, e, args, device_args):
is_reached_goal = model.is_reached_iter_goal()
shared_state = {'after_save': False}
shared_state = { 'after_save' : False }
loss_string = ""
save_iter = model.get_iter()
save_iter = model.get_iter()
def model_save():
if not debug and not is_reached_goal:
io.log_info("Saving....", end='\r')
io.log_info ("Saving....", end='\r')
model.save()
shared_state['after_save'] = True
def send_preview():
if not debug:
previews = model.get_previews()
c2s.put({'op': 'show', 'previews': previews, 'iter': model.get_iter(),
'loss_history': model.get_loss_history().copy()})
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
else:
previews = [('debug, press update for new', model.debug_one_iter())]
c2s.put({'op': 'show', 'previews': previews})
e.set() # Set the GUI Thread as Ready
previews = [( 'debug, press update for new', model.debug_one_iter())]
c2s.put ( {'op':'show', 'previews': previews} )
e.set() #Set the GUI Thread as Ready
if model.is_first_run():
model_save()
@ -81,16 +81,15 @@ def trainerThread(s2c, c2s, e, args, device_args):
if is_reached_goal:
io.log_info('Model already trained to target iteration. You can use preview.')
else:
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % (
model.get_target_iter()))
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) )
else:
io.log_info('Starting. Press "Enter" to stop training and save model.')
last_save_time = time.time()
execute_programs = [[x[0], x[1], time.time()] for x in execute_programs]
execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ]
for i in itertools.count(0, 1):
for i in itertools.count(0,1):
if not debug:
cur_time = time.time()
@ -100,7 +99,7 @@ def trainerThread(s2c, c2s, e, args, device_args):
if prog_time > 0 and (cur_time - start_time) >= prog_time:
x[0] = 0
exec_prog = True
elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
elif prog_time < 0 and (cur_time - last_time) >= -prog_time:
x[2] = cur_time
exec_prog = True
@ -108,7 +107,7 @@ def trainerThread(s2c, c2s, e, args, device_args):
try:
exec(prog)
except Exception as e:
print("Unable to execute program: %s" % (prog))
print("Unable to execute program: %s" % (prog) )
if not is_reached_goal:
iter, iter_time, batch_size = model.train_one_iter()
@ -116,23 +115,20 @@ def trainerThread(s2c, c2s, e, args, device_args):
loss_history = model.get_loss_history()
time_str = time.strftime("[%H:%M:%S]")
if iter_time >= 10:
loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format(time_str, iter,
'{:0.4f}'.format(iter_time),
batch_size)
loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format ( time_str, iter, '{:0.4f}'.format(iter_time), batch_size )
else:
loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format(time_str, iter,
int(iter_time * 1000), batch_size)
loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format ( time_str, iter, int(iter_time*1000), batch_size)
if shared_state['after_save']:
shared_state['after_save'] = False
last_save_time = time.time() # upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274
mean_loss = np.mean([np.array(loss_history[i]) for i in range(save_iter, iter)], axis=0)
mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0)
for loss_value in mean_loss:
loss_string += "[%.4f]" % (loss_value)
io.log_info(loss_string)
io.log_info (loss_string)
save_iter = iter
else:
@ -140,21 +136,21 @@ def trainerThread(s2c, c2s, e, args, device_args):
loss_string += "[%.4f]" % (loss_value)
if io.is_colab():
io.log_info('\r' + loss_string, end='')
io.log_info ('\r' + loss_string, end='')
else:
io.log_info(loss_string, end='\r')
io.log_info (loss_string, end='\r')
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
io.log_info('Reached target iteration.')
io.log_info ('Reached target iteration.')
model_save()
is_reached_goal = True
io.log_info('You can use preview now.')
io.log_info ('You can use preview now.')
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min * 60:
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
model_save()
send_preview()
if i == 0:
if i==0:
if is_reached_goal:
model.pass_one_iter()
send_preview()
@ -179,169 +175,127 @@ def trainerThread(s2c, c2s, e, args, device_args):
if i == -1:
break
model.finalize()
except Exception as e:
print('Error: %s' % (str(e)))
print ('Error: %s' % (str(e)))
traceback.print_exc()
break
c2s.put({'op': 'close'})
c2s.put ( {'op':'close'} )
class Preview:
def __init__(self, c2s, s2c, preview_queue):
self.c2s = c2s
self.s2c = s2c
self.preview_queue = preview_queue
# self.wnd_name = "Training preview"
# io.named_window(wnd_name)
# io.capture_keys(wnd_name)
class Zoom(Enum):
ZOOM_25 = (1/4, '25%')
ZOOM_33 = (1/3, '33%')
ZOOM_50 = (1/2, '50%')
ZOOM_67 = (2/3, '67%')
ZOOM_75 = (3/4, '75%')
ZOOM_80 = (4/5, '80%')
ZOOM_90 = (9/10, '90%')
ZOOM_100 = (1, '100%')
ZOOM_110 = (11/10, '110%')
ZOOM_125 = (5/4, '125%')
ZOOM_150 = (3/2, '150%')
ZOOM_175 = (7/4, '175%')
ZOOM_200 = (2, '200%')
ZOOM_250 = (5/2, '250%')
ZOOM_300 = (3, '300%')
ZOOM_400 = (4, '400%')
ZOOM_500 = (5, '500%')
self.previews = None
self.loss_history = None
self.selected_preview = 0
self.update_preview = False
self.is_showing = False
self.is_waiting_preview = False
self.show_last_history_iters_count = 0
self.iter = 0
self.batch_size = 1
self.preview_min_height = 512
self.preview_max_height = 1024
self.close = False
def __init__(self, scale, label):
self.scale = scale
self.label = label
def get_preview(self):
while not self.close:
self.process_queue_items()
self.update_preview_frame()
def prev(self):
cls = self.__class__
members = list(cls)
index = members.index(self) - 1
if index < 0:
return self
return members[index]
def process_queue_items(self):
if not self.c2s.empty():
input = self.c2s.get()
op = input['op']
if op == 'show':
self.is_waiting_preview = False
self.loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
self.previews = input['previews'] if 'previews' in input.keys() else None
self.iter = input['iter'] if 'iter' in input.keys() else 0
if self.previews is not None:
self.resize_previews()
self.selected_preview = self.selected_preview % len(self.previews)
self.update_preview = True
elif op == 'close':
self.close = True
elif op == 'update':
self.update()
elif op == 'next_preview':
self.next_preview()
elif op == 'change_history_range':
self.change_history_range()
def update_preview_frame(self):
if self.update_preview:
self.update_preview = False
selected_preview_name = self.previews[self.selected_preview][0]
selected_preview_rgb = self.previews[self.selected_preview][1]
(h, w, c) = selected_preview_rgb.shape
# HEAD
head_lines = [
'[s]:save [enter]:exit',
'[p]:update [space]:next preview [l]:change history range',
'Preview: "%s" [%d/%d]' % (selected_preview_name, self.selected_preview + 1, len(self.previews))
]
head_line_height = 15
head_height = len(head_lines) * head_line_height
head = np.ones((head_height, w, c)) * 0.1
for i in range(0, len(head_lines)):
t = i * head_line_height
b = (i + 1) * head_line_height
head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8] * c)
final = head
if self.loss_history is not None:
if self.show_last_history_iters_count == 0:
loss_history_to_show = self.loss_history
else:
loss_history_to_show = self.loss_history[-self.show_last_history_iters_count:]
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, self.iter, self.batch_size, w,
c)
final = np.concatenate([final, lh_img], axis=0)
final = np.concatenate([final, selected_preview_rgb], axis=0)
final = np.clip(final, 0, 1)
preview_pane = (final * 255).astype(np.uint8)
retval, buffer = cv2.imencode('.jpg', preview_pane)
# jpg_as_text = base64.b64encode(buffer)
jpg_as_text = buffer.tostring()
self.preview_queue.put(jpg_as_text)
def resize_previews(self):
preview_height = max((h for h, w, c in (im.shape for name, im in self.previews)))
if preview_height > self.preview_max_height:
preview_height = self.preview_max_height
elif preview_height < self.preview_min_height:
preview_height = self.preview_min_height
# make all previews size equal
for p in self.previews[:]:
(preview_name, preview_rgb) = p
(h, w, c) = preview_rgb.shape
if h != preview_height:
scale_factor = preview_height / float(h)
self.previews.remove(p)
self.previews.append((preview_name, cv2.resize(preview_rgb, (0, 0),
fx=scale_factor,
fy=scale_factor,
interpolation=cv2.INTER_AREA)))
self.selected_preview = self.selected_preview % len(self.previews)
def save(self):
self.s2c.put({'op': 'save'})
def exit(self):
self.s2c.put({'op': 'close'})
def update(self):
if not self.is_waiting_preview:
self.is_waiting_preview = True
self.s2c.put({'op': 'preview'})
def next_preview(self):
self.selected_preview = (self.selected_preview + 1) % len(self.previews)
self.update_preview = True
def change_history_range(self):
if self.show_last_history_iters_count == 0:
self.show_last_history_iters_count = 5000
elif self.show_last_history_iters_count == 5000:
self.show_last_history_iters_count = 10000
elif self.show_last_history_iters_count == 10000:
self.show_last_history_iters_count = 50000
elif self.show_last_history_iters_count == 50000:
self.show_last_history_iters_count = 100000
elif self.show_last_history_iters_count == 100000:
self.show_last_history_iters_count = 0
self.update_preview = True
def next(self):
cls = self.__class__
members = list(cls)
index = members.index(self) + 1
if index >= len(members):
return self
return members[index]
def flask_thread(s2c, c2s, preview_queue):
config = {
"DEBUG": True, # some Flask specific configs
"CACHE_TYPE": "simple", # Flask-Caching related configs
"CACHE_DEFAULT_TIMEOUT": 300
}
def scale_previews(previews, zoom=Zoom.ZOOM_100):
scaled = []
for preview in previews:
preview_name, preview_rgb = preview
scale_factor = zoom.scale
if scale_factor < 1:
scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0),
fx=scale_factor,
fy=scale_factor,
interpolation=cv2.INTER_AREA)))
elif scale_factor > 1:
scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0),
fx=scale_factor,
fy=scale_factor,
interpolation=cv2.INTER_LANCZOS4)))
else:
scaled.append((preview_name, preview_rgb))
return scaled
def create_preview_pane_image(previews, selected_preview, loss_history,
show_last_history_iters_count, iteration, batch_size, zoom=Zoom.ZOOM_100):
scaled_previews = scale_previews(previews, zoom)
selected_preview_name = scaled_previews[selected_preview][0]
selected_preview_rgb = scaled_previews[selected_preview][1]
h, w, c = selected_preview_rgb.shape
# HEAD
head_lines = [
'[s]:save [enter]:exit [-/+]:zoom: %s' % zoom.label,
'[p]:update [space]:next preview [l]:change history range',
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews))
]
head_line_height = int(15 * zoom.scale)
head_height = len(head_lines) * head_line_height
head = np.ones((head_height, w, c)) * 0.1
for i in range(0, len(head_lines)):
t = i * head_line_height
b = (i+1) * head_line_height
head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8]*c)
final = head
if loss_history is not None:
if show_last_history_iters_count == 0:
loss_history_to_show = loss_history
else:
loss_history_to_show = loss_history[-show_last_history_iters_count:]
lh_height = int(100 * zoom.scale)
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iteration, batch_size, w, c, lh_height)
final = np.concatenate ( [final, lh_img], axis=0 )
final = np.concatenate([final, selected_preview_rgb], axis=0)
final = np.clip(final, 0, 1)
return (final*255).astype(np.uint8)
def flask_thread(s2c, c2s, s2flask, args):
# config = {
# "DEBUG": True, # some Flask specific configs
# "CACHE_TYPE": "simple", # Flask-Caching related configs
# "CACHE_DEFAULT_TIMEOUT": 300
# }
app = Flask(__name__)
app.config.from_mapping(config)
cache = Cache(app)
# app.config.from_mapping(config)
# cache = Cache(app)
template = """<html>
<head>
<title>Video Streaming Demonstration</title>
<title>Flask Server Demonstration</title>
</head>
<body>
<h1>Video Streaming Demonstration</h1>
@ -357,8 +311,17 @@ def flask_thread(s2c, c2s, preview_queue):
</html>"""
def gen():
if not preview_queue.empty():
frame = preview_queue.get()
model_path = Path(args.get('model_path', ''))
print('[MainThread]', 'model_path:', model_path)
filename = 'preview.jpg'
preview_file = str(model_path / filename)
print('[MainThread]', 'preview_file:', preview_file)
frame = open(preview_file, 'rb').read()
while True:
try:
frame = open(preview_file, 'rb').read()
except:
pass
yield b'--frame\r\nContent-Type: image/jpeg\r\n\r\n'
yield frame
yield b'\r\n\r\n'
@ -371,46 +334,135 @@ def flask_thread(s2c, c2s, preview_queue):
elif 'exit' in request.form:
s2c.put({'op': 'close'})
elif 'update' in request.form:
while not s2flask.empty():
input = s2flask.get()
c2s.put({'op': 'update'})
while s2flask.empty():
pass
input = s2flask.get()
elif 'next_preview' in request.form:
c2s.put({'op': 'preview'})
while not s2flask.empty():
input = s2flask.get()
c2s.put({'op': 'next_preview'})
while s2flask.empty():
pass
input = s2flask.get()
elif 'change_history_range' in request.form:
while not s2flask.empty():
input = s2flask.get()
c2s.put({'op': 'change_history_range'})
while s2flask.empty():
pass
input = s2flask.get()
# return '', 204
return render_template_string(template)
def queue_not_empty():
return not preview_queue.empty()
# @app.route('/preview_image')
# @cache.cached(timeout=300, unless=queue_not_empty)
# def preview_image():
# yield Response(preview_queue.get(),
# mimetype='multipart/x-mixed-replace;boundary=frame')
# return Response(gen(), mimetype='multipart/x-mixed-replace;boundary=frame')
@app.route('/preview_image')
@cache.cached(timeout=300, unless=queue_not_empty)
def preview_image():
return Response(preview_queue.get(), mimetype='image/jpeg')
model_path = Path(args.get('model_path', ''))
filename = 'preview.jpg'
preview_file = str(model_path / filename)
return send_file(preview_file, mimetype='image/jpeg', cache_timeout=-1)
app.run(debug=True, use_reloader=False)
app.run(debug=False, use_reloader=False)
def main(args, device_args):
io.log_info("Running trainer.\r\n")
io.log_info ("Running trainer.\r\n")
no_preview = args.get('no_preview', False)
s2c = queue.Queue()
c2s = queue.Queue()
preview_queue = queue.Queue()
s2flask = queue.Queue()
e = threading.Event()
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args))
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args) )
thread.start()
e.wait() # Wait for inital load to occur.
e.wait() #Wait for inital load to occur.
flask_t = threading.Thread(target=flask_thread, args=(s2c, c2s, preview_queue))
flask_t = threading.Thread(target=flask_thread, args=(s2c, c2s, s2flask, args))
flask_t.start()
preview = Preview(c2s, s2c, preview_queue)
preview.get_preview()
wnd_name = "Training preview"
io.named_window(wnd_name)
io.capture_keys(wnd_name)
previews = None
loss_history = None
selected_preview = 0
update_preview = False
is_showing = False
is_waiting_preview = False
show_last_history_iters_count = 0
iteration = 0
batch_size = 1
zoom = Zoom.ZOOM_100
while True:
if not c2s.empty():
input = c2s.get()
op = input['op']
if op == 'show':
is_waiting_preview = False
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
previews = input['previews'] if 'previews' in input.keys() else None
iteration = input['iter'] if 'iter' in input.keys() else 0
#batch_size = input['batch_size'] if 'iter' in input.keys() else 1
if previews is not None:
update_preview = True
elif op == 'update':
if not is_waiting_preview:
is_waiting_preview = True
s2c.put({'op': 'preview'})
elif op == 'next_preview':
selected_preview = (selected_preview + 1) % len(previews)
update_preview = True
elif op == 'change_history_range':
if show_last_history_iters_count == 0:
show_last_history_iters_count = 5000
elif show_last_history_iters_count == 5000:
show_last_history_iters_count = 10000
elif show_last_history_iters_count == 10000:
show_last_history_iters_count = 50000
elif show_last_history_iters_count == 50000:
show_last_history_iters_count = 100000
elif show_last_history_iters_count == 100000:
show_last_history_iters_count = 0
update_preview = True
if update_preview:
update_preview = False
selected_preview = selected_preview % len(previews)
preview_pane_image = create_preview_pane_image(previews,
selected_preview,
loss_history,
show_last_history_iters_count,
iteration,
batch_size,
zoom)
# io.show_image(wnd_name, preview_pane_image)
model_path = Path(args.get('model_path', ''))
filename = 'preview.jpg'
preview_file = str(model_path / filename)
cv2.imwrite(preview_file, preview_pane_image)
s2flask.put({'op': 'show'})
# socketio.emit('some event', {'data': 42})
# cv2.imshow(wnd_name, preview_pane_image)
is_showing = True
try:
io.process_messages(0.01)
except KeyboardInterrupt:
s2c.put({'op': 'close'})
io.destroy_all_windows()