SAE collapse fix (#245)

* test

* _

* _

* upd dev_poseest

* SAE: finally collapses are fixed

* fix batch size help
This commit is contained in:
iperov 2019-04-24 09:38:26 +04:00 committed by GitHub
parent d62785ca5a
commit e1da9c56b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 70 deletions

View file

@ -44,18 +44,17 @@ class Model(ModelBase):
if self.is_training_mode:
f = SampleProcessor.TypeFlags
face_type = f.FACE_TYPE_FULL if self.options['face_type'] == 'f' else f.FACE_TYPE_HALF
normalize_vgg = False
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
sample_process_options=SampleProcessor.Options( rotation_range=[0,0], motion_blur = [25, 1] ), #random_flip=True,
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE | f.OPT_APPLY_MOTION_BLUR, self.resolution, {'normalize_vgg':normalize_vgg} ],
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE | f.OPT_APPLY_MOTION_BLUR, self.resolution ],
[f.PITCH_YAW_ROLL],
]),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE, self.resolution, {'normalize_vgg':normalize_vgg} ],
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE, self.resolution ],
[f.PITCH_YAW_ROLL],
])
])
@ -68,9 +67,9 @@ class Model(ModelBase):
def onTrainOneIter(self, generators_samples, generators_list):
target_src, pitch_yaw_roll = generators_samples[0]
loss = self.pose_est.train_on_batch( target_src, pitch_yaw_roll )
pitch_loss,yaw_loss,roll_loss = self.pose_est.train_on_batch( target_src, pitch_yaw_roll )
return ( ('loss', loss), )
return ( ('pitch_loss', pitch_loss), ('yaw_loss', yaw_loss), ('roll_loss', roll_loss) )
#override
def onGetPreview(self, generators_samples):