aboutsummaryrefslogtreecommitdiff
path: root/mlir/tools/mlir-tblgen/DialectGen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/tools/mlir-tblgen/DialectGen.cpp')
-rw-r--r--mlir/tools/mlir-tblgen/DialectGen.cpp166
1 files changed, 166 insertions, 0 deletions
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
new file mode 100644
index 000000000000..c0009d6e1231
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -0,0 +1,166 @@
+//===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// DialectGen uses the description of dialects to generate C++ definitions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Support/StringExtras.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/OpClass.h"
+#include "mlir/TableGen/OpInterfaces.h"
+#include "mlir/TableGen/OpTrait.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+#define DEBUG_TYPE "mlir-tblgen-opdefgen"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
+static llvm::cl::opt<std::string>
+ selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
+ llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
+
+/// Given a set of records for a T, filter the ones that correspond to
+/// the given dialect.
+template <typename T>
+static auto filterForDialect(ArrayRef<llvm::Record *> records,
+ Dialect &dialect) {
+ return llvm::make_filter_range(records, [&](const llvm::Record *record) {
+ return T(record).getDialect() == dialect;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Dialect declarations
+//===----------------------------------------------------------------------===//
+
+/// The code block for the start of a dialect class declaration.
+///
+/// {0}: The name of the dialect class.
+/// {1}: The dialect namespace.
+static const char *const dialectDeclBeginStr = R"(
+class {0} : public ::mlir::Dialect {
+public:
+ explicit {0}(::mlir::MLIRContext *context);
+ static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
+)";
+
+/// The code block for the attribute parser/printer hooks.
+static const char *const attrParserDecl = R"(
+ /// Parse an attribute registered to this dialect.
+ ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
+ ::mlir::Type type) const override;
+
+ /// Print an attribute registered to this dialect.
+ void printAttribute(::mlir::Attribute attr,
+ ::mlir::DialectAsmPrinter &os) const override;
+)";
+
+/// The code block for the type parser/printer hooks.
+static const char *const typeParserDecl = R"(
+ /// Parse a type registered to this dialect.
+ ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
+
+ /// Print a type registered to this dialect.
+ void printType(::mlir::Type type,
+ ::mlir::DialectAsmPrinter &os) const override;
+)";
+
+/// The code block for the constant materializer hook.
+static const char *const constantMaterializerDecl = R"(
+ /// Materialize a single constant operation from a given attribute value with
+ /// the desired resultant type.
+ ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
+ ::mlir::Attribute value,
+ ::mlir::Type type,
+ ::mlir::Location loc) override;
+)";
+
+/// Generate the declaration for the given dialect class.
+static void emitDialectDecl(
+ Dialect &dialect,
+ FunctionTraits<decltype(&filterForDialect<Attribute>)>::result_t
+ dialectAttrs,
+ FunctionTraits<decltype(&filterForDialect<Type>)>::result_t dialectTypes,
+ raw_ostream &os) {
+ // Emit the start of the decl.
+ std::string cppName = dialect.getCppClassName();
+ os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
+
+ // Check for any attributes/types registered to this dialect. If there are,
+ // add the hooks for parsing/printing.
+ if (!dialectAttrs.empty())
+ os << attrParserDecl;
+ if (!dialectTypes.empty())
+ os << typeParserDecl;
+
+ // Add the decls for the various features of the dialect.
+ if (dialect.hasConstantMaterializer())
+ os << constantMaterializerDecl;
+ if (llvm::Optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
+ os << *extraDecl;
+
+ // End the dialect decl.
+ os << "};\n";
+}
+
+static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ emitSourceFileHeader("Dialect Declarations", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("Dialect");
+ if (defs.empty())
+ return false;
+
+ // Select the dialect to gen for.
+ const llvm::Record *dialectDef = nullptr;
+ if (defs.size() == 1 && selectedDialect.getNumOccurrences() == 0) {
+ dialectDef = defs.front();
+ } else if (selectedDialect.getNumOccurrences() == 0) {
+ llvm::errs() << "when more than 1 dialect is present, one must be selected "
+ "via '-dialect'";
+ return true;
+ } else {
+ auto dialectIt = llvm::find_if(defs, [](const llvm::Record *def) {
+ return Dialect(def).getName() == selectedDialect;
+ });
+ if (dialectIt == defs.end()) {
+ llvm::errs() << "selected dialect with '-dialect' does not exist";
+ return true;
+ }
+ dialectDef = *dialectIt;
+ }
+
+ auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr");
+ auto typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
+ Dialect dialect(dialectDef);
+ emitDialectDecl(dialect, filterForDialect<Attribute>(attrDefs, dialect),
+ filterForDialect<Type>(typeDefs, dialect), os);
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Dialect registration hooks
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+ genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ return emitDialectDecls(records, os);
+ });