mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-08-20 21:43:22 -07:00
upd modelhub.torch
This commit is contained in:
parent
6e84cbe8b6
commit
d02f46dfc7
2 changed files with 1 additions and 122 deletions
|
@ -1,121 +0,0 @@
|
|||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from xlib.file import SplittedFile
|
||||
from xlib.torch import TorchDeviceInfo, get_cpu_device_info
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, ch):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
x = F.leaky_relu(self.conv1(x), 0.2)
|
||||
x = F.leaky_relu(self.conv2(x) + inp, 0.2)
|
||||
return x
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, resolution, in_ch, ae_ch):
|
||||
super().__init__()
|
||||
self._resolution = resolution
|
||||
self._in_ch = in_ch
|
||||
self._ae_ch = ae_ch
|
||||
|
||||
self.conv1 = nn.Conv2d(in_ch*1, in_ch*2, kernel_size=5, stride=2, padding=2)
|
||||
self.conv2 = nn.Conv2d(in_ch*2, in_ch*4, kernel_size=5, stride=2, padding=2)
|
||||
self.conv3 = nn.Conv2d(in_ch*4, in_ch*8, kernel_size=5, stride=2, padding=2)
|
||||
self.conv4 = nn.Conv2d(in_ch*8, in_ch*8, kernel_size=5, stride=2, padding=2)
|
||||
|
||||
self.dense_in = nn.Linear( in_ch*8 * ( resolution // (2**4) )**2, ae_ch)
|
||||
self.dense_out = nn.Linear( ae_ch, in_ch*8 * ( resolution // (2**4) )**2 )
|
||||
|
||||
self.up_conv4 = nn.ConvTranspose2d(in_ch*8, in_ch*8, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||
self.up_conv3 = nn.ConvTranspose2d(in_ch*8, in_ch*4, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||
self.up_conv2 = nn.ConvTranspose2d(in_ch*4, in_ch*2, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||
self.up_conv1 = nn.ConvTranspose2d(in_ch*2, in_ch*1, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
||||
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
x = F.leaky_relu(self.conv1(x), 0.1)
|
||||
x = F.leaky_relu(self.conv2(x), 0.1)
|
||||
x = F.leaky_relu(self.conv3(x), 0.1)
|
||||
x = F.leaky_relu(self.conv4(x), 0.1)
|
||||
|
||||
x = x.view( (x.shape[0], np.prod(x.shape[1:])) )
|
||||
|
||||
x = self.dense_in(x)
|
||||
x = self.dense_out(x)
|
||||
|
||||
x = x.view( (x.shape[0], self._in_ch*8, self._resolution // (2**4), self._resolution // (2**4) ))
|
||||
x = F.leaky_relu(self.up_conv4(x), 0.1)
|
||||
x = F.leaky_relu(self.up_conv3(x), 0.1)
|
||||
x = F.leaky_relu(self.up_conv2(x), 0.1)
|
||||
x = F.leaky_relu(self.up_conv1(x), 0.1)
|
||||
|
||||
# from xlib.console import diacon
|
||||
# diacon.Diacon.stop()
|
||||
# import code
|
||||
# code.interact(local=dict(globals(), **locals()))
|
||||
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CTSOTNet(nn.Module):
|
||||
def __init__(self, resolution, in_ch=6, inner_ch=64, ae_ch=256, out_ch=3, res_block_count = 12):
|
||||
super().__init__()
|
||||
self.in_conv = nn.Conv2d(in_ch, inner_ch, kernel_size=1, stride=1, padding=0)
|
||||
self.ae = AutoEncoder(resolution, inner_ch, ae_ch)
|
||||
self.out_conv = nn.Conv2d(inner_ch, out_ch, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, img1_t, img2_t):
|
||||
x = torch.cat([img1_t, img2_t], dim=1)
|
||||
|
||||
x = self.in_conv(x)
|
||||
|
||||
x = self.ae(x)
|
||||
x = self.out_conv(x)
|
||||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class CTSOT:
|
||||
def __init__(self, device_info : TorchDeviceInfo = None,
|
||||
state_dict : Union[dict, None] = None,
|
||||
training : bool = False):
|
||||
if device_info is None:
|
||||
device_info = get_cpu_device_info()
|
||||
self.device_info = device_info
|
||||
|
||||
self._net = net = CTSOTNet()
|
||||
|
||||
if state_dict is not None:
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if training:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
self.set_device(device_info)
|
||||
|
||||
def set_device(self, device_info : TorchDeviceInfo = None):
|
||||
if device_info is None or device_info.is_cpu():
|
||||
self._net.cpu()
|
||||
else:
|
||||
self._net.cuda(device_info.get_index())
|
||||
|
||||
def get_state_dict(self):
|
||||
return self.net.state_dict()
|
||||
|
||||
def get_net(self) -> CTSOTNet:
|
||||
return self._net
|
|
@ -1,3 +1,3 @@
|
|||
from .CenterFace.CenterFace import CenterFace_to_onnx
|
||||
from .S3FD.S3FD import S3FD
|
||||
from .CTSOT.CTSOT import CTSOT, CTSOTNet
|
||||
from .FaceAligner.FaceAligner import FaceAlignerNet
|
Loading…
Add table
Add a link
Reference in a new issue