mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
added XSeg model
This commit is contained in:
parent
a030ff6951
commit
b0b9072981
5 changed files with 513 additions and 0 deletions
180
models/Model_XSeg/Model.py
Normal file
180
models/Model_XSeg/Model.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
import multiprocessing
|
||||||
|
import operator
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from core import mathlib
|
||||||
|
from core.interact import interact as io
|
||||||
|
from core.leras import nn
|
||||||
|
from facelib import FaceType, TernausNet
|
||||||
|
from models import ModelBase
|
||||||
|
from samplelib import *
|
||||||
|
|
||||||
|
class XSegModel(ModelBase):
|
||||||
|
|
||||||
|
#override
|
||||||
|
def on_initialize_options(self):
|
||||||
|
device_config = nn.getCurrentDeviceConfig()
|
||||||
|
yn_str = {True:'y',False:'n'}
|
||||||
|
|
||||||
|
#default_resolution = 256
|
||||||
|
|
||||||
|
ask_override = self.ask_override()
|
||||||
|
if self.is_first_run() or ask_override:
|
||||||
|
self.ask_autobackup_hour()
|
||||||
|
self.ask_target_iter()
|
||||||
|
self.ask_batch_size(24)
|
||||||
|
|
||||||
|
#if self.is_first_run():
|
||||||
|
#resolution = io.input_int("Resolution", default_resolution, add_info="64-512")
|
||||||
|
#resolution = np.clip ( (resolution // 16) * 16, 64, 512)
|
||||||
|
#self.options['resolution'] = resolution
|
||||||
|
|
||||||
|
#override
|
||||||
|
def on_initialize(self):
|
||||||
|
device_config = nn.getCurrentDeviceConfig()
|
||||||
|
self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC"
|
||||||
|
nn.initialize(data_format=self.model_data_format)
|
||||||
|
tf = nn.tf
|
||||||
|
|
||||||
|
device_config = nn.getCurrentDeviceConfig()
|
||||||
|
devices = device_config.devices
|
||||||
|
|
||||||
|
self.resolution = resolution = 256#self.options['resolution']
|
||||||
|
|
||||||
|
place_model_on_cpu = True#len(devices) == 0
|
||||||
|
models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'
|
||||||
|
|
||||||
|
bgr_shape = nn.get4Dshape(resolution,resolution,3)
|
||||||
|
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||||
|
|
||||||
|
# Initializing model classes
|
||||||
|
self.model = TernausNet(f'{self.model_name}_SkinSeg',
|
||||||
|
resolution,
|
||||||
|
load_weights=not self.is_first_run(),
|
||||||
|
weights_file_root=self.get_model_root_path(),
|
||||||
|
training=True,
|
||||||
|
place_model_on_cpu=place_model_on_cpu,
|
||||||
|
data_format=nn.data_format)
|
||||||
|
|
||||||
|
if self.is_training:
|
||||||
|
# Adjust batch size for multiple GPU
|
||||||
|
gpu_count = max(1, len(devices) )
|
||||||
|
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
|
||||||
|
self.set_batch_size( gpu_count*bs_per_gpu)
|
||||||
|
|
||||||
|
|
||||||
|
# Compute losses per GPU
|
||||||
|
gpu_pred_list = []
|
||||||
|
|
||||||
|
gpu_losses = []
|
||||||
|
gpu_loss_gvs = []
|
||||||
|
|
||||||
|
for gpu_id in range(gpu_count):
|
||||||
|
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||||
|
|
||||||
|
with tf.device(f'/CPU:0'):
|
||||||
|
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
||||||
|
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
||||||
|
gpu_input_t = self.model.input_t [batch_slice,:,:,:]
|
||||||
|
gpu_target_t = self.model.target_t [batch_slice,:,:,:]
|
||||||
|
|
||||||
|
# process model tensors
|
||||||
|
gpu_pred_logits_t, gpu_pred_t = self.model.net([gpu_input_t])
|
||||||
|
gpu_pred_list.append(gpu_pred_t)
|
||||||
|
|
||||||
|
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
|
||||||
|
gpu_losses += [gpu_loss]
|
||||||
|
|
||||||
|
gpu_loss_gvs += [ nn.tf_gradients ( gpu_loss, self.model.net_weights ) ]
|
||||||
|
|
||||||
|
|
||||||
|
# Average losses and gradients, and create optimizer update ops
|
||||||
|
with tf.device (models_opt_device):
|
||||||
|
pred = nn.tf_concat(gpu_pred_list, 0)
|
||||||
|
loss = tf.reduce_mean(gpu_losses)
|
||||||
|
|
||||||
|
loss_gv_op = self.model.opt.get_update_op (nn.tf_average_gv_list (gpu_loss_gvs))
|
||||||
|
|
||||||
|
|
||||||
|
# Initializing training and view functions
|
||||||
|
def train(input_np, target_np):
|
||||||
|
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
|
||||||
|
return l
|
||||||
|
self.train = train
|
||||||
|
|
||||||
|
def view(input_np):
|
||||||
|
return nn.tf_sess.run ( [pred], feed_dict={self.model.input_t :input_np})
|
||||||
|
self.view = view
|
||||||
|
|
||||||
|
# initializing sample generators
|
||||||
|
cpu_count = min(multiprocessing.cpu_count(), 8)
|
||||||
|
src_generators_count = cpu_count // 2
|
||||||
|
dst_generators_count = cpu_count // 2
|
||||||
|
src_generators_count = int(src_generators_count * 1.5)
|
||||||
|
|
||||||
|
|
||||||
|
src_generator = SampleGeneratorFaceCelebAMaskHQ ( root_path=self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), resolution=256, generators_count=src_generators_count, data_format = nn.data_format)
|
||||||
|
|
||||||
|
dst_generator = SampleGeneratorImage(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||||
|
sample_process_options=SampleProcessor.Options(random_flip=True),
|
||||||
|
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'data_format':nn.data_format, 'resolution': resolution} ],
|
||||||
|
generators_count=dst_generators_count,
|
||||||
|
raise_on_no_data=False )
|
||||||
|
if not dst_generator.is_initialized():
|
||||||
|
io.log_info(f"\nTo view the model on unseen faces, place any image faces in {self.training_data_dst_path}.\n")
|
||||||
|
|
||||||
|
self.set_training_data_generators ([src_generator, dst_generator])
|
||||||
|
|
||||||
|
#override
|
||||||
|
def get_model_filename_list(self):
|
||||||
|
return self.model.model_filename_list
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onSave(self):
|
||||||
|
self.model.save_weights()
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onTrainOneIter(self):
|
||||||
|
image_np, mask_np = self.generate_next_samples()[0]
|
||||||
|
loss = self.train (image_np, mask_np)
|
||||||
|
|
||||||
|
return ( ('loss', loss ), )
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onGetPreview(self, samples):
|
||||||
|
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
|
||||||
|
|
||||||
|
src_samples, dst_samples = samples
|
||||||
|
image_np, mask_np = src_samples
|
||||||
|
|
||||||
|
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
|
||||||
|
M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ]
|
||||||
|
|
||||||
|
green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) )
|
||||||
|
|
||||||
|
result = []
|
||||||
|
st = []
|
||||||
|
for i in range(n_samples):
|
||||||
|
ar = I[i]*M[i]+ green_bg*(1-M[i]), IM[i], I[i]*IM[i] + green_bg*(1-IM[i])
|
||||||
|
st.append ( np.concatenate ( ar, axis=1) )
|
||||||
|
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
|
||||||
|
|
||||||
|
if len(dst_samples) != 0:
|
||||||
|
dst_np, = dst_samples
|
||||||
|
|
||||||
|
|
||||||
|
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ]
|
||||||
|
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
|
||||||
|
|
||||||
|
st = []
|
||||||
|
for i in range(n_samples):
|
||||||
|
ar = D[i], DM[i], D[i]*DM[i]+ green_bg*(1-DM[i])
|
||||||
|
st.append ( np.concatenate ( ar, axis=1) )
|
||||||
|
|
||||||
|
result += [ ('XSeg unseen faces', np.concatenate (st, axis=0 )), ]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
Model = XSegModel
|
1
models/Model_XSeg/__init__.py
Normal file
1
models/Model_XSeg/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .Model import Model
|
264
samplelib/SampleGeneratorFaceCelebAMaskHQ.py
Normal file
264
samplelib/SampleGeneratorFaceCelebAMaskHQ.py
Normal file
|
@ -0,0 +1,264 @@
|
||||||
|
import multiprocessing
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from core import imagelib, mplib, pathex
|
||||||
|
from core.cv2ex import *
|
||||||
|
from core.interact import interact as io
|
||||||
|
from core.joblib import SubprocessGenerator, ThisThreadGenerator
|
||||||
|
from facelib import LandmarksProcessor
|
||||||
|
from samplelib import SampleGeneratorBase
|
||||||
|
|
||||||
|
|
||||||
|
class MaskType(IntEnum):
|
||||||
|
none = 0,
|
||||||
|
cloth = 1,
|
||||||
|
ear_r = 2,
|
||||||
|
eye_g = 3,
|
||||||
|
hair = 4,
|
||||||
|
hat = 5,
|
||||||
|
l_brow = 6,
|
||||||
|
l_ear = 7,
|
||||||
|
l_eye = 8,
|
||||||
|
l_lip = 9,
|
||||||
|
mouth = 10,
|
||||||
|
neck = 11,
|
||||||
|
neck_l = 12,
|
||||||
|
nose = 13,
|
||||||
|
r_brow = 14,
|
||||||
|
r_ear = 15,
|
||||||
|
r_eye = 16,
|
||||||
|
skin = 17,
|
||||||
|
u_lip = 18
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
MaskType_to_name = {
|
||||||
|
int(MaskType.none ) : 'none',
|
||||||
|
int(MaskType.cloth ) : 'cloth',
|
||||||
|
int(MaskType.ear_r ) : 'ear_r',
|
||||||
|
int(MaskType.eye_g ) : 'eye_g',
|
||||||
|
int(MaskType.hair ) : 'hair',
|
||||||
|
int(MaskType.hat ) : 'hat',
|
||||||
|
int(MaskType.l_brow) : 'l_brow',
|
||||||
|
int(MaskType.l_ear ) : 'l_ear',
|
||||||
|
int(MaskType.l_eye ) : 'l_eye',
|
||||||
|
int(MaskType.l_lip ) : 'l_lip',
|
||||||
|
int(MaskType.mouth ) : 'mouth',
|
||||||
|
int(MaskType.neck ) : 'neck',
|
||||||
|
int(MaskType.neck_l) : 'neck_l',
|
||||||
|
int(MaskType.nose ) : 'nose',
|
||||||
|
int(MaskType.r_brow) : 'r_brow',
|
||||||
|
int(MaskType.r_ear ) : 'r_ear',
|
||||||
|
int(MaskType.r_eye ) : 'r_eye',
|
||||||
|
int(MaskType.skin ) : 'skin',
|
||||||
|
int(MaskType.u_lip ) : 'u_lip',
|
||||||
|
}
|
||||||
|
|
||||||
|
MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() }
|
||||||
|
|
||||||
|
class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
|
||||||
|
def __init__ (self, root_path, debug=False, batch_size=1, resolution=256,
|
||||||
|
generators_count=4, data_format="NHWC",
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
super().__init__(debug, batch_size)
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
dataset_path = root_path / 'CelebAMask-HQ'
|
||||||
|
if not dataset_path.exists():
|
||||||
|
raise ValueError(f'Unable to find {dataset_path}')
|
||||||
|
|
||||||
|
images_path = dataset_path /'CelebA-HQ-img'
|
||||||
|
if not images_path.exists():
|
||||||
|
raise ValueError(f'Unable to find {images_path}')
|
||||||
|
|
||||||
|
masks_path = dataset_path / 'CelebAMask-HQ-mask-anno'
|
||||||
|
if not masks_path.exists():
|
||||||
|
raise ValueError(f'Unable to find {masks_path}')
|
||||||
|
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
self.generators_count = 1
|
||||||
|
else:
|
||||||
|
self.generators_count = max(1, generators_count)
|
||||||
|
|
||||||
|
source_images_paths = pathex.get_image_paths(images_path, return_Path_class=True)
|
||||||
|
source_images_paths_len = len(source_images_paths)
|
||||||
|
mask_images_paths = pathex.get_image_paths(masks_path, subdirs=True, return_Path_class=True)
|
||||||
|
|
||||||
|
if source_images_paths_len == 0 or len(mask_images_paths) == 0:
|
||||||
|
raise ValueError('No training data provided.')
|
||||||
|
|
||||||
|
mask_file_id_hash = {}
|
||||||
|
|
||||||
|
for filepath in io.progress_bar_generator(mask_images_paths, "Loading"):
|
||||||
|
stem = filepath.stem
|
||||||
|
|
||||||
|
file_id, mask_type = stem.split('_', 1)
|
||||||
|
file_id = int(file_id)
|
||||||
|
|
||||||
|
if file_id not in mask_file_id_hash:
|
||||||
|
mask_file_id_hash[file_id] = {}
|
||||||
|
|
||||||
|
mask_file_id_hash[file_id][ MaskType_from_name[mask_type] ] = str(filepath.relative_to(masks_path))
|
||||||
|
|
||||||
|
source_file_id_set = set()
|
||||||
|
|
||||||
|
for filepath in source_images_paths:
|
||||||
|
stem = filepath.stem
|
||||||
|
|
||||||
|
file_id = int(stem)
|
||||||
|
source_file_id_set.update ( {file_id} )
|
||||||
|
|
||||||
|
for k in mask_file_id_hash.keys():
|
||||||
|
if k not in source_file_id_set:
|
||||||
|
io.log_err (f"Corrupted dataset: {k} not in {images_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
self.generators = [ThisThreadGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format) )]
|
||||||
|
else:
|
||||||
|
self.generators = [SubprocessGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format), start_now=False ) \
|
||||||
|
for i in range(self.generators_count) ]
|
||||||
|
|
||||||
|
SubprocessGenerator.start_in_parallel( self.generators )
|
||||||
|
|
||||||
|
self.generator_counter = -1
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
#overridable
|
||||||
|
def is_initialized(self):
|
||||||
|
return self.initialized
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
self.generator_counter += 1
|
||||||
|
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||||
|
return next(generator)
|
||||||
|
|
||||||
|
def batch_func(self, param ):
|
||||||
|
images_path, masks_path, mask_file_id_hash, data_format = param
|
||||||
|
|
||||||
|
file_ids = list(mask_file_id_hash.keys())
|
||||||
|
|
||||||
|
shuffle_file_ids = []
|
||||||
|
shuffle_file_ids_random_ct = []
|
||||||
|
|
||||||
|
resolution = 256
|
||||||
|
random_flip = True
|
||||||
|
rotation_range=[-15,15]
|
||||||
|
scale_range=[-0.25, 0.75]
|
||||||
|
tx_range=[-0.3, 0.3]
|
||||||
|
ty_range=[-0.3, 0.3]
|
||||||
|
|
||||||
|
motion_blur = (25, 5)
|
||||||
|
gaussian_blur = (25, 5)
|
||||||
|
|
||||||
|
bs = self.batch_size
|
||||||
|
while True:
|
||||||
|
batches = None
|
||||||
|
|
||||||
|
n_batch = 0
|
||||||
|
while n_batch < bs:
|
||||||
|
try:
|
||||||
|
if len(shuffle_file_ids) == 0:
|
||||||
|
shuffle_file_ids = file_ids.copy()
|
||||||
|
np.random.shuffle(shuffle_file_ids)
|
||||||
|
if len(shuffle_file_ids_random_ct) == 0:
|
||||||
|
shuffle_file_ids_random_ct = file_ids.copy()
|
||||||
|
np.random.shuffle(shuffle_file_ids_random_ct)
|
||||||
|
|
||||||
|
file_id = shuffle_file_ids.pop()
|
||||||
|
#file_id_random_ct = shuffle_file_ids_random_ct.pop()
|
||||||
|
|
||||||
|
masks = mask_file_id_hash[file_id]
|
||||||
|
|
||||||
|
image_path = images_path / f'{file_id}.jpg'
|
||||||
|
#image_random_ct_path = images_path / f'{file_id_random_ct}.jpg'
|
||||||
|
|
||||||
|
skin_path = masks.get(MaskType.skin, None)
|
||||||
|
hair_path = masks.get(MaskType.hair, None)
|
||||||
|
hat_path = masks.get(MaskType.hat, None)
|
||||||
|
neck_path = masks.get(MaskType.neck, None)
|
||||||
|
|
||||||
|
img = cv2_imread(image_path).astype(np.float32) / 255.0
|
||||||
|
mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
if hair_path is not None:
|
||||||
|
hair_path = masks_path / hair_path
|
||||||
|
if hair_path.exists():
|
||||||
|
hair = cv2_imread(hair_path)[...,0:1].astype(np.float32) / 255.0
|
||||||
|
mask *= (1-hair)
|
||||||
|
|
||||||
|
if hat_path is not None:
|
||||||
|
hat_path = masks_path / hat_path
|
||||||
|
if hat_path.exists():
|
||||||
|
hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0
|
||||||
|
mask *= (1-hat)
|
||||||
|
|
||||||
|
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
|
||||||
|
|
||||||
|
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 )
|
||||||
|
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
||||||
|
h = ( h + np.random.randint(360) ) % 360
|
||||||
|
s = np.clip ( s + np.random.random()-0.5, 0, 1 )
|
||||||
|
v = np.clip ( v + np.random.random()/2-0.25, 0, 1 )
|
||||||
|
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
|
||||||
|
|
||||||
|
#img_random_ct = cv2.resize( cv2_imread(image_random_ct_path).astype(np.float32) / 255.0, (resolution,resolution), cv2.INTER_LANCZOS4 )
|
||||||
|
#img = imagelib.color_transfer ('idt', img, img_random_ct )
|
||||||
|
|
||||||
|
if motion_blur is not None:
|
||||||
|
chance, mb_max_size = motion_blur
|
||||||
|
chance = np.clip(chance, 0, 100)
|
||||||
|
|
||||||
|
mblur_rnd_chance = np.random.randint(100)
|
||||||
|
mblur_rnd_kernel = np.random.randint(mb_max_size)+1
|
||||||
|
mblur_rnd_deg = np.random.randint(360)
|
||||||
|
|
||||||
|
if mblur_rnd_chance < chance:
|
||||||
|
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )
|
||||||
|
|
||||||
|
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
|
||||||
|
|
||||||
|
if gaussian_blur is not None:
|
||||||
|
chance, kernel_max_size = gaussian_blur
|
||||||
|
chance = np.clip(chance, 0, 100)
|
||||||
|
|
||||||
|
gblur_rnd_chance = np.random.randint(100)
|
||||||
|
gblur_rnd_kernel = np.random.randint(kernel_max_size)*2+1
|
||||||
|
|
||||||
|
if gblur_rnd_chance < chance:
|
||||||
|
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
|
||||||
|
|
||||||
|
mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None]
|
||||||
|
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
|
||||||
|
mask[mask < 0.5] = 0.0
|
||||||
|
mask[mask >= 0.5] = 1.0
|
||||||
|
mask = np.clip(mask, 0, 1)
|
||||||
|
|
||||||
|
if data_format == "NCHW":
|
||||||
|
img = np.transpose(img, (2,0,1) )
|
||||||
|
mask = np.transpose(mask, (2,0,1) )
|
||||||
|
|
||||||
|
if batches is None:
|
||||||
|
batches = [ [], [] ]
|
||||||
|
|
||||||
|
batches[0].append ( img )
|
||||||
|
batches[1].append ( mask )
|
||||||
|
|
||||||
|
n_batch += 1
|
||||||
|
except:
|
||||||
|
io.log_err ( traceback.format_exc() )
|
||||||
|
|
||||||
|
yield [ np.array(batch) for batch in batches]
|
66
samplelib/SampleGeneratorImage.py
Normal file
66
samplelib/SampleGeneratorImage.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from core.joblib import SubprocessGenerator, ThisThreadGenerator
|
||||||
|
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||||
|
SampleType)
|
||||||
|
|
||||||
|
|
||||||
|
class SampleGeneratorImage(SampleGeneratorBase):
|
||||||
|
def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs):
|
||||||
|
super().__init__(debug, batch_size)
|
||||||
|
self.initialized = False
|
||||||
|
self.sample_process_options = sample_process_options
|
||||||
|
self.output_sample_types = output_sample_types
|
||||||
|
|
||||||
|
samples = SampleLoader.load (SampleType.IMAGE, samples_path)
|
||||||
|
|
||||||
|
if len(samples) == 0:
|
||||||
|
if raise_on_no_data:
|
||||||
|
raise ValueError('No training data provided.')
|
||||||
|
return
|
||||||
|
|
||||||
|
self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \
|
||||||
|
[SubprocessGenerator ( self.batch_func, samples )]
|
||||||
|
|
||||||
|
self.generator_counter = -1
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
self.generator_counter += 1
|
||||||
|
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||||
|
return next(generator)
|
||||||
|
|
||||||
|
def batch_func(self, samples):
|
||||||
|
samples_len = len(samples)
|
||||||
|
|
||||||
|
|
||||||
|
idxs = [ *range(samples_len) ]
|
||||||
|
shuffle_idxs = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
batches = None
|
||||||
|
for n_batch in range(self.batch_size):
|
||||||
|
|
||||||
|
if len(shuffle_idxs) == 0:
|
||||||
|
shuffle_idxs = idxs.copy()
|
||||||
|
np.random.shuffle (shuffle_idxs)
|
||||||
|
|
||||||
|
idx = shuffle_idxs.pop()
|
||||||
|
sample = samples[idx]
|
||||||
|
|
||||||
|
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
|
||||||
|
|
||||||
|
if batches is None:
|
||||||
|
batches = [ [] for _ in range(len(x)) ]
|
||||||
|
|
||||||
|
for i in range(len(x)):
|
||||||
|
batches[i].append ( x[i] )
|
||||||
|
|
||||||
|
yield [ np.array(batch) for batch in batches]
|
|
@ -6,5 +6,7 @@ from .SampleGeneratorBase import SampleGeneratorBase
|
||||||
from .SampleGeneratorFace import SampleGeneratorFace
|
from .SampleGeneratorFace import SampleGeneratorFace
|
||||||
from .SampleGeneratorFacePerson import SampleGeneratorFacePerson
|
from .SampleGeneratorFacePerson import SampleGeneratorFacePerson
|
||||||
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
|
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
|
||||||
|
from .SampleGeneratorImage import SampleGeneratorImage
|
||||||
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
|
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
|
||||||
|
from .SampleGeneratorFaceCelebAMaskHQ import SampleGeneratorFaceCelebAMaskHQ
|
||||||
from .PackedFaceset import PackedFaceset
|
from .PackedFaceset import PackedFaceset
|
Loading…
Add table
Add a link
Reference in a new issue