mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
Add optional tensorboard logging
This commit is contained in:
parent
73c21eb335
commit
075e8aaf82
2 changed files with 66 additions and 0 deletions
4
main.py
4
main.py
|
@ -127,6 +127,8 @@ if __name__ == "__main__":
|
||||||
'silent_start' : arguments.silent_start,
|
'silent_start' : arguments.silent_start,
|
||||||
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
|
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
|
||||||
'debug' : arguments.debug,
|
'debug' : arguments.debug,
|
||||||
|
'tensorboard_dir' : arguments.tensorboard_dir,
|
||||||
|
'start_tensorboard' : arguments.start_tensorboard
|
||||||
}
|
}
|
||||||
from mainscripts import Trainer
|
from mainscripts import Trainer
|
||||||
Trainer.main(**kwargs)
|
Trainer.main(**kwargs)
|
||||||
|
@ -144,6 +146,8 @@ if __name__ == "__main__":
|
||||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
||||||
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
||||||
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
||||||
|
p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files")
|
||||||
|
p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir")
|
||||||
|
|
||||||
|
|
||||||
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
||||||
|
|
|
@ -11,6 +11,41 @@ from core import imagelib
|
||||||
import cv2
|
import cv2
|
||||||
import models
|
import models
|
||||||
from core.interact import interact as io
|
from core.interact import interact as io
|
||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
# adapted from https://stackoverflow.com/a/52295534
|
||||||
|
class TensorBoardTool:
|
||||||
|
def __init__(self, dir_path):
|
||||||
|
self.dir_path = dir_path
|
||||||
|
def run(self):
|
||||||
|
from tensorboard import default
|
||||||
|
from tensorboard import program
|
||||||
|
# remove http messages
|
||||||
|
log = logging.getLogger('werkzeug').setLevel(logging.ERROR)
|
||||||
|
# Start tensorboard server
|
||||||
|
tb = program.TensorBoard(default.get_plugins())
|
||||||
|
tb.configure(argv=[None, '--logdir', self.dir_path, '--port', '6006', '--bind_all'])
|
||||||
|
url = tb.launch()
|
||||||
|
print('Launched TensorBoard at {}'.format(url))
|
||||||
|
|
||||||
|
def process_img_for_tensorboard(input_img):
|
||||||
|
# convert format from bgr to rgb
|
||||||
|
img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
|
||||||
|
# adjust axis to put channel count at the beginning
|
||||||
|
img = np.moveaxis(img, -1, 0)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def log_tensorboard_previews(iter, previews, folder_name, train_summary_writer):
|
||||||
|
for preview in previews:
|
||||||
|
(preview_name, preview_bgr) = preview
|
||||||
|
preview_rgb = process_img_for_tensorboard(preview_bgr)
|
||||||
|
train_summary_writer.add_image('{}/{}'.format(folder_name, preview_name), preview_rgb, iter)
|
||||||
|
|
||||||
|
def log_tensorboard_model_previews(iter, model, train_summary_writer):
|
||||||
|
log_tensorboard_previews(iter, model.get_previews(), 'preview', train_summary_writer)
|
||||||
|
log_tensorboard_previews(iter, model.get_static_previews(), 'static_preview', train_summary_writer)
|
||||||
|
|
||||||
def trainerThread (s2c, c2s, e,
|
def trainerThread (s2c, c2s, e,
|
||||||
model_class_name = None,
|
model_class_name = None,
|
||||||
|
@ -26,6 +61,8 @@ def trainerThread (s2c, c2s, e,
|
||||||
silent_start=False,
|
silent_start=False,
|
||||||
execute_programs = None,
|
execute_programs = None,
|
||||||
debug=False,
|
debug=False,
|
||||||
|
tensorboard_dir=None,
|
||||||
|
start_tensorboard=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -59,6 +96,21 @@ def trainerThread (s2c, c2s, e,
|
||||||
|
|
||||||
is_reached_goal = model.is_reached_iter_goal()
|
is_reached_goal = model.is_reached_iter_goal()
|
||||||
|
|
||||||
|
train_summary_writer = None
|
||||||
|
if tensorboard_dir is not None:
|
||||||
|
try:
|
||||||
|
import tensorboardX
|
||||||
|
if not os.path.exists(tensorboard_dir):
|
||||||
|
os.makedirs(tensorboard_dir)
|
||||||
|
summary_writer_folder = os.path.join(tensorboard_dir, model.model_name)
|
||||||
|
train_summary_writer = tensorboardX.SummaryWriter(summary_writer_folder)
|
||||||
|
if start_tensorboard:
|
||||||
|
tb_tool = TensorBoardTool(tensorboard_dir)
|
||||||
|
tb_tool.run()
|
||||||
|
except:
|
||||||
|
print("Error importing tensorboardX, please ensure it is installed (pip install tensorboardX)")
|
||||||
|
print("Continuing training without tensorboard logging...")
|
||||||
|
|
||||||
shared_state = { 'after_save' : False }
|
shared_state = { 'after_save' : False }
|
||||||
loss_string = ""
|
loss_string = ""
|
||||||
save_iter = model.get_iter()
|
save_iter = model.get_iter()
|
||||||
|
@ -74,6 +126,8 @@ def trainerThread (s2c, c2s, e,
|
||||||
|
|
||||||
def send_preview():
|
def send_preview():
|
||||||
if not debug:
|
if not debug:
|
||||||
|
if train_summary_writer is not None:
|
||||||
|
log_tensorboard_model_previews(iter, model, train_summary_writer)
|
||||||
previews = model.get_previews()
|
previews = model.get_previews()
|
||||||
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
|
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
|
||||||
else:
|
else:
|
||||||
|
@ -149,6 +203,14 @@ def trainerThread (s2c, c2s, e,
|
||||||
else:
|
else:
|
||||||
io.log_info (loss_string, end='\r')
|
io.log_info (loss_string, end='\r')
|
||||||
|
|
||||||
|
if train_summary_writer is not None:
|
||||||
|
# report iteration time summary
|
||||||
|
train_summary_writer.add_scalar('iteration time', iter_time, iter)
|
||||||
|
# report loss summary
|
||||||
|
src_loss, dst_loss = loss_history[-1]
|
||||||
|
train_summary_writer.add_scalar('loss/src', src_loss, iter)
|
||||||
|
train_summary_writer.add_scalar('loss/dst', dst_loss, iter)
|
||||||
|
|
||||||
if model.get_iter() == 1:
|
if model.get_iter() == 1:
|
||||||
model_save()
|
model_save()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue