Pass loss string through socket

This commit is contained in:
Jeremy Hummel 2019-09-14 16:32:09 -07:00
commit 7a8fca84cc
4 changed files with 23 additions and 11 deletions

View file

@ -34,11 +34,16 @@
console.log(new Date(), '- new preview -', msg);
$('img#preview').attr("src", "{{ url_for('preview_image') }}?q=" + new Date().getTime());
});
socket.on('loss', function(loss_string) {
console.log(new Date(), '- loss string -', loss_string);
$('h1#loss').html(loss_string);
});
});
</script>
</head>
<body>
<h1>Flask Server Demonstration</h1>
<h1 id="loss"></h1>
<div>
<button class='btn btn-default' id='save'>Save</button>
<button class='btn btn-default' id='exit'>Exit</button>

17
main.py
View file

@ -48,8 +48,8 @@ if __name__ == "__main__":
p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU. Forces to use MT extractor.")
p.set_defaults (func=process_extract)
def process_dev_extract_umd_csv(arguments):
os_utils.set_process_lowest_prio()
from mainscripts import Extractor
@ -80,7 +80,7 @@ if __name__ == "__main__":
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
p.set_defaults (func=process_extract_fanseg)
"""
def process_sort(arguments):
os_utils.set_process_lowest_prio()
from mainscripts import Sorter
@ -109,7 +109,7 @@ if __name__ == "__main__":
if arguments.remove_ie_polys:
Util.remove_ie_polys_folder (input_path=arguments.input_dir)
p = subparsers.add_parser( "util", help="Utilities.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
@ -129,13 +129,14 @@ if __name__ == "__main__":
'model_name' : arguments.model_name,
'no_preview' : arguments.no_preview,
'debug' : arguments.debug,
'flask_preview' : arguments.flask_preview,
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ]
}
device_args = {'cpu_only' : arguments.cpu_only,
'force_gpu_idx' : arguments.force_gpu_idx,
}
from mainscripts import FlaskTrainer
FlaskTrainer.main(args, device_args)
from mainscripts import Trainer
Trainer.main(args, device_args)
p = subparsers.add_parser( "train", help="Trainer")
p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir",
@ -155,6 +156,8 @@ if __name__ == "__main__":
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
p.add_argument('--pingpong', dest="ping_pong", default=False,
help="Cycle between a batch size of 1 and the chosen batch size")
p.add_argument('--flask-preview', action="store_true", dest="flask_preview", default=False,
help="Launches a flask server to view the previews in a web browser")
p.set_defaults (func=process_train)
def process_convert(arguments):
@ -251,7 +254,7 @@ if __name__ == "__main__":
p.add_argument('--confirmed-dir', required=True, action=fixPathAction, dest="confirmed_dir", help="This is where the labeled faces will be stored.")
p.add_argument('--skipped-dir', required=True, action=fixPathAction, dest="skipped_dir", help="This is where the labeled faces will be stored.")
p.add_argument('--no-default-mask', action="store_true", dest="no_default_mask", default=False, help="Don't use default mask.")
p.set_defaults(func=process_labelingtool_edit_mask)
def bad_args(arguments):

View file

@ -15,7 +15,7 @@ import models
from interact import interact as io
def trainerThread (s2c, c2s, e, args, device_args):
def trainerThread (s2c, c2s, e, args, device_args, socketio=None):
while True:
try:
start_time = time.time()
@ -138,6 +138,9 @@ def trainerThread (s2c, c2s, e, args, device_args):
else:
io.log_info (loss_string, end='\r')
if socketio is not None:
socketio.emit('loss', loss_string)
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
io.log_info ('Reached target iteration.')
model_save()
@ -289,14 +292,14 @@ def main(args, device_args):
s2c = queue.Queue()
c2s = queue.Queue()
s2flask = queue.Queue()
socketio, flask_app = create_flask_app(s2c, c2s, s2flask, args)
e = threading.Event()
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args) )
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args, socketio))
thread.start()
e.wait() #Wait for inital load to occur.
socketio, flask_app = create_flask_app(s2c, c2s, s2flask, args)
flask_t = threading.Thread(target=socketio.run, args=(flask_app,), kwargs={'debug': True, 'use_reloader': False})
flask_t.start()

View file

@ -285,6 +285,7 @@ def main(args, device_args):
io.log_info ("Running trainer.\r\n")
no_preview = args.get('no_preview', False)
flask_preview = args.get('flask_preview', False)
s2c = queue.Queue()