diff --git a/models/ModelBase.py b/models/ModelBase.py index 09cb31b..0fbbd9a 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -61,7 +61,7 @@ class ModelBase(object): # sort by modified datetime saved_models_names = sorted(saved_models_names, key=operator.itemgetter(1), reverse=True ) saved_models_names = [ x[0] for x in saved_models_names ] - + if len(saved_models_names) != 0: if silent_start: @@ -125,7 +125,7 @@ class ModelBase(object): self.model_name = self.model_name.replace('_', ' ') break - + self.model_name = self.model_name + '_' + self.model_class_name else: self.model_name = force_model_class_name @@ -150,10 +150,10 @@ class ModelBase(object): if self.is_first_run(): io.log_info ("\nModel first run.") - + if silent_start: self.device_config = nn.DeviceConfig.BestGPU() - io.log_info (f"Silent start: choosed device {'CPU' if self.device_config.cpu_only else self.device_config.devices[0].name}") + io.log_info (f"Silent start: choosed device {'CPU' if self.device_config.cpu_only else self.device_config.devices[0].name}") else: self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \ if not cpu_only else nn.DeviceConfig.CPU() @@ -208,7 +208,7 @@ class ModelBase(object): raise ValueError('training data generator is not subclass of SampleGeneratorBase') self.update_sample_for_preview(choose_preview_history=self.choose_preview_history) - + if self.autobackup_hour != 0: self.autobackup_start_time = time.time() @@ -216,19 +216,21 @@ class ModelBase(object): self.autobackups_path.mkdir(exist_ok=True) io.log_info( self.get_summary_text() ) - + def update_sample_for_preview(self, choose_preview_history=False, force_new=False): if self.sample_for_preview is None or choose_preview_history or force_new: if choose_preview_history and io.is_support_windows(): - io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.") - wnd_name = "[p] - next. [enter] - confirm." + wnd_name = "[p] - next. [space] - switch preview type. [enter] - confirm." + io.log_info (f"Choose image for the preview history. {wnd_name}") io.named_window(wnd_name) io.capture_keys(wnd_name) choosed = False + preview_id_counter = 0 while not choosed: self.sample_for_preview = self.generate_next_samples() - preview = self.get_static_preview() - io.show_image( wnd_name, (preview*255).astype(np.uint8) ) + previews = self.get_static_previews() + + io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) ) while True: key_events = io.get_key_events(wnd_name) @@ -236,6 +238,9 @@ class ModelBase(object): if key == ord('\n') or key == ord('\r'): choosed = True break + elif key == ord(' '): + preview_id_counter += 1 + break elif key == ord('p'): break @@ -249,12 +254,12 @@ class ModelBase(object): self.sample_for_preview = self.generate_next_samples() try: - self.get_static_preview() + self.get_static_previews() except: self.sample_for_preview = self.generate_next_samples() self.last_sample = self.sample_for_preview - + def load_or_def_option(self, name, def_value): options_val = self.options.get(name, None) if options_val is not None: @@ -354,8 +359,8 @@ class ModelBase(object): def get_previews(self): return self.onGetPreview ( self.last_sample ) - def get_static_preview(self): - return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr + def get_static_previews(self): + return self.onGetPreview (self.sample_for_preview) def save(self): Path( self.get_summary_path() ).write_text( self.get_summary_text() ) @@ -377,16 +382,16 @@ class ModelBase(object): if diff_hour > 0 and diff_hour % self.autobackup_hour == 0: self.autobackup_start_time += self.autobackup_hour*3600 self.create_backup() - + def create_backup(self): io.log_info ("Creating backup...", end='\r') - + if not self.autobackups_path.exists(): self.autobackups_path.mkdir(exist_ok=True) - + bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ] bckp_filename_list += [ str(self.get_summary_path()), str(self.model_data_path) ] - + for i in range(24,0,-1): idx_str = '%.2d' % i next_idx_str = '%.2d' % (i+1) @@ -427,7 +432,7 @@ class ModelBase(object): return imagelib.equalize_and_stack_square (images) def generate_next_samples(self): - sample = [] + sample = [] for generator in self.generator_list: if generator.is_initialized(): sample.append ( generator.generate_next() ) @@ -445,7 +450,7 @@ class ModelBase(object): self.loss_history.append ( [float(loss[1]) for loss in losses] ) if (not io.is_colab() and self.iter % 10 == 0) or \ - (io.is_colab() and self.iter % 100 == 0): + (io.is_colab() and self.iter % 100 == 0): plist = [] if io.is_colab(): @@ -455,8 +460,15 @@ class ModelBase(object): plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ] if self.write_preview_history: - plist += [ (self.get_static_preview(), str (self.preview_history_path / ('%.6d.jpg' % (self.iter))) ) ] - + previews = self.get_static_previews() + for i in range(len(previews)): + name, bgr = previews[i] + path = self.preview_history_path / name + path.mkdir(parents=True, exist_ok=True) + plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ] + if not io.is_colab(): + plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ] + for preview, filepath in plist: preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) @@ -508,7 +520,7 @@ class ModelBase(object): def get_summary_path(self): return self.get_strpath_storage_for_file('summary.txt') - + def get_summary_text(self): ###Generate text summary of model hyperparameters #Find the longest key name and value string. Used as column widths.