8000 [Topi][Op][PyTorch][Vitas] Fix inconsistent kernel layout conventions for conv2d_transpose by AndrewZhaoLuo · Pull Request #9336 · apache/tvm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Topi][Op][PyTorch][Vitas] Fix inconsistent kernel layout conventions for conv2d_transpose #9336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6dd1566
fix a lot of initial tests
AndrewZhaoLuo Oct 21, 2021
6b00883
make pytorch tests pass
AndrewZhaoLuo Oct 21, 2021
0447f10
lint
AndrewZhaoLuo Oct 21, 2021
2425b41
add test
AndrewZhaoLuo Oct 21, 2021
a1c8f6b
fix bug with layout transform
AndrewZhaoLuo Oct 21, 2021
26e1462
change layouts for conv2d_transpose too
AndrewZhaoLuo Oct 21, 2021
8b93a0e
fix vitis tests
AndrewZhaoLuo Oct 21, 2021
aa3daf9
fix qnn conv2d transpose tests
AndrewZhaoLuo Oct 21, 2021
5bb9e80
fix fake quantization pass
AndrewZhaoLuo Oct 21, 2021
8c526e2
add todo
AndrewZhaoLuo Oct 21, 2021
03bad08
lint
AndrewZhaoLuo Oct 21, 2021
962766c
undo just formatting changes
AndrewZhaoLuo Oct 21, 2021
83d4fa6
remove formatting only change
AndrewZhaoLuo Oct 21, 2021
b2f049d
remove f2qi for later pr
8000 AndrewZhaoLuo Oct 21, 2021
525ab24
more frontend tests fixes
AndrewZhaoLuo Oct 22, 2021
306ac98
fix a lot of initial tests
AndrewZhaoLuo Oct 21, 2021
b200eeb
make pytorch tests pass
AndrewZhaoLuo Oct 21, 2021
3223e88
lint
AndrewZhaoLuo Oct 21, 2021
3c45340
add test
AndrewZhaoLuo Oct 21, 2021
2a95c3f
fix bug with layout transform
AndrewZhaoLuo Oct 21, 2021
737238e
change layouts for conv2d_transpose too
AndrewZhaoLuo Oct 21, 2021
08107d9
fix vitis tests
AndrewZhaoLuo Oct 21, 2021
d9de5d0
fix qnn conv2d transpose tests
AndrewZhaoLuo Oct 21, 2021
815cc4b
fix fake quantization pass
AndrewZhaoLuo Oct 21, 2021
9baac54
add todo
AndrewZhaoLuo Oct 21, 2021
435e3b2
lint
AndrewZhaoLuo Oct 21, 2021
1228fc5
undo just formatting changes
AndrewZhaoLuo Oct 21, 2021
35e5617
remove formatting only change
AndrewZhaoLuo Oct 21, 2021
30101fa
remove f2qi for later pr
AndrewZhaoLuo Oct 21, 2021
0fa4acb
more frontend tests fixes
AndrewZhaoLuo Oct 22, 2021
5424dcf
jostle
AndrewZhaoLuo Oct 22, 2021
54bd920
fix keras
AndrewZhaoLuo Oct 22, 2021
be20086
Merge branch 'aluo/qnn/conv2d-transpose-fixes' of github.com:AndrewZh…
AndrewZhaoLuo Oct 25, 2021
dbfa74b
fix another frontend test
AndrewZhaoLuo Oct 25, 2021
7a2838d
fix things
AndrewZhaoLuo Oct 25, 2021
2c2136e
jostle ci
AndrewZhaoLuo Oct 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/tvm/relay/frontend/caffe.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import numpy as np
import tvm
from tvm.ir import IRModule

from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable
from .common import infer_shape as _infer_shape

Expand Down Expand Up @@ -513,14 +514,16 @@ def convert_deconv(self, op):
weight_shape = [-1, conv_params.num_output, kh, kw]
weight_value = np.asarray(weight.data, np.float32)
weight_value = np.reshape(weight_value, weight_shape)

# weight shape is in relay's IOHW format rn, we need it to be OIHW
weight_value = np.transpose(weight_value, [1, 0, 2, 3])
else:
raise Exception("No weight value of layer {} in caffemodel".format(op.name))

weight_expr = self.exp_tab.new_const(weight_value, dtype="float32")
in_expr = self.exp_tab.get_expr(inputs[0])
out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params)
if bias:

bias_value = np.asarray(bias.data, np.float32)
bias_expr = self.exp_tab.new_const(bias_value, dtype="float32")
out = _op.nn.bias_add(out, bias_expr)
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,14 @@ def _convert_convolution(inexpr, keras_layer, etab):
else:
kernel_layout = "HWIO"
else:
kernel_layout = "OIHW"
if is_deconv:
kernel_layout = "IOHW"
else:
kernel_layout = "OIHW"

if is_deconv:
kernel_h, kernel_w, n_filters, in_channels = weight.shape
if kernel_layout == "OIHW":
if kernel_layout == "IOHW":
weight = weight.transpose([3, 2, 0, 1])
elif is_depthconv:
kernel_h, kernel_w, in_channels, depth_mult = weight.shape
Expand Down
46 changes: 28 additions & 18 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,50 @@
"""MXNet symbol frontend."""
import json
import math

import numpy as np
import tvm
from tvm.ir import IRModule

from tvm import relay
from tvm.ir import IRModule
from tvm.topi.utils import get_const_tuple

from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import scope_builder as _scope_builder
from ... import nd as _nd

from .common import StrAttrsDict
from .common import infer_type as _infer_type
from .common import get_name as _get_name
from .common import infer_shape as _infer_shape
from .common import infer_type as _infer_type
from .common import infer_value as _infer_value
from .common import get_name as _get_name
from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
from .nnvm_common import _clip, _transpose, _upsampling
from .nnvm_common import _elemwise_sum, _reshape
from .nnvm_common import _warn_not_used
from .mxnet_qnn_op_utils import (
quantize_mxnet_min_max,
quantize_conv_weights_bias_channel_mkldnn_from_var,
quantize_conv_bias_mkldnn_from_var,
get_conv_mkldnn_requantized_scale_outDtype,
dequantize_mxnet_min_max,
get_conv_mkldnn_requantized_scale_outDtype,
get_mkldnn_int8_scale,
get_mkldnn_uint8_scale,
get_mkldnn_requantize_scale_outDtype,
get_mkldnn_uint8_scale,
quantize_conv_bias_mkldnn_from_var,
quantize_conv_weights_bias_channel_mkldnn_from_var,
quantize_mxnet_min_max,
)
from .nnvm_common import (
_arg_reduce,
_binop_scalar,
_cast,
_clip,
_elemwise_sum,
_init_op,
_rbinop_scalar,
_reduce,
_rename,
_reshape,
_softmax_op,
_transpose,
_upsampling,
_warn_not_used,
)


__all__ = ["from_mxnet"]

Expand Down Expand Up @@ -329,7 +339,7 @@ def _mx_conv2d_transpose(inputs, attrs):
if "kernel_layout" in attrs.attrs:
kernel_layout = attrs.get_str("kernel_layout")
else:
kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW"
kernel_layout = "HWIO" if data_layout == "NHWC" else "IOHW"

new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
# pylint: disable=missing-function-docstring
"""PT: PyTorch frontend."""
import itertools
import functools
import itertools
import logging
import math
import sys
Expand All @@ -40,11 +40,11 @@
from ..prelude import Prelude, StaticTensorArrayOps
from ..ty import Any, TensorType, TupleType
from . import qnn_torch
from .common import AttrCvt, get_relay_op, unbind, lstm_cell, gru_cell
from .common import infer_value as _infer_value
from .common import AttrCvt, get_relay_op, gru_cell
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import try_infer_value
from .common import lstm_cell, try_infer_value, unbind
from .pytorch_utils import is_version_greater_than

__all__ = ["from_pytorch"]
Expand Down Expand Up @@ -1022,6 +1022,9 @@ def convolution(self, inputs, input_types):
elif len(kernel_size) == 2:
data_layout = "NCHW"
kernel_layout = "OIHW"
if use_transpose:
# Transposed convolutions have IOHW layout.
kernel_layout = "IOHW"
else:
data_layout = "NCW"
kernel_layout = "OIW"
Expand Down
8 changes: 2 additions & 6 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging

import numpy as np

import tvm
from tvm import relay
from tvm.relay import expr as _expr
Expand Down Expand Up @@ -1043,11 +1042,8 @@ def _impl(inputs, _):

weight_shape = list(infer_shape(weight))

# Swap I and O dims to match shape relay expects for OIHW
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]

kernel_size = (weight_shape[2], weight_shape[3])
out_channels = weight_shape[0]
out_channels = weight_shape[1]

conv_out = relay.qnn.op.conv2d_transpose(
inputs[0],
Expand All @@ -1064,7 +1060,7 @@ def _impl(inputs, _):
channels=out_channels,
output_padding=output_padding,
out_dtype="int32",
kernel_layout="OIHW",
kernel_layout="IOHW",
)

return _do_bias_and_requantize(
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,11 @@ def _impl(inputs, attr, params, mod):
raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))

if "kernel_layout" not in attr:
if opname in ["conv", "conv_transpose"]:
if opname == "conv":
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW"
elif opname == "conv_transpose":
# conv_transpose in TVM has weights be IOHW for NCHW
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "IOHW"
else:
attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"

Expand Down
39 changes: 20 additions & 19 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,25 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
"""Tensorflow lite frontend."""
import math
import itertools
import math

import numpy as np
import tvm
from tvm import relay
from tvm.ir import IRModule

from tvm import relay
from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import qnn as _qnn
from ... import nd as _nd
from .common import ExprTable
from .common import infer_shape as _infer_shape, to_int_list
from .common import infer_shape as _infer_shape
from .common import to_int_list
from .tflite_flexbuffer import FlexBufferDecoder


__all__ = ["from_tflite"]


Expand All @@ -53,9 +54,9 @@ class OperatorConverter(object):
def __init__(self, model, subgraph, exp_tab):

try:
from tflite.ActivationFunctionType import ActivationFunctionType
from tflite.BuiltinOperator import BuiltinOperator
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -1061,8 +1062,8 @@ def convert_log_softmax(self, op):
def convert_concatenation(self, op):
"""Convert TFLite concatenation"""
try:
from tflite.ConcatenationOptions import ConcatenationOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ConcatenationOptions import ConcatenationOptions
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -1248,10 +1249,10 @@ def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):
"""Generic method to Convert TFLite elemwise"""
try:
from tflite.AddOptions import AddOptions
from tflite.SubOptions import SubOptions
from tflite.MulOptions import MulOptions
from tflite.DivOptions import DivOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.DivOptions import DivOptions
from tflite.MulOptions import MulOptions
from tflite.SubOptions import SubOptions
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -1810,9 +1811,9 @@ def convert_reduce_any(self, op):
def _convert_arg_min_max(self, relay_op, op):
"""Generic method converting TFLite arg_min_max"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ArgMinOptions import ArgMinOptions
from tflite.ArgMaxOptions import ArgMaxOptions
from tflite.ArgMinOptions import ArgMinOptions
from tflite.BuiltinOptions import BuiltinOptions
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -1859,8 +1860,8 @@ def convert_arg_max(self, op):
def convert_fully_connected(self, op):
"""Convert TFLite fully connected"""
try:
from tflite.FullyConnectedOptions import FullyConnectedOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.FullyConnectedOptions import FullyConnectedOptions
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")
Expand Down Expand Up @@ -2030,10 +2031,10 @@ def convert_conv(self, op, conv_type):
"""convolution implementation."""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.TensorType import TensorType
from tflite.Conv2DOptions import Conv2DOptions
from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions
from tflite.Padding import Padding
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -2440,8 +2441,8 @@ def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Pool2DOptions import Pool2DOptions
from tflite.Padding import Padding
from tflite.Pool2DOptions import Pool2DOptions
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -2856,9 +2857,9 @@ def convert_transpose_conv(self, op):
"""Convert TFLite TRANSPOSE_CONV"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Padding import Padding
from tflite.TensorType import TensorType
from tflite.TransposeConvOptions import TransposeConvOptions
from tflite.Padding import Padding
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -2952,7 +2953,7 @@ def convert_transpose_conv(self, op):
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
kernel_layout="IOHW",
out_dtype="int32",
)
else:
Expand All @@ -2964,7 +2965,7 @@ def convert_transpose_conv(self, op):
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
kernel_layout="IOHW",
out_dtype=output_tensor_type_str,
)

Expand Down Expand Up @@ -3723,8 +3724,8 @@ def from_tflite(model, shape_dict=None, dtype_dict=None, op_converter=OperatorCo
The parameter dict to be used by relay
"""
try:
import tflite.SubGraph
import tflite.BuiltinOperator
import tflite.SubGraph
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def conv2d_transpose(
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
kernel_layout="IOHW",
out_layout="",
output_padding=(0, 0),
out_dtype="",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/qnn/op/layout_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def convert_qnn_conv2d_transpose(attrs, inputs, tinfos, desired_layouts):

# Handle default kernel layouts
if desired_data_layout == "NCHW":
new_attrs["kernel_layout"] = "OIHW"
new_attrs["kernel_layout"] = "IOHW"
return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs)
if desired_data_layout == "NHWC":
new_attrs["kernel_layout"] = "HWIO"
Expand Down
Loading
0