aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td20
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp13
2 files changed, 12 insertions, 21 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index d6bc06218a91..e01e30b73bff 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -985,26 +985,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
/*desc=*/[{
- Return the range of position in the result of the affine map
- computed by getLoopsToShapesMap() which correspond to the
- AffineExprs used to access the outputs of the operation.
- }],
- /*retTy=*/"std::pair<int64_t, int64_t>",
- /*methodName=*/"getResultsPositionInLoopsToShapeMap",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- int64_t inputRankSum = 0;
- int64_t outputRankSum = 0;
- for(OpOperand *input : getInputOperands())
- inputRankSum += getRank(input);
- for(OpOperand *output : getOutputOperands())
- outputRankSum += getRank(output);
- return {inputRankSum, inputRankSum + outputRankSum};
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
Like `getShape`, but only returns statically-known information, without
generating any new IR. For each shape dimension, returns >=0 if that
dimension is statically known, or ShapeType::kDynamicSize otherwise.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f389c18efe5e..3b63824b829a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -566,6 +566,17 @@ private:
llvm::SmallBitVector positions;
};
+static std::pair<int64_t, int64_t>
+getResultsPositionInLoopsToShapeMap(LinalgOp &op) {
+ int64_t inputRankSum = 0;
+ int64_t outputRankSum = 0;
+ for (OpOperand *input : op.getInputOperands())
+ inputRankSum += op.getRank(input);
+ for (OpOperand *output : op.getOutputOperands())
+ outputRankSum += op.getRank(output);
+ return {inputRankSum, inputRankSum + outputRankSum};
+}
+
LogicalResult
LinalgOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
@@ -582,7 +593,7 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
// Find the position in the above map that represents the shape of the
// result:dim being inferred.
- auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
+ auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
/// From loopsToShapesMap extract the submap that represents the shape of the
/// (resultIdx, dim) needed.