diff --git a/main.py b/main.py index fb64f96..8e22b89 100644 --- a/main.py +++ b/main.py @@ -127,6 +127,8 @@ if __name__ == "__main__": 'silent_start' : arguments.silent_start, 'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ], 'debug' : arguments.debug, + 'tensorboard_dir' : arguments.tensorboard_dir, + 'start_tensorboard' : arguments.start_tensorboard } from mainscripts import Trainer Trainer.main(**kwargs) @@ -144,6 +146,8 @@ if __name__ == "__main__": p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.") + p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files") + p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir") p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 7facd29..00b030d 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -11,6 +11,41 @@ from core import imagelib import cv2 import models from core.interact import interact as io +import logging +import datetime +import os + +# adapted from https://stackoverflow.com/a/52295534 +class TensorBoardTool: + def __init__(self, dir_path): + self.dir_path = dir_path + def run(self): + from tensorboard import default + from tensorboard import program + # remove http messages + log = logging.getLogger('werkzeug').setLevel(logging.ERROR) + # Start tensorboard server + tb = program.TensorBoard(default.get_plugins()) + tb.configure(argv=[None, '--logdir', self.dir_path, '--port', '6006', '--bind_all']) + url = tb.launch() + print('Launched TensorBoard at {}'.format(url)) + +def process_img_for_tensorboard(input_img): + # convert format from bgr to rgb + img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB) + # adjust axis to put channel count at the beginning + img = np.moveaxis(img, -1, 0) + return img + +def log_tensorboard_previews(iter, previews, folder_name, train_summary_writer): + for preview in previews: + (preview_name, preview_bgr) = preview + preview_rgb = process_img_for_tensorboard(preview_bgr) + train_summary_writer.add_image('{}/{}'.format(folder_name, preview_name), preview_rgb, iter) + +def log_tensorboard_model_previews(iter, model, train_summary_writer): + log_tensorboard_previews(iter, model.get_previews(), 'preview', train_summary_writer) + log_tensorboard_previews(iter, model.get_static_previews(), 'static_preview', train_summary_writer) def trainerThread (s2c, c2s, e, model_class_name = None, @@ -26,6 +61,8 @@ def trainerThread (s2c, c2s, e, silent_start=False, execute_programs = None, debug=False, + tensorboard_dir=None, + start_tensorboard=False, **kwargs): while True: try: @@ -59,6 +96,21 @@ def trainerThread (s2c, c2s, e, is_reached_goal = model.is_reached_iter_goal() + train_summary_writer = None + if tensorboard_dir is not None: + try: + import tensorboardX + if not os.path.exists(tensorboard_dir): + os.makedirs(tensorboard_dir) + summary_writer_folder = os.path.join(tensorboard_dir, model.model_name) + train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder) + if start_tensorboard: + tb_tool = TensorBoardTool(tensorboard_dir) + tb_tool.run() + except: + print("Error importing tensorboardX, please ensure it is installed (pip install tensorboardX)") + print("Continuing training without tensorboard logging...") + shared_state = { 'after_save' : False } loss_string = "" save_iter = model.get_iter() @@ -74,6 +126,8 @@ def trainerThread (s2c, c2s, e, def send_preview(): if not debug: + if train_summary_writer is not None: + log_tensorboard_model_previews(iter, model, train_summary_writer) previews = model.get_previews() c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) else: @@ -149,6 +203,14 @@ def trainerThread (s2c, c2s, e, else: io.log_info (loss_string, end='\r') + if train_summary_writer is not None: + # report iteration time summary + train_summary_writer.add_scalar('iteration time', iter_time, iter) + # report loss summary + src_loss, dst_loss = loss_history[-1] + train_summary_writer.add_scalar('loss/src', src_loss, iter) + train_summary_writer.add_scalar('loss/dst', dst_loss, iter) + if model.get_iter() == 1: model_save()