Refactor preview image code

This commit is contained in:
jh 2019-09-11 18:28:56 -07:00
commit f94c77b226

View file

@ -181,6 +181,64 @@ def trainerThread (s2c, c2s, e, args, device_args):
c2s.put ( {'op':'close'} )
def scale_previews(previews, preview_min_height, preview_max_height):
preview_height = max((h for h, w, c in (im.shape for name, im in previews)))
if preview_height > preview_max_height:
preview_height = preview_max_height
elif preview_height < preview_min_height:
preview_height = preview_min_height
# make all previews size equal
for preview in previews[:]:
(preview_name, preview_rgb) = preview
(h, w, c) = preview_rgb.shape
if h != preview_height:
scale_factor = preview_height / float(h)
previews.remove(preview)
previews.append((preview_name, cv2.resize(preview_rgb, (0, 0),
fx=scale_factor,
fy=scale_factor,
interpolation=cv2.INTER_AREA)))
return previews
def create_preview_pane_image(previews, selected_preview, loss_history,
show_last_history_iters_count, batch_size):
selected_preview_name = previews[selected_preview][0]
selected_preview_rgb = previews[selected_preview][1]
(h,w,c) = selected_preview_rgb.shape
# HEAD
head_lines = [
'[s]:save [enter]:exit',
'[p]:update [space]:next preview [l]:change history range',
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) )
]
head_line_height = 15
head_height = len(head_lines) * head_line_height
head = np.ones ( (head_height,w,c) ) * 0.1
for i in range(0, len(head_lines)):
t = i*head_line_height
b = (i+1)*head_line_height
head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c )
final = head
if loss_history is not None:
if show_last_history_iters_count == 0:
loss_history_to_show = loss_history
else:
loss_history_to_show = loss_history[-show_last_history_iters_count:]
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, batch_size, w, c)
final = np.concatenate ( [final, lh_img], axis=0 )
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
final = np.clip(final, 0, 1)
return (final*255).astype(np.uint8)
def main(args, device_args):
io.log_info ("Running trainer.\r\n")
@ -236,24 +294,7 @@ def main(args, device_args):
iter = input['iter'] if 'iter' in input.keys() else 0
#batch_size = input['batch_size'] if 'iter' in input.keys() else 1
if previews is not None:
preview_height = max((h for h, w, c in (im.shape for name, im in previews)))
if preview_height > preview_max_height:
preview_height = preview_max_height
elif preview_height < preview_min_height:
preview_height = preview_min_height
# make all previews size equal
for preview in previews[:]:
(preview_name, preview_rgb) = preview
(h, w, c) = preview_rgb.shape
if h != preview_height:
scale_factor = preview_height / float(h)
previews.remove(preview)
previews.append((preview_name, cv2.resize(preview_rgb, (0, 0),
fx=scale_factor,
fy=scale_factor,
interpolation=cv2.INTER_AREA)))
previews = scale_previews(previews, preview_min_height, preview_max_height)
selected_preview = selected_preview % len(previews)
update_preview = True
elif op == 'close':
@ -261,41 +302,12 @@ def main(args, device_args):
if update_preview:
update_preview = False
selected_preview_name = previews[selected_preview][0]
selected_preview_rgb = previews[selected_preview][1]
(h,w,c) = selected_preview_rgb.shape
# HEAD
head_lines = [
'[s]:save [enter]:exit',
'[p]:update [space]:next preview [l]:change history range',
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) )
]
head_line_height = 15
head_height = len(head_lines) * head_line_height
head = np.ones ( (head_height,w,c) ) * 0.1
for i in range(0, len(head_lines)):
t = i*head_line_height
b = (i+1)*head_line_height
head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c )
final = head
if loss_history is not None:
if show_last_history_iters_count == 0:
loss_history_to_show = loss_history
else:
loss_history_to_show = loss_history[-show_last_history_iters_count:]
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, batch_size, w, c)
final = np.concatenate ( [final, lh_img], axis=0 )
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
final = np.clip(final, 0, 1)
io.show_image( wnd_name, (final*255).astype(np.uint8) )
preview_pane_image = create_preview_pane_image(previews,
selected_preview,
loss_history,
show_last_history_iters_count,
batch_size)
io.show_image( wnd_name, preview_pane_image)
is_showing = True
key_events = io.get_key_events(wnd_name)
@ -324,6 +336,12 @@ def main(args, device_args):
elif key == ord(' '):
selected_preview = (selected_preview + 1) % len(previews)
update_preview = True
elif key == ord('-'):
# Decrease zoom
pass
elif key == ord('+'):
# Increase zoom
pass
try:
io.process_messages(0.1)