diff options
author | Jacques Pienaar <jpienaar@google.com> | 2022-03-28 11:24:47 -0700 |
---|---|---|
committer | Jacques Pienaar <jpienaar@google.com> | 2022-03-28 11:24:47 -0700 |
commit | 7c38fd605ba85657a0ecbea75a8e3a68174d3dff (patch) | |
tree | 420972f033748b603360f29cd414847fcaa3bdbd /mlir/lib/Conversion/VectorToLLVM | |
parent | 1066e397fa907629f0da370f9721821c838ed30a (diff) |
[mlir] Flip Vector dialect accessors used to prefixed form.
This has been on _Both for a couple of weeks. Flip usages in core with
intention to flip flag to _Prefixed in follow up. Needed to add a couple
of helper methods in AffineOps and Linalg to facilitate a pure flag flip
in follow up as some of these classes are used in templates and so
sensitive to Vector dialect changes.
Differential Revision: https://reviews.llvm.org/D122151
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 118 |
1 files changed, 59 insertions, 59 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 20e51008c52b..3f6b3524b896 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -155,9 +155,9 @@ public: matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( - matmulOp, typeConverter->convertType(matmulOp.res().getType()), - adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), - matmulOp.lhs_columns(), matmulOp.rhs_columns()); + matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), + adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), + matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); return success(); } }; @@ -173,8 +173,8 @@ public: matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( - transOp, typeConverter->convertType(transOp.res().getType()), - adaptor.matrix(), transOp.rows(), transOp.columns()); + transOp, typeConverter->convertType(transOp.getRes().getType()), + adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); return success(); } }; @@ -194,14 +194,14 @@ static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( - loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align); + loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); } static void replaceLoadOrStoreOp(vector::StoreOp storeOp, vector::StoreOpAdaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { - rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(), + rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), ptr, align); } @@ -210,7 +210,7 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( - storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align); + storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); } /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and @@ -240,8 +240,8 @@ public: // Resolve address. auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) .template cast<VectorType>(); - Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(), - adaptor.indices(), rewriter); + Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), + adaptor.getIndices(), rewriter); Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); @@ -269,16 +269,16 @@ public: // Resolve address. Value ptrs; VectorType vType = gather.getVectorType(); - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); - if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, - adaptor.index_vec(), memRefType, vType, ptrs))) + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), + adaptor.getIndices(), rewriter); + if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, + adaptor.getIndexVec(), memRefType, vType, ptrs))) return failure(); // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp<LLVM::masked_gather>( - gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), - adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); + gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), + adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); return success(); } }; @@ -303,15 +303,15 @@ public: // Resolve address. Value ptrs; VectorType vType = scatter.getVectorType(); - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); - if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, - adaptor.index_vec(), memRefType, vType, ptrs))) + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), + adaptor.getIndices(), rewriter); + if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, + adaptor.getIndexVec(), memRefType, vType, ptrs))) return failure(); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( - scatter, adaptor.valueToStore(), ptrs, adaptor.mask(), + scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), rewriter.getI32IntegerAttr(align)); return success(); } @@ -331,11 +331,11 @@ public: // Resolve address. auto vtype = typeConverter->convertType(expand.getVectorType()); - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), + adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( - expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); + expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); return success(); } }; @@ -353,11 +353,11 @@ public: MemRefType memRefType = compress.getMemRefType(); // Resolve address. - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), + adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( - compress, adaptor.valueToStore(), ptr, adaptor.mask()); + compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); return success(); } }; @@ -374,8 +374,8 @@ public: LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto kind = reductionOp.kind(); - Type eltType = reductionOp.dest().getType(); + auto kind = reductionOp.getKind(); + Type eltType = reductionOp.getDest().getType(); Type llvmType = typeConverter->convertType(eltType); Value operand = adaptor.getOperands()[0]; if (eltType.isIntOrIndex()) { @@ -468,7 +468,7 @@ public: auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getVectorType(); Type llvmType = typeConverter->convertType(vectorType); - auto maskArrayAttr = shuffleOp.mask(); + auto maskArrayAttr = shuffleOp.getMask(); // Bail if result type cannot be lowered. if (!llvmType) @@ -484,7 +484,7 @@ public: // there is direct shuffle support in LLVM. Use it! if (rank == 1 && v1Type == v2Type) { Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( - loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); + loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr); rewriter.replaceOp(shuffleOp, llvmShuffleOp); return success(); } @@ -499,10 +499,10 @@ public: int64_t insPos = 0; for (const auto &en : llvm::enumerate(maskArrayAttr)) { int64_t extPos = en.value().cast<IntegerAttr>().getInt(); - Value value = adaptor.v1(); + Value value = adaptor.getV1(); if (extPos >= v1Dim) { extPos -= v1Dim; - value = adaptor.v2(); + value = adaptor.getV2(); } Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, eltType, rank, extPos); @@ -537,12 +537,12 @@ public: loc, typeConverter->convertType(idxType), rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.vector(), zero); + extractEltOp, llvmType, adaptor.getVector(), zero); return success(); } rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.vector(), adaptor.position()); + extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); return success(); } }; @@ -559,7 +559,7 @@ public: auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter->convertType(resultType); - auto positionArrayAttr = extractOp.position(); + auto positionArrayAttr = extractOp.getPosition(); // Bail if result type cannot be lowered. if (!llvmResultType) @@ -567,21 +567,21 @@ public: // Extract entire vector. Should be handled by folder, but just to be safe. if (positionArrayAttr.empty()) { - rewriter.replaceOp(extractOp, adaptor.vector()); + rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa<VectorType>()) { Value extracted = rewriter.create<LLVM::ExtractValueOp>( - loc, llvmResultType, adaptor.vector(), positionArrayAttr); + loc, llvmResultType, adaptor.getVector(), positionArrayAttr); rewriter.replaceOp(extractOp, extracted); return success(); } // Potential extraction of 1-D vector from array. auto *context = extractOp->getContext(); - Value extracted = adaptor.vector(); + Value extracted = adaptor.getVector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); @@ -628,8 +628,8 @@ public: VectorType vType = fmaOp.getVectorType(); if (vType.getRank() != 1) return failure(); - rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), - adaptor.rhs(), adaptor.acc()); + rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>( + fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; @@ -656,13 +656,13 @@ public: loc, typeConverter->convertType(idxType), rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero); + insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); return success(); } rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.dest(), adaptor.source(), - adaptor.position()); + insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), + adaptor.getPosition()); return success(); } }; @@ -679,7 +679,7 @@ public: auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); - auto positionArrayAttr = insertOp.position(); + auto positionArrayAttr = insertOp.getPosition(); // Bail if result type cannot be lowered. if (!llvmResultType) @@ -688,14 +688,14 @@ public: // Overwrite entire vector with value. Should be handled by folder, but // just to be safe. if (positionArrayAttr.empty()) { - rewriter.replaceOp(insertOp, adaptor.source()); + rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa<VectorType>()) { Value inserted = rewriter.create<LLVM::InsertValueOp>( - loc, llvmResultType, adaptor.dest(), adaptor.source(), + loc, llvmResultType, adaptor.getDest(), adaptor.getSource(), positionArrayAttr); rewriter.replaceOp(insertOp, inserted); return success(); @@ -703,7 +703,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = insertOp->getContext(); - Value extracted = adaptor.dest(); + Value extracted = adaptor.getDest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast<IntegerAttr>(); auto oneDVectorType = destVectorType; @@ -721,15 +721,15 @@ public: auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); Value inserted = rewriter.create<LLVM::InsertElementOp>( loc, typeConverter->convertType(oneDVectorType), extracted, - adaptor.source(), constant); + adaptor.getSource(), constant); // Potential insertion of resulting 1-D vector into array. if (positionAttrs.size() > 1) { auto nMinusOnePositionAttrs = ArrayAttr::get(context, positionAttrs.drop_back()); - inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, - adaptor.dest(), inserted, - nMinusOnePositionAttrs); + inserted = rewriter.create<LLVM::InsertValueOp>( + loc, llvmResultType, adaptor.getDest(), inserted, + nMinusOnePositionAttrs); } rewriter.replaceOp(insertOp, inserted); @@ -780,9 +780,9 @@ public: loc, elemType, rewriter.getZeroAttr(elemType)); Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { - Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); - Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); - Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); + Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); + Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); + Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); desc = rewriter.create<InsertOp>(loc, fma, desc, i); } @@ -1009,7 +1009,7 @@ public: // Unroll vector into elementary print calls. int64_t rank = vectorType ? vectorType.getRank() : 0; Type type = vectorType ? vectorType : eltType; - emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank, + emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, conversion); emitCall(rewriter, printOp->getLoc(), LLVM::lookupOrCreatePrintNewlineFn( @@ -1119,13 +1119,13 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { // For 0-d vector, we simply do `insertelement`. if (resultType.getRank() == 0) { rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - splatOp, vectorType, undef, adaptor.input(), zero); + splatOp, vectorType, undef, adaptor.getInput(), zero); return success(); } // For 1-d vector, we additionally do a `vectorshuffle`. auto v = rewriter.create<LLVM::InsertElementOp>( - splatOp.getLoc(), vectorType, undef, adaptor.input(), zero); + splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); SmallVector<int32_t, 4> zeroValues(width, 0); @@ -1170,7 +1170,7 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, - adaptor.input(), zero); + adaptor.getInput(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); |