FOMO - Embedded Image Segmentation Pytorch Model

Pytorch Implementation of FOMO model form Edge Impulse
pytorch python
Author

Deebul Nair

Published

June 16, 2023

Faster Objects More Objects aka FOMO

Pytorch implementation

FOMO introduced by Edge Impulse is actually rebranded architecture callen bnn which was intially developed by Mat Palm and explained in the blog. The tensorflow code was made available in github.

The architecture diagram of FOMO/BNN is describe sa shown below :

Here I try to convert the above diagram into a pytorch model. Hoep it helps anyone looking to deploy the FOMO model in real world.

Model Description as per Mat Palm

the model the architecture of the network is a very vanilla u-net.

a fully convolutional network trained on half resolution patches but run
 against full resolution images encoding is a sequence of 4 3x3 convolutions
  with stride 2 decoding is a sequence of nearest neighbours resizes + 3x3
  convolution (stride 1) + skip connection from the encoders final layer is a
  1x1 convolution (stride 1) with sigmoid activation (i.e. binary bee / no bee
   choice per pixel) after some emperical experiments i chose to only decode
   back to half the resolution of the input. it was good enough.

i did the decoding using a nearest neighbour resize instead of a deconvolution
pretty much out of habit.
import torch
import torch.nn as nn
#ToDo Questions
# how is the padding working should it be the paper is same but we are doing zero
# upsampling what mode should it be

class FOMO(torch.nn.Module):
    def __init__(self):
        super(FOMO, self).__init__()

        #Reduction
        #3x3 conv stride 2  with 4 out channel
        self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=(1,1))
        #3x3 conv stride 2 with 8 out channel
        self.conv2 = torch.nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, stride=2, padding=(1,1))
        #3x3 conv stride 2 with 16 out channel
        self.conv3 = torch.nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=2, padding=(1,1))
        #3x3 conv stride 2 with 32 out channel
        self.conv4 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=(1,1))
        #3x3 conv stride 1 with 16 out channel
        self.conv5 = torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding='same')

        self.upsample = torch.nn.Upsample(scale_factor=2, mode='bilinear')

        #Increasing
        self.conv6 = torch.nn.Conv2d(in_channels=32, out_channels=8, kernel_size=3, stride=1, padding='same')
        self.conv7 = torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding='same')
        self.conv8 = torch.nn.Conv2d(in_channels=8, out_channels=1, kernel_size=1, stride=1, padding='same')





    def forward(self, x):

        #Downsample
        out1 = self.conv1(x)
        out2 = self.conv2(out1)
        out3 = self.conv3(out2)

        output = self.conv4(out3)

        output = self.upsample(output)
        output = self.conv5(output)
        output = torch.concat(( output, out3), dim=1)
        output = self.upsample(output)
        output = self.conv6(output)
        output = torch.concat(( output, out2), dim=1)
        output = self.upsample(output)
        output = self.conv7(output)
        output = torch.concat(( output, out1), dim=1)
        output = self.conv8(output)

        return output

model = FOMO()
print (model)
FOMO(
  (conv1): Conv2d(3, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(4, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv3): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (upsample): Upsample(scale_factor=2.0, mode=bilinear)
  (conv6): Conv2d(32, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv7): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv8): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)
)
x = torch.randn(1, 3, 512, 384)
y = model(x)
print (y.shape)
torch.Size([1, 1, 256, 192])
%timeit y=model(x)
19.8 ms ± 2.77 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

ToDo train on a dataset