mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
_
This commit is contained in:
parent
1f1f94848b
commit
659aa5705a
2 changed files with 107 additions and 102 deletions
|
@ -37,7 +37,7 @@ class Model(ModelBase):
|
|||
#override
|
||||
def onInitialize(self):
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
self.set_vram_batch_requirements( {4:32} )
|
||||
self.set_vram_batch_requirements( {4:64} )
|
||||
|
||||
self.resolution = 128
|
||||
self.face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
||||
|
@ -58,14 +58,13 @@ class Model(ModelBase):
|
|||
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
|
||||
output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution, 'motion_blur':(25, 1) },
|
||||
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL,)}
|
||||
]),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
|
||||
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
|
||||
output_sample_types=[ {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
|
||||
output_sample_types=[ {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL,)}
|
||||
])
|
||||
])
|
||||
|
||||
|
@ -75,16 +74,16 @@ class Model(ModelBase):
|
|||
|
||||
#override
|
||||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
target_srcw, target_src, target_srcm, pitch_yaw_roll = generators_samples[0]
|
||||
target_srcw, target_src, pitch_yaw_roll = generators_samples[0]
|
||||
|
||||
bgr_loss, pyr_loss = self.pose_est.train_on_batch( target_srcw, target_src, target_srcm, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
|
||||
bgr_loss, pyr_loss = self.pose_est.train_on_batch( target_srcw, target_src, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
|
||||
|
||||
return ( ('bgr_loss', bgr_loss), ('pyr_loss', pyr_loss), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, generators_samples):
|
||||
test_src = generators_samples[0][1][0:4] #first 4 samples
|
||||
test_pyr_src = generators_samples[0][3][0:4]
|
||||
test_pyr_src = generators_samples[0][2][0:4]
|
||||
test_dst = generators_samples[1][0][0:4]
|
||||
test_pyr_dst = generators_samples[1][1][0:4]
|
||||
|
||||
|
@ -94,8 +93,8 @@ class Model(ModelBase):
|
|||
result = []
|
||||
for name, img, pyr in [ ['training data', test_src, test_pyr_src], \
|
||||
['evaluating data',test_dst, test_pyr_dst] ]:
|
||||
pyr_pred = self.pose_est.extract(img)
|
||||
|
||||
bgr_pred, pyr_pred = self.pose_est.extract(img)
|
||||
|
||||
hor_imgs = []
|
||||
for i in range(len(img)):
|
||||
img_info = np.ones ( (h,w,c) ) * 0.1
|
||||
|
@ -112,6 +111,7 @@ class Model(ModelBase):
|
|||
|
||||
hor_imgs.append ( np.concatenate ( (
|
||||
img[i,:,:,0:3],
|
||||
bgr_pred[i],
|
||||
img_info
|
||||
), axis=1) )
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue