aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp20
1 files changed, 14 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index adab8da3d518..070b1fc4eb82 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -909,12 +909,20 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
IRRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
auto maybeThreadDimMappingAttr = getThreadDimMapping();
- FailureOr<ForeachThreadTilingResult> tilingResult =
- linalg::tileToForeachThreadOp(
- rewriter, target, getAsOpFoldResult(getNumThreads()),
- maybeThreadDimMappingAttr
- ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
- : ArrayRef<int64_t>{});
+ auto dimMapping =
+ llvm::to_vector(maybeThreadDimMappingAttr
+ ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
+ : ArrayRef<int64_t>{});
+
+ FailureOr<ForeachThreadTilingResult> tilingResult = failure();
+ if (Optional<ArrayAttr> numThreads = getNumThreads())
+ tilingResult = linalg::tileToForeachThreadOp(
+ rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
+
+ if (Optional<ArrayAttr> tileSizes = getTileSizes())
+ tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
+ rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
+
if (failed(tilingResult))
return emitDefaultSilenceableFailure(target);
rewriter.replaceOp(target, tilingResult->tileOp->getResults());