mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
Add optional tensorboard logging
This commit is contained in:
parent
73c21eb335
commit
075e8aaf82
2 changed files with 66 additions and 0 deletions
4
main.py
4
main.py
|
@ -127,6 +127,8 @@ if __name__ == "__main__":
|
|||
'silent_start' : arguments.silent_start,
|
||||
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
|
||||
'debug' : arguments.debug,
|
||||
'tensorboard_dir' : arguments.tensorboard_dir,
|
||||
'start_tensorboard' : arguments.start_tensorboard
|
||||
}
|
||||
from mainscripts import Trainer
|
||||
Trainer.main(**kwargs)
|
||||
|
@ -144,6 +146,8 @@ if __name__ == "__main__":
|
|||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
||||
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
||||
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
||||
p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files")
|
||||
p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir")
|
||||
|
||||
|
||||
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
||||
|
|
|
@ -11,6 +11,41 @@ from core import imagelib
|
|||
import cv2
|
||||
import models
|
||||
from core.interact import interact as io
|
||||
import logging
|
||||
import datetime
|
||||
import os
|
||||
|
||||
# adapted from https://stackoverflow.com/a/52295534
|
||||
class TensorBoardTool:
|
||||
def __init__(self, dir_path):
|
||||
self.dir_path = dir_path
|
||||
def run(self):
|
||||
from tensorboard import default
|
||||
from tensorboard import program
|
||||
# remove http messages
|
||||
log = logging.getLogger('werkzeug').setLevel(logging.ERROR)
|
||||
# Start tensorboard server
|
||||
tb = program.TensorBoard(default.get_plugins())
|
||||
tb.configure(argv=[None, '--logdir', self.dir_path, '--port', '6006', '--bind_all'])
|
||||
url = tb.launch()
|
||||
print('Launched TensorBoard at {}'.format(url))
|
||||
|
||||
def process_img_for_tensorboard(input_img):
|
||||
# convert format from bgr to rgb
|
||||
img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
|
||||
# adjust axis to put channel count at the beginning
|
||||
img = np.moveaxis(img, -1, 0)
|
||||
return img
|
||||
|
||||
def log_tensorboard_previews(iter, previews, folder_name, train_summary_writer):
|
||||
for preview in previews:
|
||||
(preview_name, preview_bgr) = preview
|
||||
preview_rgb = process_img_for_tensorboard(preview_bgr)
|
||||
train_summary_writer.add_image('{}/{}'.format(folder_name, preview_name), preview_rgb, iter)
|
||||
|
||||
def log_tensorboard_model_previews(iter, model, train_summary_writer):
|
||||
log_tensorboard_previews(iter, model.get_previews(), 'preview', train_summary_writer)
|
||||
log_tensorboard_previews(iter, model.get_static_previews(), 'static_preview', train_summary_writer)
|
||||
|
||||
def trainerThread (s2c, c2s, e,
|
||||
model_class_name = None,
|
||||
|
@ -26,6 +61,8 @@ def trainerThread (s2c, c2s, e,
|
|||
silent_start=False,
|
||||
execute_programs = None,
|
||||
debug=False,
|
||||
tensorboard_dir=None,
|
||||
start_tensorboard=False,
|
||||
**kwargs):
|
||||
while True:
|
||||
try:
|
||||
|
@ -59,6 +96,21 @@ def trainerThread (s2c, c2s, e,
|
|||
|
||||
is_reached_goal = model.is_reached_iter_goal()
|
||||
|
||||
train_summary_writer = None
|
||||
if tensorboard_dir is not None:
|
||||
try:
|
||||
import tensorboardX
|
||||
if not os.path.exists(tensorboard_dir):
|
||||
os.makedirs(tensorboard_dir)
|
||||
summary_writer_folder = os.path.join(tensorboard_dir, model.model_name)
|
||||
train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder)
|
||||
if start_tensorboard:
|
||||
tb_tool = TensorBoardTool(tensorboard_dir)
|
||||
tb_tool.run()
|
||||
except:
|
||||
print("Error importing tensorboardX, please ensure it is installed (pip install tensorboardX)")
|
||||
print("Continuing training without tensorboard logging...")
|
||||
|
||||
shared_state = { 'after_save' : False }
|
||||
loss_string = ""
|
||||
save_iter = model.get_iter()
|
||||
|
@ -74,6 +126,8 @@ def trainerThread (s2c, c2s, e,
|
|||
|
||||
def send_preview():
|
||||
if not debug:
|
||||
if train_summary_writer is not None:
|
||||
log_tensorboard_model_previews(iter, model, train_summary_writer)
|
||||
previews = model.get_previews()
|
||||
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
|
||||
else:
|
||||
|
@ -149,6 +203,14 @@ def trainerThread (s2c, c2s, e,
|
|||
else:
|
||||
io.log_info (loss_string, end='\r')
|
||||
|
||||
if train_summary_writer is not None:
|
||||
# report iteration time summary
|
||||
train_summary_writer.add_scalar('iteration time', iter_time, iter)
|
||||
# report loss summary
|
||||
src_loss, dst_loss = loss_history[-1]
|
||||
train_summary_writer.add_scalar('loss/src', src_loss, iter)
|
||||
train_summary_writer.add_scalar('loss/dst', dst_loss, iter)
|
||||
|
||||
if model.get_iter() == 1:
|
||||
model_save()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue