aboutsummaryrefslogtreecommitdiff
path: root/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
blob: 405d93aa02fe41a2ea7fee396ac66c1ecc842470 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
//===- LLVMDialect.h - MLIR LLVM IR dialect ---------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the LLVM IR dialect in MLIR, containing LLVM operations and
// LLVM type system.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
#define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_

#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffects.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"

#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc"

namespace llvm {
class Type;
class LLVMContext;
} // end namespace llvm

namespace mlir {
namespace LLVM {
class LLVMDialect;

namespace detail {
struct LLVMTypeStorage;
struct LLVMDialectImpl;
} // namespace detail

class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
                                             detail::LLVMTypeStorage> {
public:
  enum Kind {
    LLVM_TYPE = FIRST_LLVM_TYPE,
  };

  using Base::Base;

  static bool kindof(unsigned kind) { return kind == LLVM_TYPE; }

  LLVMDialect &getDialect();
  llvm::Type *getUnderlyingType() const;

  /// Utilities to identify types.
  bool isHalfTy() { return getUnderlyingType()->isHalfTy(); }
  bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
  bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); }
  bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
  bool isIntegerTy(unsigned bitwidth) {
    return getUnderlyingType()->isIntegerTy(bitwidth);
  }

  /// Array type utilities.
  LLVMType getArrayElementType();
  unsigned getArrayNumElements();
  bool isArrayTy();

  /// Vector type utilities.
  LLVMType getVectorElementType();
  bool isVectorTy();

  /// Function type utilities.
  LLVMType getFunctionParamType(unsigned argIdx);
  unsigned getFunctionNumParams();
  LLVMType getFunctionResultType();
  bool isFunctionTy();

  /// Pointer type utilities.
  LLVMType getPointerTo(unsigned addrSpace = 0);
  LLVMType getPointerElementTy();
  bool isPointerTy();

  /// Struct type utilities.
  LLVMType getStructElementType(unsigned i);
  unsigned getStructNumElements();
  bool isStructTy();

  /// Utilities used to generate floating point types.
  static LLVMType getDoubleTy(LLVMDialect *dialect);
  static LLVMType getFloatTy(LLVMDialect *dialect);
  static LLVMType getHalfTy(LLVMDialect *dialect);
  static LLVMType getFP128Ty(LLVMDialect *dialect);
  static LLVMType getX86_FP80Ty(LLVMDialect *dialect);

  /// Utilities used to generate integer types.
  static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
  static LLVMType getInt1Ty(LLVMDialect *dialect) {
    return getIntNTy(dialect, /*numBits=*/1);
  }
  static LLVMType getInt8Ty(LLVMDialect *dialect) {
    return getIntNTy(dialect, /*numBits=*/8);
  }
  static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
    return getInt8Ty(dialect).getPointerTo();
  }
  static LLVMType getInt16Ty(LLVMDialect *dialect) {
    return getIntNTy(dialect, /*numBits=*/16);
  }
  static LLVMType getInt32Ty(LLVMDialect *dialect) {
    return getIntNTy(dialect, /*numBits=*/32);
  }
  static LLVMType getInt64Ty(LLVMDialect *dialect) {
    return getIntNTy(dialect, /*numBits=*/64);
  }

  /// Utilities used to generate other miscellaneous types.
  static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements);
  static LLVMType getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
                                bool isVarArg);
  static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
    return getFunctionTy(result, llvm::None, isVarArg);
  }
  static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
                              bool isPacked = false);
  static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
    return getStructTy(dialect, llvm::None, isPacked);
  }
  template <typename... Args>
  static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
                                 LLVMType>::type
  getStructTy(LLVMType elt1, Args... elts) {
    SmallVector<LLVMType, 8> fields({elt1, elts...});
    return getStructTy(&elt1.getDialect(), fields);
  }
  static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);

  /// Void type utilities.
  static LLVMType getVoidTy(LLVMDialect *dialect);
  bool isVoidTy();

  // Creation and setting of LLVM's identified struct types
  static LLVMType createStructTy(LLVMDialect *dialect,
                                 ArrayRef<LLVMType> elements,
                                 Optional<StringRef> name,
                                 bool isPacked = false);

  static LLVMType createStructTy(LLVMDialect *dialect,
                                 Optional<StringRef> name) {
    return createStructTy(dialect, llvm::None, name);
  }

  static LLVMType createStructTy(ArrayRef<LLVMType> elements,
                                 Optional<StringRef> name,
                                 bool isPacked = false) {
    assert(!elements.empty() &&
           "This method may not be invoked with an empty list");
    LLVMType ele0 = elements.front();
    return createStructTy(&ele0.getDialect(), elements, name, isPacked);
  }

  template <typename... Args>
  static typename std::enable_if_t<llvm::are_base_of<LLVMType, Args...>::value,
                                   LLVMType>
  createStructTy(StringRef name, LLVMType elt1, Args... elts) {
    SmallVector<LLVMType, 8> fields({elt1, elts...});
    Optional<StringRef> opt_name(name);
    return createStructTy(&elt1.getDialect(), fields, opt_name);
  }

  static LLVMType setStructTyBody(LLVMType structType,
                                  ArrayRef<LLVMType> elements,
                                  bool isPacked = false);

  template <typename... Args>
  static typename std::enable_if_t<llvm::are_base_of<LLVMType, Args...>::value,
                                   LLVMType>
  setStructTyBody(LLVMType structType, LLVMType elt1, Args... elts) {
    SmallVector<LLVMType, 8> fields({elt1, elts...});
    return setStructTyBody(structType, fields);
  }

private:
  friend LLVMDialect;

  /// Get an LLVMType with a pre-existing llvm type.
  static LLVMType get(MLIRContext *context, llvm::Type *llvmType);

  /// Get an LLVMType with an llvm type that may cause changes to the underlying
  /// llvm context when constructed.
  static LLVMType getLocked(LLVMDialect *dialect,
                            function_ref<llvm::Type *()> typeBuilder);
};

///// Ops /////
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOps.h.inc"

class LLVMDialect : public Dialect {
public:
  explicit LLVMDialect(MLIRContext *context);
  ~LLVMDialect();
  static StringRef getDialectNamespace() { return "llvm"; }

  llvm::LLVMContext &getLLVMContext();
  llvm::Module &getLLVMModule();

  /// Parse a type registered to this dialect.
  Type parseType(DialectAsmParser &parser) const override;

  /// Print a type registered to this dialect.
  void printType(Type type, DialectAsmPrinter &os) const override;

  /// Verify a region argument attribute registered to this dialect.
  /// Returns failure if the verification failed, success otherwise.
  LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
                                         unsigned argIdx,
                                         NamedAttribute argAttr) override;

private:
  friend LLVMType;

  std::unique_ptr<detail::LLVMDialectImpl> impl;
};

/// Create an LLVM global containing the string "value" at the module containing
/// surrounding the insertion point of builder. Obtain the address of that
/// global and use it to compute the address of the first character in the
/// string (operations inserted at the builder insertion point).
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
                         StringRef value, LLVM::Linkage linkage,
                         LLVM::LLVMDialect *llvmDialect);

/// LLVM requires some operations to be inside of a Module operation. This
/// function confirms that the Operation has the desired properties.
bool satisfiesLLVMModule(Operation *op);

} // end namespace LLVM
} // end namespace mlir

#endif // MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_