diff --git a/facelib/FaceEnhancer.h5 b/facelib/FaceEnhancer.h5 new file mode 100644 index 0000000..201105b Binary files /dev/null and b/facelib/FaceEnhancer.h5 differ diff --git a/facelib/FaceEnhancer.py b/facelib/FaceEnhancer.py new file mode 100644 index 0000000..c3b2016 --- /dev/null +++ b/facelib/FaceEnhancer.py @@ -0,0 +1,154 @@ +import operator +from pathlib import Path + +import cv2 +import numpy as np + + + +class FaceEnhancer(object): + """ + x4 face enhancer + """ + def __init__(self): + from nnlib import nnlib + exec( nnlib.import_all(), locals(), globals() ) + + model_path = Path(__file__).parent / "FaceEnhancer.h5" + if not model_path.exists(): + return + + bgr_inp = Input ( (192,192,3) ) + t_param_inp = Input ( (1,) ) + t_param1_inp = Input ( (1,) ) + x = Conv2D (64, 3, strides=1, padding='same' )(bgr_inp) + + a = Dense (64, use_bias=False) ( t_param_inp ) + a = Reshape( (1,1,64) )(a) + b = Dense (64, use_bias=False ) ( t_param1_inp ) + b = Reshape( (1,1,64) )(b) + x = Add()([x,a,b]) + + x = LeakyReLU(0.1)(x) + + x = LeakyReLU(0.1)(Conv2D (64, 3, strides=1, padding='same' )(x)) + x = e0 = LeakyReLU(0.1)(Conv2D (64, 3, strides=1, padding='same')(x)) + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (112, 3, strides=1, padding='same')(x)) + x = e1 = LeakyReLU(0.1)(Conv2D (112, 3, strides=1, padding='same')(x)) + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (192, 3, strides=1, padding='same')(x)) + x = e2 = LeakyReLU(0.1)(Conv2D (192, 3, strides=1, padding='same')(x)) + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (336, 3, strides=1, padding='same')(x)) + x = e3 = LeakyReLU(0.1)(Conv2D (336, 3, strides=1, padding='same')(x)) + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + x = e4 = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + + x = Concatenate()([ BilinearInterpolation()(x), e4 ]) + + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + + x = Concatenate()([ BilinearInterpolation()(x), e3 ]) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same')(x)) + + x = Concatenate()([ BilinearInterpolation()(x), e2 ]) + x = LeakyReLU(0.1)(Conv2D (288, 3, strides=1, padding='same')(x)) + x = LeakyReLU(0.1)(Conv2D (288, 3, strides=1, padding='same')(x)) + + x = Concatenate()([ BilinearInterpolation()(x), e1 ]) + x = LeakyReLU(0.1)(Conv2D (160, 3, strides=1, padding='same')(x)) + x = LeakyReLU(0.1)(Conv2D (160, 3, strides=1, padding='same')(x)) + + x = Concatenate()([ BilinearInterpolation()(x), e0 ]) + x = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same')(x)) + x = d0 = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same')(x)) + + x = LeakyReLU(0.1)(Conv2D (48, 3, strides=1, padding='same')(x)) + + x = Conv2D (3, 3, strides=1, padding='same', activation='tanh')(x) + out1x = Add()([bgr_inp, x]) + + x = d0 + x = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same')(x)) + x = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same')(x)) + x = d2x = BilinearInterpolation()(x) + + x = LeakyReLU(0.1)(Conv2D (48, 3, strides=1, padding='same')(x)) + x = Conv2D (3, 3, strides=1, padding='same', activation='tanh')(x) + + out2x = Add()([BilinearInterpolation()(out1x), x]) + + x = d2x + x = LeakyReLU(0.1)(Conv2D (72, 3, strides=1, padding='same')(x)) + x = LeakyReLU(0.1)(Conv2D (72, 3, strides=1, padding='same')(x)) + x = d4x = BilinearInterpolation()(x) + + x = LeakyReLU(0.1)(Conv2D (36, 3, strides=1, padding='same')(x)) + x = Conv2D (3, 3, strides=1, padding='same', activation='tanh')(x) + out4x = Add()([BilinearInterpolation()(out2x), x ]) + + self.model = keras.models.Model ( [bgr_inp,t_param_inp,t_param1_inp], [out4x] ) + self.model.load_weights (str(model_path)) + + + def enhance (self, inp_img, is_tanh=False, preserve_size=True): + if not is_tanh: + inp_img = np.clip( inp_img * 2 -1, -1, 1 ) + + param = np.array([0.2]) + param1 = np.array([1.0]) + up_res = 4 + patch_size = 192 + patch_size_half = patch_size // 2 + + h,w,c = inp_img.shape + + i_max = w-patch_size+1 + j_max = h-patch_size+1 + + final_img = np.zeros ( (h*up_res,w*up_res,c), dtype=np.float32 ) + final_img_div = np.zeros ( (h*up_res,w*up_res,1), dtype=np.float32 ) + + x = np.concatenate ( [ np.linspace (0,1,patch_size_half*up_res), np.linspace (1,0,patch_size_half*up_res) ] ) + x,y = np.meshgrid(x,x) + patch_mask = (x*y)[...,None] + + j=0 + while j < j_max: + i = 0 + while i < i_max: + patch_img = inp_img[j:j+patch_size, i:i+patch_size,:] + x = self.model.predict( [ patch_img[None,...], param, param1 ] )[0] + final_img [j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += x*patch_mask + final_img_div[j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += patch_mask + if i == i_max-1: + break + i = min( i+patch_size_half, i_max-1) + if j == j_max-1: + break + j = min( j+patch_size_half, j_max-1) + + final_img_div[final_img_div==0] = 1.0 + final_img /= final_img_div + + if preserve_size: + final_img = cv2.resize (final_img, (w,h), cv2.INTER_LANCZOS4) + + if not is_tanh: + final_img = np.clip( final_img/2+0.5, 0, 1 ) + + return final_img diff --git a/facelib/__init__.py b/facelib/__init__.py index 8ca071b..cde2ab5 100644 --- a/facelib/__init__.py +++ b/facelib/__init__.py @@ -3,4 +3,5 @@ from .DLIBExtractor import DLIBExtractor from .MTCExtractor import MTCExtractor from .S3FDExtractor import S3FDExtractor from .FANExtractor import FANExtractor -from .PoseEstimator import PoseEstimator \ No newline at end of file +from .PoseEstimator import PoseEstimator +from .FaceEnhancer import FaceEnhancer \ No newline at end of file diff --git a/main.py b/main.py index 86cf1f5..2166817 100644 --- a/main.py +++ b/main.py @@ -286,6 +286,21 @@ if __name__ == "__main__": p.set_defaults(func=process_labelingtool_edit_mask) + facesettool_parser = subparsers.add_parser( "facesettool", help="Faceset tools.").add_subparsers() + + def process_faceset_enhancer(arguments): + os_utils.set_process_lowest_prio() + from mainscripts import FacesetEnhancer + FacesetEnhancer.process_folder ( Path(arguments.input_dir), multi_gpu=arguments.multi_gpu, cpu_only=arguments.cpu_only ) + + p = facesettool_parser.add_parser ("enhance", help="Enhance details in DFL faceset.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") + p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.") + p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Process on CPU.") + + p.set_defaults(func=process_faceset_enhancer) + + """ def process_relight_faceset(arguments): os_utils.set_process_lowest_prio() from mainscripts import FacesetRelighter @@ -295,9 +310,7 @@ if __name__ == "__main__": os_utils.set_process_lowest_prio() from mainscripts import FacesetRelighter FacesetRelighter.delete_relighted (arguments.input_dir) - - facesettool_parser = subparsers.add_parser( "facesettool", help="Faceset tools.").add_subparsers() - + p = facesettool_parser.add_parser ("relight", help="Synthesize new faces from existing ones by relighting them. With the relighted faces neural network will better reproduce face shadows.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") p.add_argument('--lighten', action="store_true", dest="lighten", default=None, help="Lighten the faces.") @@ -307,6 +320,7 @@ if __name__ == "__main__": p = facesettool_parser.add_parser ("delete_relighted", help="Delete relighted faces.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") p.set_defaults(func=process_delete_relighted) + """ def bad_args(arguments): parser.print_help() diff --git a/mainscripts/FacesetEnhancer.py b/mainscripts/FacesetEnhancer.py new file mode 100644 index 0000000..fb8558c --- /dev/null +++ b/mainscripts/FacesetEnhancer.py @@ -0,0 +1,163 @@ +import multiprocessing +import shutil + +from DFLIMG import * +from interact import interact as io +from joblib import Subprocessor +from nnlib import nnlib +from utils import Path_utils +from utils.cv2_utils import * + + +class FacesetEnhancerSubprocessor(Subprocessor): + + #override + def __init__(self, image_paths, output_dirpath, multi_gpu=False, cpu_only=False): + self.image_paths = image_paths + self.output_dirpath = output_dirpath + self.result = [] + self.devices = FacesetEnhancerSubprocessor.get_devices_for_config(multi_gpu, cpu_only) + + super().__init__('FacesetEnhancer', FacesetEnhancerSubprocessor.Cli, 600) + + #override + def on_clients_initialized(self): + io.progress_bar (None, len (self.image_paths)) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def process_info_generator(self): + base_dict = {'output_dirpath':self.output_dirpath} + + for (device_idx, device_type, device_name, device_total_vram_gb) in self.devices: + client_dict = base_dict.copy() + client_dict['device_idx'] = device_idx + client_dict['device_name'] = device_name + client_dict['device_type'] = device_type + yield client_dict['device_name'], {}, client_dict + + #override + def get_data(self, host_dict): + if len (self.image_paths) > 0: + return self.image_paths.pop(0) + + #override + def on_data_return (self, host_dict, data): + self.image_paths.insert(0, data) + + #override + def on_result (self, host_dict, data, result): + io.progress_bar_inc(1) + if result[0] == 1: + self.result +=[ (result[1], result[2]) ] + + #override + def get_result(self): + return self.result + + @staticmethod + def get_devices_for_config (multi_gpu, cpu_only): + backend = nnlib.device.backend + if 'cpu' in backend: + cpu_only = True + + if not cpu_only and backend == "plaidML": + cpu_only = True + + if not cpu_only: + devices = [] + if multi_gpu: + devices = nnlib.device.getValidDevicesWithAtLeastTotalMemoryGB(2) + + if len(devices) == 0: + idx = nnlib.device.getBestValidDeviceIdx() + if idx != -1: + devices = [idx] + + if len(devices) == 0: + cpu_only = True + + result = [] + for idx in devices: + dev_name = nnlib.device.getDeviceName(idx) + dev_vram = nnlib.device.getDeviceVRAMTotalGb(idx) + + result += [ (idx, 'GPU', dev_name, dev_vram) ] + + return result + + if cpu_only: + return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in range( min(8, multiprocessing.cpu_count() // 2) ) ] + + class Cli(Subprocessor.Cli): + + #override + def on_initialize(self, client_dict): + device_idx = client_dict['device_idx'] + cpu_only = client_dict['device_type'] == 'CPU' + self.output_dirpath = client_dict['output_dirpath'] + + device_config = nnlib.DeviceConfig ( cpu_only=cpu_only, force_gpu_idx=device_idx, allow_growth=True) + nnlib.import_all (device_config) + + device_vram = device_config.gpu_vram_gb[0] + + intro_str = 'Running on %s.' % (client_dict['device_name']) + if not cpu_only and device_vram <= 2: + intro_str += " Recommended to close all programs using this device." + + self.log_info (intro_str) + + from facelib import FaceEnhancer + self.fe = FaceEnhancer() + + #override + def process_data(self, filepath): + try: + dflimg = DFLIMG.load (filepath) + if dflimg is None: + self.log_err ("%s is not a dfl image file" % (filepath.name) ) + else: + img = cv2_imread(filepath).astype(np.float32) / 255.0 + + img = self.fe.enhance(img) + + img = np.clip (img*255, 0, 255).astype(np.uint8) + + output_filepath = self.output_dirpath / filepath.name + + cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + dflimg.embed_and_set ( str(output_filepath) ) + return (1, filepath, output_filepath) + except: + self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}") + + return (0, filepath, None) + +def process_folder ( dirpath, multi_gpu=False, cpu_only=False ): + output_dirpath = dirpath.parent / (dirpath.name + '_enhanced') + output_dirpath.mkdir (exist_ok=True, parents=True) + + dirpath_parts = '/'.join( dirpath.parts[-2:]) + output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] ) + io.log_info (f"Enhancing faceset in {dirpath_parts}.") + io.log_info ( f"Processing to {output_dirpath_parts}.") + + output_images_paths = Path_utils.get_image_paths(output_dirpath) + if len(output_images_paths) > 0: + for filename in output_images_paths: + Path(filename).unlink() + + image_paths = [Path(x) for x in Path_utils.get_image_paths( dirpath )] + result = FacesetEnhancerSubprocessor ( image_paths, output_dirpath, multi_gpu=multi_gpu, cpu_only=cpu_only).run() + + io.log_info (f"Copying processed files to {dirpath_parts}.") + + for (filepath, output_filepath) in result: + shutil.copy (output_filepath, filepath) + + io.log_info (f"Removing {output_dirpath_parts}.") + shutil.rmtree(output_dirpath) diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index c5af230..5952b70 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -28,7 +28,8 @@ class nnlib(object): tf = None tf_sess = None - + tf_sess_config = None + PML = None PMLK = None PMLTile= None @@ -105,6 +106,7 @@ PixelShuffler = nnlib.PixelShuffler SubpixelUpscaler = nnlib.SubpixelUpscaler SubpixelDownscaler = nnlib.SubpixelDownscaler Scale = nnlib.Scale +BilinearInterpolation = nnlib.BilinearInterpolation BlurPool = nnlib.BlurPool FUNITAdain = nnlib.FUNITAdain SelfAttention = nnlib.SelfAttention @@ -192,7 +194,8 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator config.gpu_options.force_gpu_compatible = True config.gpu_options.allow_growth = device_config.allow_growth - + nnlib.tf_sess_config = config + nnlib.tf_sess = tf.Session(config=config) if suppressor is not None: @@ -710,6 +713,141 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator base_config = super(Scale, self).get_config() return dict(list(base_config.items()) + list(config.items())) nnlib.Scale = Scale + + + """ + unable to work in plaidML, due to unimplemented ops + + class BilinearInterpolation(KL.Layer): + def __init__(self, size=(2,2), **kwargs): + self.size = size + super(BilinearInterpolation, self).__init__(**kwargs) + + def compute_output_shape(self, input_shape): + return (input_shape[0], input_shape[1]*self.size[1], input_shape[2]*self.size[0], input_shape[3]) + + + def call(self, X): + _,h,w,_ = K.int_shape(X) + + X = K.concatenate( [ X, X[:,:,-2:-1,:] ],axis=2 ) + X = K.concatenate( [ X, X[:,:,-2:-1,:] ],axis=2 ) + X = K.concatenate( [ X, X[:,-2:-1,:,:] ],axis=1 ) + X = K.concatenate( [ X, X[:,-2:-1,:,:] ],axis=1 ) + + X_sh = K.shape(X) + batch_size, height, width, num_channels = X_sh[0], X_sh[1], X_sh[2], X_sh[3] + + output_h, output_w = (h*self.size[1]+4, w*self.size[0]+4) + + x_linspace = np.linspace(-1. , 1. - 2/output_w, output_w)# + y_linspace = np.linspace(-1. , 1. - 2/output_h, output_h)# + + x_coordinates, y_coordinates = np.meshgrid(x_linspace, y_linspace) + x_coordinates = K.flatten(K.constant(x_coordinates, dtype=K.floatx() )) + y_coordinates = K.flatten(K.constant(y_coordinates, dtype=K.floatx() )) + + grid = K.concatenate([x_coordinates, y_coordinates, K.ones_like(x_coordinates)], 0) + grid = K.flatten(grid) + + + grids = K.tile(grid, ( batch_size, ) ) + grids = K.reshape(grids, (batch_size, 3, output_h * output_w )) + + + x = K.cast(K.flatten(grids[:, 0:1, :]), dtype='float32') + y = K.cast(K.flatten(grids[:, 1:2, :]), dtype='float32') + x = .5 * (x + 1.0) * K.cast(width, dtype='float32') + y = .5 * (y + 1.0) * K.cast(height, dtype='float32') + x0 = K.cast(x, 'int32') + x1 = x0 + 1 + y0 = K.cast(y, 'int32') + y1 = y0 + 1 + max_x = int(K.int_shape(X)[2] -1) + max_y = int(K.int_shape(X)[1] -1) + + x0 = K.clip(x0, 0, max_x) + x1 = K.clip(x1, 0, max_x) + y0 = K.clip(y0, 0, max_y) + y1 = K.clip(y1, 0, max_y) + + + pixels_batch = K.constant ( np.arange(0, batch_size) * (height * width), dtype=K.floatx() ) + + pixels_batch = K.expand_dims(pixels_batch, axis=-1) + + base = K.tile(pixels_batch, (1, output_h * output_w ) ) + base = K.flatten(base) + + # base_y0 = base + (y0 * width) + base_y0 = y0 * width + base_y0 = base + base_y0 + # base_y1 = base + (y1 * width) + base_y1 = y1 * width + base_y1 = base_y1 + base + + indices_a = base_y0 + x0 + indices_b = base_y1 + x0 + indices_c = base_y0 + x1 + indices_d = base_y1 + x1 + + flat_image = K.reshape(X, (-1, num_channels) ) + flat_image = K.cast(flat_image, dtype='float32') + pixel_values_a = K.gather(flat_image, indices_a) + pixel_values_b = K.gather(flat_image, indices_b) + pixel_values_c = K.gather(flat_image, indices_c) + pixel_values_d = K.gather(flat_image, indices_d) + + x0 = K.cast(x0, 'float32') + x1 = K.cast(x1, 'float32') + y0 = K.cast(y0, 'float32') + y1 = K.cast(y1, 'float32') + + area_a = K.expand_dims(((x1 - x) * (y1 - y)), 1) + area_b = K.expand_dims(((x1 - x) * (y - y0)), 1) + area_c = K.expand_dims(((x - x0) * (y1 - y)), 1) + area_d = K.expand_dims(((x - x0) * (y - y0)), 1) + + values_a = area_a * pixel_values_a + values_b = area_b * pixel_values_b + values_c = area_c * pixel_values_c + values_d = area_d * pixel_values_d + interpolated_image = values_a + values_b + values_c + values_d + + new_shape = (batch_size, output_h, output_w, num_channels) + interpolated_image = K.reshape(interpolated_image, new_shape) + + interpolated_image = interpolated_image[:,:-4,:-4,:] + return interpolated_image + + def get_config(self): + config = {"size": self.size} + base_config = super(BilinearInterpolation, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + """ + class BilinearInterpolation(KL.Layer): + def __init__(self, size=(2,2), **kwargs): + self.size = size + super(BilinearInterpolation, self).__init__(**kwargs) + + def compute_output_shape(self, input_shape): + return (input_shape[0], input_shape[1]*self.size[1], input_shape[2]*self.size[0], input_shape[3]) + + def call(self, X): + _,h,w,_ = K.int_shape(X) + + return K.cast( K.tf.image.resize_images(X, (h*self.size[1],w*self.size[0]) ), K.floatx() ) + + def get_config(self): + config = {"size": self.size} + base_config = super(BilinearInterpolation, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + nnlib.BilinearInterpolation = BilinearInterpolation + + + + class SelfAttention(KL.Layer): def __init__(self, nc, squeeze_factor=8, **kwargs):