While designing DL modules like a classification head, it is required to calculate the input features. PyTorch Lazy modules comes to the rescue by helping us automate it.
In this post, we will explore how we can use PyTorch Lazy modules to re-write PyTorch models used for
- Image classifiers
- Unet
When solving deep learning problems, it is common to try out different architectures like efficientnet and resnet.
You can find the code for the blog here
Lazy module with Image classification
Let’s look at an example code which can modify the head of 2 architectures.
resnet = timm.create_model(model_name='resnet50',pretrained=True)
effnet = timm.create_model(model_name='tf_efficientnetv2_b0',pretrained=True)
resnet.fc, effnet.classifier
Output
(Linear(in_features=2048, out_features=1000, bias=True),
Linear(in_features=1280, out_features=1000, bias=True))
Let’s change the architecture of the model head.
def create_head_old(in_features,out_features):
head = nn.Sequential(
nn.Linear(in_features=in_features, out_features=512,bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(),
nn.Linear(in_features=512,out_features=out_features,bias=False))
return head
resnet.fc = create_head_old(2048,1)
effnet.classifier = create_head_old(1280,1)
In the above code, we show how to replace the Linear layer in each of these models with a slightly complicated network.
The important thing is how we need to hardcode the input features required for each model. This will make it difficult to try different model families.
The Lazy Module feature comes with Lazy variants of Linear, BatchNorm, Conv and ConvTranspose which will help in automatic initialisation.
It is actually easy to understand this by looking at an example. Lets replace the linear
in create_head_old
with the Lazy variant.
def create_head_new(out_features):
head = nn.Sequential(
nn.LazyLinear(512,bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(),
nn.Linear(in_features=512,
out_features=out_features,
bias=False))
return head
resnet.fc = create_head_new(1)
effnet.classifier = create_head_new(1)
Output:
Sequential(
(0): LazyLinear(in_features=0, out_features=512, bias=False)
(1): ReLU(inplace=True)
(2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=512, out_features=1, bias=False)
)
We replace only the 1st Linear layer of our custom head. If we observe the output, we see the input feature is initialised to 0 and it changes to the actual values on first run. We can simply pass a dummy batch through the model, and it will get initialised and behave like a normal model.
dummy_tensor = torch.randn((2,3,224,224))
_ = resnet(dummy_tensor)
resnet.fc
Output:
Sequential(
(0): Linear(in_features=2048, out_features=512, bias=False)
(1): ReLU(inplace=True)
(2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=512, out_features=1, bias=False)
)
Now, we can easily use our new model head on any model without hardcoding input features. We may not appreciate the advantage of Lazy modules, when working with a simple example. But, the lazy modules makes life a lot easier when you are trying to build architectures for segmentation and object detection. Lets try how we can use the Lazy modules on a Unet architecture and simplify the code and make it flexible to adopt different models as encoders.
Lazy module with Unet Architecture:
For this example, I use a trimmed version of the Unet shown in of my earlier videos, where I explained how to build
Unet model using Timm and Fastai. Looking through the complete code could be overwhelming, so I would recommend to look at Unet
and UnetDecoder
class.
class Encoder(nn.Module):
def __init__(self, model_name='resnext50_32x4d'):
super().__init__()
self.encoder = timm.create_model(model_name, features_only=True, pretrained=False)
def forward(self, x):
return self.encoder(x)
def conv_block(in_feat,out_feat):
conv_block = nn.Sequential(nn.Conv2d(in_feat,out_feat,3,1,1,bias=False),
nn.BatchNorm2d(out_feat),
nn.ReLU())
return conv_block
class UnetBlock(nn.Module):
def __init__(self,in_channels,chanels,out_channels):
super().__init__()
self.conv1 = conv_block(in_channels,chanels)
self.conv2 = conv_block(chanels,out_channels)
def forward(self,x):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
x = self.conv1(x)
x = self.conv2(x)
return x
class UnetDecoder(nn.Module):
def __init__(self, fs=32, expansion=4,n_out=1):
super().__init__()
center_ch = 512*expansion
decoder5_ch = center_ch + (256*expansion)
channels = 512
self.center = nn.Sequential(conv_block(center_ch,center_ch),conv_block(center_ch,center_ch//2))
self.decoder5 = UnetBlock(decoder5_ch,channels,fs)
self.decoder4 = UnetBlock(256*expansion+fs,256,fs)
self.decoder3 = UnetBlock(128*expansion+fs,128,fs)
self.decoder2 = UnetBlock(64*expansion+fs,64,fs)
self.decoder1 = UnetBlock(fs,fs,fs)
self.logit = nn.Sequential(conv_block(fs,fs//2),conv_block(fs//2,fs//2),nn.Conv2d(fs//2,n_out,kernel_size=1))
def forward(self, feats):
e1,e2,e3,e4,e5 = feats #'64 256 512 1024 2048'
f = self.center(e5)
d5 = self.decoder5(torch.cat([f, e5], 1))
d4 = self.decoder4(torch.cat([d5, e4], 1))
d3 = self.decoder3(torch.cat([d4, e3], 1))
d2 = self.decoder2(torch.cat([d3, e2], 1))
d1 = self.decoder1(d2)
return self.logit(d1)
class Unet(nn.Module):
def __init__(self, fs=32, expansion=4, model_name='resnext50_32x4d',n_out=1):
super().__init__()
self.encoder = Encoder(model_name)
self.decoder = UnetDecoder(fs=fs, expansion=expansion,n_out=n_out)
def forward(self, x):
feats = self.encoder(x)
out = self.decoder(feats)
return out
Let’s focus on the UnetBlock
in the UnetDecoder
class and observe how we have to hardcode the input features of the UnetBlock
.
We have to calculate the input features manually and they could change when we try different architectures and thus break our code.
When someone else looks at the code, they often wonder how did we come up with the input features in the UnetBlock
.
Let’s simplify the above code, by adopting the lazy modules. Most of the code from above remains unchanged. Only the conv_block
and UnetDecoder
changes. By simply replacing the Conv2d
with LazyConv2d
we do not need to do any hardcoding or manual calculations.
def lazy_conv_block(out_feat):
conv_block = nn.Sequential(nn.LazyConv2d(out_feat,3,1,1,bias=False),
nn.BatchNorm2d(out_feat),
nn.ReLU())
return conv_block
class LazyUnetDecoder(nn.Module):
def __init__(self, fs=32, expansion=4,n_out=1):
super().__init__()
channels = 512
center_ch = channels*expansion
self.center = nn.Sequential(lazy_conv_block(center_ch),lazy_conv_block(center_ch//2))
self.decoder5 = LazyUnetBlock(channels,fs)
self.decoder4 = LazyUnetBlock(channels//2,fs)
self.decoder3 = LazyUnetBlock(channels//4,fs)
self.decoder2 = LazyUnetBlock(channels//8,fs)
self.decoder1 = LazyUnetBlock(fs,fs)
self.logit = nn.Sequential(lazy_conv_block(fs//2),lazy_conv_block(fs//2),nn.Conv2d(fs//2,n_out,kernel_size=1))
def forward(self, feats):
e1,e2,e3,e4,e5 = feats #'64 256 512 1024 2048'
f = self.center(e5)
d5 = self.decoder5(torch.cat([f, e5], 1))
d4 = self.decoder4(torch.cat([d5, e4], 1))
d3 = self.decoder3(torch.cat([d4, e3], 1))
d2 = self.decoder2(torch.cat([d3, e2], 1))
d1 = self.decoder1(d2)
return self.logit(d1)
In addition to simplifying the code, we can also use most of the backbone/architectures. Lets try few different backbones and check what happens?
Fails ❌
Unet(model_name='resnet18')(dummy_batch).shape
Works ✅
LazyUnet(model_name='resnet18')(dummy_batch).shape
Conclusion:
The LazyModules of PyTorch is going to simplify writing a lot of future architectures. Since it is a new feature under heavy development, the API or functionality could change in the future.