refactoring

This commit is contained in:
iperov 2018-12-24 13:45:40 +04:00
parent 44798c2b85
commit 8a223845fb
19 changed files with 963 additions and 468 deletions

View file

@ -1,7 +1,7 @@
from models import ModelBase
from models import TrainingDataType
import numpy as np
import cv2
from models import ModelBase
from samples import *
from nnlib import tf_dssim
from nnlib import DSSIMLossClass
from nnlib import conv
@ -72,21 +72,20 @@ class Model(ModelBase):
self.BA256_view = K.function ([input_B_warped64], [BA_rec256])
if self.is_training_mode:
from models import TrainingDataGenerator
f = TrainingDataGenerator.SampleTypeFlags
f = SampleProcessor.TypeFlags
self.set_training_data_generators ([
TrainingDataGenerator(TrainingDataType.FACE, self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, output_sample_types=[
[f.WARPED_TRANSFORMED | f.HALF_FACE | f.MODE_BGR, 64],
[f.TRANSFORMED | f.HALF_FACE | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FULL_FACE | f.MODE_BGR, 256],
[f.SOURCE | f.HALF_FACE | f.MODE_BGR, 64],
[f.SOURCE | f.HALF_FACE | f.MODE_BGR, 256] ] ),
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, output_sample_types=[
[f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 256],
[f.SOURCE | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.SOURCE | f.FACE_ALIGN_HALF | f.MODE_BGR, 256] ] ),
TrainingDataGenerator(TrainingDataType.FACE, self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, output_sample_types=[
[f.WARPED_TRANSFORMED | f.HALF_FACE | f.MODE_BGR, 64],
[f.TRANSFORMED | f.HALF_FACE | f.MODE_BGR, 64],
[f.SOURCE | f.HALF_FACE | f.MODE_BGR, 64],
[f.SOURCE | f.HALF_FACE | f.MODE_BGR, 256] ] )
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, output_sample_types=[
[f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.SOURCE | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.SOURCE | f.FACE_ALIGN_HALF | f.MODE_BGR, 256] ] )
])
#override
def onSave(self):