This commit is contained in:
iperov 2019-04-28 21:16:27 +04:00
parent 7140ee8684
commit 2cd028fb67
2 changed files with 33 additions and 60 deletions

View file

@ -59,13 +59,13 @@ class Model(ModelBase):
output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution, 'motion_blur':(25, 1) },
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution':self.resolution },
{'types': (t.IMG_PITCH_YAW_ROLL,)}
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
]),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
output_sample_types=[ {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
{'types': (t.IMG_PITCH_YAW_ROLL,)}
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
])
])
@ -77,9 +77,9 @@ class Model(ModelBase):
def onTrainOneIter(self, generators_samples, generators_list):
target_srcw, target_src, target_srcm, pitch_yaw_roll = generators_samples[0]
bgr_loss, pitch_loss,yaw_loss,roll_loss = self.pose_est.train_on_batch( target_srcw, target_src, target_srcm, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
bgr_loss, pyr_loss = self.pose_est.train_on_batch( target_srcw, target_src, target_srcm, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
return ( ('bgr_loss', bgr_loss), ('pitch_loss', pitch_loss), ('yaw_loss', yaw_loss), ('roll_loss', roll_loss) )
return ( ('bgr_loss', bgr_loss), ('pyr_loss', pyr_loss), )
#override
def onGetPreview(self, generators_samples):
@ -99,8 +99,11 @@ class Model(ModelBase):
hor_imgs = []
for i in range(len(img)):
img_info = np.ones ( (h,w,c) ) * 0.1
lines = ["%s" % ( str(pyr[i]) ),
"%s" % ( str(pyr_pred[i]) ) ]
i_pyr = pyr[i]
i_pyr_pred = pyr_pred[i]
lines = ["%.4f %.4f %.4f" % (i_pyr[0],i_pyr[1],i_pyr[2]),
"%.4f %.4f %.4f" % (i_pyr_pred[0],i_pyr_pred[1],i_pyr_pred[2]) ]
lines_count = len(lines)
for ln in range(lines_count):