diff --git a/xlib/torch/model/MobileNet.py b/xlib/torch/model/MobileNet.py new file mode 100644 index 0000000..fa674ae --- /dev/null +++ b/xlib/torch/model/MobileNet.py @@ -0,0 +1,139 @@ +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F + +def _make_divisible(v: float, divisor: int, min_value = None) -> int: + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + +class SqueezeExcitation(nn.Module): + def __init__( self, in_ch: int, squeeze_channels: int, activation = nn.ReLU, scale_activation = nn.Sigmoid): + super().__init__() + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d(in_ch, squeeze_channels, 1) + self.fc2 = nn.Conv2d(squeeze_channels, in_ch, 1) + self.activation = activation() + self.scale_activation = scale_activation() + + def forward(self, input): + scale = self.avgpool(input) + scale = self.fc1(scale) + scale = self.activation(scale) + scale = self.fc2(scale) + scale = self.scale_activation(scale) + return scale * input + +class ConvNormActivation(nn.Sequential): + def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, stride: int = 1, padding = None, groups: int = 1, norm_layer = nn.BatchNorm2d, activation_layer = nn.ReLU,) -> None: + if padding is None: + padding = (kernel_size - 1) // 2 + layers = [torch.nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, groups=groups, bias=norm_layer is None)] + if norm_layer is not None: + layers.append(norm_layer(out_ch)) + if activation_layer is not None: + layers.append(activation_layer()) + super().__init__(*layers) + self.out_ch = out_ch + + +class InvertedResidual(nn.Module): + def __init__(self, in_ch: int, mid_ch: int, out_ch: int, kernel: int, stride: int, use_se: bool, + hs_act : bool, width_mult: float = 1.0, + norm_layer = None,): + super().__init__() + + mid_ch = _make_divisible(mid_ch * width_mult, 8) + out_ch = _make_divisible(out_ch * width_mult, 8) + self._is_res_connect = stride == 1 and in_ch == out_ch + activation_layer = nn.Hardswish if hs_act else nn.ReLU + + layers = [] + + if mid_ch != in_ch: + layers.append(ConvNormActivation(in_ch, mid_ch, kernel_size=1, norm_layer=norm_layer, activation_layer=activation_layer)) + + layers.append(ConvNormActivation(mid_ch, mid_ch, kernel_size=kernel, stride=stride, groups=mid_ch, norm_layer=norm_layer, activation_layer=activation_layer)) + + if use_se: + layers.append( SqueezeExcitation(mid_ch, _make_divisible(mid_ch // 4, 8), scale_activation=nn.Hardsigmoid) ) + + layers.append(ConvNormActivation(mid_ch, out_ch, kernel_size=1, norm_layer=norm_layer, activation_layer=None)) + + self.block = nn.Sequential(*layers) + self.out_ch = out_ch + + def forward(self, input): + result = self.block(input) + if self._is_res_connect: + result = result + input + return result + +class MobileNet(nn.Module): + def __init__(self, in_ch, out_ch, width_mult=1.0): + super().__init__() + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + + self.c0 = c0 = ConvNormActivation(in_ch, _make_divisible(16 * width_mult, 8), kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.Hardswish) + self.c1 = c1 = InvertedResidual ( c0.out_ch, 16, 16, 3, 1, use_se=False, hs_act=False, norm_layer=norm_layer, width_mult=width_mult) + self.c2 = c2 = InvertedResidual ( c1.out_ch, 64, 24, 3, 2, use_se=False, hs_act=False, norm_layer=norm_layer, width_mult=width_mult) + self.c3 = c3 = InvertedResidual ( c2.out_ch, 72, 24, 3, 1, use_se=False, hs_act=False, norm_layer=norm_layer, width_mult=width_mult) + self.c4 = c4 = InvertedResidual ( c3.out_ch, 72, 40, 5, 2, use_se=True, hs_act=False, norm_layer=norm_layer, width_mult=width_mult) + self.c5 = c5 = InvertedResidual ( c4.out_ch, 120, 40, 5, 1, use_se=True, hs_act=False, norm_layer=norm_layer, width_mult=width_mult) + self.c6 = c6 = InvertedResidual ( c5.out_ch, 120, 40, 5, 1, use_se=True, hs_act=False, norm_layer=norm_layer, width_mult=width_mult) + self.c7 = c7 = InvertedResidual ( c6.out_ch, 240, 80, 3, 2, use_se=False, hs_act=True, norm_layer=norm_layer, width_mult=width_mult) + self.c8 = c8 = InvertedResidual ( c7.out_ch, 200, 80, 3, 1, use_se=False, hs_act=True, norm_layer=norm_layer, width_mult=width_mult) + self.c9 = c9 = InvertedResidual ( c8.out_ch, 184, 80, 3, 1, use_se=False, hs_act=True, norm_layer=norm_layer, width_mult=width_mult) + self.c10 = c10 = InvertedResidual ( c9.out_ch, 184, 80, 3, 1, use_se=False, hs_act=True, norm_layer=norm_layer, width_mult=width_mult) + self.c11 = c11 = InvertedResidual ( c10.out_ch, 480, 112, 3, 1, use_se=True, hs_act=True, norm_layer=norm_layer, width_mult=width_mult) + self.c12 = c12 = InvertedResidual ( c11.out_ch, 672, 112, 3, 1, use_se=True, hs_act=True, norm_layer=norm_layer, width_mult=width_mult) + self.c13 = c13 = InvertedResidual ( c12.out_ch, 672, 160, 5, 2, use_se=True, hs_act=True, norm_layer=norm_layer, width_mult=width_mult) + self.c14 = c14 = InvertedResidual ( c13.out_ch, 960, 160, 5, 1, use_se=True, hs_act=True, norm_layer=norm_layer, width_mult=width_mult) + self.c15 = c15 = InvertedResidual ( c14.out_ch, 960, 160, 5, 1, use_se=True, hs_act=True, norm_layer=norm_layer, width_mult=width_mult) + self.c16 = c16 = ConvNormActivation(c15.out_ch, _make_divisible(6*160*width_mult, 8), kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Hardswish) + + self.fc1 = nn.Linear(c16.out_ch, _make_divisible(c16.out_ch*1.33, 8) ) + self.fc1_act = nn.Hardswish() + self.fc2 = nn.Linear(self.fc1.out_features, out_ch) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def forward(self, inp): + x = inp + + x = self.c0(x) + x = self.c1(x) + x = self.c2(x) + x = self.c3(x) + x = self.c4(x) + x = self.c5(x) + x = self.c6(x) + x = self.c7(x) + x = self.c8(x) + x = self.c9(x) + x = self.c10(x) + x = self.c11(x) + x = self.c12(x) + x = self.c13(x) + x = self.c14(x) + x = self.c15(x) + x = self.c16(x) + + x = self.fc1(x.mean((-2,-1))) + x = self.fc1_act(x) + x = self.fc2(x) + return x diff --git a/xlib/torch/model/XsegNet.py b/xlib/torch/model/XsegNet.py new file mode 100644 index 0000000..d2ddc43 --- /dev/null +++ b/xlib/torch/model/XsegNet.py @@ -0,0 +1,225 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FRNorm2D(nn.Module): + def __init__(self, in_ch): + super().__init__() + self.in_ch = in_ch + self.weight = nn.parameter.Parameter( torch.Tensor(1, in_ch, 1, 1), requires_grad=True) + self.bias = nn.parameter.Parameter( torch.Tensor(1, in_ch, 1, 1), requires_grad=True) + self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True) + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + nn.init.constant_(self.eps, 1e-6) + + def forward(self, x): + nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True) + x = x * torch.rsqrt(nu2 + self.eps.abs()) + return self.weight * x + self.bias + +class TLU(nn.Module): + def __init__(self, in_ch): + super(TLU, self).__init__() + self.in_ch = in_ch + self.tau = nn.parameter.Parameter(torch.Tensor(1, in_ch, 1, 1), requires_grad=True) + nn.init.zeros_(self.tau) + + def forward(self, x): + return torch.max(x, self.tau) + +class BlurPool(nn.Module): + def __init__(self, filt_size=3, stride=2, pad_off=0): + super().__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] + self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride-1)/2.) + + if(self.filt_size==2): + a = np.array([1., 1.]) + elif(self.filt_size==3): + a = np.array([1., 2., 1.]) + elif(self.filt_size==4): + a = np.array([1., 3., 3., 1.]) + elif(self.filt_size==5): + a = np.array([1., 4., 6., 4., 1.]) + elif(self.filt_size==6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif(self.filt_size==7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + filt = torch.Tensor(a[:,None]*a[None,:]) + filt = filt/torch.sum(filt) + self.register_buffer('filt', filt[None,None,:,:]) + + self.pad = nn.ZeroPad2d(self.pad_sizes) + + def forward(self, inp): + filt = self.filt.repeat((inp.shape[1],1,1,1)) + return F.conv2d(self.pad(inp), filt, stride=self.stride, groups=inp.shape[1]) + + +class ConvBlock(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv = nn.Conv2d (in_ch, out_ch, kernel_size=3, padding=1) + self.frn = FRNorm2D(out_ch) + self.tlu = TLU(out_ch) + + def forward(self, x): + x = self.conv(x) + x = self.frn(x) + x = self.tlu(x) + return x + +class UpConvBlock(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv = nn.ConvTranspose2d (in_ch, out_ch, kernel_size=3, stride=2, padding=1,output_padding=1) + self.frn = FRNorm2D(out_ch) + self.tlu = TLU(out_ch) + + def forward(self, x): + x = self.conv(x) + x = self.frn(x) + x = self.tlu(x) + return x + +class XSegNet(nn.Module): + def __init__(self, in_ch, out_ch, base_ch=32): + """ + + """ + super().__init__() + self.base_ch = base_ch + + self.conv01 = ConvBlock(in_ch, base_ch) + self.conv02 = ConvBlock(base_ch, base_ch) + self.bp0 = BlurPool (filt_size=4) + + self.conv11 = ConvBlock(base_ch, base_ch*2) + self.conv12 = ConvBlock(base_ch*2, base_ch*2) + self.bp1 = BlurPool (filt_size=3) + + self.conv21 = ConvBlock(base_ch*2, base_ch*4) + self.conv22 = ConvBlock(base_ch*4, base_ch*4) + self.bp2 = BlurPool (filt_size=2) + + self.conv31 = ConvBlock(base_ch*4, base_ch*8) + self.conv32 = ConvBlock(base_ch*8, base_ch*8) + self.conv33 = ConvBlock(base_ch*8, base_ch*8) + self.bp3 = BlurPool (filt_size=2) + + self.conv41 = ConvBlock(base_ch*8, base_ch*8) + self.conv42 = ConvBlock(base_ch*8, base_ch*8) + self.conv43 = ConvBlock(base_ch*8, base_ch*8) + self.bp4 = BlurPool (filt_size=2) + + self.conv51 = ConvBlock(base_ch*8, base_ch*8) + self.conv52 = ConvBlock(base_ch*8, base_ch*8) + self.conv53 = ConvBlock(base_ch*8, base_ch*8) + self.bp5 = BlurPool (filt_size=2) + + self.dense1 = nn.Linear ( 4*4* base_ch*8, 512) + self.dense2 = nn.Linear ( 512, 4*4* base_ch*8) + + self.up5 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv53 = ConvBlock(base_ch*12, base_ch*8) + self.uconv52 = ConvBlock(base_ch*8, base_ch*8) + self.uconv51 = ConvBlock(base_ch*8, base_ch*8) + + self.up4 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv43 = ConvBlock(base_ch*12, base_ch*8) + self.uconv42 = ConvBlock(base_ch*8, base_ch*8) + self.uconv41 = ConvBlock(base_ch*8, base_ch*8) + + self.up3 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv33 = ConvBlock(base_ch*12, base_ch*8) + self.uconv32 = ConvBlock(base_ch*8, base_ch*8) + self.uconv31 = ConvBlock(base_ch*8, base_ch*8) + + self.up2 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv22 = ConvBlock(base_ch*8, base_ch*4) + self.uconv21 = ConvBlock(base_ch*4, base_ch*4) + + self.up1 = UpConvBlock (base_ch*4, base_ch*2) + self.uconv12 = ConvBlock(base_ch*4, base_ch*2) + self.uconv11 = ConvBlock(base_ch*2, base_ch*2) + + self.up0 = UpConvBlock (base_ch*2, base_ch) + self.uconv02 = ConvBlock(base_ch*2, base_ch) + self.uconv01 = ConvBlock(base_ch, base_ch) + + self.out_conv = nn.Conv2d (base_ch, out_ch, kernel_size=7, padding=3) + + def forward(self, inp): + x = inp + + x = self.conv01(x) + x = x0 = self.conv02(x) + x = self.bp0(x) + + x = self.conv11(x) + x = x1 = self.conv12(x) + x = self.bp1(x) + + x = self.conv21(x) + x = x2 = self.conv22(x) + x = self.bp2(x) + + x = self.conv31(x) + x = self.conv32(x) + x = x3 = self.conv33(x) + x = self.bp3(x) + + x = self.conv41(x) + x = self.conv42(x) + x = x4 = self.conv43(x) + x = self.bp4(x) + + x = self.conv51(x) + x = self.conv52(x) + x = x5 = self.conv53(x) + x = self.bp5(x) + + x = x.view(x.shape[0], -1) + x = self.dense1(x) + x = self.dense2(x) + x = x.view (-1, self.base_ch*8, 4, 4) + + x = self.up5(x) + + x = self.uconv53(torch.cat([x,x5],axis=1)) + x = self.uconv52(x) + x = self.uconv51(x) + + x = self.up4(x) + x = self.uconv43(torch.cat([x,x4],axis=1)) + x = self.uconv42(x) + x = self.uconv41(x) + + x = self.up3(x) + x = self.uconv33(torch.cat([x,x3],axis=1)) + x = self.uconv32(x) + x = self.uconv31(x) + + x = self.up2(x) + x = self.uconv22(torch.cat([x,x2],axis=1)) + x = self.uconv21(x) + + x = self.up1(x) + x = self.uconv12(torch.cat([x,x1],axis=1)) + x = self.uconv11(x) + + x = self.up0(x) + x = self.uconv02(torch.cat([x,x0],axis=1)) + x = self.uconv01(x) + + x = self.out_conv(x) + + return x + diff --git a/xlib/torch/model/__init__.py b/xlib/torch/model/__init__.py new file mode 100644 index 0000000..4b3baa7 --- /dev/null +++ b/xlib/torch/model/__init__.py @@ -0,0 +1,2 @@ +from .XsegNet import XSegNet +from .MobileNet import MobileNet \ No newline at end of file