//===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V /// location invocation ID. This function will create necessary operations with /// `builder` at the proper region containing `op`. static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType, Location loc, OpBuilder *builder) { assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions"); Value invocation = spirv::getBuiltinVariableValue( op, spirv::BuiltIn::LocalInvocationId, integerType, *builder); Type xType = invocation.getType().cast().getElementType(); return builder->create( loc, xType, invocation, builder->getI32ArrayAttr({dim})); } //===----------------------------------------------------------------------===// // Reduction (single workgroup) //===----------------------------------------------------------------------===// namespace { /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition /// that the linalg.generic op is performing reduction with a workload size that /// can fit in one workgroup. struct SingleWorkgroupReduction final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; /// Matches the given linalg.generic op as performing reduction and returns /// the binary op kind if successful. static Optional matchAsPerformingReduction(linalg::GenericOp genericOp); LogicalResult matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace Optional SingleWorkgroupReduction::matchAsPerformingReduction( linalg::GenericOp genericOp) { Operation *op = genericOp.getOperation(); // Make sure the linalg.generic is working on memrefs. if (!genericOp.hasBufferSemantics()) return llvm::None; // Make sure this is reduction with one input and one output. if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1) return llvm::None; auto originalInputType = op->getOperand(0).getType().cast(); auto originalOutputType = op->getOperand(1).getType().cast(); // Make sure the original input has one dimension. if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1) return llvm::None; // Make sure the original output has one element. if (!originalOutputType.hasStaticShape() || originalOutputType.getNumElements() != 1) return llvm::None; if (!genericOp.hasSingleReductionLoop()) return llvm::None; auto indexingMaps = genericOp.getIndexingMapsArray(); if (indexingMaps.size() != 2) return llvm::None; // TODO: create utility functions for these checks in Linalg // and use them. auto inputMap = indexingMaps[0]; auto outputMap = indexingMaps[1]; // The indexing map for the input should be `(i) -> (i)`. if (inputMap != AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext()))) return llvm::None; // The indexing map for the input should be `(i) -> (0)`. if (outputMap != AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext()))) return llvm::None; return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp); } LogicalResult SingleWorkgroupReduction::matchAndRewrite( linalg::GenericOp genericOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Operation *op = genericOp.getOperation(); auto originalInputType = op->getOperand(0).getType().cast(); auto originalOutputType = op->getOperand(1).getType().cast(); auto binaryOpKind = matchAsPerformingReduction(genericOp); if (!binaryOpKind) return failure(); // Query the shader interface for local workgroup size to make sure the // invocation configuration fits with the input memref's shape. DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp); if (!localSize) return failure(); if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0)) return failure(); if (llvm::any_of(llvm::drop_begin(localSize.getValues(), 1), [](const APInt &size) { return !size.isOneValue(); })) return failure(); // TODO: Query the target environment to make sure the current // workload fits in a local workgroup. Value convertedInput = adaptor.getOperands()[0]; Value convertedOutput = adaptor.getOperands()[1]; Location loc = genericOp.getLoc(); auto *typeConverter = getTypeConverter(); auto indexType = typeConverter->getIndexType(); // Get the invocation ID. Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc, &rewriter); // TODO: Load to Workgroup storage class first. // Get the input element accessed by this invocation. Value inputElementPtr = spirv::getElementPtr( *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); Value inputElement = rewriter.create(loc, inputElementPtr); // Perform the group reduction operation. Value groupOperation; #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \ case linalg::RegionMatcher::BinaryOpKind::opKind: { \ groupOperation = rewriter.create( \ loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \ spirv::GroupOperation::Reduce, inputElement, \ /*cluster_size=*/nullptr); \ } break switch (*binaryOpKind) { CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp); } #undef CREATE_GROUP_NON_UNIFORM_BIN_OP // Get the output element accessed by this reduction. Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter); SmallVector zeroIndices(originalOutputType.getRank(), zero); Value outputElementPtr = spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput, zeroIndices, loc, rewriter); // Write out the final reduction result. This should be only conducted by one // invocation. We use spv.GroupNonUniformElect to find the invocation with the // lowest ID. // // ``` // if (spv.GroupNonUniformElect) { output = ... } // ``` Value condition = rewriter.create( loc, spirv::Scope::Subgroup); auto createAtomicOp = [&](OpBuilder &builder) { #define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \ case linalg::RegionMatcher::BinaryOpKind::opKind: { \ builder.create(loc, outputElementPtr, spirv::Scope::Device, \ spirv::MemorySemantics::AcquireRelease, \ groupOperation); \ } break switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); } #undef CREATE_ATOMIC_BIN_OP }; spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter); rewriter.eraseOp(genericOp); return success(); } //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); }