8000 [BUG] topi.sparse.csrmv only accepts float32, but not other data types · Issue #8406 · apache/tvm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[BUG] topi.sparse.csrmv only accepts float32, but not other data types #8406
Closed
@learning-chip

Description

@learning-chip

Problem description

topi.sparse.csrmv has "float32" hard-coded inside ir_builder and te.extern, making it only accept float32, but not float64 and other data types:

with irb.for_range(0, num_rows, kind="parallel", name="row") as row:
dot = irb.allocate("float32", (1,), name="dot", scope="local")
out_ptr[row] = 0.0

matmul = te.extern(
oshape,
[data, indices, indptr, weight],
lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="csrmv",
dtype="float32",
name="csrmv",
)

Same problem for topi.sparse.csrmm.

Steps to reproduce

Build TVM 0.8dev from the latest master branch, and then run:

# extracted from tests/python/topi/python/test_topi_sparse.py
from tvm import te
from tvm import topi
import tvm.contrib.sparse as tvmsp

dtype = "float64"  # "float32" works fine
nr, nc = (3, 5)
nnz = 6

A = tvmsp.placeholder(shape=(nr, nc), nonzeros=nnz, dtype=dtype, name="A")
B = te.placeholder((nc, 1), name="B")
out = topi.sparse.csrmv(A, B)  # TVMError: Cannot match type float64 vs float32

Full error message:

---------------------------------------------------------------------------
TVMError                                  Traceback (most recent call last)
<ipython-input-1-6daa8fd2fb08> in <module>
      9 A = tvmsp.placeholder(shape=(nr, nc), nonzeros=nnz, dtype=dtype, name="A")
     10 B = te.placeholder((nc, 1), name="B")
---> 11 out = topi.sparse.csrmv(A, B)  # TVMError: Cannot match type float64 vs float32

/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv(a, x, y)
    111         2-D dense matrix with shape [m, 1]
    112     """
--> 113     return csrmv_default(a.data, a.indices, a.indptr, x, y)

/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv_default(data, indices, indptr, weight, bias)
     78
     79     oshape = (batch, 1)
---> 80     matmul = te.extern(
     81         oshape,
     82         [data, indices, indptr, weight],

/tvm_install/tvm/python/tvm/te/operation.py in extern(shape, inputs, fcompute, name, dtype, in_buffers, out_buffers, tag, attrs)
    315         for shp, dt in zip(shape, dtype):
    316             output_placeholders.append(tvm.tir.decl_buffer(shp, dt, name))
--> 317     body = fcompute(input_placeholders, output_placeholders)
    318     if isinstance(body, tvm.tir.PrimExpr):
    319         body = tvm.tir.Evaluate(body)

/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in <lambda>(ins, outs)
     81         oshape,
     82         [data, indices, indptr, weight],
---> 83         lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
     84         tag="csrmv",
     85         dtype="float32",

/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv_default_ir(data, indices, indptr, weight, out)
     73             with irb.for_range(0, row_elems, name="elemidx") as elemidx:
     74                 elem = row_start + elemidx
---> 75                 dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]]
     76             out_ptr[row] += dot[0]
     77         return irb.get()

/tvm_install/tvm/python/tvm/tir/expr.py in __mul__(self, other)
     75
     76     def __mul__(self, other):
---> 77         return _generic.multiply(self, other)
     78
     79     def __rmul__(self, other):

/tvm_install/tvm/python/tvm/topi/generic_op_impl.py in _tensor_bop_impl(lhs, rhs)
     81         """
     82         if not isinstance(lhs, te.tensor.Tensor) and not isinstance(rhs, te.tensor.Tensor):
---> 83             return orig_bop(lhs, rhs)
     84         return broadcast_bop(lhs, rhs)
     85

/tvm_install/tvm/python/tvm/tir/generic.py in multiply(lhs, rhs, span)
     84         The result Expr of multiply operaton.
     85     """
---> 86     return _ffi_api._OpMul(lhs, rhs, span)
     87
     88

/tvm_install/tvm/python/tvm/_ffi/_ctypes/packed_func.py in __call__(self, *args)
    235             != 0
    236         ):
--> 237             raise get_last_ffi_error()
    238         _ = temp_args
    239         _ = args

TVMError: Traceback (most recent call last):
  3: TVMFuncCall
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::PrimExpr (tvm::PrimExpr, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}>(tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::mul(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
  0: tvm::BinaryOpMatchTypes(tvm::PrimExpr&, tvm::PrimExpr&, tvm::Span)
  File "/tvm_install/tvm/src/tir/op/op.cc", line 144
TVMError: Cannot match type float64 vs float32

Desired fix

  • topi.sparse.{csrmv, csrmm} should be independent of data type.
  • Add unit tests to tests/python/topi/python/test_topi_sparse.py to make sure multiple data types work

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0