XSeg model has been changed to work better with large amount of various faces, thus you should retrain existing xseg model.

Windows build: Added Generic XSeg model pretrained on various faces. It is most suitable for src faceset because it contains clean faces, but also can be applied on dst footage without complex face obstructions.
This commit is contained in:
iperov 2021-05-12 13:28:48 +04:00
parent 65a703c024
commit 66bb72f164

View file

@ -28,11 +28,12 @@ class XSeg(nn.ModelBase):
x = self.frn(x)
x = self.tlu(x)
return x
self.base_ch = base_ch
self.conv01 = ConvBlock(in_ch, base_ch)
self.conv02 = ConvBlock(base_ch, base_ch)
self.bp0 = nn.BlurPool (filt_size=3)
self.bp0 = nn.BlurPool (filt_size=4)
self.conv11 = ConvBlock(base_ch, base_ch*2)
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
@ -40,19 +41,30 @@ class XSeg(nn.ModelBase):
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
self.conv23 = ConvBlock(base_ch*4, base_ch*4)
self.bp2 = nn.BlurPool (filt_size=3)
self.bp2 = nn.BlurPool (filt_size=2)
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
self.bp3 = nn.BlurPool (filt_size=3)
self.bp3 = nn.BlurPool (filt_size=2)
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
self.bp4 = nn.BlurPool (filt_size=3)
self.bp4 = nn.BlurPool (filt_size=2)
self.conv51 = ConvBlock(base_ch*8, base_ch*8)
self.conv52 = ConvBlock(base_ch*8, base_ch*8)
self.conv53 = ConvBlock(base_ch*8, base_ch*8)
self.bp5 = nn.BlurPool (filt_size=2)
self.dense1 = nn.Dense ( 4*4* base_ch*8, 512)
self.dense2 = nn.Dense ( 512, 4*4* base_ch*8)
self.up5 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
@ -65,8 +77,7 @@ class XSeg(nn.ModelBase):
self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv23 = ConvBlock(base_ch*8, base_ch*4)
self.uconv22 = ConvBlock(base_ch*4, base_ch*4)
self.uconv22 = ConvBlock(base_ch*8, base_ch*4)
self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
@ -78,8 +89,7 @@ class XSeg(nn.ModelBase):
self.uconv01 = ConvBlock(base_ch, base_ch)
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
self.conv_center = ConvBlock(base_ch*8, base_ch*8)
def forward(self, inp):
x = inp
@ -92,8 +102,7 @@ class XSeg(nn.ModelBase):
x = self.bp1(x)
x = self.conv21(x)
x = self.conv22(x)
x = x2 = self.conv23(x)
x = x2 = self.conv22(x)
x = self.bp2(x)
x = self.conv31(x)
@ -106,8 +115,21 @@ class XSeg(nn.ModelBase):
x = x4 = self.conv43(x)
x = self.bp4(x)
x = self.conv_center(x)
x = self.conv51(x)
x = self.conv52(x)
x = x5 = self.conv53(x)
x = self.bp5(x)
x = nn.flatten(x)
x = self.dense1(x)
x = self.dense2(x)
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
x = self.up5(x)
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
x = self.uconv52(x)
x = self.uconv51(x)
x = self.up4(x)
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
x = self.uconv42(x)
@ -119,8 +141,7 @@ class XSeg(nn.ModelBase):
x = self.uconv31(x)
x = self.up2(x)
x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
x = self.uconv22(x)
x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
x = self.uconv21(x)
x = self.up1(x)