This commit is contained in:
Colombo 2020-03-16 22:40:55 +04:00
parent 2300da40e9
commit f3b4658810
3 changed files with 14 additions and 14 deletions

View file

@ -1,4 +1,5 @@
import collections
import math
from enum import IntEnum
import cv2
@ -7,6 +8,7 @@ import numpy as np
from core import imagelib
from facelib import FaceType, LandmarksProcessor
class SampleProcessor(object):
class SampleType(IntEnum):
NONE = 0
@ -114,8 +116,8 @@ class SampleProcessor(object):
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK:
if not is_face_sample:
raise ValueError("face_samples should be provided for sample_type FACE_*")
if is_face_sample:
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK:
face_type = opts.get('face_type', None)
face_mask_type = opts.get('face_mask_type', SPFMT.NONE)
@ -125,7 +127,6 @@ class SampleProcessor(object):
if face_type > sample.face_type:
raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, face_type) )
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK:
if sample_type == SPST.FACE_MASK:
@ -156,7 +157,7 @@ class SampleProcessor(object):
img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC )
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR)
if len(img.shape) == 2:
img = img[...,None]
@ -175,11 +176,11 @@ class SampleProcessor(object):
else:
if w != resolution:
img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC )
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate)
img = np.clip(img.astype(np.float32), 0, 1)
# Apply random color transfer
if ct_mode is not None and ct_sample is not None:
@ -273,9 +274,8 @@ class SampleProcessor(object):
l = np.clip(l, 0.0, 1.0)
out_sample = l
elif sample_type == SPST.PITCH_YAW_ROLL or sample_type == SPST.PITCH_YAW_ROLL_SIGMOID:
pitch_yaw_roll = sample.get_pitch_yaw_roll()
if params['flip']:
pitch,yaw,roll = sample.get_pitch_yaw_roll()
if params_per_resolution[resolution]['flip']:
yaw = -yaw
if sample_type == SPST.PITCH_YAW_ROLL_SIGMOID:
@ -283,7 +283,7 @@ class SampleProcessor(object):
yaw = np.clip( (yaw / math.pi) / 2.0 + 0.5, 0, 1)
roll = np.clip( (roll / math.pi) / 2.0 + 0.5, 0, 1)
out_sample = (pitch, yaw, roll)
out_sample = (pitch, yaw)
else:
raise ValueError ('expected sample_type')