Add optional tensorboard logging

This commit is contained in:
Vincent Riemer 2020-11-23 01:25:32 -08:00
commit 075e8aaf82
2 changed files with 66 additions and 0 deletions

View file

@ -127,6 +127,8 @@ if __name__ == "__main__":
'silent_start' : arguments.silent_start,
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
'debug' : arguments.debug,
'tensorboard_dir' : arguments.tensorboard_dir,
'start_tensorboard' : arguments.start_tensorboard
}
from mainscripts import Trainer
Trainer.main(**kwargs)
@ -144,6 +146,8 @@ if __name__ == "__main__":
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files")
p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir")
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')

View file

@ -11,6 +11,41 @@ from core import imagelib
import cv2
import models
from core.interact import interact as io
import logging
import datetime
import os
# adapted from https://stackoverflow.com/a/52295534
class TensorBoardTool:
def __init__(self, dir_path):
self.dir_path = dir_path
def run(self):
from tensorboard import default
from tensorboard import program
# remove http messages
log = logging.getLogger('werkzeug').setLevel(logging.ERROR)
# Start tensorboard server
tb = program.TensorBoard(default.get_plugins())
tb.configure(argv=[None, '--logdir', self.dir_path, '--port', '6006', '--bind_all'])
url = tb.launch()
print('Launched TensorBoard at {}'.format(url))
def process_img_for_tensorboard(input_img):
# convert format from bgr to rgb
img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
# adjust axis to put channel count at the beginning
img = np.moveaxis(img, -1, 0)
return img
def log_tensorboard_previews(iter, previews, folder_name, train_summary_writer):
for preview in previews:
(preview_name, preview_bgr) = preview
preview_rgb = process_img_for_tensorboard(preview_bgr)
train_summary_writer.add_image('{}/{}'.format(folder_name, preview_name), preview_rgb, iter)
def log_tensorboard_model_previews(iter, model, train_summary_writer):
log_tensorboard_previews(iter, model.get_previews(), 'preview', train_summary_writer)
log_tensorboard_previews(iter, model.get_static_previews(), 'static_preview', train_summary_writer)
def trainerThread (s2c, c2s, e,
model_class_name = None,
@ -26,6 +61,8 @@ def trainerThread (s2c, c2s, e,
silent_start=False,
execute_programs = None,
debug=False,
tensorboard_dir=None,
start_tensorboard=False,
**kwargs):
while True:
try:
@ -59,6 +96,21 @@ def trainerThread (s2c, c2s, e,
is_reached_goal = model.is_reached_iter_goal()
train_summary_writer = None
if tensorboard_dir is not None:
try:
import tensorboardX
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir)
summary_writer_folder = os.path.join(tensorboard_dir, model.model_name)
train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder)
if start_tensorboard:
tb_tool = TensorBoardTool(tensorboard_dir)
tb_tool.run()
except:
print("Error importing tensorboardX, please ensure it is installed (pip install tensorboardX)")
print("Continuing training without tensorboard logging...")
shared_state = { 'after_save' : False }
loss_string = ""
save_iter = model.get_iter()
@ -74,6 +126,8 @@ def trainerThread (s2c, c2s, e,
def send_preview():
if not debug:
if train_summary_writer is not None:
log_tensorboard_model_previews(iter, model, train_summary_writer)
previews = model.get_previews()
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
else:
@ -149,6 +203,14 @@ def trainerThread (s2c, c2s, e,
else:
io.log_info (loss_string, end='\r')
if train_summary_writer is not None:
# report iteration time summary
train_summary_writer.add_scalar('iteration time', iter_time, iter)
# report loss summary
src_loss, dst_loss = loss_history[-1]
train_summary_writer.add_scalar('loss/src', src_loss, iter)
train_summary_writer.add_scalar('loss/dst', dst_loss, iter)
if model.get_iter() == 1:
model_save()