This commit is contained in:
iperov 2019-04-30 17:14:02 +04:00
parent 1f1f94848b
commit 659aa5705a
2 changed files with 107 additions and 102 deletions

View file

@ -37,7 +37,7 @@ class Model(ModelBase):
#override
def onInitialize(self):
exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {4:32} )
self.set_vram_batch_requirements( {4:64} )
self.resolution = 128
self.face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
@ -58,14 +58,13 @@ class Model(ModelBase):
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution, 'motion_blur':(25, 1) },
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution':self.resolution },
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
{'types': (t.IMG_PITCH_YAW_ROLL,)}
]),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
output_sample_types=[ {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
output_sample_types=[ {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':self.resolution },
{'types': (t.IMG_PITCH_YAW_ROLL,)}
])
])
@ -75,16 +74,16 @@ class Model(ModelBase):
#override
def onTrainOneIter(self, generators_samples, generators_list):
target_srcw, target_src, target_srcm, pitch_yaw_roll = generators_samples[0]
target_srcw, target_src, pitch_yaw_roll = generators_samples[0]
bgr_loss, pyr_loss = self.pose_est.train_on_batch( target_srcw, target_src, target_srcm, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
bgr_loss, pyr_loss = self.pose_est.train_on_batch( target_srcw, target_src, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
return ( ('bgr_loss', bgr_loss), ('pyr_loss', pyr_loss), )
#override
def onGetPreview(self, generators_samples):
test_src = generators_samples[0][1][0:4] #first 4 samples
test_pyr_src = generators_samples[0][3][0:4]
test_pyr_src = generators_samples[0][2][0:4]
test_dst = generators_samples[1][0][0:4]
test_pyr_dst = generators_samples[1][1][0:4]
@ -94,8 +93,8 @@ class Model(ModelBase):
result = []
for name, img, pyr in [ ['training data', test_src, test_pyr_src], \
['evaluating data',test_dst, test_pyr_dst] ]:
pyr_pred = self.pose_est.extract(img)
bgr_pred, pyr_pred = self.pose_est.extract(img)
hor_imgs = []
for i in range(len(img)):
img_info = np.ones ( (h,w,c) ) * 0.1
@ -112,6 +111,7 @@ class Model(ModelBase):
hor_imgs.append ( np.concatenate ( (
img[i,:,:,0:3],
bgr_pred[i],
img_info
), axis=1) )