8000 change torch.max to torch.amax when reducing on multiple dimensions · Issue #34 · Talmaj/onnx2pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

change torch.max to torch.amax when reducing on multiple dimensions #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
maybeLee opened this issue Dec 22, 2021 · 2 comments · Fixed by #37
Closed

change torch.max to torch.amax when reducing on multiple dimensions #34

maybeLee opened this issue Dec 22, 2021 · 2 comments · Fixed by #37

Comments

@maybeLee
Copy link
maybeLee commented Dec 22, 2021

I was trying to convert Keras's GlobalMaxPooling2D to PyTorch through ONNX but failed. After a rough investigation, I notice that onnx2pytorch will convert ReduceMax to torch.max, however, torch.max can only support one-dimension reduction while ReduceMax sometimes supports two-dimension reduction. In this case, we may use torch.amax instead of torch.max.

Script to reproduce is as follows:

# Build the model
import tensorflow as tf
import tf2onnx
x = tf.keras.layers.Input((32, 32, 3))
y = tf.keras.layers.GlobalMaxPooling2D()(x)
model = tf.keras.Model(x, y)
model.summary()
# Convert the model
input_shape = model.layers[0].input_shape[0]
spec = (tf.TensorSpec(input_shape, tf.float32, name="input"),)
_, _ = tf2onnx.convert.from_keras(model, input_signature=spec, \
        opset=15, output_path="temp.onnx")
from onnx2pytorch import ConvertModel
import onnx
onnx_model = onnx.load("temp.onnx")
torch_model = ConvertModel(onnx_model, experimental=True)

# Predict
import torch
import numpy as np
input = np.random.rand(10, *input_shape[1:])
input = torch.from_numpy(input)
torch_model.double()
pred = torch_model(input)
pred = pred.detach().numpy()
print("The prediction is: ", pred.shape)

You can also access the code below:
https://colab.research.google.com/drive/1mwoJDtjroZ6ynNtaFdu-2EjX9eVTd-pS?usp=sharing

The crash information is as follows:

/usr/local/lib/python3.7/dist-packages/onnx2pytorch/convert/model.py in forward(self, *input_list, **input_dict)
    222                     activations[out_op_id] = output
    223             else:
--> 224                 activations[out_op_id] = op(*in_activations)
    225 
    226             # Remove activations that are no longer needed

TypeError: max() received an invalid combination of arguments - got (Tensor, dim=tuple, keepdim=bool), but expected one of:
 * (Tensor input)
 * (Tensor input, Tensor other, *, Tensor out)
 * (Tensor input, int dim, bool keepdim, *, tuple of Tensors out)
 * (Tensor input, name dim, bool keepdim, *, tuple of Tensors out)

To fix this bug, my suggestion is to add a guard on ReduceMax (line 200 in convert/operations.py) as follows:

200         elif node.op_type == "ReduceMax":
201             kwargs = dict(keepdim=True)
202             kwargs.update(extract_attributes(node))
+ 203             if isinstance(kwargs["dim"], (tuple, list)) and len(kwargs["dim"]) > 1: 
+ 204                 op = partial(torch.amax, **kwargs)
+ 205             else:
206                 op = partial(torch.max, **kwargs)

Another idea is to just replace torch.max with torch.amax, but it seems that these two still have some differences so I am not sure such replacement is safe.

Please see if it is fine, I can contribute to a pull request.

@calvinmccarter-at-lightmatter
Copy link
Contributor

@maybeLee -- PR #37 should fix this for you.

@maybeLee
Copy link
Author
maybeLee commented Jan 4, 2022

Thanks for your help!

@maybeLee maybeLee closed this as completed Jan 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants
0