diff --git a/main.py b/main.py index 3232657..be0f27a 100644 --- a/main.py +++ b/main.py @@ -135,6 +135,7 @@ if __name__ == "__main__": device_args = {'cpu_only' : arguments.cpu_only, 'force_gpu_idx' : arguments.force_gpu_idx, } + from mainscripts import Trainer Trainer.main(args, device_args) diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 0701d7b..b7ce915 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -8,25 +8,24 @@ import itertools import numpy as np from pathlib import Path -from flaskr.app import create_flask_app import imagelib import cv2 import models from interact import interact as io -def trainer_thread (s2c, c2s, e, args, device_args, socketio=None): +def trainer_thread(s2c, c2s, e, args, device_args, socketio=None): while True: try: start_time = time.time() - training_data_src_path = Path( args.get('training_data_src_dir', '') ) - training_data_dst_path = Path( args.get('training_data_dst_dir', '') ) + training_data_src_path = Path(args.get('training_data_src_dir', '')) + training_data_dst_path = Path(args.get('training_data_dst_dir', '')) pretraining_data_path = args.get('pretraining_data_dir', '') pretraining_data_path = Path(pretraining_data_path) if pretraining_data_path is not None else None - model_path = Path( args.get('model_path', '') ) + model_path = Path(args.get('model_path', '')) model_name = args.get('model_name', '') save_interval_min = 15 debug = args.get('debug', '') @@ -44,33 +43,34 @@ def trainer_thread (s2c, c2s, e, args, device_args, socketio=None): model_path.mkdir(exist_ok=True) model = models.import_model(model_name)( - model_path, - training_data_src_path=training_data_src_path, - training_data_dst_path=training_data_dst_path, - pretraining_data_path=pretraining_data_path, - debug=debug, - device_args=device_args) + model_path, + training_data_src_path=training_data_src_path, + training_data_dst_path=training_data_dst_path, + pretraining_data_path=pretraining_data_path, + debug=debug, + device_args=device_args) is_reached_goal = model.is_reached_iter_goal() - shared_state = { 'after_save' : False } + shared_state = {'after_save': False} loss_string = "" - save_iter = model.get_iter() + save_iter = model.get_iter() + def model_save(): if not debug and not is_reached_goal: - io.log_info ("Saving....", end='\r') + io.log_info("Saving....", end='\r') model.save() shared_state['after_save'] = True def send_preview(): if not debug: previews = model.get_previews() - c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) + c2s.put({'op': 'show', 'previews': previews, 'iter': model.get_iter(), + 'loss_history': model.get_loss_history().copy()}) else: - previews = [( 'debug, press update for new', model.debug_one_iter())] - c2s.put ( {'op':'show', 'previews': previews} ) - e.set() #Set the GUI Thread as Ready - + previews = [('debug, press update for new', model.debug_one_iter())] + c2s.put({'op': 'show', 'previews': previews}) + e.set() # Set the GUI Thread as Ready if model.is_first_run(): model_save() @@ -79,25 +79,26 @@ def trainer_thread (s2c, c2s, e, args, device_args, socketio=None): if is_reached_goal: io.log_info('Model already trained to target iteration. You can use preview.') else: - io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) ) + io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( + model.get_target_iter())) else: io.log_info('Starting. Press "Enter" to stop training and save model.') last_save_time = time.time() - execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ] + execute_programs = [[x[0], x[1], time.time()] for x in execute_programs] - for i in itertools.count(0,1): + for i in itertools.count(0, 1): if not debug: cur_time = time.time() for x in execute_programs: prog_time, prog, last_time = x exec_prog = False - if prog_time > 0 and (cur_time - start_time) >= prog_time: + if 0 < prog_time <= (cur_time - start_time): x[0] = 0 exec_prog = True - elif prog_time < 0 and (cur_time - last_time) >= -prog_time: + elif prog_time < 0 and (cur_time - last_time) >= -prog_time: x[2] = cur_time exec_prog = True @@ -105,7 +106,7 @@ def trainer_thread (s2c, c2s, e, args, device_args, socketio=None): try: exec(prog) except Exception as e: - print("Unable to execute program: %s" % (prog) ) + print("Unable to execute program: %s" % (prog)) if not is_reached_goal: iter, iter_time, batch_size = model.train_one_iter() @@ -113,45 +114,49 @@ def trainer_thread (s2c, c2s, e, args, device_args, socketio=None): loss_history = model.get_loss_history() time_str = time.strftime("[%H:%M:%S]") if iter_time >= 10: - loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format ( time_str, iter, '{:0.4f}'.format(iter_time), batch_size ) + loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format(time_str, iter, + '{:0.4f}'.format(iter_time), + batch_size) else: - loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format ( time_str, iter, int(iter_time*1000), batch_size) + loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format(time_str, iter, + int(iter_time * 1000), batch_size) if shared_state['after_save']: shared_state['after_save'] = False - last_save_time = time.time() #upd last_save_time only after save+one_iter, because plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274 + last_save_time = time.time() # upd last_save_time only after save+one_iter, because + # plaidML rebuilds programs after save https://github.com/plaidml/plaidml/issues/274 - mean_loss = np.mean ( [ np.array(loss_history[i]) for i in range(save_iter, iter) ], axis=0) + mean_loss = np.mean([np.array(loss_history[i]) for i in range(save_iter, iter)], axis=0) for loss_value in mean_loss: - loss_string += "[%.4f]" % (loss_value) + loss_string += "[%.4f]" % loss_value - io.log_info (loss_string) + io.log_info(loss_string) save_iter = iter else: for loss_value in loss_history[-1]: - loss_string += "[%.4f]" % (loss_value) + loss_string += "[%.4f]" % loss_value if io.is_colab(): - io.log_info ('\r' + loss_string, end='') + io.log_info('\r' + loss_string, end='') else: - io.log_info (loss_string, end='\r') + io.log_info(loss_string, end='\r') if socketio is not None: socketio.emit('loss', loss_string) if model.get_target_iter() != 0 and model.is_reached_iter_goal(): - io.log_info ('Reached target iteration.') + io.log_info('Reached target iteration.') model_save() is_reached_goal = True - io.log_info ('You can use preview now.') + io.log_info('You can use preview now.') - if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60: + if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min * 60: model_save() send_preview() - if i==0: + if i == 0: if is_reached_goal: model.pass_one_iter() send_preview() @@ -160,8 +165,8 @@ def trainer_thread (s2c, c2s, e, args, device_args, socketio=None): time.sleep(0.005) while not s2c.empty(): - input = s2c.get() - op = input['op'] + item = s2c.get() + op = item['op'] if op == 'save': model_save() elif op == 'preview': @@ -176,32 +181,30 @@ def trainer_thread (s2c, c2s, e, args, device_args, socketio=None): if i == -1: break - - model.finalize() except Exception as e: - print ('Error: %s' % (str(e))) + print('Error: %s' % (str(e))) traceback.print_exc() break - c2s.put ( {'op':'close'} ) + c2s.put({'op': 'close'}) class Zoom(Enum): - ZOOM_25 = (1/4, '25%') - ZOOM_33 = (1/3, '33%') - ZOOM_50 = (1/2, '50%') - ZOOM_67 = (2/3, '67%') - ZOOM_75 = (3/4, '75%') - ZOOM_80 = (4/5, '80%') - ZOOM_90 = (9/10, '90%') + ZOOM_25 = (1 / 4, '25%') + ZOOM_33 = (1 / 3, '33%') + ZOOM_50 = (1 / 2, '50%') + ZOOM_67 = (2 / 3, '67%') + ZOOM_75 = (3 / 4, '75%') + ZOOM_80 = (4 / 5, '80%') + ZOOM_90 = (9 / 10, '90%') ZOOM_100 = (1, '100%') - ZOOM_110 = (11/10, '110%') - ZOOM_125 = (5/4, '125%') - ZOOM_150 = (3/2, '150%') - ZOOM_175 = (7/4, '175%') + ZOOM_110 = (11 / 10, '110%') + ZOOM_125 = (5 / 4, '125%') + ZOOM_150 = (3 / 2, '150%') + ZOOM_175 = (7 / 4, '175%') ZOOM_200 = (2, '200%') - ZOOM_250 = (5/2, '250%') + ZOOM_250 = (5 / 2, '250%') ZOOM_300 = (3, '300%') ZOOM_400 = (4, '400%') ZOOM_500 = (5, '500%') @@ -258,7 +261,7 @@ def create_preview_pane_image(previews, selected_preview, loss_history, head_lines = [ '[s]:save [enter]:exit [-/+]:zoom: %s' % zoom.label, '[p]:update [space]:next preview [l]:change history range', - 'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews)) + 'Preview: "%s" [%d/%d]' % (selected_preview_name, selected_preview + 1, len(previews)) ] head_line_height = int(15 * zoom.scale) head_height = len(head_lines) * head_line_height @@ -266,8 +269,8 @@ def create_preview_pane_image(previews, selected_preview, loss_history, for i in range(0, len(head_lines)): t = i * head_line_height - b = (i+1) * head_line_height - head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8]*c) + b = (i + 1) * head_line_height + head[t:b, 0:w] += imagelib.get_text_image((head_line_height, w, c), head_lines[i], color=[0.8] * c) final = head @@ -278,15 +281,15 @@ def create_preview_pane_image(previews, selected_preview, loss_history, loss_history_to_show = loss_history[-show_last_history_iters_count:] lh_height = int(100 * zoom.scale) lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iteration, batch_size, w, c, lh_height) - final = np.concatenate ( [final, lh_img], axis=0 ) + final = np.concatenate([final, lh_img], axis=0) final = np.concatenate([final, selected_preview_rgb], axis=0) final = np.clip(final, 0, 1) - return (final*255).astype(np.uint8) + return (final * 255).astype(np.uint8) def main(args, device_args): - io.log_info ("Running trainer.\r\n") + io.log_info("Running trainer.\r\n") no_preview = args.get('no_preview', False) flask_preview = args.get('flask_preview', False) @@ -295,7 +298,18 @@ def main(args, device_args): c2s = queue.Queue() e = threading.Event() + previews = None + loss_history = None + selected_preview = 0 + update_preview = False + is_waiting_preview = False + show_last_history_iters_count = 0 + iteration = 0 + batch_size = 1 + zoom = Zoom.ZOOM_100 + if flask_preview: + from flaskr.app import create_flask_app s2flask = queue.Queue() socketio, flask_app = create_flask_app(s2c, c2s, s2flask, args) @@ -303,36 +317,22 @@ def main(args, device_args): thread = threading.Thread(target=trainer_thread, args=(s2c, c2s, e, args, device_args, socketio)) thread.start() - e.wait() #Wait for inital load to occur. + e.wait() # Wait for inital load to occur. - flask_t = threading.Thread(target=socketio.run, args=(flask_app,), kwargs={'debug': True, 'use_reloader': False}) + flask_t = threading.Thread(target=socketio.run, args=(flask_app,), + kwargs={'debug': True, 'use_reloader': False}) flask_t.start() - wnd_name = "Training preview" - io.named_window(wnd_name) - io.capture_keys(wnd_name) - - previews = None - loss_history = None - selected_preview = 0 - update_preview = False - is_showing = False - is_waiting_preview = False - show_last_history_iters_count = 0 - iteration = 0 - batch_size = 1 - zoom = Zoom.ZOOM_100 - while True: if not c2s.empty(): - input = c2s.get() - op = input['op'] + item = c2s.get() + op = item['op'] if op == 'show': is_waiting_preview = False - loss_history = input['loss_history'] if 'loss_history' in input.keys() else None - previews = input['previews'] if 'previews' in input.keys() else None - iteration = input['iter'] if 'iter' in input.keys() else 0 - #batch_size = input['batch_size'] if 'iter' in input.keys() else 1 + loss_history = item['loss_history'] if 'loss_history' in item.keys() else None + previews = item['previews'] if 'previews' in item.keys() else None + iteration = item['iter'] if 'iter' in item.keys() else 0 + # batch_size = input['batch_size'] if 'iter' in input.keys() else 1 if previews is not None: update_preview = True elif op == 'update': @@ -375,7 +375,6 @@ def main(args, device_args): cv2.imwrite(preview_file, preview_pane_image) s2flask.put({'op': 'show'}) socketio.emit('preview', {'iter': iteration, 'loss': loss_history[-1]}) - is_showing = True try: io.process_messages(0.01) except KeyboardInterrupt: @@ -384,19 +383,19 @@ def main(args, device_args): thread = threading.Thread(target=trainer_thread, args=(s2c, c2s, e, args, device_args)) thread.start() - e.wait() #Wait for inital load to occur. + e.wait() # Wait for inital load to occur. if no_preview: while True: if not c2s.empty(): - input = c2s.get() - op = input.get('op','') + item = c2s.get() + op = item.get('op', '') if op == 'close': break try: io.process_messages(0.1) except KeyboardInterrupt: - s2c.put ( {'op': 'close'} ) + s2c.put({'op': 'close'}) else: wnd_name = "Training preview" io.named_window(wnd_name) @@ -415,14 +414,14 @@ def main(args, device_args): while True: if not c2s.empty(): - input = c2s.get() - op = input['op'] + item = c2s.get() + op = item['op'] if op == 'show': is_waiting_preview = False - loss_history = input['loss_history'] if 'loss_history' in input.keys() else None - previews = input['previews'] if 'previews' in input.keys() else None - iteration = input['iter'] if 'iter' in input.keys() else 0 - #batch_size = input['batch_size'] if 'iter' in input.keys() else 1 + loss_history = item['loss_history'] if 'loss_history' in item.keys() else None + previews = item['previews'] if 'previews' in item.keys() else None + iteration = item['iter'] if 'iter' in item.keys() else 0 + # batch_size = input['batch_size'] if 'iter' in input.keys() else 1 if previews is not None: update_preview = True elif op == 'close': @@ -439,19 +438,19 @@ def main(args, device_args): batch_size, zoom) io.show_image(wnd_name, preview_pane_image) - is_showing = True key_events = io.get_key_events(wnd_name) - key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else ( + 0, 0, False, False, False) if key == ord('\n') or key == ord('\r'): - s2c.put ( {'op': 'close'} ) + s2c.put({'op': 'close'}) elif key == ord('s'): - s2c.put ( {'op': 'save'} ) + s2c.put({'op': 'save'}) elif key == ord('p'): if not is_waiting_preview: is_waiting_preview = True - s2c.put ( {'op': 'preview'} ) + s2c.put({'op': 'preview'}) elif key == ord('l'): if show_last_history_iters_count == 0: show_last_history_iters_count = 5000 @@ -476,6 +475,6 @@ def main(args, device_args): try: io.process_messages(0.1) except KeyboardInterrupt: - s2c.put ( {'op': 'close'} ) + s2c.put({'op': 'close'}) io.destroy_all_windows()