diff --git a/core/leras/models/XSeg.py b/core/leras/models/XSeg.py index e6bde65..f59eb8c 100644 --- a/core/leras/models/XSeg.py +++ b/core/leras/models/XSeg.py @@ -88,9 +88,9 @@ class XSeg(nn.ModelBase): self.uconv02 = ConvBlock(base_ch*2, base_ch) self.uconv01 = ConvBlock(base_ch, base_ch) self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME') + - - def forward(self, inp): + def forward(self, inp, pretrain=False): x = inp x = self.conv01(x) @@ -126,29 +126,41 @@ class XSeg(nn.ModelBase): x = nn.reshape_4D (x, 4, 4, self.base_ch*8 ) x = self.up5(x) + if pretrain: + x5 = tf.zeros_like(x5) x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis)) x = self.uconv52(x) x = self.uconv51(x) x = self.up4(x) + if pretrain: + x4 = tf.zeros_like(x4) x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis)) x = self.uconv42(x) x = self.uconv41(x) x = self.up3(x) + if pretrain: + x3 = tf.zeros_like(x3) x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis)) x = self.uconv32(x) x = self.uconv31(x) x = self.up2(x) + if pretrain: + x2 = tf.zeros_like(x2) x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis)) x = self.uconv21(x) x = self.up1(x) + if pretrain: + x1 = tf.zeros_like(x1) x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis)) x = self.uconv11(x) x = self.up0(x) + if pretrain: + x0 = tf.zeros_like(x0) x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis)) x = self.uconv01(x) diff --git a/facelib/XSegNet.py b/facelib/XSegNet.py index 5621a65..ff2bd08 100644 --- a/facelib/XSegNet.py +++ b/facelib/XSegNet.py @@ -81,8 +81,8 @@ class XSegNet(object): def get_resolution(self): return self.resolution - def flow(self, x): - return self.model(x) + def flow(self, x, pretrain=False): + return self.model(x, pretrain=pretrain) def get_weights(self): return self.model_weights