Pytorch转ONNX遇到的问题及解决方案

Pytorch转ONNX遇到的问题及解决方案

  • ONNX不支持torch.linspace
  • ONNX不支持torch中的grid_sampler操作
  • **完美解决方案:用mmcv中的grid sample替换,支持onnx模型导出,支持onnxruntime推理,支持onnx-IR转换。**
  • onnx动态输入问题
  • ONNX不支持torch.linspace

    报错提示

    RuntimeError: Exporting the operator linspace to ONNX opset version 11 is not supported. 
    Please feel free to request support or submit a pull request on PyTorch GitHub.
    

    源代码:

    tenHorizontal = torch.linspace(-1.0, 1.0, 1080,device=device)
    

    解决方案:
    用torch.range()代替

    tenHorizontal = torch.range(-1.0,1.0,(2/1080),device=device)
    

    注意:
    torch.arange()和torch.range()的区别:

    >>> torch.range(-1.0,1.0,0.5)
    tensor([-1.0000, -0.5000,  0.0000,  0.5000,  1.0000])
    >>> torch.arange(-1.0,1.0,0.5)
    tensor([-1.0000, -0.5000,  0.0000,  0.5000])
    

    range的输出是包含end,而arange是不包含end

    不知道是cuda和cpu上是不是存在浮点精度不一致的问题还是什么原因,arange在本地和服务器上同行代码但是size却不一致。pytorch和python版本完全一样。

    经过无数次尝试后终于彻底解决这个迷幻的问题:

    >>> a=torch.linspace(-1.0, 1.0, 224)
    >>> print(a.shape)
    torch.Size([224])
    >>> b=torch.arange(-1.0, 1.0+(2.0/224), 2.0/(224-1))
    >>> print(b.shape)
    torch.Size([224])
    
    >>> print(a-b)
    tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  5.9605e-08,
             5.9605e-08,  5.9605e-08,  5.9605e-08,  5.9605e-08,  5.9605e-08,
             5.9605e-08,  5.9605e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  5.9605e-08,  5.9605e-08,  5.9605e-08,  5.9605e-08,
             5.9605e-08,  5.9605e-08,  5.9605e-08,  5.9605e-08,  2.9802e-08,
             2.9802e-08,  2.9802e-08,  2.9802e-08,  2.9802e-08,  2.9802e-08,
             2.9802e-08,  2.9802e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             2.9802e-08,  2.9802e-08,  2.9802e-08,  2.9802e-08,  2.9802e-08,
             2.9802e-08,  4.4703e-08,  2.9802e-08,  7.4506e-08,  7.4506e-08,
             7.4506e-08,  7.4506e-08,  7.4506e-08,  7.4506e-08,  8.9407e-08,
             7.4506e-08,  4.4703e-08,  4.4703e-08,  4.4703e-08,  5.2154e-08,
             4.4703e-08,  4.4703e-08,  4.4703e-08,  5.2154e-08,  1.4901e-08,
             1.4901e-08,  1.4901e-08,  1.4901e-08,  1.4901e-08,  1.6764e-08,
             1.7695e-08,  2.1886e-08, -6.8918e-08, -6.8918e-08, -6.8918e-08,
            -6.7055e-08, -6.7055e-08, -6.7055e-08, -6.7055e-08, -5.9605e-08,
            -3.7253e-08, -3.7253e-08, -3.7253e-08, -2.9802e-08, -3.7253e-08,
            -3.7253e-08, -2.9802e-08, -2.9802e-08, -5.9605e-08, -5.9605e-08,
            -5.9605e-08, -5.9605e-08, -5.9605e-08, -5.9605e-08, -4.4703e-08,
            -5.9605e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08,
            -2.9802e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08, -5.9605e-08,
            -5.9605e-08, -5.9605e-08, -5.9605e-08, -5.9605e-08, -5.9605e-08,
            -5.9605e-08, -5.9605e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08,
            -2.9802e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08, -2.9802e-08,
            -5.9605e-08, -5.9605e-08, -5.9605e-08, -5.9605e-08, -5.9605e-08,
            -5.9605e-08, -5.9605e-08, -5.9605e-08,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00, -5.9605e-08, -5.9605e-08, -5.9605e-08,
            -5.9605e-08, -5.9605e-08, -5.9605e-08, -5.9605e-08, -5.9605e-08,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
             0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00])
            
    

    两种function的等差数列非常近似,最多也就只有10^-8的偏差。实际使用下完全不影响。

    ONNX不支持torch中的grid_sampler操作

    报错提示:

    RuntimeError: Exporting the operator grid_sampler to ONNX opset version 11 is not supported. 
    Please feel free to request support or submit a pull request on PyTorch GitHub.
    

    尝试方案一:将grid_sampler转为自定义OP并注册,然后成功导出ONNX模型

    # 自定义一个名为grid_sampler的OP
    import torch.onnx.symbolic_opset11 as sym_opset
    import torch.onnx.symbolic_helper as sym_help
    from torch.onnx import register_custom_op_symbolic
    
    def grid_sampler(g, input, grid, mode, padding_mode, align_corners):  # long, long, long: contants dtype
        mode_i = sym_help._maybe_get_scalar(mode)
        paddingmode_i = sym_help._maybe_get_scalar(padding_mode)
        aligncorners_i = sym_help._maybe_get_scalar(align_corners)
        return g.op("myonnx_plugin::GridSampler", input, grid, interpolationmode_i=mode_i, paddingmode_i=paddingmode_i, aligncorners_i=aligncorners_i)  # just a dummy definition for onnx runtime since we don't need onnx inference
    
    # 注册这个自定义的OP
    sym_opset.grid_sampler = grid_sampler
    register_custom_op_symbolic('myop::GridSampler', grid_sampler, 11)
    
    class GridSampler(torch.nn.Module):
        def __init__(self):
            super(GridSampler, self).__init__()
        
        def forward(self, tenInput, g):
            return F.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
    
    """
    tenInput = 
    g =
    """ 
    gridsampler = GridSampler()
    img = gridsampler(tenInput, g)
    

    成功导出onnx模型。

    但后续发现这种方式仅限于跳过onnx的自检,在推理阶段仍然会报未经注册的OP。
    网上看到有人自己写的torch自定义OP实现该功能,但是非常慢。几乎不可用
    近日又找到了torch.nn.fuctional.grid_sampler完美的替代方案

    完美解决方案:用mmcv中的grid sample替换,支持onnx模型导出,支持onnxruntime推理,支持onnx-IR转换。

    from mmcv.ops.point_sample import bilinear_grid_sample
    import torch.nn.functional as F
    # input grid 自己按照自己的任务就可以 和torch中的grid sampler的输入是一致的
    img = bilinear_grid_sample(tenInput, grid, align_corners=False)
    img_o = F.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
    print(img-img_o)
    

    经自己实验的结果torch和mmcv中的两个grid sample几乎没有区别,可以完美替换。

    转onnx是没问题的。onnxruntime也不会报错了。

    另外:如果需要在cuda下运行,建议修改mmcv的源文件。
    路径为/anaconda3/envs/pytorch/lib/python3.9/site-packages/mmcv/ops/point_sample.py
    修改64-71行

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        x0 = torch.where(x0 < 0, torch.tensor(0).to(device), x0)
        x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x0)
        x1 = torch.where(x1 < 0, torch.tensor(0).to(device), x1)
        x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1).to(device), x1)
        y0 = torch.where(y0 < 0, torch.tensor(0).to(device), y0)
        y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y0)
        y1 = torch.where(y1 < 0, torch.tensor(0).to(device), y1)
        y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1).to(device), y1)
    

    onnx动态输入问题

    torch_input = torch.from_numpy(input_rand).half().to(device)  # rand(4, 1, 4, 4) # N C H W
    torch_grid = torch.from_numpy(grid_rand).half().to(device)  # rand(4, 4, 4, 2)
    model = GridSampler()
    torch.onnx.export(model, (torch_input, torch_grid), onnx_model_file, verbose=False, 
                      input_names=['input', 'grid'], output_names=['output'], opset_version=11, 
                      dynamic_axes={"input": {1: 'channel', 2: 'height', 3: 'width'}, 
                      "grid": {1: 'height', 2: 'width', 3: 'channel'}}, enable_onnx_checker=False)
    

    要设置的参数为dynamic_axes
    将输入中会产生动态变换的维度添加在该参数中,
    以上述为例,"input"输入是动态输入,在第1,2,3维度会产生动态变化。
    若其他输入或输出要设置为动态,同理。

    来源:JoeyChen1219

    物联沃分享整理
    物联沃-IOTWORD物联网 » Pytorch转ONNX遇到的问题及解决方案

    发表评论