diff --git a/core/leras/models/Ternaus.py b/core/leras/models/Ternaus.py new file mode 100644 index 0000000..ad5ffc3 --- /dev/null +++ b/core/leras/models/Ternaus.py @@ -0,0 +1,92 @@ +""" +using https://github.com/ternaus/TernausNet +TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation +""" + +from core.leras import nn +tf = nn.tf + +class Ternaus(nn.ModelBase): + def on_build(self, in_ch, base_ch): + + self.features_0 = nn.Conv2D (in_ch, base_ch, kernel_size=3, padding='SAME') + self.features_3 = nn.Conv2D (base_ch, base_ch*2, kernel_size=3, padding='SAME') + self.features_6 = nn.Conv2D (base_ch*2, base_ch*4, kernel_size=3, padding='SAME') + self.features_8 = nn.Conv2D (base_ch*4, base_ch*4, kernel_size=3, padding='SAME') + self.features_11 = nn.Conv2D (base_ch*4, base_ch*8, kernel_size=3, padding='SAME') + self.features_13 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') + self.features_16 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') + self.features_18 = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') + + self.blurpool_0 = nn.BlurPool (filt_size=3) + self.blurpool_3 = nn.BlurPool (filt_size=3) + self.blurpool_8 = nn.BlurPool (filt_size=3) + self.blurpool_13 = nn.BlurPool (filt_size=3) + self.blurpool_18 = nn.BlurPool (filt_size=3) + + self.conv_center = nn.Conv2D (base_ch*8, base_ch*8, kernel_size=3, padding='SAME') + + self.conv1_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME') + self.conv1 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME') + + self.conv2_up = nn.Conv2DTranspose (base_ch*8, base_ch*4, kernel_size=3, padding='SAME') + self.conv2 = nn.Conv2D (base_ch*12, base_ch*8, kernel_size=3, padding='SAME') + + self.conv3_up = nn.Conv2DTranspose (base_ch*8, base_ch*2, kernel_size=3, padding='SAME') + self.conv3 = nn.Conv2D (base_ch*6, base_ch*4, kernel_size=3, padding='SAME') + + self.conv4_up = nn.Conv2DTranspose (base_ch*4, base_ch, kernel_size=3, padding='SAME') + self.conv4 = nn.Conv2D (base_ch*3, base_ch*2, kernel_size=3, padding='SAME') + + self.conv5_up = nn.Conv2DTranspose (base_ch*2, base_ch//2, kernel_size=3, padding='SAME') + self.conv5 = nn.Conv2D (base_ch//2+base_ch, base_ch, kernel_size=3, padding='SAME') + + self.out_conv = nn.Conv2D (base_ch, 1, kernel_size=3, padding='SAME') + + def forward(self, inp): + x, = inp + + x = x0 = tf.nn.relu(self.features_0(x)) + x = self.blurpool_0(x) + + x = x1 = tf.nn.relu(self.features_3(x)) + x = self.blurpool_3(x) + + x = tf.nn.relu(self.features_6(x)) + x = x2 = tf.nn.relu(self.features_8(x)) + x = self.blurpool_8(x) + + x = tf.nn.relu(self.features_11(x)) + x = x3 = tf.nn.relu(self.features_13(x)) + x = self.blurpool_13(x) + + x = tf.nn.relu(self.features_16(x)) + x = x4 = tf.nn.relu(self.features_18(x)) + x = self.blurpool_18(x) + + x = self.conv_center(x) + + x = tf.nn.relu(self.conv1_up(x)) + x = tf.concat( [x,x4], nn.conv2d_ch_axis) + x = tf.nn.relu(self.conv1(x)) + + x = tf.nn.relu(self.conv2_up(x)) + x = tf.concat( [x,x3], nn.conv2d_ch_axis) + x = tf.nn.relu(self.conv2(x)) + + x = tf.nn.relu(self.conv3_up(x)) + x = tf.concat( [x,x2], nn.conv2d_ch_axis) + x = tf.nn.relu(self.conv3(x)) + + x = tf.nn.relu(self.conv4_up(x)) + x = tf.concat( [x,x1], nn.conv2d_ch_axis) + x = tf.nn.relu(self.conv4(x)) + + x = tf.nn.relu(self.conv5_up(x)) + x = tf.concat( [x,x0], nn.conv2d_ch_axis) + x = tf.nn.relu(self.conv5(x)) + + logits = self.out_conv(x) + return logits, tf.nn.sigmoid(logits) + +nn.Ternaus = Ternaus \ No newline at end of file