Description
I was trying to convert Keras's GlobalMaxPooling2D
to PyTorch through ONNX but failed. After a rough investigation, I notice that onnx2pytorch will convert ReduceMax
to torch.max
, however, torch.max can only support one-dimension reduction while ReduceMax
sometimes supports two-dimension reduction. In this case, we may use torch.amax
instead of torch.max
.
Script to reproduce is as follows:
# Build the model
import tensorflow as tf
import tf2onnx
x = tf.keras.layers.Input((32, 32, 3))
y = tf.keras.layers.GlobalMaxPooling2D()(x)
model = tf.keras.Model(x, y)
model.summary()
# Convert the model
input_shape = model.layers[0].input_shape[0]
spec = (tf.TensorSpec(input_shape, tf.float32, name="input"),)
_, _ = tf2onnx.convert.from_keras(model, input_signature=spec, \
opset=15, output_path="temp.onnx")
from onnx2pytorch import ConvertModel
import onnx
onnx_model = onnx.load("temp.onnx")
torch_model = ConvertModel(onnx_model, experimental=True)
# Predict
import torch
import numpy as np
input = np.random.rand(10, *input_shape[1:])
input = torch.from_numpy(input)
torch_model.double()
pred = torch_model(input)
pred = pred.detach().numpy()
print("The prediction is: ", pred.shape)
You can also access the code below:
https://colab.research.google.com/drive/1mwoJDtjroZ6ynNtaFdu-2EjX9eVTd-pS?usp=sharing
The crash information is as follows:
/usr/local/lib/python3.7/dist-packages/onnx2pytorch/convert/model.py in forward(self, *input_list, **input_dict)
222 activations[out_op_id] = output
223 else:
--> 224 activations[out_op_id] = op(*in_activations)
225
226 # Remove activations that are no longer needed
TypeError: max() received an invalid combination of arguments - got (Tensor, dim=tuple, keepdim=bool), but expected one of:
* (Tensor input)
* (Tensor input, Tensor other, *, Tensor out)
* (Tensor input, int dim, bool keepdim, *, tuple of Tensors out)
* (Tensor input, name dim, bool keepdim, *, tuple of Tensors out)
To fix this bug, my suggestion is to add a guard on ReduceMax (line 200 in convert/operations.py
) as follows:
200 elif node.op_type == "ReduceMax":
201 kwargs = dict(keepdim=True)
202 kwargs.update(extract_attributes(node))
+ 203 if isinstance(kwargs["dim"], (tuple, list)) and len(kwargs["dim"]) > 1:
+ 204 op = partial(torch.amax, **kwargs)
+ 205 else:
206 op = partial(torch.max, **kwargs)
Another idea is to just replace torch.max
with torch.amax
, but it seems that these two still have some differences so I am not sure such replacement is safe.
Please see if it is fine, I can contribute to a pull request.