mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-15 01:23:45 -07:00
121 lines
No EOL
4.2 KiB
Python
121 lines
No EOL
4.2 KiB
Python
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from xlib.file import SplittedFile
|
|
from xlib.torch import TorchDeviceInfo, get_cpu_device_info
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
def __init__(self, ch):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
|
|
self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
|
|
|
|
def forward(self, inp):
|
|
x = inp
|
|
x = F.leaky_relu(self.conv1(x), 0.2)
|
|
x = F.leaky_relu(self.conv2(x) + inp, 0.2)
|
|
return x
|
|
|
|
class AutoEncoder(nn.Module):
|
|
def __init__(self, resolution, in_ch, ae_ch):
|
|
super().__init__()
|
|
self._resolution = resolution
|
|
self._in_ch = in_ch
|
|
self._ae_ch = ae_ch
|
|
|
|
self.conv1 = nn.Conv2d(in_ch*1, in_ch*2, kernel_size=5, stride=2, padding=2)
|
|
self.conv2 = nn.Conv2d(in_ch*2, in_ch*4, kernel_size=5, stride=2, padding=2)
|
|
self.conv3 = nn.Conv2d(in_ch*4, in_ch*8, kernel_size=5, stride=2, padding=2)
|
|
self.conv4 = nn.Conv2d(in_ch*8, in_ch*8, kernel_size=5, stride=2, padding=2)
|
|
|
|
self.dense_in = nn.Linear( in_ch*8 * ( resolution // (2**4) )**2, ae_ch)
|
|
self.dense_out = nn.Linear( ae_ch, in_ch*8 * ( resolution // (2**4) )**2 )
|
|
|
|
self.up_conv4 = nn.ConvTranspose2d(in_ch*8, in_ch*8, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
|
self.up_conv3 = nn.ConvTranspose2d(in_ch*8, in_ch*4, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
|
self.up_conv2 = nn.ConvTranspose2d(in_ch*4, in_ch*2, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
|
self.up_conv1 = nn.ConvTranspose2d(in_ch*2, in_ch*1, kernel_size=3, stride=2, padding=1, output_padding=(1,1), )
|
|
|
|
|
|
def forward(self, inp):
|
|
x = inp
|
|
x = F.leaky_relu(self.conv1(x), 0.1)
|
|
x = F.leaky_relu(self.conv2(x), 0.1)
|
|
x = F.leaky_relu(self.conv3(x), 0.1)
|
|
x = F.leaky_relu(self.conv4(x), 0.1)
|
|
|
|
x = x.view( (x.shape[0], np.prod(x.shape[1:])) )
|
|
|
|
x = self.dense_in(x)
|
|
x = self.dense_out(x)
|
|
|
|
x = x.view( (x.shape[0], self._in_ch*8, self._resolution // (2**4), self._resolution // (2**4) ))
|
|
x = F.leaky_relu(self.up_conv4(x), 0.1)
|
|
x = F.leaky_relu(self.up_conv3(x), 0.1)
|
|
x = F.leaky_relu(self.up_conv2(x), 0.1)
|
|
x = F.leaky_relu(self.up_conv1(x), 0.1)
|
|
|
|
# from xlib.console import diacon
|
|
# diacon.Diacon.stop()
|
|
# import code
|
|
# code.interact(local=dict(globals(), **locals()))
|
|
|
|
|
|
return x
|
|
|
|
|
|
class CTSOTNet(nn.Module):
|
|
def __init__(self, resolution, in_ch=6, inner_ch=64, ae_ch=256, out_ch=3, res_block_count = 12):
|
|
super().__init__()
|
|
self.in_conv = nn.Conv2d(in_ch, inner_ch, kernel_size=1, stride=1, padding=0)
|
|
self.ae = AutoEncoder(resolution, inner_ch, ae_ch)
|
|
self.out_conv = nn.Conv2d(inner_ch, out_ch, kernel_size=1, stride=1, padding=0)
|
|
|
|
def forward(self, img1_t, img2_t):
|
|
x = torch.cat([img1_t, img2_t], dim=1)
|
|
|
|
x = self.in_conv(x)
|
|
|
|
x = self.ae(x)
|
|
x = self.out_conv(x)
|
|
x = torch.tanh(x)
|
|
return x
|
|
|
|
|
|
|
|
class CTSOT:
|
|
def __init__(self, device_info : TorchDeviceInfo = None,
|
|
state_dict : Union[dict, None] = None,
|
|
training : bool = False):
|
|
if device_info is None:
|
|
device_info = get_cpu_device_info()
|
|
self.device_info = device_info
|
|
|
|
self._net = net = CTSOTNet()
|
|
|
|
if state_dict is not None:
|
|
net.load_state_dict(state_dict)
|
|
|
|
if training:
|
|
net.train()
|
|
else:
|
|
net.eval()
|
|
|
|
self.set_device(device_info)
|
|
|
|
def set_device(self, device_info : TorchDeviceInfo = None):
|
|
if device_info is None or device_info.is_cpu():
|
|
self._net.cpu()
|
|
else:
|
|
self._net.cuda(device_info.get_index())
|
|
|
|
def get_state_dict(self):
|
|
return self.net.state_dict()
|
|
|
|
def get_net(self) -> CTSOTNet:
|
|
return self._net |