mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 14:24:40 -07:00
Merge pull request #63 from faceshiftlabs/feat/zoomable-trainer-preview
Feat/zoomable trainer preview
This commit is contained in:
commit
fdef5622dd
2 changed files with 119 additions and 61 deletions
|
@ -3,6 +3,8 @@ import traceback
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
|
@ -181,6 +183,103 @@ def trainerThread (s2c, c2s, e, args, device_args):
|
|||
c2s.put ( {'op':'close'} )
|
||||
|
||||
|
||||
class Zoom(Enum):
|
||||
ZOOM_25 = (1/4, '25%')
|
||||
ZOOM_33 = (1/3, '33%')
|
||||
ZOOM_50 = (1/2, '50%')
|
||||
ZOOM_67 = (2/3, '67%')
|
||||
ZOOM_75 = (3/4, '75%')
|
||||
ZOOM_80 = (4/5, '80%')
|
||||
ZOOM_90 = (9/10, '90%')
|
||||
ZOOM_100 = (1, '100%')
|
||||
ZOOM_110 = (11/10, '110%')
|
||||
ZOOM_125 = (5/4, '125%')
|
||||
ZOOM_150 = (3/2, '150%')
|
||||
ZOOM_175 = (7/4, '175%')
|
||||
ZOOM_200 = (2, '200%')
|
||||
ZOOM_250 = (5/2, '250%')
|
||||
ZOOM_300 = (3, '300%')
|
||||
ZOOM_400 = (4, '400%')
|
||||
ZOOM_500 = (5, '500%')
|
||||
|
||||
def __init__(self, scale, label):
|
||||
self.scale = scale
|
||||
self.label = label
|
||||
|
||||
def prev(self):
|
||||
cls = self.__class__
|
||||
members = list(cls)
|
||||
index = members.index(self) - 1
|
||||
if index < 0:
|
||||
return self
|
||||
return members[index]
|
||||
|
||||
def next(self):
|
||||
cls = self.__class__
|
||||
members = list(cls)
|
||||
index = members.index(self) + 1
|
||||
if index >= len(members):
|
||||
return self
|
||||
return members[index]
|
||||
|
||||
|
||||
def scale_previews(previews, zoom=Zoom.ZOOM_100):
|
||||
scaled = []
|
||||
for preview in previews:
|
||||
preview_name, preview_rgb = preview
|
||||
scale_factor = zoom.scale
|
||||
if scale_factor < 1:
|
||||
scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0),
|
||||
fx=scale_factor,
|
||||
fy=scale_factor,
|
||||
interpolation=cv2.INTER_AREA)))
|
||||
elif scale_factor > 1:
|
||||
scaled.append((preview_name, cv2.resize(preview_rgb, (0, 0),
|
||||
fx=scale_factor,
|
||||
fy=scale_factor,
|
||||
interpolation=cv2.INTER_LANCZOS4)))
|
||||
else:
|
||||
scaled.append((preview_name, preview_rgb))
|
||||
return scaled
|
||||
|
||||
|
||||
def create_preview_pane_image(previews, selected_preview, loss_history,
|
||||
show_last_history_iters_count, iteration, batch_size, zoom=Zoom.ZOOM_100):
|
||||
scaled_previews = scale_previews(previews, zoom)
|
||||
selected_preview_name = scaled_previews[selected_preview][0]
|
||||
selected_preview_rgb = scaled_previews[selected_preview][1]
|
||||
h, w, c = selected_preview_rgb.shape
|
||||
|
||||
# HEAD
|
||||
head_lines = [
|
||||
'[s]:save [enter]:exit [-/+]:zoom: %s' % zoom.label,
|
||||
'[p]:update [space]:next preview [l]:change history range',
|
||||
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews))
|
||||
]
|
||||
head_line_height = int(15 * zoom.scale)
|
||||
head_height = len(head_lines) * head_line_height
|
||||
head = np.ones((head_height, w, c)) * 0.1
|
||||
|
||||
for i in range(0, len(head_lines)):
|
||||
t = i * head_line_height
|
||||
b = (i+1) * head_line_height
|
||||
head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8]*c)
|
||||
|
||||
final = head
|
||||
|
||||
if loss_history is not None:
|
||||
if show_last_history_iters_count == 0:
|
||||
loss_history_to_show = loss_history
|
||||
else:
|
||||
loss_history_to_show = loss_history[-show_last_history_iters_count:]
|
||||
lh_height = int(100 * zoom.scale)
|
||||
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iteration, batch_size, w, c, lh_height)
|
||||
final = np.concatenate ( [final, lh_img], axis=0 )
|
||||
|
||||
final = np.concatenate([final, selected_preview_rgb], axis=0)
|
||||
final = np.clip(final, 0, 1)
|
||||
return (final*255).astype(np.uint8)
|
||||
|
||||
|
||||
def main(args, device_args):
|
||||
io.log_info ("Running trainer.\r\n")
|
||||
|
@ -220,10 +319,9 @@ def main(args, device_args):
|
|||
is_showing = False
|
||||
is_waiting_preview = False
|
||||
show_last_history_iters_count = 0
|
||||
iter = 0
|
||||
iteration = 0
|
||||
batch_size = 1
|
||||
preview_min_height = 512
|
||||
preview_max_height = 1024
|
||||
zoom = Zoom.ZOOM_100
|
||||
|
||||
while True:
|
||||
if not c2s.empty():
|
||||
|
@ -233,69 +331,24 @@ def main(args, device_args):
|
|||
is_waiting_preview = False
|
||||
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
|
||||
previews = input['previews'] if 'previews' in input.keys() else None
|
||||
iter = input['iter'] if 'iter' in input.keys() else 0
|
||||
iteration = input['iter'] if 'iter' in input.keys() else 0
|
||||
#batch_size = input['batch_size'] if 'iter' in input.keys() else 1
|
||||
if previews is not None:
|
||||
preview_height = max((h for h, w, c in (im.shape for name, im in previews)))
|
||||
|
||||
if preview_height > preview_max_height:
|
||||
preview_height = preview_max_height
|
||||
elif preview_height < preview_min_height:
|
||||
preview_height = preview_min_height
|
||||
|
||||
# make all previews size equal
|
||||
for preview in previews[:]:
|
||||
(preview_name, preview_rgb) = preview
|
||||
(h, w, c) = preview_rgb.shape
|
||||
if h != preview_height:
|
||||
scale_factor = preview_height / float(h)
|
||||
previews.remove(preview)
|
||||
previews.append((preview_name, cv2.resize(preview_rgb, (0, 0),
|
||||
fx=scale_factor,
|
||||
fy=scale_factor,
|
||||
interpolation=cv2.INTER_AREA)))
|
||||
selected_preview = selected_preview % len(previews)
|
||||
update_preview = True
|
||||
elif op == 'close':
|
||||
break
|
||||
|
||||
if update_preview:
|
||||
update_preview = False
|
||||
|
||||
selected_preview_name = previews[selected_preview][0]
|
||||
selected_preview_rgb = previews[selected_preview][1]
|
||||
(h,w,c) = selected_preview_rgb.shape
|
||||
|
||||
# HEAD
|
||||
head_lines = [
|
||||
'[s]:save [enter]:exit',
|
||||
'[p]:update [space]:next preview [l]:change history range',
|
||||
'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) )
|
||||
]
|
||||
head_line_height = 15
|
||||
head_height = len(head_lines) * head_line_height
|
||||
head = np.ones ( (head_height,w,c) ) * 0.1
|
||||
|
||||
for i in range(0, len(head_lines)):
|
||||
t = i*head_line_height
|
||||
b = (i+1)*head_line_height
|
||||
head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c )
|
||||
|
||||
final = head
|
||||
|
||||
if loss_history is not None:
|
||||
if show_last_history_iters_count == 0:
|
||||
loss_history_to_show = loss_history
|
||||
else:
|
||||
loss_history_to_show = loss_history[-show_last_history_iters_count:]
|
||||
|
||||
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, batch_size, w, c)
|
||||
final = np.concatenate ( [final, lh_img], axis=0 )
|
||||
|
||||
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
|
||||
final = np.clip(final, 0, 1)
|
||||
|
||||
io.show_image( wnd_name, (final*255).astype(np.uint8) )
|
||||
selected_preview = selected_preview % len(previews)
|
||||
preview_pane_image = create_preview_pane_image(previews,
|
||||
selected_preview,
|
||||
loss_history,
|
||||
show_last_history_iters_count,
|
||||
iteration,
|
||||
batch_size,
|
||||
zoom)
|
||||
io.show_image(wnd_name, preview_pane_image)
|
||||
is_showing = True
|
||||
|
||||
key_events = io.get_key_events(wnd_name)
|
||||
|
@ -324,7 +377,12 @@ def main(args, device_args):
|
|||
elif key == ord(' '):
|
||||
selected_preview = (selected_preview + 1) % len(previews)
|
||||
update_preview = True
|
||||
|
||||
elif key == ord('-'):
|
||||
zoom = zoom.prev()
|
||||
update_preview = True
|
||||
elif key == ord('=') or key == ord('+'):
|
||||
zoom = zoom.next()
|
||||
update_preview = True
|
||||
try:
|
||||
io.process_messages(0.1)
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
@ -641,11 +641,11 @@ class ModelBase(object):
|
|||
self.batch_size = d[ keys[-1] ]
|
||||
|
||||
@staticmethod
|
||||
def get_loss_history_preview(loss_history, iter,batch_size, w, c):
|
||||
def get_loss_history_preview(loss_history, iter, batch_size, w, c, lh_height=100):
|
||||
loss_history = np.array (loss_history.copy())
|
||||
|
||||
lh_height = 100
|
||||
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
|
||||
lh_img = np.ones((lh_height, w, c)) * 0.1
|
||||
|
||||
if len(loss_history) != 0:
|
||||
loss_count = len(loss_history[0])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue