From 66bb72f1642305fdbd2223b550ef64b0f9ff714a Mon Sep 17 00:00:00 2001 From: iperov Date: Wed, 12 May 2021 13:28:48 +0400 Subject: [PATCH] XSeg model has been changed to work better with large amount of various faces, thus you should retrain existing xseg model. Windows build: Added Generic XSeg model pretrained on various faces. It is most suitable for src faceset because it contains clean faces, but also can be applied on dst footage without complex face obstructions. --- core/leras/models/XSeg.py | 55 +++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/core/leras/models/XSeg.py b/core/leras/models/XSeg.py index 0ba19a6..e6bde65 100644 --- a/core/leras/models/XSeg.py +++ b/core/leras/models/XSeg.py @@ -28,11 +28,12 @@ class XSeg(nn.ModelBase): x = self.frn(x) x = self.tlu(x) return x + + self.base_ch = base_ch self.conv01 = ConvBlock(in_ch, base_ch) self.conv02 = ConvBlock(base_ch, base_ch) - self.bp0 = nn.BlurPool (filt_size=3) - + self.bp0 = nn.BlurPool (filt_size=4) self.conv11 = ConvBlock(base_ch, base_ch*2) self.conv12 = ConvBlock(base_ch*2, base_ch*2) @@ -40,19 +41,30 @@ class XSeg(nn.ModelBase): self.conv21 = ConvBlock(base_ch*2, base_ch*4) self.conv22 = ConvBlock(base_ch*4, base_ch*4) - self.conv23 = ConvBlock(base_ch*4, base_ch*4) - self.bp2 = nn.BlurPool (filt_size=3) - + self.bp2 = nn.BlurPool (filt_size=2) self.conv31 = ConvBlock(base_ch*4, base_ch*8) self.conv32 = ConvBlock(base_ch*8, base_ch*8) self.conv33 = ConvBlock(base_ch*8, base_ch*8) - self.bp3 = nn.BlurPool (filt_size=3) + self.bp3 = nn.BlurPool (filt_size=2) self.conv41 = ConvBlock(base_ch*8, base_ch*8) self.conv42 = ConvBlock(base_ch*8, base_ch*8) self.conv43 = ConvBlock(base_ch*8, base_ch*8) - self.bp4 = nn.BlurPool (filt_size=3) + self.bp4 = nn.BlurPool (filt_size=2) + + self.conv51 = ConvBlock(base_ch*8, base_ch*8) + self.conv52 = ConvBlock(base_ch*8, base_ch*8) + self.conv53 = ConvBlock(base_ch*8, base_ch*8) + self.bp5 = nn.BlurPool (filt_size=2) + + self.dense1 = nn.Dense ( 4*4* base_ch*8, 512) + self.dense2 = nn.Dense ( 512, 4*4* base_ch*8) + + self.up5 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv53 = ConvBlock(base_ch*12, base_ch*8) + self.uconv52 = ConvBlock(base_ch*8, base_ch*8) + self.uconv51 = ConvBlock(base_ch*8, base_ch*8) self.up4 = UpConvBlock (base_ch*8, base_ch*4) self.uconv43 = ConvBlock(base_ch*12, base_ch*8) @@ -65,8 +77,7 @@ class XSeg(nn.ModelBase): self.uconv31 = ConvBlock(base_ch*8, base_ch*8) self.up2 = UpConvBlock (base_ch*8, base_ch*4) - self.uconv23 = ConvBlock(base_ch*8, base_ch*4) - self.uconv22 = ConvBlock(base_ch*4, base_ch*4) + self.uconv22 = ConvBlock(base_ch*8, base_ch*4) self.uconv21 = ConvBlock(base_ch*4, base_ch*4) self.up1 = UpConvBlock (base_ch*4, base_ch*2) @@ -78,8 +89,7 @@ class XSeg(nn.ModelBase): self.uconv01 = ConvBlock(base_ch, base_ch) self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME') - self.conv_center = ConvBlock(base_ch*8, base_ch*8) - + def forward(self, inp): x = inp @@ -92,8 +102,7 @@ class XSeg(nn.ModelBase): x = self.bp1(x) x = self.conv21(x) - x = self.conv22(x) - x = x2 = self.conv23(x) + x = x2 = self.conv22(x) x = self.bp2(x) x = self.conv31(x) @@ -106,8 +115,21 @@ class XSeg(nn.ModelBase): x = x4 = self.conv43(x) x = self.bp4(x) - x = self.conv_center(x) - + x = self.conv51(x) + x = self.conv52(x) + x = x5 = self.conv53(x) + x = self.bp5(x) + + x = nn.flatten(x) + x = self.dense1(x) + x = self.dense2(x) + x = nn.reshape_4D (x, 4, 4, self.base_ch*8 ) + + x = self.up5(x) + x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis)) + x = self.uconv52(x) + x = self.uconv51(x) + x = self.up4(x) x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis)) x = self.uconv42(x) @@ -119,8 +141,7 @@ class XSeg(nn.ModelBase): x = self.uconv31(x) x = self.up2(x) - x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis)) - x = self.uconv22(x) + x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis)) x = self.uconv21(x) x = self.up1(x)