mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-05 20:42:11 -07:00
XSeg model has been changed to work better with large amount of various faces, thus you should retrain existing xseg model.
Windows build: Added Generic XSeg model pretrained on various faces. It is most suitable for src faceset because it contains clean faces, but also can be applied on dst footage without complex face obstructions.
This commit is contained in:
parent
65a703c024
commit
66bb72f164
1 changed files with 38 additions and 17 deletions
|
@ -28,11 +28,12 @@ class XSeg(nn.ModelBase):
|
|||
x = self.frn(x)
|
||||
x = self.tlu(x)
|
||||
return x
|
||||
|
||||
self.base_ch = base_ch
|
||||
|
||||
self.conv01 = ConvBlock(in_ch, base_ch)
|
||||
self.conv02 = ConvBlock(base_ch, base_ch)
|
||||
self.bp0 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.bp0 = nn.BlurPool (filt_size=4)
|
||||
|
||||
self.conv11 = ConvBlock(base_ch, base_ch*2)
|
||||
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
|
||||
|
@ -40,19 +41,30 @@ class XSeg(nn.ModelBase):
|
|||
|
||||
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
|
||||
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.conv23 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.bp2 = nn.BlurPool (filt_size=3)
|
||||
|
||||
self.bp2 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
|
||||
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp3 = nn.BlurPool (filt_size=3)
|
||||
self.bp3 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp4 = nn.BlurPool (filt_size=3)
|
||||
self.bp4 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.conv51 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv52 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.conv53 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.bp5 = nn.BlurPool (filt_size=2)
|
||||
|
||||
self.dense1 = nn.Dense ( 4*4* base_ch*8, 512)
|
||||
self.dense2 = nn.Dense ( 512, 4*4* base_ch*8)
|
||||
|
||||
self.up5 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
|
||||
self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
|
||||
self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
|
||||
|
@ -65,8 +77,7 @@ class XSeg(nn.ModelBase):
|
|||
self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
|
||||
self.uconv23 = ConvBlock(base_ch*8, base_ch*4)
|
||||
self.uconv22 = ConvBlock(base_ch*4, base_ch*4)
|
||||
self.uconv22 = ConvBlock(base_ch*8, base_ch*4)
|
||||
self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
|
||||
|
||||
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
|
||||
|
@ -78,8 +89,7 @@ class XSeg(nn.ModelBase):
|
|||
self.uconv01 = ConvBlock(base_ch, base_ch)
|
||||
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
|
||||
|
||||
self.conv_center = ConvBlock(base_ch*8, base_ch*8)
|
||||
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
|
||||
|
@ -92,8 +102,7 @@ class XSeg(nn.ModelBase):
|
|||
x = self.bp1(x)
|
||||
|
||||
x = self.conv21(x)
|
||||
x = self.conv22(x)
|
||||
x = x2 = self.conv23(x)
|
||||
x = x2 = self.conv22(x)
|
||||
x = self.bp2(x)
|
||||
|
||||
x = self.conv31(x)
|
||||
|
@ -106,8 +115,21 @@ class XSeg(nn.ModelBase):
|
|||
x = x4 = self.conv43(x)
|
||||
x = self.bp4(x)
|
||||
|
||||
x = self.conv_center(x)
|
||||
|
||||
x = self.conv51(x)
|
||||
x = self.conv52(x)
|
||||
x = x5 = self.conv53(x)
|
||||
x = self.bp5(x)
|
||||
|
||||
x = nn.flatten(x)
|
||||
x = self.dense1(x)
|
||||
x = self.dense2(x)
|
||||
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
|
||||
|
||||
x = self.up5(x)
|
||||
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv52(x)
|
||||
x = self.uconv51(x)
|
||||
|
||||
x = self.up4(x)
|
||||
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv42(x)
|
||||
|
@ -119,8 +141,7 @@ class XSeg(nn.ModelBase):
|
|||
x = self.uconv31(x)
|
||||
|
||||
x = self.up2(x)
|
||||
x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv22(x)
|
||||
x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv21(x)
|
||||
|
||||
x = self.up1(x)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue