YOLOv3学习:(一)Darknet-53结构推导与实现

jupiter
2021-02-06 / 0 评论 / 1,044 阅读 / 正在检测是否收录...
温馨提示:
本文最后更新于2021年12月07日,已超过1110天没有更新,若内容或图片失效,请留言反馈。

YOLOv3学习:(一)Darknet-53结构推导与实现

原生Darknet-53网络结构

Darknet-53网络结构

代码实现-1(更易读)

模型代码

import torch
import torch.nn as nn

# Darknet53 中的基本块--卷积块,由Conv+BN+LeakyReLU共同组成
class ConvBNReLU(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride,padding):
        super(ConvBNReLU,self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
        self.BN = nn.BatchNorm2d(out_channels)
        self.leaky_relu = nn.ReLU6(inplace=True)
    def forward(self,x):
        x = self.conv(x)
        x = self.BN(x)
        x = self.leaky_relu(x)
        return x

# Darknet53 中的基本块--下采样块,用卷积(stride=2)实现
class DownSample(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DownSample,self).__init__()
        self.down_samp = nn.Conv2d(in_channels,out_channels,3,2,1)
    def forward(self,x):
        x = self.down_samp(x)
        return x

# Darknet53 中的基本块--ResBlock
class ResBlock(nn.Module):
    def __init__(self, nchannels):
        super(ResBlock, self).__init__()
        mid_channels = nchannels // 2
        self.conv1x1 = ConvBNReLU(nchannels, mid_channels,1,1,0)
        self.conv3x3 = ConvBNReLU(mid_channels, nchannels,3,1,1)

    def forward(self, x):
        out = self.conv3x3(self.conv1x1(x))
        return out + x

num_classes=1000
darknet53= nn.Sequential()
darknet53.add_module('conv_bn_relu',ConvBNReLU(3,32,3,1,1))
darknet53.add_module('down_samp_0',DownSample(32,64))
darknet53.add_module('res_block_1_1',ResBlock(64))
darknet53.add_module('down_samp_1',DownSample(64,128))
darknet53.add_module('res_block_2_1',ResBlock(128))
darknet53.add_module('res_block_2_2',ResBlock(128))
darknet53.add_module('down_samp_2',DownSample(128,256))
darknet53.add_module('res_block_3_1',ResBlock(256))
darknet53.add_module('res_block_3_2',ResBlock(256))
darknet53.add_module('res_block_3_3',ResBlock(256))
darknet53.add_module('res_block_3_4',ResBlock(256))
darknet53.add_module('res_block_3_5',ResBlock(256))
darknet53.add_module('res_block_3_6',ResBlock(256))
darknet53.add_module('res_block_3_7',ResBlock(256))
darknet53.add_module('res_block_3_8',ResBlock(256))
darknet53.add_module('down_samp_3',DownSample(256,512))
darknet53.add_module('res_block_4_1',ResBlock(512))
darknet53.add_module('res_block_4_2',ResBlock(512))
darknet53.add_module('res_block_4_3',ResBlock(512))
darknet53.add_module('res_block_4_4',ResBlock(512))
darknet53.add_module('res_block_4_5',ResBlock(512))
darknet53.add_module('res_block_4_6',ResBlock(512))
darknet53.add_module('res_block_4_7',ResBlock(512))
darknet53.add_module('res_block_4_8',ResBlock(512))
darknet53.add_module('down_samp_4',DownSample(512,1024))
darknet53.add_module('res_block_5_1',ResBlock(1024))
darknet53.add_module('res_block_5_2',ResBlock(1024))
darknet53.add_module('res_block_5_3',ResBlock(1024))
darknet53.add_module('res_block_5_4',ResBlock(1024))
darknet53.add_module('avg_pool',nn.AvgPool2d(kernel_size=8,stride=1))
darknet53.add_module('flatten',nn.Flatten())
darknet53.add_module('linear',nn.Linear(in_features=1024,out_features=num_classes))
darknet53.add_module('softmax',nn.Softmax(dim=1))

print(darknet53)

输入输出验证

fake_imput = torch.zeros((1,3,256,256))
print(fake_imput.shape)

output = darknet53(fake_imput)
print(output.shape)
torch.Size([1, 3, 256, 256])
torch.Size([1, 1000])

代码实现-2(代码更少)

import torch
import torch.nn as nn

def Conv3x3BNReLU(in_channels,out_channels,stride=1):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=stride,padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True)
    )

def Conv1x1BNReLU(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True)
    )

class Residual(nn.Module):
    def __init__(self, nchannels):
        super(Residual, self).__init__()
        mid_channels = nchannels // 2
        self.conv1x1 = Conv1x1BNReLU(in_channels=nchannels, out_channels=mid_channels)
        self.conv3x3 = Conv3x3BNReLU(in_channels=mid_channels, out_channels=nchannels)

    def forward(self, x):
        out = self.conv3x3(self.conv1x1(x))
        return out + x

class Darknet53(nn.Module):
    def __init__(self, num_classes=1000):
        super(Darknet53, self).__init__()
        self.first_conv = Conv3x3BNReLU(in_channels=3, out_channels=32)

        self.block1 = self._make_layers(in_channels=32,out_channels=64, block_num=1)
        self.block2 = self._make_layers(in_channels=64,out_channels=128, block_num=2)
        self.block3 = self._make_layers(in_channels=128,out_channels=256, block_num=8)
        self.block4 = self._make_layers(in_channels=256,out_channels=512, block_num=8)
        self.block5 = self._make_layers(in_channels=512,out_channels=1024, block_num=4)

        self.avg_pool = nn.AvgPool2d(kernel_size=8,stride=1)
        self.linear = nn.Linear(in_features=1024,out_features=num_classes)
        self.softmax = nn.Softmax(dim=1)

    def _make_layers(self, in_channels,out_channels, block_num):
        _layers = []
        _layers.append(Conv3x3BNReLU(in_channels=in_channels, out_channels=out_channels, stride=2))
        for _ in range(block_num):
            _layers.append(Residual(nchannels=out_channels))
        return nn.Sequential(*_layers)

    def forward(self, x):
        x = self.first_conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)

        x = self.avg_pool(x)
        x = x.view(x.size(0),-1)
        x = self.linear(x)
        out = self.softmax(x)
        return x

model = Darknet53()
print(model)

input = torch.randn(1,3,256,256)
out = model(input)
print(out.shape)

YOLOv3中的Darknet53的网络各层参数详解

参考资料

  1. Pytorch实现Darknet-53:https://blog.csdn.net/qq_41979513/article/details/102680028
  2. Darknet53网络各层参数详解:https://blog.csdn.net/qq_40210586/article/details/106144197
0

评论 (0)

打卡
取消