mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-08 05:51:40 -07:00
fixes
This commit is contained in:
parent
2300da40e9
commit
f3b4658810
3 changed files with 14 additions and 14 deletions
|
@ -25,8 +25,8 @@ def gen_warp_params (w, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5],
|
||||||
|
|
||||||
half_cell_size = cell_size // 2
|
half_cell_size = cell_size // 2
|
||||||
|
|
||||||
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32)
|
mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
|
||||||
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size-1,half_cell_size:-half_cell_size-1].astype(np.float32)
|
mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32)
|
||||||
|
|
||||||
#random transform
|
#random transform
|
||||||
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
|
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale)
|
||||||
|
|
|
@ -64,7 +64,7 @@ class Sample(object):
|
||||||
|
|
||||||
def get_pitch_yaw_roll(self):
|
def get_pitch_yaw_roll(self):
|
||||||
if self.pitch_yaw_roll is None:
|
if self.pitch_yaw_roll is None:
|
||||||
self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(landmarks, size=self.shape[1])
|
self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1])
|
||||||
return self.pitch_yaw_roll
|
return self.pitch_yaw_roll
|
||||||
|
|
||||||
def set_filename_offset_size(self, filename, offset, size):
|
def set_filename_offset_size(self, filename, offset, size):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import collections
|
import collections
|
||||||
|
import math
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
@ -7,6 +8,7 @@ import numpy as np
|
||||||
from core import imagelib
|
from core import imagelib
|
||||||
from facelib import FaceType, LandmarksProcessor
|
from facelib import FaceType, LandmarksProcessor
|
||||||
|
|
||||||
|
|
||||||
class SampleProcessor(object):
|
class SampleProcessor(object):
|
||||||
class SampleType(IntEnum):
|
class SampleType(IntEnum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
|
@ -115,7 +117,7 @@ class SampleProcessor(object):
|
||||||
if not is_face_sample:
|
if not is_face_sample:
|
||||||
raise ValueError("face_samples should be provided for sample_type FACE_*")
|
raise ValueError("face_samples should be provided for sample_type FACE_*")
|
||||||
|
|
||||||
if is_face_sample:
|
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK:
|
||||||
face_type = opts.get('face_type', None)
|
face_type = opts.get('face_type', None)
|
||||||
face_mask_type = opts.get('face_mask_type', SPFMT.NONE)
|
face_mask_type = opts.get('face_mask_type', SPFMT.NONE)
|
||||||
|
|
||||||
|
@ -125,7 +127,6 @@ class SampleProcessor(object):
|
||||||
if face_type > sample.face_type:
|
if face_type > sample.face_type:
|
||||||
raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, face_type) )
|
raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample.filename, sample.face_type, face_type) )
|
||||||
|
|
||||||
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK:
|
|
||||||
|
|
||||||
if sample_type == SPST.FACE_MASK:
|
if sample_type == SPST.FACE_MASK:
|
||||||
|
|
||||||
|
@ -175,12 +176,12 @@ class SampleProcessor(object):
|
||||||
else:
|
else:
|
||||||
if w != resolution:
|
if w != resolution:
|
||||||
img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC )
|
img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC )
|
||||||
|
|
||||||
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate)
|
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate)
|
||||||
|
|
||||||
img = np.clip(img.astype(np.float32), 0, 1)
|
img = np.clip(img.astype(np.float32), 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Apply random color transfer
|
# Apply random color transfer
|
||||||
if ct_mode is not None and ct_sample is not None:
|
if ct_mode is not None and ct_sample is not None:
|
||||||
if ct_sample_bgr is None:
|
if ct_sample_bgr is None:
|
||||||
|
@ -273,9 +274,8 @@ class SampleProcessor(object):
|
||||||
l = np.clip(l, 0.0, 1.0)
|
l = np.clip(l, 0.0, 1.0)
|
||||||
out_sample = l
|
out_sample = l
|
||||||
elif sample_type == SPST.PITCH_YAW_ROLL or sample_type == SPST.PITCH_YAW_ROLL_SIGMOID:
|
elif sample_type == SPST.PITCH_YAW_ROLL or sample_type == SPST.PITCH_YAW_ROLL_SIGMOID:
|
||||||
pitch_yaw_roll = sample.get_pitch_yaw_roll()
|
pitch,yaw,roll = sample.get_pitch_yaw_roll()
|
||||||
|
if params_per_resolution[resolution]['flip']:
|
||||||
if params['flip']:
|
|
||||||
yaw = -yaw
|
yaw = -yaw
|
||||||
|
|
||||||
if sample_type == SPST.PITCH_YAW_ROLL_SIGMOID:
|
if sample_type == SPST.PITCH_YAW_ROLL_SIGMOID:
|
||||||
|
@ -283,7 +283,7 @@ class SampleProcessor(object):
|
||||||
yaw = np.clip( (yaw / math.pi) / 2.0 + 0.5, 0, 1)
|
yaw = np.clip( (yaw / math.pi) / 2.0 + 0.5, 0, 1)
|
||||||
roll = np.clip( (roll / math.pi) / 2.0 + 0.5, 0, 1)
|
roll = np.clip( (roll / math.pi) / 2.0 + 0.5, 0, 1)
|
||||||
|
|
||||||
out_sample = (pitch, yaw, roll)
|
out_sample = (pitch, yaw)
|
||||||
else:
|
else:
|
||||||
raise ValueError ('expected sample_type')
|
raise ValueError ('expected sample_type')
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue