Невозможно использовать matmul для данной ошибки тензоров при преобразовании pytorch в onnx JS

Я сделал простой pytorch MLP (генератор GAN) и преобразовал его в onnx с помощью учебника (https://www.youtube.com/watch?v=Vs730jsRgO8) мой код немного отличается, но я не могу обнаружить ошибку.

class Generator(nn.Module):
def __init__(self, g_input_dim, g_output_dim):
    super(Generator, self).__init__()
    # g_input = 100
    self.net = nn.Sequential(
      nn.Linear(g_input_dim, 256),
      nn.LeakyReLU(.2),
      nn.Linear(256, 512),
      nn.LeakyReLU(.2),
      nn.Linear(512, 1024),
      nn.LeakyReLU(.2),
      nn.Linear(1024, 784),
      nn.Tanh()
    )

# forward method
def forward(self, x): 
    return self.net(x)

После обучения экспортирую модель в onnx.

torch.save(G.state_dict(), "pytorch_model.pth")
import torch.onnx

model = Generator(z_dim,mnist_dim)
state_dict = torch.load("pytorch_model.pth")

model.load_state_dict(state_dict)
model.eval()

dummy_input = torch.zeros(100)

torch.onnx.export(model, dummy_input, "onnx_model.onnx", verbose=True)

Это дает следующий график onnx, который кажется точным.

graph(%input.1 : Float(100),
      %net.0.bias : Float(256),
      %net.2.bias : Float(512),
      %net.4.bias : Float(1024),
      %net.6.bias : Float(784),
      %25 : Float(100, 256),
      %26 : Float(256, 512),
      %27 : Float(512, 1024),
      %28 : Float(1024, 784)):
  %10 : Float(256) = onnx::MatMul(%input.1, %25) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1612:0
  %11 : Float(256) = onnx::Add(%10, %net.0.bias)
  %12 : Float(256) = onnx::LeakyRelu[alpha=0.20000000000000001](%11) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1239:0
  %14 : Float(512) = onnx::MatMul(%12, %26) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1612:0
  %15 : Float(512) = onnx::Add(%14, %net.2.bias)
  %16 : Float(512) = onnx::LeakyRelu[alpha=0.20000000000000001](%15) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1239:0
  %18 : Float(1024) = onnx::MatMul(%16, %27) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1612:0
  %19 : Float(1024) = onnx::Add(%18, %net.4.bias)
  %20 : Float(1024) = onnx::LeakyRelu[alpha=0.20000000000000001](%19) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1239:0
  %22 : Float(784) = onnx::MatMul(%20, %28) # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1612:0
  %23 : Float(784) = onnx::Add(%22, %net.6.bias)
  %24 : Float(784) 

Затем я импортировал код в javascript.

<html>
  <body>
    <script src="./onnx.min.js"></script>
    <script>
      async function test() {
        const sess = new onnx.InferenceSession()
        await sess.loadModel('./onnx_model.onnx')
        const input = new onnx.Tensor(new Float32Array(100), 'float32', [100])
        const outputMap = await sess.run([input])
        const outputTensor = outputMap.values().next().value
        console.log(`Output tensor: ${outputTensor.data}`)
      }
      test()
    </script>
  </body>
</html>

Я знаю, что размер ввода правильный, но onnx дает мне следующую ошибку.

onnx.min.js:8 Uncaught (in promise) Error: Can't use matmul on the given tensors
    at e.createProgramInfo (onnx.min.js:8)
    at t.run (onnx.min.js:8)
    at e.run (onnx.min.js:8)
    at t.<anonymous> (onnx.min.js:14)
    at onnx.min.js:14
    at Object.next (onnx.min.js:14)
    at onnx.min.js:14
    at new Promise (<anonymous>)
    at r (onnx.min.js:14)
    at onnx.min.js:14

Я также знаю, что matmul является оператором, поддерживаемым onnx, но я не могу понять, как и правильно ли мой входной тензор.


person Mazeyar Moeini Feizabadi    schedule 22.05.2020    source источник


Ответы (1)


Я думаю, что оператор matmul ожидает, что ввод будет двумерным. Кажется, это работает, когда я добавляю размер партии к входу (размер партии 1):

Раньше: dummy_input = torch.zeros(100)

После: dummy_input = torch.zeros(1, 100)

Раньше: const input = new onnx.Tensor(new Float32Array(100), 'float32', [100])

После: const input = new onnx.Tensor(new Float32Array(100), 'float32', [1, 100])

person Elliot Waite    schedule 22.05.2020
comment
OMG спасибо ВАМ ОЧЕНЬ БОЛЬШОЕ !!!! это был последний групповой проект, и нам нужно было его исправить, еще раз спасибо. Вы абсолютная легенда. ❤️ - person Mazeyar Moeini Feizabadi; 22.05.2020