diff --git a/models/ModelBase.py b/models/ModelBase.py index d243e87..b5b5b5c 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -143,12 +143,12 @@ class ModelBase(object): self.batch_size = 1 if self.is_training_mode: - if self.write_preview_history: - if self.device_args['force_gpu_idx'] == -1: - self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) ) - else: - self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) ) + if self.device_args['force_gpu_idx'] == -1: + self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) ) + else: + self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) ) + if self.write_preview_history or io.is_colab(): if not self.preview_history_path.exists(): self.preview_history_path.mkdir(exist_ok=True) else: @@ -365,12 +365,19 @@ class ModelBase(object): self.loss_history.append ( [float(loss[1]) for loss in losses] ) - if self.write_preview_history: - if self.iter % 10 == 0: - preview = self.get_static_preview() + if self.iter % 10 == 0: + plist = [] + if io.is_colab(): + plist += [ (self.get_previews()[0][1], 'preview.jpg') ] + + if self.write_preview_history: + plist += [ (self.get_static_preview(), '%.6d.jpg' %(self.iter) ) ] + + for preview, filename in plist: preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) - cv2_imwrite ( str (self.preview_history_path / ('%.6d.jpg' %( self.iter) )), img ) + cv2_imwrite ( str (self.preview_history_path / filename), img ) + self.iter += 1