mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE collapse fix (#245)
* test * _ * _ * upd dev_poseest * SAE: finally collapses are fixed * fix batch size help
This commit is contained in:
parent
d62785ca5a
commit
e1da9c56b4
4 changed files with 87 additions and 70 deletions
|
@ -44,18 +44,17 @@ class Model(ModelBase):
|
|||
if self.is_training_mode:
|
||||
f = SampleProcessor.TypeFlags
|
||||
face_type = f.FACE_TYPE_FULL if self.options['face_type'] == 'f' else f.FACE_TYPE_HALF
|
||||
|
||||
normalize_vgg = False
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
|
||||
sample_process_options=SampleProcessor.Options( rotation_range=[0,0], motion_blur = [25, 1] ), #random_flip=True,
|
||||
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE | f.OPT_APPLY_MOTION_BLUR, self.resolution, {'normalize_vgg':normalize_vgg} ],
|
||||
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE | f.OPT_APPLY_MOTION_BLUR, self.resolution ],
|
||||
[f.PITCH_YAW_ROLL],
|
||||
]),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
|
||||
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
|
||||
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE, self.resolution, {'normalize_vgg':normalize_vgg} ],
|
||||
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE, self.resolution ],
|
||||
[f.PITCH_YAW_ROLL],
|
||||
])
|
||||
])
|
||||
|
@ -68,9 +67,9 @@ class Model(ModelBase):
|
|||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
target_src, pitch_yaw_roll = generators_samples[0]
|
||||
|
||||
loss = self.pose_est.train_on_batch( target_src, pitch_yaw_roll )
|
||||
pitch_loss,yaw_loss,roll_loss = self.pose_est.train_on_batch( target_src, pitch_yaw_roll )
|
||||
|
||||
return ( ('loss', loss), )
|
||||
return ( ('pitch_loss', pitch_loss), ('yaw_loss', yaw_loss), ('roll_loss', roll_loss) )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, generators_samples):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue