mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
refactoring
This commit is contained in:
parent
44798c2b85
commit
8a223845fb
19 changed files with 963 additions and 468 deletions
|
@ -10,7 +10,7 @@ from utils import image_utils
|
|||
import numpy as np
|
||||
import cv2
|
||||
import gpufmkmgr
|
||||
from .TrainingDataGeneratorBase import TrainingDataGeneratorBase
|
||||
from samples import SampleGeneratorBase
|
||||
|
||||
'''
|
||||
You can implement your own model. Check examples.
|
||||
|
@ -47,13 +47,11 @@ class ModelBase(object):
|
|||
self.epoch = model_data['epoch']
|
||||
self.options = model_data['options']
|
||||
self.loss_history = model_data['loss_history'] if 'loss_history' in model_data.keys() else []
|
||||
self.generator_dict_states = model_data['generator_dict_states'] if 'generator_dict_states' in model_data.keys() else None
|
||||
self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None
|
||||
else:
|
||||
self.epoch = 0
|
||||
self.options = {}
|
||||
self.loss_history = []
|
||||
self.generator_dict_states = None
|
||||
self.sample_for_preview = None
|
||||
|
||||
if self.write_preview_history:
|
||||
|
@ -97,11 +95,8 @@ class ModelBase(object):
|
|||
raise Exception( 'You didnt set_training_data_generators()')
|
||||
else:
|
||||
for i, generator in enumerate(self.generator_list):
|
||||
if not isinstance(generator, TrainingDataGeneratorBase):
|
||||
raise Exception('training data generator is not subclass of TrainingDataGeneratorBase')
|
||||
|
||||
if self.generator_dict_states is not None and i < len(self.generator_dict_states):
|
||||
generator.set_dict_state ( self.generator_dict_states[i] )
|
||||
if not isinstance(generator, SampleGeneratorBase):
|
||||
raise Exception('training data generator is not subclass of SampleGeneratorBase')
|
||||
|
||||
if self.sample_for_preview is None:
|
||||
self.sample_for_preview = self.generate_next_sample()
|
||||
|
@ -212,7 +207,6 @@ class ModelBase(object):
|
|||
'epoch': self.epoch,
|
||||
'options': self.options,
|
||||
'loss_history': self.loss_history,
|
||||
'generator_dict_states' : [generator.get_dict_state() for generator in self.generator_list],
|
||||
'sample_for_preview' : self.sample_for_preview
|
||||
}
|
||||
self.model_data_path.write_bytes( pickle.dumps(model_data) )
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue