preview file for colab

This commit is contained in:
iperov 2019-03-26 17:14:59 +04:00
parent 96de328221
commit 6241c8af41

View file

@ -143,12 +143,12 @@ class ModelBase(object):
self.batch_size = 1 self.batch_size = 1
if self.is_training_mode: if self.is_training_mode:
if self.write_preview_history: if self.device_args['force_gpu_idx'] == -1:
if self.device_args['force_gpu_idx'] == -1: self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) )
self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) ) else:
else: self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
if self.write_preview_history or io.is_colab():
if not self.preview_history_path.exists(): if not self.preview_history_path.exists():
self.preview_history_path.mkdir(exist_ok=True) self.preview_history_path.mkdir(exist_ok=True)
else: else:
@ -365,12 +365,19 @@ class ModelBase(object):
self.loss_history.append ( [float(loss[1]) for loss in losses] ) self.loss_history.append ( [float(loss[1]) for loss in losses] )
if self.write_preview_history: if self.iter % 10 == 0:
if self.iter % 10 == 0: plist = []
preview = self.get_static_preview() if io.is_colab():
plist += [ (self.get_previews()[0][1], 'preview.jpg') ]
if self.write_preview_history:
plist += [ (self.get_static_preview(), '%.6d.jpg' %(self.iter) ) ]
for preview, filename in plist:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite ( str (self.preview_history_path / ('%.6d.jpg' %( self.iter) )), img ) cv2_imwrite ( str (self.preview_history_path / filename), img )
self.iter += 1 self.iter += 1