aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorJacques Pienaar <jpienaar@google.com>2022-03-28 11:24:47 -0700
committerJacques Pienaar <jpienaar@google.com>2022-03-28 11:24:47 -0700
commit7c38fd605ba85657a0ecbea75a8e3a68174d3dff (patch)
tree420972f033748b603360f29cd414847fcaa3bdbd /mlir/lib/Conversion/VectorToLLVM
parent1066e397fa907629f0da370f9721821c838ed30a (diff)
[mlir] Flip Vector dialect accessors used to prefixed form.
This has been on _Both for a couple of weeks. Flip usages in core with intention to flip flag to _Prefixed in follow up. Needed to add a couple of helper methods in AffineOps and Linalg to facilitate a pure flag flip in follow up as some of these classes are used in templates and so sensitive to Vector dialect changes. Differential Revision: https://reviews.llvm.org/D122151
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp118
1 files changed, 59 insertions, 59 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 20e51008c52b..3f6b3524b896 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -155,9 +155,9 @@ public:
matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
- matmulOp, typeConverter->convertType(matmulOp.res().getType()),
- adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
- matmulOp.lhs_columns(), matmulOp.rhs_columns());
+ matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
+ adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
+ matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
return success();
}
};
@@ -173,8 +173,8 @@ public:
matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
- transOp, typeConverter->convertType(transOp.res().getType()),
- adaptor.matrix(), transOp.rows(), transOp.columns());
+ transOp, typeConverter->convertType(transOp.getRes().getType()),
+ adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
return success();
}
};
@@ -194,14 +194,14 @@ static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
- loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
+ loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
}
static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
vector::StoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
ptr, align);
}
@@ -210,7 +210,7 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
- storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
+ storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
}
/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
@@ -240,8 +240,8 @@ public:
// Resolve address.
auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
.template cast<VectorType>();
- Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
- adaptor.indices(), rewriter);
+ Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
@@ -269,16 +269,16 @@ public:
// Resolve address.
Value ptrs;
VectorType vType = gather.getVectorType();
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
- if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
- adaptor.index_vec(), memRefType, vType, ptrs)))
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
+ adaptor.getIndexVec(), memRefType, vType, ptrs)))
return failure();
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
- adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
+ gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+ adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
return success();
}
};
@@ -303,15 +303,15 @@ public:
// Resolve address.
Value ptrs;
VectorType vType = scatter.getVectorType();
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
- if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
- adaptor.index_vec(), memRefType, vType, ptrs)))
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
+ adaptor.getIndexVec(), memRefType, vType, ptrs)))
return failure();
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
- scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
+ scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
rewriter.getI32IntegerAttr(align));
return success();
}
@@ -331,11 +331,11 @@ public:
// Resolve address.
auto vtype = typeConverter->convertType(expand.getVectorType());
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
- expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
+ expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
return success();
}
};
@@ -353,11 +353,11 @@ public:
MemRefType memRefType = compress.getMemRefType();
// Resolve address.
- Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
- compress, adaptor.valueToStore(), ptr, adaptor.mask());
+ compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
return success();
}
};
@@ -374,8 +374,8 @@ public:
LogicalResult
matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto kind = reductionOp.kind();
- Type eltType = reductionOp.dest().getType();
+ auto kind = reductionOp.getKind();
+ Type eltType = reductionOp.getDest().getType();
Type llvmType = typeConverter->convertType(eltType);
Value operand = adaptor.getOperands()[0];
if (eltType.isIntOrIndex()) {
@@ -468,7 +468,7 @@ public:
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
Type llvmType = typeConverter->convertType(vectorType);
- auto maskArrayAttr = shuffleOp.mask();
+ auto maskArrayAttr = shuffleOp.getMask();
// Bail if result type cannot be lowered.
if (!llvmType)
@@ -484,7 +484,7 @@ public:
// there is direct shuffle support in LLVM. Use it!
if (rank == 1 && v1Type == v2Type) {
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
- loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
+ loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr);
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
}
@@ -499,10 +499,10 @@ public:
int64_t insPos = 0;
for (const auto &en : llvm::enumerate(maskArrayAttr)) {
int64_t extPos = en.value().cast<IntegerAttr>().getInt();
- Value value = adaptor.v1();
+ Value value = adaptor.getV1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
- value = adaptor.v2();
+ value = adaptor.getV2();
}
Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
eltType, rank, extPos);
@@ -537,12 +537,12 @@ public:
loc, typeConverter->convertType(idxType),
rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.vector(), zero);
+ extractEltOp, llvmType, adaptor.getVector(), zero);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.vector(), adaptor.position());
+ extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
return success();
}
};
@@ -559,7 +559,7 @@ public:
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
- auto positionArrayAttr = extractOp.position();
+ auto positionArrayAttr = extractOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
@@ -567,21 +567,21 @@ public:
// Extract entire vector. Should be handled by folder, but just to be safe.
if (positionArrayAttr.empty()) {
- rewriter.replaceOp(extractOp, adaptor.vector());
+ rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmResultType, adaptor.vector(), positionArrayAttr);
+ loc, llvmResultType, adaptor.getVector(), positionArrayAttr);
rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
auto *context = extractOp->getContext();
- Value extracted = adaptor.vector();
+ Value extracted = adaptor.getVector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
auto oneDVectorType = reducedVectorTypeBack(vectorType);
@@ -628,8 +628,8 @@ public:
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
- adaptor.rhs(), adaptor.acc());
+ rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
+ fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
return success();
}
};
@@ -656,13 +656,13 @@ public:
loc, typeConverter->convertType(idxType),
rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
+ insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
- adaptor.position());
+ insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
+ adaptor.getPosition());
return success();
}
};
@@ -679,7 +679,7 @@ public:
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
- auto positionArrayAttr = insertOp.position();
+ auto positionArrayAttr = insertOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
@@ -688,14 +688,14 @@ public:
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
if (positionArrayAttr.empty()) {
- rewriter.replaceOp(insertOp, adaptor.source());
+ rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
Value inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, llvmResultType, adaptor.dest(), adaptor.source(),
+ loc, llvmResultType, adaptor.getDest(), adaptor.getSource(),
positionArrayAttr);
rewriter.replaceOp(insertOp, inserted);
return success();
@@ -703,7 +703,7 @@ public:
// Potential extraction of 1-D vector from array.
auto *context = insertOp->getContext();
- Value extracted = adaptor.dest();
+ Value extracted = adaptor.getDest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
auto oneDVectorType = destVectorType;
@@ -721,15 +721,15 @@ public:
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
- adaptor.source(), constant);
+ adaptor.getSource(), constant);
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
auto nMinusOnePositionAttrs =
ArrayAttr::get(context, positionAttrs.drop_back());
- inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
- adaptor.dest(), inserted,
- nMinusOnePositionAttrs);
+ inserted = rewriter.create<LLVM::InsertValueOp>(
+ loc, llvmResultType, adaptor.getDest(), inserted,
+ nMinusOnePositionAttrs);
}
rewriter.replaceOp(insertOp, inserted);
@@ -780,9 +780,9 @@ public:
loc, elemType, rewriter.getZeroAttr(elemType));
Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
- Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
- Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
- Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
+ Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
+ Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
+ Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
desc = rewriter.create<InsertOp>(loc, fma, desc, i);
}
@@ -1009,7 +1009,7 @@ public:
// Unroll vector into elementary print calls.
int64_t rank = vectorType ? vectorType.getRank() : 0;
Type type = vectorType ? vectorType : eltType;
- emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
+ emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
conversion);
emitCall(rewriter, printOp->getLoc(),
LLVM::lookupOrCreatePrintNewlineFn(
@@ -1119,13 +1119,13 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
// For 0-d vector, we simply do `insertelement`.
if (resultType.getRank() == 0) {
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- splatOp, vectorType, undef, adaptor.input(), zero);
+ splatOp, vectorType, undef, adaptor.getInput(), zero);
return success();
}
// For 1-d vector, we additionally do a `vectorshuffle`.
auto v = rewriter.create<LLVM::InsertElementOp>(
- splatOp.getLoc(), vectorType, undef, adaptor.input(), zero);
+ splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
SmallVector<int32_t, 4> zeroValues(width, 0);
@@ -1170,7 +1170,7 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
- adaptor.input(), zero);
+ adaptor.getInput(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);