refactoring

This commit is contained in:
Colombo 2020-03-09 13:08:32 +04:00
parent 45abcff3d1
commit a030ff6951
2 changed files with 25 additions and 16 deletions

View file

@ -10,11 +10,12 @@ from facelib import FaceType, LandmarksProcessor
class SampleProcessor(object):
class SampleType(IntEnum):
NONE = 0
FACE_IMAGE = 1
FACE_MASK = 2
LANDMARKS_ARRAY = 3
PITCH_YAW_ROLL = 4
PITCH_YAW_ROLL_SIGMOID = 5
IMAGE = 1
FACE_IMAGE = 2
FACE_MASK = 3
LANDMARKS_ARRAY = 4
PITCH_YAW_ROLL = 5
PITCH_YAW_ROLL_SIGMOID = 6
class ChannelType(IntEnum):
NONE = 0
@ -92,11 +93,12 @@ class SampleProcessor(object):
ct_mode = opts.get('ct_mode', None)
data_format = opts.get('data_format', 'NHWC')
if sample_type == SPST.FACE_MASK:
if sample_type == SPST.FACE_MASK or sample_type == SPST.IMAGE:
border_replicate = False
elif sample_type == SPST.FACE_IMAGE:
border_replicate = True
border_replicate = opts.get('border_replicate', border_replicate)
borderMode = cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT
@ -230,9 +232,16 @@ class SampleProcessor(object):
out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0)
if data_format == "NCHW":
out_sample = np.transpose(out_sample, (2,0,1) )
#else:
# img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=True)
# img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC )
elif sample_type == SPST.IMAGE:
img = sample_bgr
img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=True)
img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC )
out_sample = img
if data_format == "NCHW":
out_sample = np.transpose(out_sample, (2,0,1) )
elif sample_type == SPST.LANDMARKS_ARRAY:
l = sample_landmarks
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )