diff --git a/core/leras/archis/DeepFakeArchi.py b/core/leras/archis/DeepFakeArchi.py index 5f49258..b86ae2c 100644 --- a/core/leras/archis/DeepFakeArchi.py +++ b/core/leras/archis/DeepFakeArchi.py @@ -25,7 +25,7 @@ class DeepFakeArchi(nn.ArchiBase): super().__init__(*kwargs) def on_build(self, *args, **kwargs ): - self.conv1 = nn.SeparableConv2D( self.in_ch, + self.conv1 = nn.Conv2D( self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, @@ -60,7 +60,7 @@ class DeepFakeArchi(nn.ArchiBase): class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): - self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') def forward(self, x): x = self.conv1(x) @@ -70,8 +70,8 @@ class DeepFakeArchi(nn.ArchiBase): class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): - self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') - self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') def forward(self, inp): x = self.conv1(inp) @@ -161,12 +161,12 @@ class DeepFakeArchi(nn.ArchiBase): self.res1 = ResidualBlock(d_ch*4, kernel_size=3) self.res2 = ResidualBlock(d_ch*2, kernel_size=3) - self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME') + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME') self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) - self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') def forward(self, inp): z = inp @@ -211,7 +211,7 @@ class DeepFakeArchi(nn.ArchiBase): super().__init__(*kwargs) def on_build(self, *args, **kwargs ): - self.conv1 = nn.SeparableConv2D( self.in_ch, + self.conv1 = nn.Conv2D( self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, strides=1 if self.subpixel else 2, @@ -248,7 +248,7 @@ class DeepFakeArchi(nn.ArchiBase): class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): - self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') def forward(self, x): x = self.conv1(x) @@ -258,8 +258,8 @@ class DeepFakeArchi(nn.ArchiBase): class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): - self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') - self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') def forward(self, inp): x = self.conv1(inp) @@ -314,8 +314,8 @@ class DeepFakeArchi(nn.ArchiBase): self.upscalem2 = Upscale(d_ch, d_ch//2) self.upscalem3 = Upscale(d_ch//2, d_ch//2) - self.out_conv = nn.SeparableConv2D( d_ch*1, 3, kernel_size=1, padding='SAME') - self.out_convm = nn.SeparableConv2D( d_ch//2, 1, kernel_size=1, padding='SAME') + self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME') + self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME') def forward(self, inp): z = inp @@ -346,7 +346,7 @@ class DeepFakeArchi(nn.ArchiBase): super().__init__(*kwargs) def on_build(self, *args, **kwargs ): - self.conv1 = nn.SeparableConv2D( self.in_ch, + self.conv1 = nn.Conv2D( self.in_ch, self.out_ch // (4 if self.subpixel else 1), kernel_size=self.kernel_size, depth_multiplier=self.depth_multiplier, @@ -443,7 +443,7 @@ class DeepFakeArchi(nn.ArchiBase): class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3, depth_multiplier=1 ): - self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') #self.frn1 = nn.FRNorm2D(out_ch*4) #self.tlu1 = nn.TLU(out_ch*4) @@ -457,11 +457,11 @@ class DeepFakeArchi(nn.ArchiBase): class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3, depth_multiplier=1 ): - self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') #self.frn1 = nn.FRNorm2D(ch) #self.tlu1 = nn.TLU(ch) - self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') #self.frn2 = nn.FRNorm2D(ch) #self.tlu2 = nn.TLU(ch) @@ -545,12 +545,12 @@ class DeepFakeArchi(nn.ArchiBase): self.res1 = ResidualBlock(d_ch*4, kernel_size=3) self.res2 = ResidualBlock(d_ch*2, kernel_size=3) - self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME') + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME') self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) - self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') def forward(self, inp): z = inp diff --git a/facelib/FANExtractor.py b/facelib/FANExtractor.py index 5676fcc..0604cc2 100644 --- a/facelib/FANExtractor.py +++ b/facelib/FANExtractor.py @@ -2,170 +2,304 @@ import os import traceback from pathlib import Path +import warnings +#warnings.simplefilter(action='ignore', category=FutureWarning) + import cv2 import numpy as np from numpy import linalg as npla +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import ToTensor + +import math + +from facelib.nn_pt import nn as nn_pt + from facelib import FaceType, LandmarksProcessor -from core.leras import nn +from facelib.LandmarksProcessor import convert_98_to_68 + +from facelib.coord_conv import CoordConvTh + """ -ported from https://github.com/1adrianb/face-alignment +ported from https://github.com/protossw512/AdaptiveWingLoss """ class FANExtractor(object): - def __init__ (self, landmarks_3D=False, place_model_on_cpu=False): + def __init__ (self, place_model_on_cpu=False): + model_path = Path(__file__).parent / "AWL.pth" + + nn_pt.initialize() - model_path = Path(__file__).parent / ( "2DFAN.npy" if not landmarks_3D else "3DFAN.npy") if not model_path.exists(): - raise Exception("Unable to load FANExtractor model") + raise Exception("Unable to load AWL.pth") - nn.initialize(data_format="NHWC") - tf = nn.tf + def conv3x3(in_planes, out_planes, strd=1, padding=1, + bias=False,dilation=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=strd, padding=padding, bias=bias, + dilation=dilation) - class ConvBlock(nn.ModelBase): - def on_build(self, in_planes, out_planes): + class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + # self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + # self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + # out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + # out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + + class ConvBlock(nn.Module): + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() self.in_planes = in_planes self.out_planes = out_planes - self.bn1 = nn.BatchNorm2D(in_planes) - self.conv1 = nn.Conv2D (in_planes, out_planes/2, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) - self.bn2 = nn.BatchNorm2D(out_planes//2) - self.conv2 = nn.Conv2D (out_planes/2, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + self.bn2 = nn.BatchNorm2d(out_planes//2) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), + padding=1, dilation=1) - self.bn3 = nn.BatchNorm2D(out_planes//4) - self.conv3 = nn.Conv2D (out_planes/4, out_planes/4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + self.bn3 = nn.BatchNorm2d(out_planes//4) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), + padding=1, dilation=1) if self.in_planes != self.out_planes: - self.down_bn1 = nn.BatchNorm2D(in_planes) - self.down_conv1 = nn.Conv2D (in_planes, out_planes, kernel_size=1, strides=1, padding='VALID', use_bias=False ) + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, + kernel_size=1, stride=1, bias=False), + ) else: - self.down_bn1 = None - self.down_conv1 = None + self.downsample = None - def forward(self, input): - x = input - x = self.bn1(x) - x = tf.nn.relu(x) - x = out1 = self.conv1(x) + def forward(self, x): + residual = x - x = self.bn2(x) - x = tf.nn.relu(x) - x = out2 = self.conv2(x) + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) - x = self.bn3(x) - x = tf.nn.relu(x) - x = out3 = self.conv3(x) + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) - x = tf.concat ([out1, out2, out3], axis=-1) + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) - if self.in_planes != self.out_planes: - downsample = self.down_bn1(input) - downsample = tf.nn.relu (downsample) - downsample = self.down_conv1 (downsample) - x = x + downsample + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + class HourGlass (nn.Module): + def __init__(self, num_modules, depth, num_features, first_one=False): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + self.coordconv = CoordConvTh(x_dim=64, y_dim=64, + with_r=True, with_boundary=True, + in_channels=256, first_one=first_one, + out_channels=256, + kernel_size=1, + stride=1, padding=0) + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(256, 256)) + + self.add_module('b2_' + str(level), ConvBlock(256, 256)) + + if level > 1: + self._generate_network(level - 1) else: - x = x + input + self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) - return x + self.add_module('b3_' + str(level), ConvBlock(256, 256)) - class HourGlass (nn.ModelBase): - def on_build(self, in_planes, depth): - self.b1 = ConvBlock (in_planes, 256) - self.b2 = ConvBlock (in_planes, 256) + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) - if depth > 1: - self.b2_plus = HourGlass(256, depth-1) + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) else: - self.b2_plus = ConvBlock(256, 256) + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) - self.b3 = ConvBlock(256, 256) + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) - def forward(self, input): - up1 = self.b1(input) + up2 = F.upsample(low3, scale_factor=2, mode='nearest') - low1 = tf.nn.avg_pool(input, [1,2,2,1], [1,2,2,1], 'VALID') - low1 = self.b2 (low1) + return up1 + up2 - low2 = self.b2_plus(low1) - low3 = self.b3(low2) + def forward(self, x, heatmap): + x, last_channel = self.coordconv(x, heatmap) + return self._forward(self.depth, x), last_channel - up2 = nn.upsample2d(low3) - return up1+up2 - - class FAN (nn.ModelBase): - def __init__(self): - super().__init__(name='FAN') - - def on_build(self): - self.conv1 = nn.Conv2D (3, 64, kernel_size=7, strides=2, padding='SAME') - self.bn1 = nn.BatchNorm2D(64) + class FAN (nn.Module): + def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68): + super(FAN,self).__init__() + self.num_modules = num_modules + self.gray_scale = gray_scale + self.end_relu = end_relu + self.num_landmarks = num_landmarks + # Base part + if self.gray_scale: + self.conv1 = CoordConvTh(x_dim=256, y_dim=256, + with_r=True, with_boundary=False, + in_channels=3, out_channels=64, + kernel_size=7, + stride=2, padding=3) + else: + self.conv1 = CoordConvTh(x_dim=256, y_dim=256, + with_r=True, with_boundary=False, + in_channels=3, out_channels=64, + kernel_size=7, + stride=2, padding=3) + + self.bn1 = nn.BatchNorm2d(64) self.conv2 = ConvBlock(64, 128) self.conv3 = ConvBlock(128, 128) self.conv4 = ConvBlock(128, 256) - self.m = [] - self.top_m = [] - self.conv_last = [] - self.bn_end = [] - self.l = [] - self.bl = [] - self.al = [] - for i in range(4): - self.m += [ HourGlass(256, 4) ] - self.top_m += [ ConvBlock(256, 256) ] + # Stacking part + for hg_module in range(self.num_modules): + if hg_module == 0: + first_one = True + else: + first_one = False + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256, + first_one)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, + num_landmarks+1, kernel_size=1, stride=1, padding=0)) - self.conv_last += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ] - self.bn_end += [ nn.BatchNorm2D(256) ] + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), nn.Conv2d(num_landmarks+1, + 256, kernel_size=1, stride=1, padding=0)) + - self.l += [ nn.Conv2D (256, 68, kernel_size=1, strides=1, padding='VALID') ] - if i < 4-1: - self.bl += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ] - self.al += [ nn.Conv2D (68, 256, kernel_size=1, strides=1, padding='VALID') ] - - def forward(self, inp) : - x, = inp - x = self.conv1(x) - x = self.bn1(x) - x = tf.nn.relu(x) - - x = self.conv2(x) - x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], 'VALID') + def forward(self, x): + x, _ = self.conv1(x) + x = F.relu(self.bn1(x), True) + # x = F.relu(self.bn1(self.conv1(x)), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) x = self.conv3(x) x = self.conv4(x) - outputs = [] previous = x - for i in range(4): - ll = self.m[i] (previous) - ll = self.top_m[i] (ll) - ll = self.conv_last[i] (ll) - ll = self.bn_end[i] (ll) - ll = tf.nn.relu(ll) - tmp_out = self.l[i](ll) + + outputs = [] + boundary_channels = [] + tmp_out = None + for i in range(self.num_modules): + hg, boundary_channel = self._modules['m' + str(i)](previous, + tmp_out) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu(self._modules['bn_end' + str(i)] + (self._modules['conv_last' + str(i)](ll)), True) + + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + if self.end_relu: + tmp_out = F.relu(tmp_out) # HACK: Added relu + outputs.append(tmp_out) - if i < 4 - 1: - ll = self.bl[i](ll) - previous = previous + ll + self.al[i](tmp_out) - x = outputs[-1] - x = tf.transpose(x, (0,3,1,2) ) - return x + boundary_channels.append(boundary_channel) - e = None - if place_model_on_cpu: - e = tf.device("/CPU:0") + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs, boundary_channels + + def load_model(model, pretrained_path, load_to_cpu): + if load_to_cpu: + checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) + else: + device = torch.cuda.current_device() + checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(device)) + + if 'state_dict' not in checkpoint: + model.load_state_dict(checkpoint) + else: + pretrained_weights = checkpoint['state_dict'] + model_weights = model.state_dict() + pretrained_weights = {k: v for k, v in pretrained_weights.items() \ + if k in model_weights} + model_weights.update(pretrained_weights) + model.load_state_dict(model_weights) + return model + + + torch.set_grad_enabled(False) + + self.model = FAN(num_modules=4, num_landmarks=98, end_relu=False, gray_scale=False) + self.model = load_model(self.model, model_path, place_model_on_cpu) - if e is not None: e.__enter__() - self.model = FAN() - self.model.load_weights(str(model_path)) - if e is not None: e.__exit__(None,None,None) + self.model.eval() + + self.device = torch.device("cpu" if place_model_on_cpu else "cuda") + self.model = self.model.to(self.device) - self.model.build_for_run ([ ( tf.float32, (None,256,256,3) ) ]) def extract (self, input_image, rects, second_pass_extractor=None, is_bgr=True, multi_sample=False): + if len(rects) == 0: return [] @@ -174,12 +308,14 @@ class FANExtractor(object): is_bgr = False (h, w, ch) = input_image.shape + print("Image shape:", input_image.shape) landmarks = [] for (left, top, right, bottom) in rects: scale = (right - left + bottom - top) / 195.0 center = np.array( [ (left + right) / 2.0, (top + bottom) / 2.0] ) + print("Center:", center) centers = [ center ] if multi_sample: @@ -190,26 +326,30 @@ class FANExtractor(object): ] images = [] - ptss = [] - try: for c in centers: - images += [ self.crop(input_image, c, scale) ] + images += [ self.crop_old(input_image, c, scale) ] images = np.stack (images) images = images.astype(np.float32) / 255.0 - - predicted = [] - for i in range( len(images) ): - predicted += [ self.model.run ( [ images[i][None,...] ] )[0] ] - - predicted = np.stack(predicted) - - for i, pred in enumerate(predicted): - ptss += [ self.get_pts_from_predict ( pred, centers[i], scale) ] - pts_img = np.mean ( np.array(ptss), 0 ) - - landmarks.append (pts_img) + + i = 0 + for img in images: + img = ToTensor()(img) + img = img.to(self.device) + + outputs, boundary_channels = self.model(img[None,...]) + + pred_heatmap = outputs[-1][:, :-1, :, :][i].detach().cpu() + + pred_landmarks, _ = self.get_pts_from_predict ( pred_heatmap.unsqueeze(0), centers[i], scale) + i += 1 + + pred_landmarks = pred_landmarks.squeeze().numpy() + pred_landmarks = convert_98_to_68(pred_landmarks) + + landmarks += [pred_landmarks] + except: landmarks.append (None) @@ -228,8 +368,8 @@ class FANExtractor(object): pass return landmarks - - def transform(self, point, center, scale, resolution): + + def transform_old(self, point, center, scale, resolution): pt = np.array ( [point[0], point[1], 1.0] ) h = 200.0 * scale m = np.eye(3) @@ -240,9 +380,73 @@ class FANExtractor(object): m = np.linalg.inv(m) return np.matmul (m, pt)[0:2] - def crop(self, image, center, scale, resolution=256.0): - ul = self.transform([1, 1], center, scale, resolution).astype( np.int ) - br = self.transform([resolution, resolution], center, scale, resolution).astype( np.int ) + def transform(self, point, center, scale, resolution, rotation=0, invert=False): + _pt = np.ones(3) + _pt[0] = point[0] + _pt[1] = point[1] + + h = 200.0 * scale + t = np.eye(3) + t[0, 0] = resolution / h + t[1, 1] = resolution / h + t[0, 2] = resolution * (-center[0] / h + 0.5) + t[1, 2] = resolution * (-center[1] / h + 0.5) + + if rotation != 0: + rotation = -1*rotation + r = np.eye(3) + ang = rotation * math.pi / 180.0 + s = math.sin(ang) + c = math.cos(ang) + r[0][0] = c + r[0][1] = -s + r[1][0] = s + r[1][1] = c + + t_ = np.eye(3) + t_[0][2] = -resolution / 2.0 + t_[1][2] = -resolution / 2.0 + t_inv = torch.eye(3) + t_inv[0][2] = resolution / 2.0 + t_inv[1][2] = resolution / 2.0 + t = reduce(np.matmul, [t_inv, r, t_, t]) + + if invert: + t = np.linalg.inv(t) + new_point = (np.matmul(t, _pt))[0:2] + + return new_point.astype(int) + + def crop(self, image, center, scale, resolution=256, center_shift=0): + new_image = cv2.copyMakeBorder(image, center_shift, + center_shift, + center_shift, + center_shift, + cv2.BORDER_CONSTANT, value=[0,0,0]) + if center_shift != 0: + center[0] += center_shift + center[1] += center_shift + length = 200 * scale + top = int(center[1] - length // 2) + bottom = int(center[1] + length // 2) + left = int(center[0] - length // 2) + right = int(center[0] + length // 2) + y_pad = abs(min(top, new_image.shape[0] - bottom, 0)) + x_pad = abs(min(left, new_image.shape[1] - right, 0)) + top, bottom, left, right = top + y_pad, bottom + y_pad, left + x_pad, right + x_pad + new_image = cv2.copyMakeBorder(new_image, y_pad, + y_pad, + x_pad, + x_pad, + cv2.BORDER_CONSTANT, value=[0,0,0]) + new_image = new_image[top:bottom, left:right] + new_image = cv2.resize(new_image, dsize=(int(resolution), int(resolution)), + interpolation=cv2.INTER_LINEAR) + return new_image + + def crop_old(self, image, center, scale, resolution=256.0): + ul = self.transform_old([1, 1], center, scale, resolution).astype( np.int ) + br = self.transform_old([resolution, resolution], center, scale, resolution).astype( np.int ) if image.ndim > 2: newDim = np.array([br[1] - ul[1], br[0] - ul[0], image.shape[2]], dtype=np.int32) @@ -260,21 +464,32 @@ class FANExtractor(object): newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), interpolation=cv2.INTER_LINEAR) return newImg + - def get_pts_from_predict(self, a, center, scale): - a_ch, a_h, a_w = a.shape + def get_pts_from_predict(self, hm, center=None, scale=None, rot=None): + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) - b = a.reshape ( (a_ch, a_h*a_w) ) - c = b.argmax(1).reshape ( (a_ch, 1) ).repeat(2, axis=1).astype(np.float) - c[:,0] %= a_w - c[:,1] = np.apply_along_axis ( lambda x: np.floor(x / a_w), 0, c[:,1] ) + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) - for i in range(a_ch): - pX, pY = int(c[i,0]), int(c[i,1]) - if pX > 0 and pX < 63 and pY > 0 and pY < 63: - diff = np.array ( [a[i,pY,pX+1]-a[i,pY,pX-1], a[i,pY+1,pX]-a[i,pY-1,pX]] ) - c[i] += np.sign(diff)*0.25 + preds.add_(-0.5) - c += 0.5 + preds_orig = torch.zeros(preds.size()) + if center is not None and scale is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = torch.from_numpy(self.transform(preds[i, j], center, scale, hm.size(2), rot if (rot is not None) else (0), True)) - return np.array( [ self.transform (c[i], center, scale, a_w) for i in range(a_ch) ] ) + return preds, preds_orig diff --git a/facelib/__init__.py b/facelib/__init__.py index e46ca51..1690c13 100644 --- a/facelib/__init__.py +++ b/facelib/__init__.py @@ -1,5 +1,9 @@ from .FaceType import FaceType from .S3FDExtractor import S3FDExtractor +from .RetinaFaceExtractor import RetinaFaceExtractor +from .nn_pt import nn +from .device_pt import Devices +from .ResNet50 import resnet50 from .FANExtractor import FANExtractor from .FaceEnhancer import FaceEnhancer from .XSegNet import XSegNet \ No newline at end of file diff --git a/main.py b/main.py index 36e79b1..7686a6a 100644 --- a/main.py +++ b/main.py @@ -43,7 +43,7 @@ if __name__ == "__main__": ) p = subparsers.add_parser( "extract", help="Extract the faces from a pictures.") - p.add_argument('--detector', dest="detector", choices=['s3fd','manual'], default=None, help="Type of detector.") + p.add_argument('--detector', dest="detector", choices=['retinaface', 'manual'], default=None, help="Type of detector.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the extracted files will be stored.") p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to _debug\ directory.") diff --git a/mainscripts/Extractor.py b/mainscripts/Extractor.py index 908e723..cbe38df 100644 --- a/mainscripts/Extractor.py +++ b/mainscripts/Extractor.py @@ -1,4 +1,4 @@ -import traceback +import traceback import math import multiprocessing import operator @@ -15,9 +15,9 @@ import facelib from core import imagelib from core import mathlib from facelib import FaceType, LandmarksProcessor +from facelib.nn_pt import nn from core.interact import interact as io from core.joblib import Subprocessor -from core.leras import nn from core import pathex from core.cv2ex import * from DFLIMG import * @@ -68,12 +68,11 @@ class ExtractSubprocessor(Subprocessor): self.log_info (f"Running on {client_dict['device_name'] }") if self.type == 'all' or self.type == 'rects-s3fd' or 'landmarks' in self.type: - self.rects_extractor = facelib.S3FDExtractor(place_model_on_cpu=place_model_on_cpu) + self.rects_extractor = facelib.RetinaFaceExtractor(place_model_on_cpu=place_model_on_cpu) if self.type == 'all' or 'landmarks' in self.type: # for head type, extract "3D landmarks" - self.landmarks_extractor = facelib.FANExtractor(landmarks_3D=self.face_type >= FaceType.HEAD, - place_model_on_cpu=place_model_on_cpu) + self.landmarks_extractor = facelib.FANExtractor(place_model_on_cpu=place_model_on_cpu) self.cached_image = (None, None) @@ -716,9 +715,9 @@ def main(detector=None, if detector is None: io.log_info ("Choose detector type.") - io.log_info ("[0] S3FD") + io.log_info ("[0] RetinaFace") io.log_info ("[1] manual") - detector = {0:'s3fd', 1:'manual'}[ io.input_int("", 0, [0,1]) ] + detector = {0:'retinaface', 1:'manual'}[ io.input_int("", 0, [0,1]) ] device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(choose_only_one=detector=='manual', suggest_all_gpu=True) ) \ if not cpu_only else nn.DeviceConfig.CPU() @@ -795,4 +794,4 @@ def main(detector=None, io.log_info ('-------------------------') io.log_info ('Images found: %d' % (images_found) ) io.log_info ('Faces detected: %d' % (faces_detected) ) - io.log_info ('-------------------------') + io.log_info ('-------------------------') \ No newline at end of file diff --git a/requirements-colab.txt b/requirements-colab.txt index 128a518..3104ef4 100644 --- a/requirements-colab.txt +++ b/requirements-colab.txt @@ -6,4 +6,6 @@ ffmpeg-python==0.1.17 scikit-image==0.14.2 scipy==1.4.1 colorama -tensorflow-gpu==1.13.2 \ No newline at end of file +tensorflow-gpu==1.13.2 +pytorch +torchvision diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 8df3478..712d698 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,4 +8,6 @@ scipy==1.4.1 colorama labelme==4.2.9 tensorflow-gpu==1.13.2 +pytorch +torchvision pyqt5 \ No newline at end of file