Fixed XSeg pretraining

This commit is contained in:
Cioscos 2021-10-02 20:22:08 +02:00
commit 224af63704
2 changed files with 16 additions and 4 deletions

View file

@ -90,7 +90,7 @@ class XSeg(nn.ModelBase):
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
def forward(self, inp):
def forward(self, inp, pretrain=False):
x = inp
x = self.conv01(x)
@ -126,29 +126,41 @@ class XSeg(nn.ModelBase):
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
x = self.up5(x)
if pretrain:
x5 = tf.zeros_like(x5)
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
x = self.uconv52(x)
x = self.uconv51(x)
x = self.up4(x)
if pretrain:
x4 = tf.zeros_like(x4)
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
x = self.uconv42(x)
x = self.uconv41(x)
x = self.up3(x)
if pretrain:
x3 = tf.zeros_like(x3)
x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
x = self.uconv32(x)
x = self.uconv31(x)
x = self.up2(x)
if pretrain:
x2 = tf.zeros_like(x2)
x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
x = self.uconv21(x)
x = self.up1(x)
if pretrain:
x1 = tf.zeros_like(x1)
x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
x = self.uconv11(x)
x = self.up0(x)
if pretrain:
x0 = tf.zeros_like(x0)
x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
x = self.uconv01(x)

View file

@ -81,8 +81,8 @@ class XSegNet(object):
def get_resolution(self):
return self.resolution
def flow(self, x):
return self.model(x)
def flow(self, x, pretrain=False):
return self.model(x, pretrain=pretrain)
def get_weights(self):
return self.model_weights