diff options
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 20 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 13 |
2 files changed, 12 insertions, 21 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index d6bc06218a91..e01e30b73bff 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -985,26 +985,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { >, InterfaceMethod< /*desc=*/[{ - Return the range of position in the result of the affine map - computed by getLoopsToShapesMap() which correspond to the - AffineExprs used to access the outputs of the operation. - }], - /*retTy=*/"std::pair<int64_t, int64_t>", - /*methodName=*/"getResultsPositionInLoopsToShapeMap", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - int64_t inputRankSum = 0; - int64_t outputRankSum = 0; - for(OpOperand *input : getInputOperands()) - inputRankSum += getRank(input); - for(OpOperand *output : getOutputOperands()) - outputRankSum += getRank(output); - return {inputRankSum, inputRankSum + outputRankSum}; - }] - >, - InterfaceMethod< - /*desc=*/[{ Like `getShape`, but only returns statically-known information, without generating any new IR. For each shape dimension, returns >=0 if that dimension is statically known, or ShapeType::kDynamicSize otherwise. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f389c18efe5e..3b63824b829a 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -566,6 +566,17 @@ private: llvm::SmallBitVector positions; }; +static std::pair<int64_t, int64_t> +getResultsPositionInLoopsToShapeMap(LinalgOp &op) { + int64_t inputRankSum = 0; + int64_t outputRankSum = 0; + for (OpOperand *input : op.getInputOperands()) + inputRankSum += op.getRank(input); + for (OpOperand *output : op.getOutputOperands()) + outputRankSum += op.getRank(output); + return {inputRankSum, inputRankSum + outputRankSum}; +} + LogicalResult LinalgOp::reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { @@ -582,7 +593,7 @@ LinalgOp::reifyResultShapes(OpBuilder &b, // Find the position in the above map that represents the shape of the // result:dim being inferred. - auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(); + auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this); /// From loopsToShapesMap extract the submap that represents the shape of the /// (resultIdx, dim) needed. |