refactoring

This commit is contained in:
iperov 2018-12-24 13:45:40 +04:00
parent 44798c2b85
commit 8a223845fb
19 changed files with 963 additions and 468 deletions

View file

@ -10,7 +10,7 @@ from utils import image_utils
import numpy as np
import cv2
import gpufmkmgr
from .TrainingDataGeneratorBase import TrainingDataGeneratorBase
from samples import SampleGeneratorBase
'''
You can implement your own model. Check examples.
@ -47,13 +47,11 @@ class ModelBase(object):
self.epoch = model_data['epoch']
self.options = model_data['options']
self.loss_history = model_data['loss_history'] if 'loss_history' in model_data.keys() else []
self.generator_dict_states = model_data['generator_dict_states'] if 'generator_dict_states' in model_data.keys() else None
self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None
else:
self.epoch = 0
self.options = {}
self.loss_history = []
self.generator_dict_states = None
self.sample_for_preview = None
if self.write_preview_history:
@ -97,11 +95,8 @@ class ModelBase(object):
raise Exception( 'You didnt set_training_data_generators()')
else:
for i, generator in enumerate(self.generator_list):
if not isinstance(generator, TrainingDataGeneratorBase):
raise Exception('training data generator is not subclass of TrainingDataGeneratorBase')
if self.generator_dict_states is not None and i < len(self.generator_dict_states):
generator.set_dict_state ( self.generator_dict_states[i] )
if not isinstance(generator, SampleGeneratorBase):
raise Exception('training data generator is not subclass of SampleGeneratorBase')
if self.sample_for_preview is None:
self.sample_for_preview = self.generate_next_sample()
@ -212,7 +207,6 @@ class ModelBase(object):
'epoch': self.epoch,
'options': self.options,
'loss_history': self.loss_history,
'generator_dict_states' : [generator.get_dict_state() for generator in self.generator_list],
'sample_for_preview' : self.sample_for_preview
}
self.model_data_path.write_bytes( pickle.dumps(model_data) )