mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 13:09:56 -07:00
Merge pull request #2 from seranus/tensorboard
merge tensorboard into the main branch
This commit is contained in:
commit
31522699ad
6 changed files with 138 additions and 8 deletions
|
@ -1,4 +1,4 @@
|
||||||
import sys
|
import sys
|
||||||
import locale
|
import locale
|
||||||
|
|
||||||
system_locale = locale.getdefaultlocale()[0]
|
system_locale = locale.getdefaultlocale()[0]
|
||||||
|
|
4
main.py
4
main.py
|
@ -127,6 +127,8 @@ if __name__ == "__main__":
|
||||||
'silent_start' : arguments.silent_start,
|
'silent_start' : arguments.silent_start,
|
||||||
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
|
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
|
||||||
'debug' : arguments.debug,
|
'debug' : arguments.debug,
|
||||||
|
'tensorboard_dir' : arguments.tensorboard_dir,
|
||||||
|
'start_tensorboard' : arguments.start_tensorboard
|
||||||
}
|
}
|
||||||
from mainscripts import Trainer
|
from mainscripts import Trainer
|
||||||
Trainer.main(**kwargs)
|
Trainer.main(**kwargs)
|
||||||
|
@ -144,6 +146,8 @@ if __name__ == "__main__":
|
||||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
||||||
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
||||||
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
||||||
|
p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files")
|
||||||
|
p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir")
|
||||||
|
|
||||||
|
|
||||||
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import traceback
|
import traceback
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import operator
|
import operator
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import os
|
import sys
|
||||||
import sys
|
|
||||||
import traceback
|
import traceback
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
|
@ -12,6 +11,41 @@ from core import imagelib
|
||||||
import cv2
|
import cv2
|
||||||
import models
|
import models
|
||||||
from core.interact import interact as io
|
from core.interact import interact as io
|
||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
# adapted from https://stackoverflow.com/a/52295534
|
||||||
|
class TensorBoardTool:
|
||||||
|
def __init__(self, dir_path):
|
||||||
|
self.dir_path = dir_path
|
||||||
|
def run(self):
|
||||||
|
from tensorboard import default
|
||||||
|
from tensorboard import program
|
||||||
|
# remove http messages
|
||||||
|
log = logging.getLogger('werkzeug').setLevel(logging.ERROR)
|
||||||
|
# Start tensorboard server
|
||||||
|
tb = program.TensorBoard(default.get_plugins())
|
||||||
|
tb.configure(argv=[None, '--logdir', self.dir_path, '--port', '6006', '--bind_all'])
|
||||||
|
url = tb.launch()
|
||||||
|
print('Launched TensorBoard at {}'.format(url))
|
||||||
|
|
||||||
|
def process_img_for_tensorboard(input_img):
|
||||||
|
# convert format from bgr to rgb
|
||||||
|
img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
|
||||||
|
# adjust axis to put channel count at the beginning
|
||||||
|
img = np.moveaxis(img, -1, 0)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def log_tensorboard_previews(iter, previews, folder_name, train_summary_writer):
|
||||||
|
for preview in previews:
|
||||||
|
(preview_name, preview_bgr) = preview
|
||||||
|
preview_rgb = process_img_for_tensorboard(preview_bgr)
|
||||||
|
train_summary_writer.add_image('{}/{}'.format(folder_name, preview_name), preview_rgb, iter)
|
||||||
|
|
||||||
|
def log_tensorboard_model_previews(iter, model, train_summary_writer):
|
||||||
|
log_tensorboard_previews(iter, model.get_previews(), 'preview', train_summary_writer)
|
||||||
|
log_tensorboard_previews(iter, model.get_static_previews(), 'static_preview', train_summary_writer)
|
||||||
|
|
||||||
def trainerThread (s2c, c2s, e,
|
def trainerThread (s2c, c2s, e,
|
||||||
model_class_name = None,
|
model_class_name = None,
|
||||||
|
@ -27,12 +61,15 @@ def trainerThread (s2c, c2s, e,
|
||||||
silent_start=False,
|
silent_start=False,
|
||||||
execute_programs = None,
|
execute_programs = None,
|
||||||
debug=False,
|
debug=False,
|
||||||
|
tensorboard_dir=None,
|
||||||
|
start_tensorboard=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
save_interval_min = 15
|
save_interval_min = 15
|
||||||
|
tensorboard_preview_interval_min = 5
|
||||||
|
|
||||||
if not training_data_src_path.exists():
|
if not training_data_src_path.exists():
|
||||||
training_data_src_path.mkdir(exist_ok=True, parents=True)
|
training_data_src_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
@ -60,6 +97,15 @@ def trainerThread (s2c, c2s, e,
|
||||||
|
|
||||||
is_reached_goal = model.is_reached_iter_goal()
|
is_reached_goal = model.is_reached_iter_goal()
|
||||||
|
|
||||||
|
if tensorboard_dir is not None:
|
||||||
|
c2s.put({
|
||||||
|
'op': 'tb',
|
||||||
|
'action': 'init',
|
||||||
|
'model_name': model.model_name,
|
||||||
|
'tensorboard_dir': tensorboard_dir,
|
||||||
|
'start_tensorboard': start_tensorboard
|
||||||
|
})
|
||||||
|
|
||||||
shared_state = { 'after_save' : False }
|
shared_state = { 'after_save' : False }
|
||||||
loss_string = ""
|
loss_string = ""
|
||||||
save_iter = model.get_iter()
|
save_iter = model.get_iter()
|
||||||
|
@ -73,6 +119,25 @@ def trainerThread (s2c, c2s, e,
|
||||||
if not debug and not is_reached_goal:
|
if not debug and not is_reached_goal:
|
||||||
model.create_backup()
|
model.create_backup()
|
||||||
|
|
||||||
|
def log_step(step, step_time, src_loss, dst_loss):
|
||||||
|
c2s.put({
|
||||||
|
'op': 'tb',
|
||||||
|
'action': 'step',
|
||||||
|
'step': step,
|
||||||
|
'step_time': step_time,
|
||||||
|
'src_loss': src_loss,
|
||||||
|
'dst_loss': dst_loss
|
||||||
|
})
|
||||||
|
|
||||||
|
def log_previews(step, previews, static_previews):
|
||||||
|
c2s.put({
|
||||||
|
'op': 'tb',
|
||||||
|
'action': 'preview',
|
||||||
|
'step': step,
|
||||||
|
'previews': previews,
|
||||||
|
'static_previews': static_previews
|
||||||
|
})
|
||||||
|
|
||||||
def send_preview():
|
def send_preview():
|
||||||
if not debug:
|
if not debug:
|
||||||
previews = model.get_previews()
|
previews = model.get_previews()
|
||||||
|
@ -91,6 +156,7 @@ def trainerThread (s2c, c2s, e,
|
||||||
io.log_info('Starting. Press "Enter" to stop training and save model.')
|
io.log_info('Starting. Press "Enter" to stop training and save model.')
|
||||||
|
|
||||||
last_save_time = time.time()
|
last_save_time = time.time()
|
||||||
|
last_preview_time = time.time()
|
||||||
|
|
||||||
execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ]
|
execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ]
|
||||||
|
|
||||||
|
@ -156,6 +222,9 @@ def trainerThread (s2c, c2s, e,
|
||||||
else:
|
else:
|
||||||
io.log_info (loss_string, end='\r')
|
io.log_info (loss_string, end='\r')
|
||||||
|
|
||||||
|
loss_entry = loss_history[-1]
|
||||||
|
log_step(iter, iter_time, loss_entry[0], loss_entry[1] if len(loss_entry) > 1 else None)
|
||||||
|
|
||||||
if model.get_iter() == 1:
|
if model.get_iter() == 1:
|
||||||
model_save()
|
model_save()
|
||||||
|
|
||||||
|
@ -165,6 +234,12 @@ def trainerThread (s2c, c2s, e,
|
||||||
is_reached_goal = True
|
is_reached_goal = True
|
||||||
io.log_info ('You can use preview now.')
|
io.log_info ('You can use preview now.')
|
||||||
|
|
||||||
|
if not is_reached_goal and (time.time() - last_preview_time) >= tensorboard_preview_interval_min*60:
|
||||||
|
last_preview_time += tensorboard_preview_interval_min*60
|
||||||
|
previews = model.get_previews()
|
||||||
|
static_previews = model.get_static_previews()
|
||||||
|
log_previews(iter, previews, static_previews)
|
||||||
|
|
||||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
||||||
last_save_time += save_interval_min*60
|
last_save_time += save_interval_min*60
|
||||||
model_save()
|
model_save()
|
||||||
|
@ -207,7 +282,54 @@ def trainerThread (s2c, c2s, e,
|
||||||
break
|
break
|
||||||
c2s.put ( {'op':'close'} )
|
c2s.put ( {'op':'close'} )
|
||||||
|
|
||||||
|
_train_summary_writer = None
|
||||||
|
def init_writer(model_name, tensorboard_dir, start_tensorboard):
|
||||||
|
import tensorboardX
|
||||||
|
global _train_summary_writer
|
||||||
|
|
||||||
|
if not os.path.exists(tensorboard_dir):
|
||||||
|
os.makedirs(tensorboard_dir)
|
||||||
|
summary_writer_folder = os.path.join(tensorboard_dir, model_name)
|
||||||
|
_train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder)
|
||||||
|
|
||||||
|
if start_tensorboard:
|
||||||
|
tb_tool = TensorBoardTool(tensorboard_dir)
|
||||||
|
tb_tool.run()
|
||||||
|
|
||||||
|
return _train_summary_writer
|
||||||
|
|
||||||
|
def get_writer():
|
||||||
|
global _train_summary_writer
|
||||||
|
return _train_summary_writer
|
||||||
|
|
||||||
|
def handle_tensorboard_op(input):
|
||||||
|
train_summary_writer = get_writer()
|
||||||
|
action = input['action']
|
||||||
|
if action == 'init':
|
||||||
|
model_name = input['model_name']
|
||||||
|
tensorboard_dir = input['tensorboard_dir']
|
||||||
|
start_tensorboard = input['start_tensorboard']
|
||||||
|
train_summary_writer = init_writer(model_name, tensorboard_dir, start_tensorboard)
|
||||||
|
if train_summary_writer is not None:
|
||||||
|
if action == 'step':
|
||||||
|
step = input['step']
|
||||||
|
step_time = input['step_time']
|
||||||
|
src_loss = input['src_loss']
|
||||||
|
dst_loss = input['dst_loss']
|
||||||
|
# report iteration time summary
|
||||||
|
train_summary_writer.add_scalar('iteration time', step_time, step)
|
||||||
|
# report loss summary
|
||||||
|
train_summary_writer.add_scalar('loss/src', src_loss, step)
|
||||||
|
if dst_loss is not None:
|
||||||
|
train_summary_writer.add_scalar('loss/dst', dst_loss, step)
|
||||||
|
elif action == 'preview':
|
||||||
|
step = input['step']
|
||||||
|
previews = input['previews']
|
||||||
|
static_previews = input['static_previews']
|
||||||
|
if previews is not None:
|
||||||
|
log_tensorboard_previews(step, previews, 'preview', train_summary_writer)
|
||||||
|
if static_previews is not None:
|
||||||
|
log_tensorboard_previews(step, static_previews, 'static_preview', train_summary_writer)
|
||||||
|
|
||||||
def main(**kwargs):
|
def main(**kwargs):
|
||||||
io.log_info ("Running trainer.\r\n")
|
io.log_info ("Running trainer.\r\n")
|
||||||
|
@ -228,7 +350,9 @@ def main(**kwargs):
|
||||||
if not c2s.empty():
|
if not c2s.empty():
|
||||||
input = c2s.get()
|
input = c2s.get()
|
||||||
op = input.get('op','')
|
op = input.get('op','')
|
||||||
if op == 'close':
|
if op == 'tb':
|
||||||
|
handle_tensorboard_op(input)
|
||||||
|
elif op == 'close':
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
io.process_messages(0.1)
|
io.process_messages(0.1)
|
||||||
|
@ -278,6 +402,8 @@ def main(**kwargs):
|
||||||
previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) )
|
previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) )
|
||||||
selected_preview = selected_preview % len(previews)
|
selected_preview = selected_preview % len(previews)
|
||||||
update_preview = True
|
update_preview = True
|
||||||
|
elif op == 'tb':
|
||||||
|
handle_tensorboard_op(input)
|
||||||
elif op == 'close':
|
elif op == 'close':
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue