mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
_
This commit is contained in:
parent
7140ee8684
commit
2cd028fb67
2 changed files with 33 additions and 60 deletions
|
@ -59,13 +59,13 @@ class Model(ModelBase):
|
|||
output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution, 'motion_blur':(25, 1) },
|
||||
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL,)}
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
|
||||
]),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
|
||||
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
|
||||
output_sample_types=[ {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL,)}
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
|
||||
])
|
||||
])
|
||||
|
||||
|
@ -77,9 +77,9 @@ class Model(ModelBase):
|
|||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
target_srcw, target_src, target_srcm, pitch_yaw_roll = generators_samples[0]
|
||||
|
||||
bgr_loss, pitch_loss,yaw_loss,roll_loss = self.pose_est.train_on_batch( target_srcw, target_src, target_srcm, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
|
||||
bgr_loss, pyr_loss = self.pose_est.train_on_batch( target_srcw, target_src, target_srcm, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
|
||||
|
||||
return ( ('bgr_loss', bgr_loss), ('pitch_loss', pitch_loss), ('yaw_loss', yaw_loss), ('roll_loss', roll_loss) )
|
||||
return ( ('bgr_loss', bgr_loss), ('pyr_loss', pyr_loss), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, generators_samples):
|
||||
|
@ -99,8 +99,11 @@ class Model(ModelBase):
|
|||
hor_imgs = []
|
||||
for i in range(len(img)):
|
||||
img_info = np.ones ( (h,w,c) ) * 0.1
|
||||
lines = ["%s" % ( str(pyr[i]) ),
|
||||
"%s" % ( str(pyr_pred[i]) ) ]
|
||||
|
||||
i_pyr = pyr[i]
|
||||
i_pyr_pred = pyr_pred[i]
|
||||
lines = ["%.4f %.4f %.4f" % (i_pyr[0],i_pyr[1],i_pyr[2]),
|
||||
"%.4f %.4f %.4f" % (i_pyr_pred[0],i_pyr_pred[1],i_pyr_pred[2]) ]
|
||||
|
||||
lines_count = len(lines)
|
||||
for ln in range(lines_count):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue