added XSeg model

This commit is contained in:
Colombo 2020-03-09 13:09:46 +04:00
parent a030ff6951
commit b0b9072981
5 changed files with 513 additions and 0 deletions

View file

@ -0,0 +1,264 @@
import multiprocessing
import pickle
import time
import traceback
from enum import IntEnum
import cv2
import numpy as np
from core import imagelib, mplib, pathex
from core.cv2ex import *
from core.interact import interact as io
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from facelib import LandmarksProcessor
from samplelib import SampleGeneratorBase
class MaskType(IntEnum):
none = 0,
cloth = 1,
ear_r = 2,
eye_g = 3,
hair = 4,
hat = 5,
l_brow = 6,
l_ear = 7,
l_eye = 8,
l_lip = 9,
mouth = 10,
neck = 11,
neck_l = 12,
nose = 13,
r_brow = 14,
r_ear = 15,
r_eye = 16,
skin = 17,
u_lip = 18
MaskType_to_name = {
int(MaskType.none ) : 'none',
int(MaskType.cloth ) : 'cloth',
int(MaskType.ear_r ) : 'ear_r',
int(MaskType.eye_g ) : 'eye_g',
int(MaskType.hair ) : 'hair',
int(MaskType.hat ) : 'hat',
int(MaskType.l_brow) : 'l_brow',
int(MaskType.l_ear ) : 'l_ear',
int(MaskType.l_eye ) : 'l_eye',
int(MaskType.l_lip ) : 'l_lip',
int(MaskType.mouth ) : 'mouth',
int(MaskType.neck ) : 'neck',
int(MaskType.neck_l) : 'neck_l',
int(MaskType.nose ) : 'nose',
int(MaskType.r_brow) : 'r_brow',
int(MaskType.r_ear ) : 'r_ear',
int(MaskType.r_eye ) : 'r_eye',
int(MaskType.skin ) : 'skin',
int(MaskType.u_lip ) : 'u_lip',
}
MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() }
class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
def __init__ (self, root_path, debug=False, batch_size=1, resolution=256,
generators_count=4, data_format="NHWC",
**kwargs):
super().__init__(debug, batch_size)
self.initialized = False
dataset_path = root_path / 'CelebAMask-HQ'
if not dataset_path.exists():
raise ValueError(f'Unable to find {dataset_path}')
images_path = dataset_path /'CelebA-HQ-img'
if not images_path.exists():
raise ValueError(f'Unable to find {images_path}')
masks_path = dataset_path / 'CelebAMask-HQ-mask-anno'
if not masks_path.exists():
raise ValueError(f'Unable to find {masks_path}')
if self.debug:
self.generators_count = 1
else:
self.generators_count = max(1, generators_count)
source_images_paths = pathex.get_image_paths(images_path, return_Path_class=True)
source_images_paths_len = len(source_images_paths)
mask_images_paths = pathex.get_image_paths(masks_path, subdirs=True, return_Path_class=True)
if source_images_paths_len == 0 or len(mask_images_paths) == 0:
raise ValueError('No training data provided.')
mask_file_id_hash = {}
for filepath in io.progress_bar_generator(mask_images_paths, "Loading"):
stem = filepath.stem
file_id, mask_type = stem.split('_', 1)
file_id = int(file_id)
if file_id not in mask_file_id_hash:
mask_file_id_hash[file_id] = {}
mask_file_id_hash[file_id][ MaskType_from_name[mask_type] ] = str(filepath.relative_to(masks_path))
source_file_id_set = set()
for filepath in source_images_paths:
stem = filepath.stem
file_id = int(stem)
source_file_id_set.update ( {file_id} )
for k in mask_file_id_hash.keys():
if k not in source_file_id_set:
io.log_err (f"Corrupted dataset: {k} not in {images_path}")
if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format) )]
else:
self.generators = [SubprocessGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format), start_now=False ) \
for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators )
self.generator_counter = -1
self.initialized = True
#overridable
def is_initialized(self):
return self.initialized
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, param ):
images_path, masks_path, mask_file_id_hash, data_format = param
file_ids = list(mask_file_id_hash.keys())
shuffle_file_ids = []
shuffle_file_ids_random_ct = []
resolution = 256
random_flip = True
rotation_range=[-15,15]
scale_range=[-0.25, 0.75]
tx_range=[-0.3, 0.3]
ty_range=[-0.3, 0.3]
motion_blur = (25, 5)
gaussian_blur = (25, 5)
bs = self.batch_size
while True:
batches = None
n_batch = 0
while n_batch < bs:
try:
if len(shuffle_file_ids) == 0:
shuffle_file_ids = file_ids.copy()
np.random.shuffle(shuffle_file_ids)
if len(shuffle_file_ids_random_ct) == 0:
shuffle_file_ids_random_ct = file_ids.copy()
np.random.shuffle(shuffle_file_ids_random_ct)
file_id = shuffle_file_ids.pop()
#file_id_random_ct = shuffle_file_ids_random_ct.pop()
masks = mask_file_id_hash[file_id]
image_path = images_path / f'{file_id}.jpg'
#image_random_ct_path = images_path / f'{file_id_random_ct}.jpg'
skin_path = masks.get(MaskType.skin, None)
hair_path = masks.get(MaskType.hair, None)
hat_path = masks.get(MaskType.hat, None)
neck_path = masks.get(MaskType.neck, None)
img = cv2_imread(image_path).astype(np.float32) / 255.0
mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0
if hair_path is not None:
hair_path = masks_path / hair_path
if hair_path.exists():
hair = cv2_imread(hair_path)[...,0:1].astype(np.float32) / 255.0
mask *= (1-hair)
if hat_path is not None:
hat_path = masks_path / hat_path
if hat_path.exists():
hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0
mask *= (1-hat)
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 )
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
h = ( h + np.random.randint(360) ) % 360
s = np.clip ( s + np.random.random()-0.5, 0, 1 )
v = np.clip ( v + np.random.random()/2-0.25, 0, 1 )
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
#img_random_ct = cv2.resize( cv2_imread(image_random_ct_path).astype(np.float32) / 255.0, (resolution,resolution), cv2.INTER_LANCZOS4 )
#img = imagelib.color_transfer ('idt', img, img_random_ct )
if motion_blur is not None:
chance, mb_max_size = motion_blur
chance = np.clip(chance, 0, 100)
mblur_rnd_chance = np.random.randint(100)
mblur_rnd_kernel = np.random.randint(mb_max_size)+1
mblur_rnd_deg = np.random.randint(360)
if mblur_rnd_chance < chance:
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
if gaussian_blur is not None:
chance, kernel_max_size = gaussian_blur
chance = np.clip(chance, 0, 100)
gblur_rnd_chance = np.random.randint(100)
gblur_rnd_kernel = np.random.randint(kernel_max_size)*2+1
if gblur_rnd_chance < chance:
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None]
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
mask = np.clip(mask, 0, 1)
if data_format == "NCHW":
img = np.transpose(img, (2,0,1) )
mask = np.transpose(mask, (2,0,1) )
if batches is None:
batches = [ [], [] ]
batches[0].append ( img )
batches[1].append ( mask )
n_batch += 1
except:
io.log_err ( traceback.format_exc() )
yield [ np.array(batch) for batch in batches]

View file

@ -0,0 +1,66 @@
import traceback
import cv2
import numpy as np
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType)
class SampleGeneratorImage(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs):
super().__init__(debug, batch_size)
self.initialized = False
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
samples = SampleLoader.load (SampleType.IMAGE, samples_path)
if len(samples) == 0:
if raise_on_no_data:
raise ValueError('No training data provided.')
return
self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \
[SubprocessGenerator ( self.batch_func, samples )]
self.generator_counter = -1
self.initialized = True
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, samples):
samples_len = len(samples)
idxs = [ *range(samples_len) ]
shuffle_idxs = []
while True:
batches = None
for n_batch in range(self.batch_size):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle (shuffle_idxs)
idx = shuffle_idxs.pop()
sample = samples[idx]
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
if batches is None:
batches = [ [] for _ in range(len(x)) ]
for i in range(len(x)):
batches[i].append ( x[i] )
yield [ np.array(batch) for batch in batches]

View file

@ -6,5 +6,7 @@ from .SampleGeneratorBase import SampleGeneratorBase
from .SampleGeneratorFace import SampleGeneratorFace
from .SampleGeneratorFacePerson import SampleGeneratorFacePerson
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
from .SampleGeneratorImage import SampleGeneratorImage
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
from .SampleGeneratorFaceCelebAMaskHQ import SampleGeneratorFaceCelebAMaskHQ
from .PackedFaceset import PackedFaceset