Closed
Description
Problem description
topi.sparse.csrmv
has "float32"
hard-coded inside ir_builder
and te.extern
, making it only accept float32, but not float64 and other data types:
tvm/python/tvm/topi/sparse/csrmv.py
Lines 66 to 68 in d3fc562
tvm/python/tvm/topi/sparse/csrmv.py
Lines 80 to 87 in d3fc562
Same problem for topi.sparse.csrmm
.
Steps to reproduce
Build TVM 0.8dev from the latest master branch, and then run:
# extracted from tests/python/topi/python/test_topi_sparse.py
from tvm import te
from tvm import topi
import tvm.contrib.sparse as tvmsp
dtype = "float64" # "float32" works fine
nr, nc = (3, 5)
nnz = 6
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=nnz, dtype=dtype, name="A")
B = te.placeholder((nc, 1), name="B")
out = topi.sparse.csrmv(A, B) # TVMError: Cannot match type float64 vs float32
Full error message:
---------------------------------------------------------------------------
TVMError Traceback (most recent call last)
<ipython-input-1-6daa8fd2fb08> in <module>
9 A = tvmsp.placeholder(shape=(nr, nc), nonzeros=nnz, dtype=dtype, name="A")
10 B = te.placeholder((nc, 1), name="B")
---> 11 out = topi.sparse.csrmv(A, B) # TVMError: Cannot match type float64 vs float32
/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv(a, x, y)
111 2-D dense matrix with shape [m, 1]
112 """
--> 113 return csrmv_default(a.data, a.indices, a.indptr, x, y)
/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv_default(data, indices, indptr, weight, bias)
78
79 oshape = (batch, 1)
---> 80 matmul = te.extern(
81 oshape,
82 [data, indices, indptr, weight],
/tvm_install/tvm/python/tvm/te/operation.py in extern(shape, inputs, fcompute, name, dtype, in_buffers, out_buffers, tag, attrs)
315 for shp, dt in zip(shape, dtype):
316 output_placeholders.append(tvm.tir.decl_buffer(shp, dt, name))
--> 317 body = fcompute(input_placeholders, output_placeholders)
318 if isinstance(body, tvm.tir.PrimExpr):
319 body = tvm.tir.Evaluate(body)
/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in <lambda>(ins, outs)
81 oshape,
82 [data, indices, indptr, weight],
---> 83 lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
84 tag="csrmv",
85 dtype="float32",
/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv_default_ir(data, indices, indptr, weight, out)
73 with irb.for_range(0, row_elems, name="elemidx") as elemidx:
74 elem = row_start + elemidx
---> 75 dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]]
76 out_ptr[row] += dot[0]
77 return irb.get()
/tvm_install/tvm/python/tvm/tir/expr.py in __mul__(self, other)
75
76 def __mul__(self, other):
---> 77 return _generic.multiply(self, other)
78
79 def __rmul__(self, other):
/tvm_install/tvm/python/tvm/topi/generic_op_impl.py in _tensor_bop_impl(lhs, rhs)
81 """
82 if not isinstance(lhs, te.tensor.Tensor) and not isinstance(rhs, te.tensor.Tensor):
---> 83 return orig_bop(lhs, rhs)
84 return broadcast_bop(lhs, rhs)
85
/tvm_install/tvm/python/tvm/tir/generic.py in multiply(lhs, rhs, span)
84 The result Expr of multiply operaton.
85 """
---> 86 return _ffi_api._OpMul(lhs, rhs, span)
87
88
/tvm_install/tvm/python/tvm/_ffi/_ctypes/packed_func.py in __call__(self, *args)
235 != 0
236 ):
--> 237 raise get_last_ffi_error()
238 _ = temp_args
239 _ = args
TVMError: Traceback (most recent call last):
3: TVMFuncCall
2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::PrimExpr (tvm::PrimExpr, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}>(tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
1: tvm::mul(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
0: tvm::BinaryOpMatchTypes(tvm::PrimExpr&, tvm::PrimExpr&, tvm::Span)
File "/tvm_install/tvm/src/tir/op/op.cc", line 144
TVMError: Cannot match type float64 vs float32
Desired fix
topi.sparse.{csrmv, csrmm}
should be independent of data type.- Add unit tests to
tests/python/topi/python/test_topi_sparse.py
to make sure multiple data types work
References
Metadata
Metadata
Assignees
Labels
No labels