diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index adab8da3d518..070b1fc4eb82 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -909,12 +909,20 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); auto maybeThreadDimMappingAttr = getThreadDimMapping(); - FailureOr<ForeachThreadTilingResult> tilingResult = - linalg::tileToForeachThreadOp( - rewriter, target, getAsOpFoldResult(getNumThreads()), - maybeThreadDimMappingAttr - ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) - : ArrayRef<int64_t>{}); + auto dimMapping = + llvm::to_vector(maybeThreadDimMappingAttr + ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) + : ArrayRef<int64_t>{}); + + FailureOr<ForeachThreadTilingResult> tilingResult = failure(); + if (Optional<ArrayAttr> numThreads = getNumThreads()) + tilingResult = linalg::tileToForeachThreadOp( + rewriter, target, getAsOpFoldResult(*numThreads), dimMapping); + + if (Optional<ArrayAttr> tileSizes = getTileSizes()) + tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( + rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping); + if (failed(tilingResult)) return emitDefaultSilenceableFailure(target); rewriter.replaceOp(target, tilingResult->tileOp->getResults()); |