diff --git a/models/ModelBase.py b/models/ModelBase.py index 1ae8dbb..af31e7e 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -74,8 +74,8 @@ class ModelBase(object): self.options.pop('epoch') if self.iter != 0: self.options = model_data['options'] - self.loss_history = model_data['loss_history'] if 'loss_history' in model_data.keys() else [] - self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None + self.loss_history = model_data.get('loss_history', []) + self.sample_for_preview = model_data.get('sample_for_preview', None) ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 5 if io.is_colab() else 2 ) @@ -178,36 +178,34 @@ class ModelBase(object): if not isinstance(generator, SampleGeneratorBase): raise ValueError('training data generator is not subclass of SampleGeneratorBase') - if (self.sample_for_preview is None) or (self.iter == 0): - - if self.iter == 0: - if choose_preview_history and io.is_support_windows(): - wnd_name = "[p] - next. [enter] - confirm." - io.named_window(wnd_name) - io.capture_keys(wnd_name) - choosed = False - while not choosed: - self.sample_for_preview = self.generate_next_sample() - preview = self.get_static_preview() - io.show_image( wnd_name, (preview*255).astype(np.uint8) ) - - while True: - key_events = io.get_key_events(wnd_name) - key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) - if key == ord('\n') or key == ord('\r'): - choosed = True - break - elif key == ord('p'): - break - - try: - io.process_messages(0.1) - except KeyboardInterrupt: - choosed = True - - io.destroy_window(wnd_name) - else: + if self.sample_for_preview is None or choose_preview_history: + if choose_preview_history and io.is_support_windows(): + wnd_name = "[p] - next. [enter] - confirm." + io.named_window(wnd_name) + io.capture_keys(wnd_name) + choosed = False + while not choosed: self.sample_for_preview = self.generate_next_sample() + preview = self.get_static_preview() + io.show_image( wnd_name, (preview*255).astype(np.uint8) ) + + while True: + key_events = io.get_key_events(wnd_name) + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + if key == ord('\n') or key == ord('\r'): + choosed = True + break + elif key == ord('p'): + break + + try: + io.process_messages(0.1) + except KeyboardInterrupt: + choosed = True + + io.destroy_window(wnd_name) + else: + self.sample_for_preview = self.generate_next_sample() model_summary_text = []