mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 13:09:56 -07:00
Fixed XSeg pretraining
This commit is contained in:
parent
a3f9410a71
commit
224af63704
2 changed files with 16 additions and 4 deletions
|
@ -88,9 +88,9 @@ class XSeg(nn.ModelBase):
|
|||
self.uconv02 = ConvBlock(base_ch*2, base_ch)
|
||||
self.uconv01 = ConvBlock(base_ch, base_ch)
|
||||
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
|
||||
|
||||
|
||||
|
||||
def forward(self, inp):
|
||||
def forward(self, inp, pretrain=False):
|
||||
x = inp
|
||||
|
||||
x = self.conv01(x)
|
||||
|
@ -126,29 +126,41 @@ class XSeg(nn.ModelBase):
|
|||
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
|
||||
|
||||
x = self.up5(x)
|
||||
if pretrain:
|
||||
x5 = tf.zeros_like(x5)
|
||||
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv52(x)
|
||||
x = self.uconv51(x)
|
||||
|
||||
x = self.up4(x)
|
||||
if pretrain:
|
||||
x4 = tf.zeros_like(x4)
|
||||
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv42(x)
|
||||
x = self.uconv41(x)
|
||||
|
||||
x = self.up3(x)
|
||||
if pretrain:
|
||||
x3 = tf.zeros_like(x3)
|
||||
x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv32(x)
|
||||
x = self.uconv31(x)
|
||||
|
||||
x = self.up2(x)
|
||||
if pretrain:
|
||||
x2 = tf.zeros_like(x2)
|
||||
x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv21(x)
|
||||
|
||||
x = self.up1(x)
|
||||
if pretrain:
|
||||
x1 = tf.zeros_like(x1)
|
||||
x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv11(x)
|
||||
|
||||
x = self.up0(x)
|
||||
if pretrain:
|
||||
x0 = tf.zeros_like(x0)
|
||||
x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
|
||||
x = self.uconv01(x)
|
||||
|
||||
|
|
|
@ -81,8 +81,8 @@ class XSegNet(object):
|
|||
def get_resolution(self):
|
||||
return self.resolution
|
||||
|
||||
def flow(self, x):
|
||||
return self.model(x)
|
||||
def flow(self, x, pretrain=False):
|
||||
return self.model(x, pretrain=pretrain)
|
||||
|
||||
def get_weights(self):
|
||||
return self.model_weights
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue