diff --git a/core/leras/models.py b/core/leras/models.py index 8dedab9..c6c90ba 100644 --- a/core/leras/models.py +++ b/core/leras/models.py @@ -9,6 +9,7 @@ def initialize_models(nn): def __init__(self, *args, name=None, **kwargs): super().__init__(name=name) self.layers = [] + self.layers_by_name = {} self.built = False self.args = args self.kwargs = kwargs @@ -31,7 +32,8 @@ def initialize_models(nn): layer.build() self.layers.append (layer) - + self.layers_by_name[layer.name] = layer + def xor_list(self, lst1, lst2): return [value for value in lst1+lst2 if (value not in lst1) or (value not in lst2) ] @@ -76,6 +78,9 @@ def initialize_models(nn): weights += layer.get_weights() return weights + def get_layer_by_name(self, name): + return self.layers_by_name.get(name, None) + def get_layers(self): if not self.built: self.build() diff --git a/facelib/TernausNet.py b/facelib/TernausNet.py index 94bde01..00a807e 100644 --- a/facelib/TernausNet.py +++ b/facelib/TernausNet.py @@ -24,44 +24,44 @@ class TernausNet(object): tf = nn.tf class Ternaus(nn.ModelBase): - def on_build(self, in_ch, ch): + def on_build(self, in_ch, base_ch): - self.features_0 = nn.Conv2D (in_ch, ch, kernel_size=3, padding='SAME') + self.features_0 = nn.Conv2D (in_ch, base_ch, kernel_size=3, padding='SAME') self.blurpool_0 = nn.BlurPool (filt_size=3) - self.features_3 = nn.Conv2D (ch, ch*2, kernel_size=3, padding='SAME') + self.features_3 = nn.Conv2D (base_ch, base_ch*2, kernel_size=3, padding='SAME') self.blurpool_3 = nn.BlurPool (filt_size=3) - self.features_6 = nn.Conv2D (ch*2, ch*4, kernel_size=3, padding='SAME') - self.features_8 = nn.Conv2D (ch*4, ch*4, kernel_size=3, padding='SAME') + self.features_6 = nn.Conv2D (base_ch*2, base_ch*4, kernel_size=3, padding='SAME') + self.features_8 = nn.Conv2D (base_ch*4, base_ch*4, kernel_size=3, padding='SAME') self.blurpool_8 = nn.BlurPool (filt_size=3) - self.features_11 = nn.Conv2D (ch*4, ch*8, kernel_size=3, padding='SAME') - self.features_13 = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME') + self.features_11 = nn.Conv2D (base_ch*4, base_ch*8, kernel_size=3, padding='SAME') + self.features_13 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') self.blurpool_13 = nn.BlurPool (filt_size=3) - self.features_16 = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME') - self.features_18 = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME') + self.features_16 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') + self.features_18 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') self.blurpool_18 = nn.BlurPool (filt_size=3) - self.conv_center = nn.Conv2D (ch*8, ch*8, kernel_size=3, padding='SAME') + self.conv_center = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') - self.conv1_up = nn.Conv2DTranspose (ch*8, ch*4, kernel_size=3, padding='SAME') - self.conv1 = nn.Conv2D (ch*12, ch*8, kernel_size=3, padding='SAME') + self.conv1_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME') + self.conv1 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME') - self.conv2_up = nn.Conv2DTranspose (ch*8, ch*4, kernel_size=3, padding='SAME') - self.conv2 = nn.Conv2D (ch*12, ch*8, kernel_size=3, padding='SAME') + self.conv2_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME') + self.conv2 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME') - self.conv3_up = nn.Conv2DTranspose (ch*8, ch*2, kernel_size=3, padding='SAME') - self.conv3 = nn.Conv2D (ch*6, ch*4, kernel_size=3, padding='SAME') + self.conv3_up = nn.Conv2DTranspose (base_ch*8, base_ch*2, kernel_size=3, padding='SAME') + self.conv3 = nn.Conv2D (base_ch*6, base_ch*4, kernel_size=3, padding='SAME') - self.conv4_up = nn.Conv2DTranspose (ch*4, ch, kernel_size=3, padding='SAME') - self.conv4 = nn.Conv2D (ch*3, ch*2, kernel_size=3, padding='SAME') + self.conv4_up = nn.Conv2DTranspose (base_ch*4, base_ch, kernel_size=3, padding='SAME') + self.conv4 = nn.Conv2D (base_ch*3, base_ch*2, kernel_size=3, padding='SAME') - self.conv5_up = nn.Conv2DTranspose (ch*2, ch//2, kernel_size=3, padding='SAME') - self.conv5 = nn.Conv2D (ch//2+ch, ch, kernel_size=3, padding='SAME') + self.conv5_up = nn.Conv2DTranspose (base_ch*2, base_ch//2, kernel_size=3, padding='SAME') + self.conv5 = nn.Conv2D (base_ch//2+base_ch, base_ch, kernel_size=3, padding='SAME') - self.out_conv = nn.Conv2D (ch, 1, kernel_size=3, padding='SAME') + self.out_conv = nn.Conv2D (base_ch, 1, kernel_size=3, padding='SAME') def forward(self, inp): x, = inp @@ -106,92 +106,120 @@ class TernausNet(object): x = tf.concat( [x,x0], -1) x = tf.nn.relu(self.conv5(x)) - x = tf.nn.sigmoid(self.out_conv(x)) - return x + logits = self.out_conv(x) + return logits, tf.nn.sigmoid(logits) if weights_file_root is not None: weights_file_root = Path(weights_file_root) else: weights_file_root = Path(__file__).parent - self.weights_path = weights_file_root / ('%s_%d_%s.npy' % (name, resolution, face_type_str) ) + self.weights_file_root = weights_file_root - e = tf.device('/CPU:0') if place_model_on_cpu else None + with tf.device ('/CPU:0'): + #Place holders on CPU + self.input_t = tf.placeholder (nn.tf_floatx, nn.get4Dshape(resolution,resolution,3) ) + self.target_t = tf.placeholder (nn.tf_floatx, nn.get4Dshape(resolution,resolution,1) ) - if e is not None: e.__enter__() - self.net = Ternaus(3, 64, name='Ternaus') - if load_weights: - self.net.load_weights (self.weights_path) + # Initializing model classes + with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'): + self.net = Ternaus(3, 64, name='Ternaus') + self.net_weights = self.net.get_weights() + + self.model_filename_list = [ [self.net, '%s_%d_%s.npy' % (name, resolution, face_type_str) ] ] + + if training: + self.opt = nn.TFRMSpropOptimizer(lr=0.0001, name='opt') + self.opt.initialize_variables (self.net_weights, vars_on_cpu=place_model_on_cpu) + self.model_filename_list += [ [self.opt, '%s_%d_%s_opt.npy' % (name, resolution, face_type_str) ] ] else: - self.net.init_weights() - if e is not None: e.__exit__(None,None,None) + _, pred = self.net([self.input_t]) + def net_run(input_np): + return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0] + self.net_run = net_run - self.net.build_for_run ( [(tf.float32, nn.get4Dshape (resolution,resolution,3) )] ) + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + do_init = not load_weights - if training: - raise Exception("training not supported yet") + if not do_init: + do_init = not model.load_weights( self.weights_file_root / filename ) + if do_init: + model.init_weights() + if model == self.net: + try: + with open( Path(__file__).parent / 'vgg11_enc_weights.npy', 'rb' ) as f: + d = pickle.loads (f.read()) - """ - if training: - try: - with open( Path(__file__).parent / 'vgg11_enc_weights.npy', 'rb' ) as f: - d = pickle.loads (f.read()) - - for i in [0,3,6,8,11,13,16,18]: - s = 'features.%d' % i - - self.model.get_layer (s).set_weights ( d[s] ) - except: - io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy") - - conv_weights_list = [] - for layer in self.model.layers: - if 'CA.' in layer.name: - conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights - CAInitializerMP ( conv_weights_list ) - """ - - - """ - if training: - inp_t = Input ( (resolution, resolution, 3) ) - real_t = Input ( (resolution, resolution, 1) ) - out_t = self.model(inp_t) - - loss = K.mean(10*K.binary_crossentropy(real_t,out_t) ) - - out_t_diff1 = out_t[:, 1:, :, :] - out_t[:, :-1, :, :] - out_t_diff2 = out_t[:, :, 1:, :] - out_t[:, :, :-1, :] - - total_var_loss = K.mean( 0.1*K.abs(out_t_diff1), axis=[1, 2, 3] ) + K.mean( 0.1*K.abs(out_t_diff2), axis=[1, 2, 3] ) - - opt = Adam(lr=0.0001, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2) - - self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( [loss], self.model.trainable_weights) ) - """ - - def __enter__(self): - return self - - def __exit__(self, exc_type=None, exc_value=None, traceback=None): - return False #pass exception between __enter__ and __exit__ to outter level + for i in [0,3,6,8,11,13,16,18]: + model.get_layer_by_name ('features_%d' % i).set_weights ( d['features.%d' % i] ) + except: + io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy") def save_weights(self): - self.net.save_weights (str(self.weights_path)) - - def train(self, inp, real): - loss, = self.train_func ([inp, real]) - return loss + for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving"): + model.save_weights( self.weights_file_root / filename ) def extract (self, input_image): input_shape_len = len(input_image.shape) if input_shape_len == 3: - input_image = input_image[np.newaxis,...] + input_image = input_image[None,...] - result = np.clip ( self.net.run([input_image]), 0, 1.0 ) + result = np.clip ( self.net_run(input_image), 0, 1.0 ) result[result < 0.1] = 0 #get rid of noise if input_shape_len == 3: result = result[0] return result + +""" +if load_weights: + self.net.load_weights (self.weights_path) +else: + self.net.init_weights() + +if load_weights: + self.opt.load_weights (self.opt_path) +else: + self.opt.init_weights() +""" +""" +if training: + try: + with open( Path(__file__).parent / 'vgg11_enc_weights.npy', 'rb' ) as f: + d = pickle.loads (f.read()) + + for i in [0,3,6,8,11,13,16,18]: + s = 'features.%d' % i + + self.model.get_layer (s).set_weights ( d[s] ) + except: + io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy") + + conv_weights_list = [] + for layer in self.model.layers: + if 'CA.' in layer.name: + conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights + CAInitializerMP ( conv_weights_list ) +""" + + + +""" +if training: + inp_t = Input ( (resolution, resolution, 3) ) + real_t = Input ( (resolution, resolution, 1) ) + out_t = self.model(inp_t) + + loss = K.mean(10*K.binary_crossentropy(real_t,out_t) ) + + out_t_diff1 = out_t[:, 1:, :, :] - out_t[:, :-1, :, :] + out_t_diff2 = out_t[:, :, 1:, :] - out_t[:, :, :-1, :] + + total_var_loss = K.mean( 0.1*K.abs(out_t_diff1), axis=[1, 2, 3] ) + K.mean( 0.1*K.abs(out_t_diff2), axis=[1, 2, 3] ) + + opt = Adam(lr=0.0001, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2) + + self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( [loss], self.model.trainable_weights) ) +""" diff --git a/mainscripts/dev_misc.py b/mainscripts/dev_misc.py index 05b6b62..1e47a08 100644 --- a/mainscripts/dev_misc.py +++ b/mainscripts/dev_misc.py @@ -1,3 +1,4 @@ +import json import multiprocessing import shutil from pathlib import Path @@ -5,12 +6,14 @@ from pathlib import Path import cv2 import numpy as np -from DFLIMG import * -from facelib import FaceType, LandmarksProcessor +from core import imagelib, pathex +from core.cv2ex import * +from core.imagelib import IEPolys from core.interact import interact as io from core.joblib import Subprocessor -from core import pathex, imagelib -from core.cv2ex import * +from core.leras import nn +from DFLIMG import * +from facelib import FaceType, LandmarksProcessor from . import Extractor, Sorter from .Extractor import ExtractSubprocessor @@ -393,7 +396,8 @@ def extract_fanseg(input_dir, device_args={} ): data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run() #unused in end user workflow -def dev_test(input_dir ): +def dev_test_68(input_dir ): + # process 68 landmarks dataset with .pts files input_path = Path(input_dir) if not input_path.exists(): raise ValueError('input_dir not found. Please ensure it exists.') @@ -441,7 +445,7 @@ def dev_test(input_dir ): continue rect = LandmarksProcessor.get_rect_from_landmarks(lmrks) - + output_filepath = output_path / (filepath.stem+'.jpg') img = cv2_imread(filepath) @@ -549,3 +553,108 @@ def dev_test1(input_dir): #import code #code.interact(local=dict(globals(), **locals())) +#unused in end user workflow +def dev_test(input_dir ): + # extract and merge .json labelme files within the faces + + + device_config = nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(suggest_all_gpu=True) ) + + + + input_path = Path(input_dir) + if not input_path.exists(): + raise ValueError('input_dir not found. Please ensure it exists.') + + output_path = input_path.parent / (input_path.name+'_merged') + + io.log_info(f'Output dir is % {output_path}') + + if output_path.exists(): + output_images_paths = pathex.get_image_paths(output_path) + if len(output_images_paths) > 0: + io.input_bool("WARNING !!! \n %s contains files! \n They will be deleted. \n Press enter to continue." % (str(output_path)), False ) + for filename in output_images_paths: + Path(filename).unlink() + else: + output_path.mkdir(parents=True, exist_ok=True) + + images_paths = pathex.get_image_paths(input_path) + + extract_data = [] + + images_jsons = {} + + for filepath in io.progress_bar_generator(images_paths, "Processing"): + filepath = Path(filepath) + + + json_filepath = filepath.parent / (filepath.stem+'.json') + + + if json_filepath.exists(): + json_dict = json.loads(json_filepath.read_text()) + images_jsons[filepath] = json_dict + + total_points = [ [x,y] for shape in json_dict['shapes'] for x,y in shape['points'] ] + total_points = np.array(total_points) + + l,r = int(total_points[:,0].min()), int(total_points[:,0].max()) + t,b = int(total_points[:,1].min()), int(total_points[:,1].max()) + + extract_data.append ( ExtractSubprocessor.Data(filepath, rects=[ [l,t,r,b] ] ) ) + + image_size = 1024 + face_type = FaceType.HEAD + extract_data = ExtractSubprocessor (extract_data, 'landmarks', image_size, face_type, device_config=device_config).run() + extract_data = ExtractSubprocessor (extract_data, 'final', image_size, face_type, final_output_path=output_path, device_config=device_config).run() + + for data in extract_data: + filepath = output_path / (data.filepath.stem+'_0.jpg') + + dflimg = DFLIMG.load(filepath) + image_to_face_mat = dflimg.get_image_to_face_mat() + + json_dict = images_jsons[data.filepath] + + ie_polys = IEPolys() + for shape in json_dict['shapes']: + ie_poly = ie_polys.add(1) + + points = np.array( [ [x,y] for x,y in shape['points'] ] ) + points = LandmarksProcessor.transform_points(points, image_to_face_mat) + + for x,y in points: + ie_poly.add( int(x), int(y) ) + + dflimg.embed_and_set (filepath, ie_polys=ie_polys) + + """ + #mark only + for data in extract_data: + filepath = data.filepath + output_filepath = output_path / (filepath.stem+'.jpg') + + img = cv2_imread(filepath) + img = imagelib.normalize_channels(img, 3) + cv2_imwrite(output_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + + json_dict = images_jsons[filepath] + + ie_polys = IEPolys() + for shape in json_dict['shapes']: + ie_poly = ie_polys.add(1) + for x,y in shape['points']: + ie_poly.add( int(x), int(y) ) + + + DFLJPG.embed_data(output_filepath, face_type=FaceType.toString(FaceType.MARK_ONLY), + landmarks=data.landmarks[0], + ie_polys=ie_polys, + source_filename=filepath.name, + source_rect=data.rects[0], + source_landmarks=data.landmarks[0] + ) + """ + + io.log_info("Done.") diff --git a/models/Model_FANSeg/Model.py b/models/Model_FANSeg/Model.py new file mode 100644 index 0000000..d8babde --- /dev/null +++ b/models/Model_FANSeg/Model.py @@ -0,0 +1,171 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType, TernausNet +from models import ModelBase +from samplelib import * + +class FANSegModel(ModelBase): + + #override + def on_initialize_options(self): + device_config = nn.getCurrentDeviceConfig() + yn_str = {True:'y',False:'n'} + + #default_resolution = 256 + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') + + ask_override = self.ask_override() + if self.is_first_run() or ask_override: + self.ask_autobackup_hour() + self.ask_write_preview_history() + self.ask_target_iter() + self.ask_batch_size(4) + + if self.is_first_run(): + #resolution = io.input_int("Resolution", default_resolution, add_info="64-512") + #resolution = np.clip ( (resolution // 16) * 16, 64, 512) + #self.options['resolution'] = resolution + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf']).lower() + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + nn.initialize(data_format="NHWC") + tf = nn.tf + + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + + self.resolution = resolution = 256#self.options['resolution'] + self.face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ] + + place_model_on_cpu = len(devices) == 0 + models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0' + + bgr_shape = nn.get4Dshape(resolution,resolution,3) + mask_shape = nn.get4Dshape(resolution,resolution,1) + + # Initializing model classes + self.model = TernausNet('FANSeg', + resolution, + FaceType.toString(self.face_type), + load_weights=not self.is_first_run(), + weights_file_root=self.get_model_root_path(), + training=True, + place_model_on_cpu=place_model_on_cpu) + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + + # Compute losses per GPU + gpu_pred_list = [] + + gpu_losses = [] + gpu_loss_gvs = [] + + for gpu_id in range(gpu_count): + with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_input_t = self.model.input_t [batch_slice,:,:,:] + gpu_target_t = self.model.target_t [batch_slice,:,:,:] + + # process model tensors + gpu_pred_logits_t, gpu_pred_t = self.model.net([gpu_input_t]) + gpu_pred_list.append(gpu_pred_t) + + gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) + gpu_losses += [gpu_loss] + + gpu_loss_gvs += [ nn.tf_gradients ( gpu_loss, self.model.net_weights ) ] + + + # Average losses and gradients, and create optimizer update ops + with tf.device (models_opt_device): + pred = nn.tf_concat(gpu_pred_list, 0) + loss = tf.reduce_mean(gpu_losses) + + loss_gv_op = self.model.opt.get_update_op (nn.tf_average_gv_list (gpu_loss_gvs)) + + + + # Initializing training and view functions + def train(input_np, target_np): + l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) + return l + self.train = train + + def view(input_np): + return nn.tf_sess.run ( [pred], feed_dict={self.model.input_t :input_np}) + self.view = view + + # initializing sample generators + training_data_src_path = self.training_data_src_path + #training_data_dst_path = self.training_data_dst_path + + cpu_count = min(multiprocessing.cpu_count(), 8) + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + src_generators_count = int(src_generators_count * 1.5) + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=True), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'idt', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.NONE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + generators_count=src_generators_count ), + ]) + + #override + def get_model_filename_list(self): + return self.model.model_filename_list + + #override + def onSave(self): + self.model.save_weights() + + #override + def onTrainOneIter(self): + ( (source_np, target_np), ) = self.generate_next_samples() + loss = self.train (source_np, target_np) + + return ( ('loss', loss ), ) + + #override + def onGetPreview(self, samples): + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + ( (source_np, target_np), ) = samples + + S, T, SM, = [ np.clip(x, 0.0, 1.0) for x in ([source_np,target_np] + self.view (source_np) ) ] + T, SM, = [ np.repeat (x, (3,), -1) for x in [T, SM] ] + + result = [] + + st = [] + for i in range(n_samples): + ar = S[i], T[i], SM[i], S[i]*SM[i] + #todo green bg + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('FANSeg', np.concatenate (st, axis=0 )), ] + + return result + +Model = FANSegModel diff --git a/models/Model_FANSeg/__init__.py b/models/Model_FANSeg/__init__.py new file mode 100644 index 0000000..0188f11 --- /dev/null +++ b/models/Model_FANSeg/__init__.py @@ -0,0 +1 @@ +from .Model import Model diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index 0dbd113..d14759d 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -148,7 +148,28 @@ class SampleProcessor(object): raise ValueError("only channel_type.G supported for the mask") elif sample_type == SPST.FACE_IMAGE: - img = sample_bgr + img = sample_bgr + + if sample_face_type == FaceType.MARK_ONLY: + mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type) + img = cv2.warpAffine( img, mat, (warp_resolution,warp_resolution), flags=cv2.INTER_CUBIC ) + img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=border_replicate) + img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC ) + else: + img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=border_replicate) + mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) + img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC ) + + img = np.clip(img.astype(np.float32), 0, 1) + + + + # Apply random color transfer + if ct_mode is not None and ct_sample is not None: + if ct_sample_bgr is None: + ct_sample_bgr = ct_sample.load_bgr() + img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), cv2.INTER_LINEAR ) ) + if motion_blur is not None: chance, mb_max_size = motion_blur chance = np.clip(chance, 0, 100) @@ -171,25 +192,7 @@ class SampleProcessor(object): if gblur_rnd_chance < chance: img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0) - - if sample_face_type == FaceType.MARK_ONLY: - mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type) - img = cv2.warpAffine( img, mat, (warp_resolution,warp_resolution), flags=cv2.INTER_CUBIC ) - img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=border_replicate) - img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC ) - else: - img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=border_replicate) - mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) - img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC ) - - img = np.clip(img.astype(np.float32), 0, 1) - - # Apply random color transfer - if ct_mode is not None and ct_sample is not None: - if ct_sample_bgr is None: - ct_sample_bgr = ct_sample.load_bgr() - img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), cv2.INTER_LINEAR ) ) - + # Transform from BGR to desired channel_type if channel_type == SPCT.BGR: out_sample = img