mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-11 15:47:02 -07:00
_
This commit is contained in:
parent
69f71ddecd
commit
d40fce6a5a
4 changed files with 99 additions and 75 deletions
|
@ -30,8 +30,9 @@ class TLU(nn.Module):
|
|||
return torch.max(x, self.tau)
|
||||
|
||||
class BlurPool(nn.Module):
|
||||
def __init__(self, filt_size=3, stride=2, pad_off=0):
|
||||
def __init__(self, in_ch, filt_size=3, stride=2, pad_off=0):
|
||||
super().__init__()
|
||||
self.in_ch = in_ch
|
||||
self.filt_size = filt_size
|
||||
self.pad_off = pad_off
|
||||
self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
|
||||
|
@ -54,13 +55,12 @@ class BlurPool(nn.Module):
|
|||
|
||||
filt = torch.Tensor(a[:,None]*a[None,:])
|
||||
filt = filt/torch.sum(filt)
|
||||
self.register_buffer('filt', filt[None,None,:,:])
|
||||
self.register_buffer('filt', filt[None,None,:,:].repeat(in_ch,1,1,1) )
|
||||
|
||||
self.pad = nn.ZeroPad2d(self.pad_sizes)
|
||||
|
||||
def forward(self, inp):
|
||||
filt = self.filt.repeat((inp.shape[1],1,1,1))
|
||||
return F.conv2d(self.pad(inp), filt, stride=self.stride, groups=inp.shape[1])
|
||||
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=self.in_ch)
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
|
@ -99,30 +99,30 @@ class XSegNet(nn.Module):
|
|||
|
||||
self.conv01 = ConvBlock(in_ch, base_ch)
|
||||
self.conv02 = ConvBlock(base_ch, base_ch)
|
||||
self.bp0 = BlurPool (filt_size=4)
|
||||
self.bp0 = BlurPool (base_ch, filt_size=4)
|
||||
|
||||
self.conv11 = ConvBlock(base_ch, base_ch*2)
|
||||
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
|
||||
self.bp1 = BlurPool (filt_size=3)
|
||||
self.bp1 = BlurPool (base_ch*2, filt_size=3)
|
||||
|
||||
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
|
||||
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.bp2 = BlurPool (filt_size=2)
|
||||
self.bp2 = BlurPool (base_ch*4, filt_size=2)
|
||||
|
||||
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
|
||||
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp3 = BlurPool (filt_size=2)
|
||||
self.bp3 = BlurPool (base_ch*8, filt_size=2)
|
||||
|
||||
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp4 = BlurPool (filt_size=2)
|
||||
self.bp4 = BlurPool (base_ch*8, filt_size=2)
|
||||
|
||||
self.conv51 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv52 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv53 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp5 = BlurPool (filt_size=2)
|
||||
self.bp5 = BlurPool (base_ch*8, filt_size=2)
|
||||
|
||||
self.dense1 = nn.Linear ( 4*4* base_ch*8, 512)
|
||||
self.dense2 = nn.Linear ( 512, 4*4* base_ch*8)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue