added XSeg model

This commit is contained in:
Colombo 2020-03-09 13:09:46 +04:00
parent a030ff6951
commit b0b9072981
5 changed files with 513 additions and 0 deletions

180
models/Model_XSeg/Model.py Normal file
View file

@ -0,0 +1,180 @@
import multiprocessing
import operator
from functools import partial
import numpy as np
from core import mathlib
from core.interact import interact as io
from core.leras import nn
from facelib import FaceType, TernausNet
from models import ModelBase
from samplelib import *
class XSegModel(ModelBase):
#override
def on_initialize_options(self):
device_config = nn.getCurrentDeviceConfig()
yn_str = {True:'y',False:'n'}
#default_resolution = 256
ask_override = self.ask_override()
if self.is_first_run() or ask_override:
self.ask_autobackup_hour()
self.ask_target_iter()
self.ask_batch_size(24)
#if self.is_first_run():
#resolution = io.input_int("Resolution", default_resolution, add_info="64-512")
#resolution = np.clip ( (resolution // 16) * 16, 64, 512)
#self.options['resolution'] = resolution
#override
def on_initialize(self):
device_config = nn.getCurrentDeviceConfig()
self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC"
nn.initialize(data_format=self.model_data_format)
tf = nn.tf
device_config = nn.getCurrentDeviceConfig()
devices = device_config.devices
self.resolution = resolution = 256#self.options['resolution']
place_model_on_cpu = True#len(devices) == 0
models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'
bgr_shape = nn.get4Dshape(resolution,resolution,3)
mask_shape = nn.get4Dshape(resolution,resolution,1)
# Initializing model classes
self.model = TernausNet(f'{self.model_name}_SkinSeg',
resolution,
load_weights=not self.is_first_run(),
weights_file_root=self.get_model_root_path(),
training=True,
place_model_on_cpu=place_model_on_cpu,
data_format=nn.data_format)
if self.is_training:
# Adjust batch size for multiple GPU
gpu_count = max(1, len(devices) )
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu)
# Compute losses per GPU
gpu_pred_list = []
gpu_losses = []
gpu_loss_gvs = []
for gpu_id in range(gpu_count):
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
gpu_input_t = self.model.input_t [batch_slice,:,:,:]
gpu_target_t = self.model.target_t [batch_slice,:,:,:]
# process model tensors
gpu_pred_logits_t, gpu_pred_t = self.model.net([gpu_input_t])
gpu_pred_list.append(gpu_pred_t)
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
gpu_losses += [gpu_loss]
gpu_loss_gvs += [ nn.tf_gradients ( gpu_loss, self.model.net_weights ) ]
# Average losses and gradients, and create optimizer update ops
with tf.device (models_opt_device):
pred = nn.tf_concat(gpu_pred_list, 0)
loss = tf.reduce_mean(gpu_losses)
loss_gv_op = self.model.opt.get_update_op (nn.tf_average_gv_list (gpu_loss_gvs))
# Initializing training and view functions
def train(input_np, target_np):
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
return l
self.train = train
def view(input_np):
return nn.tf_sess.run ( [pred], feed_dict={self.model.input_t :input_np})
self.view = view
# initializing sample generators
cpu_count = min(multiprocessing.cpu_count(), 8)
src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2
src_generators_count = int(src_generators_count * 1.5)
src_generator = SampleGeneratorFaceCelebAMaskHQ ( root_path=self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), resolution=256, generators_count=src_generators_count, data_format = nn.data_format)
dst_generator = SampleGeneratorImage(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'data_format':nn.data_format, 'resolution': resolution} ],
generators_count=dst_generators_count,
raise_on_no_data=False )
if not dst_generator.is_initialized():
io.log_info(f"\nTo view the model on unseen faces, place any image faces in {self.training_data_dst_path}.\n")
self.set_training_data_generators ([src_generator, dst_generator])
#override
def get_model_filename_list(self):
return self.model.model_filename_list
#override
def onSave(self):
self.model.save_weights()
#override
def onTrainOneIter(self):
image_np, mask_np = self.generate_next_samples()[0]
loss = self.train (image_np, mask_np)
return ( ('loss', loss ), )
#override
def onGetPreview(self, samples):
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
src_samples, dst_samples = samples
image_np, mask_np = src_samples
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ]
green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) )
result = []
st = []
for i in range(n_samples):
ar = I[i]*M[i]+ green_bg*(1-M[i]), IM[i], I[i]*IM[i] + green_bg*(1-IM[i])
st.append ( np.concatenate ( ar, axis=1) )
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
if len(dst_samples) != 0:
dst_np, = dst_samples
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ]
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
st = []
for i in range(n_samples):
ar = D[i], DM[i], D[i]*DM[i]+ green_bg*(1-DM[i])
st.append ( np.concatenate ( ar, axis=1) )
result += [ ('XSeg unseen faces', np.concatenate (st, axis=0 )), ]
return result
Model = XSegModel

View file

@ -0,0 +1 @@
from .Model import Model

View file

@ -0,0 +1,264 @@
import multiprocessing
import pickle
import time
import traceback
from enum import IntEnum
import cv2
import numpy as np
from core import imagelib, mplib, pathex
from core.cv2ex import *
from core.interact import interact as io
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from facelib import LandmarksProcessor
from samplelib import SampleGeneratorBase
class MaskType(IntEnum):
none = 0,
cloth = 1,
ear_r = 2,
eye_g = 3,
hair = 4,
hat = 5,
l_brow = 6,
l_ear = 7,
l_eye = 8,
l_lip = 9,
mouth = 10,
neck = 11,
neck_l = 12,
nose = 13,
r_brow = 14,
r_ear = 15,
r_eye = 16,
skin = 17,
u_lip = 18
MaskType_to_name = {
int(MaskType.none ) : 'none',
int(MaskType.cloth ) : 'cloth',
int(MaskType.ear_r ) : 'ear_r',
int(MaskType.eye_g ) : 'eye_g',
int(MaskType.hair ) : 'hair',
int(MaskType.hat ) : 'hat',
int(MaskType.l_brow) : 'l_brow',
int(MaskType.l_ear ) : 'l_ear',
int(MaskType.l_eye ) : 'l_eye',
int(MaskType.l_lip ) : 'l_lip',
int(MaskType.mouth ) : 'mouth',
int(MaskType.neck ) : 'neck',
int(MaskType.neck_l) : 'neck_l',
int(MaskType.nose ) : 'nose',
int(MaskType.r_brow) : 'r_brow',
int(MaskType.r_ear ) : 'r_ear',
int(MaskType.r_eye ) : 'r_eye',
int(MaskType.skin ) : 'skin',
int(MaskType.u_lip ) : 'u_lip',
}
MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() }
class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
def __init__ (self, root_path, debug=False, batch_size=1, resolution=256,
generators_count=4, data_format="NHWC",
**kwargs):
super().__init__(debug, batch_size)
self.initialized = False
dataset_path = root_path / 'CelebAMask-HQ'
if not dataset_path.exists():
raise ValueError(f'Unable to find {dataset_path}')
images_path = dataset_path /'CelebA-HQ-img'
if not images_path.exists():
raise ValueError(f'Unable to find {images_path}')
masks_path = dataset_path / 'CelebAMask-HQ-mask-anno'
if not masks_path.exists():
raise ValueError(f'Unable to find {masks_path}')
if self.debug:
self.generators_count = 1
else:
self.generators_count = max(1, generators_count)
source_images_paths = pathex.get_image_paths(images_path, return_Path_class=True)
source_images_paths_len = len(source_images_paths)
mask_images_paths = pathex.get_image_paths(masks_path, subdirs=True, return_Path_class=True)
if source_images_paths_len == 0 or len(mask_images_paths) == 0:
raise ValueError('No training data provided.')
mask_file_id_hash = {}
for filepath in io.progress_bar_generator(mask_images_paths, "Loading"):
stem = filepath.stem
file_id, mask_type = stem.split('_', 1)
file_id = int(file_id)
if file_id not in mask_file_id_hash:
mask_file_id_hash[file_id] = {}
mask_file_id_hash[file_id][ MaskType_from_name[mask_type] ] = str(filepath.relative_to(masks_path))
source_file_id_set = set()
for filepath in source_images_paths:
stem = filepath.stem
file_id = int(stem)
source_file_id_set.update ( {file_id} )
for k in mask_file_id_hash.keys():
if k not in source_file_id_set:
io.log_err (f"Corrupted dataset: {k} not in {images_path}")
if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format) )]
else:
self.generators = [SubprocessGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format), start_now=False ) \
for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators )
self.generator_counter = -1
self.initialized = True
#overridable
def is_initialized(self):
return self.initialized
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, param ):
images_path, masks_path, mask_file_id_hash, data_format = param
file_ids = list(mask_file_id_hash.keys())
shuffle_file_ids = []
shuffle_file_ids_random_ct = []
resolution = 256
random_flip = True
rotation_range=[-15,15]
scale_range=[-0.25, 0.75]
tx_range=[-0.3, 0.3]
ty_range=[-0.3, 0.3]
motion_blur = (25, 5)
gaussian_blur = (25, 5)
bs = self.batch_size
while True:
batches = None
n_batch = 0
while n_batch < bs:
try:
if len(shuffle_file_ids) == 0:
shuffle_file_ids = file_ids.copy()
np.random.shuffle(shuffle_file_ids)
if len(shuffle_file_ids_random_ct) == 0:
shuffle_file_ids_random_ct = file_ids.copy()
np.random.shuffle(shuffle_file_ids_random_ct)
file_id = shuffle_file_ids.pop()
#file_id_random_ct = shuffle_file_ids_random_ct.pop()
masks = mask_file_id_hash[file_id]
image_path = images_path / f'{file_id}.jpg'
#image_random_ct_path = images_path / f'{file_id_random_ct}.jpg'
skin_path = masks.get(MaskType.skin, None)
hair_path = masks.get(MaskType.hair, None)
hat_path = masks.get(MaskType.hat, None)
neck_path = masks.get(MaskType.neck, None)
img = cv2_imread(image_path).astype(np.float32) / 255.0
mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0
if hair_path is not None:
hair_path = masks_path / hair_path
if hair_path.exists():
hair = cv2_imread(hair_path)[...,0:1].astype(np.float32) / 255.0
mask *= (1-hair)
if hat_path is not None:
hat_path = masks_path / hat_path
if hat_path.exists():
hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0
mask *= (1-hat)
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 )
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
h = ( h + np.random.randint(360) ) % 360
s = np.clip ( s + np.random.random()-0.5, 0, 1 )
v = np.clip ( v + np.random.random()/2-0.25, 0, 1 )
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
#img_random_ct = cv2.resize( cv2_imread(image_random_ct_path).astype(np.float32) / 255.0, (resolution,resolution), cv2.INTER_LANCZOS4 )
#img = imagelib.color_transfer ('idt', img, img_random_ct )
if motion_blur is not None:
chance, mb_max_size = motion_blur
chance = np.clip(chance, 0, 100)
mblur_rnd_chance = np.random.randint(100)
mblur_rnd_kernel = np.random.randint(mb_max_size)+1
mblur_rnd_deg = np.random.randint(360)
if mblur_rnd_chance < chance:
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
if gaussian_blur is not None:
chance, kernel_max_size = gaussian_blur
chance = np.clip(chance, 0, 100)
gblur_rnd_chance = np.random.randint(100)
gblur_rnd_kernel = np.random.randint(kernel_max_size)*2+1
if gblur_rnd_chance < chance:
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None]
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
mask = np.clip(mask, 0, 1)
if data_format == "NCHW":
img = np.transpose(img, (2,0,1) )
mask = np.transpose(mask, (2,0,1) )
if batches is None:
batches = [ [], [] ]
batches[0].append ( img )
batches[1].append ( mask )
n_batch += 1
except:
io.log_err ( traceback.format_exc() )
yield [ np.array(batch) for batch in batches]

View file

@ -0,0 +1,66 @@
import traceback
import cv2
import numpy as np
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType)
class SampleGeneratorImage(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs):
super().__init__(debug, batch_size)
self.initialized = False
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
samples = SampleLoader.load (SampleType.IMAGE, samples_path)
if len(samples) == 0:
if raise_on_no_data:
raise ValueError('No training data provided.')
return
self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \
[SubprocessGenerator ( self.batch_func, samples )]
self.generator_counter = -1
self.initialized = True
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, samples):
samples_len = len(samples)
idxs = [ *range(samples_len) ]
shuffle_idxs = []
while True:
batches = None
for n_batch in range(self.batch_size):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle (shuffle_idxs)
idx = shuffle_idxs.pop()
sample = samples[idx]
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
if batches is None:
batches = [ [] for _ in range(len(x)) ]
for i in range(len(x)):
batches[i].append ( x[i] )
yield [ np.array(batch) for batch in batches]

View file

@ -6,5 +6,7 @@ from .SampleGeneratorBase import SampleGeneratorBase
from .SampleGeneratorFace import SampleGeneratorFace from .SampleGeneratorFace import SampleGeneratorFace
from .SampleGeneratorFacePerson import SampleGeneratorFacePerson from .SampleGeneratorFacePerson import SampleGeneratorFacePerson
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
from .SampleGeneratorImage import SampleGeneratorImage
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
from .SampleGeneratorFaceCelebAMaskHQ import SampleGeneratorFaceCelebAMaskHQ
from .PackedFaceset import PackedFaceset from .PackedFaceset import PackedFaceset