-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][SPIRV] Add decorateType method for MatrixType #112018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
edited
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: MingZhu Yan (trdthg) Changestry fix #108161 This PR adds a decorateType method for MatrixType, ensuring that Full diff: https://github.com/llvm/llvm-project/pull/112018.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
index 0c61f7eb54e2da..72683d50d74117 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Utils/LayoutUtils.h
@@ -24,6 +24,7 @@ namespace spirv {
class ArrayType;
class RuntimeArrayType;
class StructType;
+class MatrixType;
} // namespace spirv
/// According to the Vulkan spec "15.6.4. Offset and Stride Assignment":
@@ -67,6 +68,8 @@ class VulkanLayoutUtils {
static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
static Type decorateType(spirv::ArrayType arrayType, Size &size,
Size &alignment);
+ static Type decorateType(spirv::MatrixType matrixType, Size &size,
+ Size &alignment);
static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment);
static spirv::StructType decorateType(spirv::StructType structType,
Size &size, Size &alignment);
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index b19495bc374452..ede9397fbc552e 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -91,6 +91,8 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
return decorateType(arrayType, size, alignment);
if (auto vectorType = dyn_cast<VectorType>(type))
return decorateType(vectorType, size, alignment);
+ if (auto matrixType = dyn_cast<spirv::MatrixType>(type))
+ return decorateType(matrixType, size, alignment);
if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
size = std::numeric_limits<Size>().max();
return decorateType(arrayType, alignment);
@@ -138,6 +140,25 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
return spirv::ArrayType::get(memberType, numElements, elementSize);
}
+Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType,
+ VulkanLayoutUtils::Size &size,
+ VulkanLayoutUtils::Size &alignment) {
+ const auto numColumns = matrixType.getNumColumns();
+ const auto columnType = matrixType.getColumnType();
+ const auto numElements = matrixType.getNumElements();
+ auto elementType = matrixType.getElementType();
+ Size elementSize = 0;
+ Size elementAlignment = 1;
+
+ decorateType(elementType, elementSize, elementAlignment);
+ // According to the Vulkan spec:
+ // "A matrix type inherits scalar alignment from the equivalent array
+ // declaration.
+ size = elementSize * numElements;
+ alignment = elementAlignment;
+ return spirv::MatrixType::get(columnType, numColumns);
+}
+
Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
VulkanLayoutUtils::Size &alignment) {
auto elementType = arrayType.getElementType();
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 05ab91b6db6bd9..b63a08d96e6af9 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -497,6 +497,11 @@ func.func private @matrix_type(!spirv.matrix<4 x vector<4xf16>>) -> ()
// -----
+// CHECK: func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>)
+func.func private @matrix_type(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0])>) -> ()
+
+// -----
+
// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
func.func private @matrix_invalid_size(!spirv.matrix<5 x vector<3xf32>>) -> ()
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic looks OK, just a few minor issues
32bd29d
to
34f0aea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nit
34f0aea
to
cd90206
Compare
@trdthg seems like the code doesn't build |
This PR adds a decorateType method for MatrixType, ensuring that `spirv.matrix` with offset in `spirv.struct` can be handled correctly. Signed-off-by: MingZhu Yan <yanmingzhu@iscas.ac.cn>
cd90206
to
16ce0e8
Compare
fixed |
Resolve #112018 (comment) As described in clang-tidy, the auto type specifier will only be introduced in - Iterators - New expressions - Cast expressions https://clang.llvm.org/extra/clang-tidy/checks/modernize/use-auto.html
Resolve llvm/llvm-project#112018 (comment) As described in clang-tidy, the auto type specifier will only be introduced in - Iterators - New expressions - Cast expressions https://clang.llvm.org/extra/clang-tidy/checks/modernize/use-auto.html
Fixes llvm#108161 This PR adds a decorateType method for MatrixType, ensuring that `spirv.matrix` with offset in `spirv.struct` can be handled correctly. Signed-off-by: MingZhu Yan <yanmingzhu@iscas.ac.cn>
Resolve llvm#112018 (comment) As described in clang-tidy, the auto type specifier will only be introduced in - Iterators - New expressions - Cast expressions https://clang.llvm.org/extra/clang-tidy/checks/modernize/use-auto.html
try fix #108161
This PR adds a decorateType method for MatrixType, ensuring that
spirv.matrix
with offset inspirv.struct
can be handled correctly.