feature: random LAB rotation

This commit is contained in:
jh 2021-03-12 14:24:58 -08:00
commit fcde64fab5
3 changed files with 85 additions and 56 deletions

View file

@ -92,7 +92,7 @@ def color_transfer_mkl(x0, x1):
def color_transfer_idt(i0, i1, bins=256, n_rot=20): def color_transfer_idt(i0, i1, bins=256, n_rot=20):
import scipy.stats import scipy.stats
relaxation = 1 / n_rot relaxation = 1 / n_rot
h,w,c = i0.shape h,w,c = i0.shape
h1,w1,c1 = i1.shape h1,w1,c1 = i1.shape
@ -381,6 +381,25 @@ def color_augmentation(img):
return (face / 255.0).astype(np.float32) return (face / 255.0).astype(np.float32)
def random_lab_rotation(image, seed=None):
"""
Randomly rotates image color around the L axis in LAB colorspace,
keeping perceptual lightness constant.
"""
image = cv2.cvtColor(image.astype(np.float32), cv2.COLOR_BGR2LAB)
M = np.eye(3)
M[1:, 1:] = special_ortho_group.rvs(2, 1, seed)
image = image.dot(M)
l, a, b = cv2.split(image)
l = np.clip(l, 0, 100)
a = np.clip(a, -127, 127)
b = np.clip(b, -127, 127)
image = cv2.merge([l, a, b])
image = cv2.cvtColor(image.astype(np.float32), cv2.COLOR_LAB2BGR)
np.clip(image, 0, 1, out=image)
return image
def random_lab(image): def random_lab(image):
""" Perform random color/lightness adjustment in L*a*b* colorspace """ """ Perform random color/lightness adjustment in L*a*b* colorspace """
amount_l = 30 / 100 amount_l = 30 / 100
@ -416,4 +435,4 @@ def random_clahe(image):
tileGridSize=(grid_size, grid_size)) tileGridSize=(grid_size, grid_size))
for chan in range(3): for chan in range(3):
image[:, :, chan] = clahe.apply(image[:, :, chan]) image[:, :, chan] = clahe.apply(image[:, :, chan])
return image return image

View file

@ -58,6 +58,7 @@ class SAEHDModel(ModelBase):
default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0) default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0)
default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0) default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0)
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
default_random_color = self.options['random_color'] = self.load_or_def_option('random_color', False)
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
@ -167,6 +168,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn the area outside mask of the predicted face to be the same as dst. If you want to use this option with 'whole_face' you have to use XSeg trained mask. For whole_face you have to use XSeg trained mask. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 ) self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn the area outside mask of the predicted face to be the same as dst. If you want to use this option with 'whole_face' you have to use XSeg trained mask. For whole_face you have to use XSeg trained mask. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 )
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot', 'fs-aug'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best. FS aug adds random color to dst and src") self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot', 'fs-aug'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best. FS aug adds random color to dst and src")
self.options['random_color'] = io.input_bool ("Random color", default_random_color, help_message="Samples are randomly rotated around the L axis in LAB colorspace, helps generalize training")
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.") self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")
@ -650,11 +652,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if ct_mode == 'fs-aug': if ct_mode == 'fs-aug':
fs_aug = 'fs-aug' fs_aug = 'fs-aug'
channel_type = SampleProcessor.ChannelType.LAB_RAND_TRANSFORM if self.options['random_color'] else SampleProcessor.ChannelType.BGR
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
@ -663,8 +668,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],

View file

@ -8,6 +8,7 @@ import numpy as np
from core import imagelib from core import imagelib
from core.cv2ex import * from core.cv2ex import *
from core.imagelib import sd from core.imagelib import sd
from core.imagelib.color_transfer import random_lab_rotation
from facelib import FaceType, LandmarksProcessor from facelib import FaceType, LandmarksProcessor
@ -26,6 +27,8 @@ class SampleProcessor(object):
BGR = 1 #BGR BGR = 1 #BGR
G = 2 #Grayscale G = 2 #Grayscale
GGG = 3 #3xGrayscale GGG = 3 #3xGrayscale
LAB_RAND_TRANSFORM = 4 # LAB random transform
class FaceMaskType(IntEnum): class FaceMaskType(IntEnum):
NONE = 0 NONE = 0
@ -56,18 +59,18 @@ class SampleProcessor(object):
sample_landmarks = sample.landmarks sample_landmarks = sample.landmarks
ct_sample_bgr = None ct_sample_bgr = None
h,w,c = sample_bgr.shape h,w,c = sample_bgr.shape
def get_full_face_mask(): def get_full_face_mask():
xseg_mask = sample.get_xseg_mask() xseg_mask = sample.get_xseg_mask()
if xseg_mask is not None: if xseg_mask is not None:
if xseg_mask.shape[0] != h or xseg_mask.shape[1] != w: if xseg_mask.shape[0] != h or xseg_mask.shape[1] != w:
xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC) xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC)
xseg_mask = imagelib.normalize_channels(xseg_mask, 1) xseg_mask = imagelib.normalize_channels(xseg_mask, 1)
return np.clip(xseg_mask, 0, 1) return np.clip(xseg_mask, 0, 1)
else: else:
full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod )
return np.clip(full_face_mask, 0, 1) return np.clip(full_face_mask, 0, 1)
def get_eyes_mask(): def get_eyes_mask():
eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks) eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks)
# set eye masks to 1-2 # set eye masks to 1-2
@ -86,25 +89,25 @@ class SampleProcessor(object):
if debug and is_face_sample: if debug and is_face_sample:
LandmarksProcessor.draw_landmarks (sample_bgr, sample_landmarks, (0, 1, 0)) LandmarksProcessor.draw_landmarks (sample_bgr, sample_landmarks, (0, 1, 0))
params_per_resolution = {} params_per_resolution = {}
warp_rnd_state = np.random.RandomState (sample_rnd_seed-1) warp_rnd_state = np.random.RandomState (sample_rnd_seed-1)
for opts in output_sample_types: for opts in output_sample_types:
resolution = opts.get('resolution', None) resolution = opts.get('resolution', None)
if resolution is None: if resolution is None:
continue continue
params_per_resolution[resolution] = imagelib.gen_warp_params(resolution, params_per_resolution[resolution] = imagelib.gen_warp_params(resolution,
sample_process_options.random_flip, sample_process_options.random_flip,
rotation_range=sample_process_options.rotation_range, rotation_range=sample_process_options.rotation_range,
scale_range=sample_process_options.scale_range, scale_range=sample_process_options.scale_range,
tx_range=sample_process_options.tx_range, tx_range=sample_process_options.tx_range,
ty_range=sample_process_options.ty_range, ty_range=sample_process_options.ty_range,
rnd_state=warp_rnd_state) rnd_state=warp_rnd_state)
outputs_sample = [] outputs_sample = []
for opts in output_sample_types: for opts in output_sample_types:
sample_type = opts.get('sample_type', SPST.NONE) sample_type = opts.get('sample_type', SPST.NONE)
channel_type = opts.get('channel_type', SPCT.NONE) channel_type = opts.get('channel_type', SPCT.NONE)
resolution = opts.get('resolution', 0) resolution = opts.get('resolution', 0)
nearest_resize_to = opts.get('nearest_resize_to', None) nearest_resize_to = opts.get('nearest_resize_to', None)
warp = opts.get('warp', False) warp = opts.get('warp', False)
@ -118,29 +121,29 @@ class SampleProcessor(object):
normalize_tanh = opts.get('normalize_tanh', False) normalize_tanh = opts.get('normalize_tanh', False)
ct_mode = opts.get('ct_mode', None) ct_mode = opts.get('ct_mode', None)
data_format = opts.get('data_format', 'NHWC') data_format = opts.get('data_format', 'NHWC')
if sample_type == SPST.FACE_MASK or sample_type == SPST.IMAGE: if sample_type == SPST.FACE_MASK or sample_type == SPST.IMAGE:
border_replicate = False border_replicate = False
elif sample_type == SPST.FACE_IMAGE: elif sample_type == SPST.FACE_IMAGE:
border_replicate = True border_replicate = True
border_replicate = opts.get('border_replicate', border_replicate) border_replicate = opts.get('border_replicate', border_replicate)
borderMode = cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT borderMode = cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK: if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK:
if not is_face_sample: if not is_face_sample:
raise ValueError("face_samples should be provided for sample_type FACE_*") raise ValueError("face_samples should be provided for sample_type FACE_*")
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK: if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK:
face_type = opts.get('face_type', None) face_type = opts.get('face_type', None)
face_mask_type = opts.get('face_mask_type', SPFMT.NONE) face_mask_type = opts.get('face_mask_type', SPFMT.NONE)
if face_type is None: if face_type is None:
raise ValueError("face_type must be defined for face samples") raise ValueError("face_type must be defined for face samples")
if sample_type == SPST.FACE_MASK: if sample_type == SPST.FACE_MASK:
if face_mask_type == SPFMT.FULL_FACE: if face_mask_type == SPFMT.FULL_FACE:
img = get_full_face_mask() img = get_full_face_mask()
elif face_mask_type == SPFMT.EYES: elif face_mask_type == SPFMT.EYES:
@ -149,42 +152,42 @@ class SampleProcessor(object):
# sets both eyes and mouth mask parts # sets both eyes and mouth mask parts
img = get_full_face_mask() img = get_full_face_mask()
mask = img.copy() mask = img.copy()
mask[mask != 0.0] = 1.0 mask[mask != 0.0] = 1.0
eye_mask = get_eyes_mask() * mask eye_mask = get_eyes_mask() * mask
img = np.where(eye_mask > 1, eye_mask, img) img = np.where(eye_mask > 1, eye_mask, img)
mouth_mask = get_mouth_mask() * mask mouth_mask = get_mouth_mask() * mask
img = np.where(mouth_mask > 2, mouth_mask, img) img = np.where(mouth_mask > 2, mouth_mask, img)
else: else:
img = np.zeros ( sample_bgr.shape[0:2]+(1,), dtype=np.float32) img = np.zeros ( sample_bgr.shape[0:2]+(1,), dtype=np.float32)
if sample_face_type == FaceType.MARK_ONLY: if sample_face_type == FaceType.MARK_ONLY:
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type) mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type)
img = cv2.warpAffine( img, mat, (warp_resolution, warp_resolution), flags=cv2.INTER_LINEAR ) img = cv2.warpAffine( img, mat, (warp_resolution, warp_resolution), flags=cv2.INTER_LINEAR )
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR)
img = cv2.resize( img, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) img = cv2.resize( img, (resolution,resolution), interpolation=cv2.INTER_LINEAR )
else: else:
if face_type != sample_face_type: if face_type != sample_face_type:
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type)
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_LINEAR ) img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_LINEAR )
else: else:
if w != resolution: if w != resolution:
img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_LINEAR ) img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_LINEAR )
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR)
if len(img.shape) == 2: if len(img.shape) == 2:
img = img[...,None] img = img[...,None]
if channel_type == SPCT.G: if channel_type == SPCT.G:
out_sample = img.astype(np.float32) out_sample = img.astype(np.float32)
else: else:
raise ValueError("only channel_type.G supported for the mask") raise ValueError("only channel_type.G supported for the mask")
elif sample_type == SPST.FACE_IMAGE: elif sample_type == SPST.FACE_IMAGE:
img = sample_bgr img = sample_bgr
if random_rgb_levels: if random_rgb_levels:
random_mask = sd.random_circle_faded ([w,w], rnd_state=np.random.RandomState (sample_rnd_seed) ) if random_circle_mask else None random_mask = sd.random_circle_faded ([w,w], rnd_state=np.random.RandomState (sample_rnd_seed) ) if random_circle_mask else None
img = imagelib.apply_random_rgb_levels(img, mask=random_mask, rnd_state=np.random.RandomState (sample_rnd_seed) ) img = imagelib.apply_random_rgb_levels(img, mask=random_mask, rnd_state=np.random.RandomState (sample_rnd_seed) )
@ -193,15 +196,15 @@ class SampleProcessor(object):
random_mask = sd.random_circle_faded ([w,w], rnd_state=np.random.RandomState (sample_rnd_seed+1) ) if random_circle_mask else None random_mask = sd.random_circle_faded ([w,w], rnd_state=np.random.RandomState (sample_rnd_seed+1) ) if random_circle_mask else None
img = imagelib.apply_random_hsv_shift(img, mask=random_mask, rnd_state=np.random.RandomState (sample_rnd_seed+1) ) img = imagelib.apply_random_hsv_shift(img, mask=random_mask, rnd_state=np.random.RandomState (sample_rnd_seed+1) )
if face_type != sample_face_type: if face_type != sample_face_type:
mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type)
img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC ) img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC )
else: else:
if w != resolution: if w != resolution:
img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC )
# Apply random color transfer # Apply random color transfer
if ct_mode is not None and ct_sample is not None or ct_mode == 'fs-aug': if ct_mode is not None and ct_sample is not None or ct_mode == 'fs-aug':
if ct_mode == 'fs-aug': if ct_mode == 'fs-aug':
img = imagelib.color_augmentation(img) img = imagelib.color_augmentation(img)
@ -210,27 +213,29 @@ class SampleProcessor(object):
ct_sample_bgr = ct_sample.load_bgr() ct_sample_bgr = ct_sample.load_bgr()
img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) ) img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) )
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate) img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate)
img = np.clip(img.astype(np.float32), 0, 1) img = np.clip(img.astype(np.float32), 0, 1)
if motion_blur is not None: if motion_blur is not None:
random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+2)) if random_circle_mask else None random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+2)) if random_circle_mask else None
img = imagelib.apply_random_motion_blur(img, *motion_blur, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+2) ) img = imagelib.apply_random_motion_blur(img, *motion_blur, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+2) )
if gaussian_blur is not None: if gaussian_blur is not None:
random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+3)) if random_circle_mask else None random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+3)) if random_circle_mask else None
img = imagelib.apply_random_gaussian_blur(img, *gaussian_blur, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+3) ) img = imagelib.apply_random_gaussian_blur(img, *gaussian_blur, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+3) )
if random_bilinear_resize is not None: if random_bilinear_resize is not None:
random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+4)) if random_circle_mask else None random_mask = sd.random_circle_faded ([resolution,resolution], rnd_state=np.random.RandomState (sample_rnd_seed+4)) if random_circle_mask else None
img = imagelib.apply_random_bilinear_resize(img, *random_bilinear_resize, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+4) ) img = imagelib.apply_random_bilinear_resize(img, *random_bilinear_resize, mask=random_mask,rnd_state=np.random.RandomState (sample_rnd_seed+4) )
# Transform from BGR to desired channel_type # Transform from BGR to desired channel_type
if channel_type == SPCT.BGR: if channel_type == SPCT.BGR:
out_sample = img out_sample = img
elif channel_type == SPCT.LAB_RAND_TRANSFORM:
out_sample = random_lab_rotation(img)
elif channel_type == SPCT.G: elif channel_type == SPCT.G:
out_sample = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[...,None] out_sample = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[...,None]
elif channel_type == SPCT.GGG: elif channel_type == SPCT.GGG:
@ -239,22 +244,22 @@ class SampleProcessor(object):
# Final transformations # Final transformations
if nearest_resize_to is not None: if nearest_resize_to is not None:
out_sample = cv2_resize(out_sample, (nearest_resize_to,nearest_resize_to), interpolation=cv2.INTER_NEAREST) out_sample = cv2_resize(out_sample, (nearest_resize_to,nearest_resize_to), interpolation=cv2.INTER_NEAREST)
if not debug: if not debug:
if normalize_tanh: if normalize_tanh:
out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0) out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0)
if data_format == "NCHW": if data_format == "NCHW":
out_sample = np.transpose(out_sample, (2,0,1) ) out_sample = np.transpose(out_sample, (2,0,1) )
elif sample_type == SPST.IMAGE: elif sample_type == SPST.IMAGE:
img = sample_bgr img = sample_bgr
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=True) img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=True)
img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC )
out_sample = img out_sample = img
if data_format == "NCHW": if data_format == "NCHW":
out_sample = np.transpose(out_sample, (2,0,1) ) out_sample = np.transpose(out_sample, (2,0,1) )
elif sample_type == SPST.LANDMARKS_ARRAY: elif sample_type == SPST.LANDMARKS_ARRAY:
l = sample_landmarks l = sample_landmarks
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 ) l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )