本文介绍 pytorch module权重共享的诡异操作

pytorch module权重共享的诡异操作

This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat: jintianiloveu

记录一个十分诡异的操作。

事情的经过是这样的,我在对一个模型进行trace的过程中遇到一个错误:

  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 1469, in __init__
    check_unique(param)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 1461, in check_unique
    raise ValueError("TracedModules don't support parameter sharing between modules")
ValueError: TracedModules don't support parameter sharing between modules

这个错误的大概意思是:trace module的时候,不允许权重共享。说白了,trace就是生成一个字典,每个键就是module的attribute,值就是权重。pytorch在trace的时候发现,同一个module权重一个但是attribute却有两个,这种情况等于是两个键对应同一个值。讲道理这不是个问题,但是pytorch在trace模型的时候就是不支持。

不过这抛给我们一个问题:pytorch里面什么才算是权重共享?

如何权重共享?

至于我们是如何实现权重共享的,我们可能会经常有这样的操作:

class BasicConv(nn.Module):

    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                              stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU(inplace=True) if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


因此我们有很多Conv+BN+Relu的堆叠块啊,为什么不把他们定义一个模块呢?然后接着我们就会这样用:

class FuckNet(nn.Module):
	def __init__():
		self.block_1 = BasicConv(256, 512, 3, 0, 1)
		self.fuck_layers = nn.Sequential()
		for i in range(5):
			self.fuck_layers.add_module('{}'.format(i), BasicConv(256, 256, 3, 0, 1))

然后你用这个去训练一个模型,trace的时候就会出现上述的错误。

那么问题来了。上面的网络我如果这样写:

class FuckNet(nn.Module):
	def __init__():
		self.block_1 = BasicConv(256, 512, 3, 0, 1)
		
		self.fuck_layers = nn.Sequential()
		for i in range(5):
			self.fuck_layers.add_module('{}'.format(i), self.block_1)

二者是一样的吗?答案是不一样!!!二者不对等!!!

总结如下:

  • 首先这两个写法肯定是不一样的。

  • 对于后一种写法,权重是肯定共享的。因为你5个层都是调用的它;

  • 第一种写法,到底是共享还是不共享呢????

这个问题必须得有一个结论。 通过一个实验来测试一下:

# -----------------------
#
# Copyright Jin Fagang @2018
# 
# 7/10/19
# trace_test
# -----------------------
from torch import nn
from alfred.utils.log import logger as logging
from alfred.dl.torch.common import device
import torch



"""
we want test which one 
can be trace

"""



class BasicConv(nn.Module):

    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                              stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU(inplace=True) if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class FuckNet(nn.Module):

    def __init__(self):
        super(FuckNet, self).__init__()

        self.welcome_layer = BasicConv(3, 256, 3, 1, 1)
        self.fuck_layers = nn.Sequential()
        for i in range(5):
            self.fuck_layers.add_module('{}'.format(i), BasicConv(256, 256, 3, 1, 1))

    def forward(self, x):
        x = self.welcome_layer(x)
        return self.fuck_layers(x)


class FuckNet2(nn.Module):

    def __init__(self):
        super(FuckNet2, self).__init__()

        self.welcome_layer = BasicConv(3, 256, 3, 1, 1)
        self.block_1 = BasicConv(256, 256, 3, 1, 1)
        self.fuck_layers = nn.Sequential()
        for i in range(5):
            self.fuck_layers.add_module('{}'.format(i), self.block_1)

    def forward(self, x):
        x = self.welcome_layer(x)
        return self.fuck_layers(x)



if __name__ == '__main__':
    model1 = FuckNet()
    model2 = FuckNet2()

    model2.eval().to(device)

    # start to trace model
    example = torch.rand(1, 3, 512, 512).to(device)

    traced_script_module = torch.jit.trace(model2, example)
    traced_script_module.save('test.pt')

这个简单的实验,先说一下结论:

  • FuckNet2,这个无法trace;
  • FuckNet,这个可以trace。

那么问题来了,为什么第二个网络不行?我的理解是这样的:

在给你的module添加子modul的时候,假如这个子module要用很多次,比如出现在一个for循环里面,那么这些个子的module就不能共享权重。原因就是一旦共享,意味着同一个权重有多个name或者id,生成tracedmodule的时候也就无法给这个权重分配name了,冲突了,因为我是按照权重来index的

要解决这个问题还真的没有好的方法。唯一的办法就是不让它共享。