aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/DLTI/DLTI.cpp
blob: eaf6f1e619a01989c0c2912d7847907f6cbba873 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
//===- DLTI.cpp - Data Layout And Target Info MLIR Dialect Implementation -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;

#include "mlir/Dialect/DLTI/DLTIDialect.cpp.inc"

//===----------------------------------------------------------------------===//
// DataLayoutEntryAttr
//===----------------------------------------------------------------------===//
//
constexpr const StringLiteral mlir::DataLayoutEntryAttr::kAttrKeyword;

namespace mlir {
namespace impl {
class DataLayoutEntryStorage : public AttributeStorage {
public:
  using KeyTy = std::pair<DataLayoutEntryKey, Attribute>;

  DataLayoutEntryStorage(DataLayoutEntryKey entryKey, Attribute value)
      : entryKey(entryKey), value(value) {}

  static DataLayoutEntryStorage *construct(AttributeStorageAllocator &allocator,
                                           const KeyTy &key) {
    return new (allocator.allocate<DataLayoutEntryStorage>())
        DataLayoutEntryStorage(key.first, key.second);
  }

  bool operator==(const KeyTy &other) const {
    return other.first == entryKey && other.second == value;
  }

  DataLayoutEntryKey entryKey;
  Attribute value;
};
} // namespace impl
} // namespace mlir

DataLayoutEntryAttr DataLayoutEntryAttr::get(StringAttr key, Attribute value) {
  return Base::get(key.getContext(), key, value);
}

DataLayoutEntryAttr DataLayoutEntryAttr::get(Type key, Attribute value) {
  return Base::get(key.getContext(), key, value);
}

DataLayoutEntryKey DataLayoutEntryAttr::getKey() const {
  return getImpl()->entryKey;
}

Attribute DataLayoutEntryAttr::getValue() const { return getImpl()->value; }

/// Parses an attribute with syntax:
///   attr ::= `#target.` `dl_entry` `<` (type | quoted-string) `,` attr `>`
DataLayoutEntryAttr DataLayoutEntryAttr::parse(AsmParser &parser) {
  if (failed(parser.parseLess()))
    return {};

  Type type = nullptr;
  std::string identifier;
  SMLoc idLoc = parser.getCurrentLocation();
  OptionalParseResult parsedType = parser.parseOptionalType(type);
  if (parsedType.has_value() && failed(parsedType.value()))
    return {};
  if (!parsedType.has_value()) {
    OptionalParseResult parsedString = parser.parseOptionalString(&identifier);
    if (!parsedString.has_value() || failed(parsedString.value())) {
      parser.emitError(idLoc) << "expected a type or a quoted string";
      return {};
    }
  }

  Attribute value;
  if (failed(parser.parseComma()) || failed(parser.parseAttribute(value)) ||
      failed(parser.parseGreater()))
    return {};

  return type ? get(type, value)
              : get(parser.getBuilder().getStringAttr(identifier), value);
}

void DataLayoutEntryAttr::print(AsmPrinter &os) const {
  os << DataLayoutEntryAttr::kAttrKeyword << "<";
  if (auto type = getKey().dyn_cast<Type>())
    os << type;
  else
    os << "\"" << getKey().get<StringAttr>().strref() << "\"";
  os << ", " << getValue() << ">";
}

//===----------------------------------------------------------------------===//
// DataLayoutSpecAttr
//===----------------------------------------------------------------------===//
//
constexpr const StringLiteral mlir::DataLayoutSpecAttr::kAttrKeyword;

namespace mlir {
namespace impl {
class DataLayoutSpecStorage : public AttributeStorage {
public:
  using KeyTy = ArrayRef<DataLayoutEntryInterface>;

  DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
      : entries(entries) {}

  bool operator==(const KeyTy &key) const { return key == entries; }

  static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
                                          const KeyTy &key) {
    return new (allocator.allocate<DataLayoutSpecStorage>())
        DataLayoutSpecStorage(allocator.copyInto(key));
  }

  ArrayRef<DataLayoutEntryInterface> entries;
};
} // namespace impl
} // namespace mlir

DataLayoutSpecAttr
DataLayoutSpecAttr::get(MLIRContext *ctx,
                        ArrayRef<DataLayoutEntryInterface> entries) {
  return Base::get(ctx, entries);
}

DataLayoutSpecAttr
DataLayoutSpecAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
                               MLIRContext *context,
                               ArrayRef<DataLayoutEntryInterface> entries) {
  return Base::getChecked(emitError, context, entries);
}

LogicalResult
DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                           ArrayRef<DataLayoutEntryInterface> entries) {
  DenseSet<Type> types;
  DenseSet<StringAttr> ids;
  for (DataLayoutEntryInterface entry : entries) {
    if (auto type = entry.getKey().dyn_cast<Type>()) {
      if (!types.insert(type).second)
        return emitError() << "repeated layout entry key: " << type;
    } else {
      auto id = entry.getKey().get<StringAttr>();
      if (!ids.insert(id).second)
        return emitError() << "repeated layout entry key: " << id.getValue();
    }
  }
  return success();
}

/// Given a list of old and a list of new entries, overwrites old entries with
/// new ones if they have matching keys, appends new entries to the old entry
/// list otherwise.
static void
overwriteDuplicateEntries(SmallVectorImpl<DataLayoutEntryInterface> &oldEntries,
                          ArrayRef<DataLayoutEntryInterface> newEntries) {
  unsigned oldEntriesSize = oldEntries.size();
  for (DataLayoutEntryInterface entry : newEntries) {
    // We expect a small (dozens) number of entries, so it is practically
    // cheaper to iterate over the list linearly rather than to create an
    // auxiliary hashmap to avoid duplication. Also note that we never need to
    // check for duplicate keys the values that were added from `newEntries`.
    bool replaced = false;
    for (unsigned i = 0; i < oldEntriesSize; ++i) {
      if (oldEntries[i].getKey() == entry.getKey()) {
        oldEntries[i] = entry;
        replaced = true;
        break;
      }
    }
    if (!replaced)
      oldEntries.push_back(entry);
  }
}

/// Combines a data layout spec into the given lists of entries organized by
/// type class and identifier, overwriting them if necessary. Fails to combine
/// if the two entries with identical keys are not compatible.
static LogicalResult
combineOneSpec(DataLayoutSpecInterface spec,
               DenseMap<TypeID, DataLayoutEntryList> &entriesForType,
               DenseMap<StringAttr, DataLayoutEntryInterface> &entriesForID) {
  // A missing spec should be fine.
  if (!spec)
    return success();

  DenseMap<TypeID, DataLayoutEntryList> newEntriesForType;
  DenseMap<StringAttr, DataLayoutEntryInterface> newEntriesForID;
  spec.bucketEntriesByType(newEntriesForType, newEntriesForID);

  // Try overwriting the old entries with the new ones.
  for (auto &kvp : newEntriesForType) {
    if (!entriesForType.count(kvp.first)) {
      entriesForType[kvp.first] = std::move(kvp.second);
      continue;
    }

    Type typeSample = kvp.second.front().getKey().get<Type>();
    assert(&typeSample.getDialect() !=
               typeSample.getContext()->getLoadedDialect<BuiltinDialect>() &&
           "unexpected data layout entry for built-in type");

    auto interface = typeSample.cast<DataLayoutTypeInterface>();
    if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second))
      return failure();

    overwriteDuplicateEntries(entriesForType[kvp.first], kvp.second);
  }

  for (const auto &kvp : newEntriesForID) {
    StringAttr id = kvp.second.getKey().get<StringAttr>();
    Dialect *dialect = id.getReferencedDialect();
    if (!entriesForID.count(id)) {
      entriesForID[id] = kvp.second;
      continue;
    }

    // Attempt to combine the enties using the dialect interface. If the
    // dialect is not loaded for some reason, use the default combinator
    // that conservatively accepts identical entries only.
    entriesForID[id] =
        dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
                      entriesForID[id], kvp.second)
                : DataLayoutDialectInterface::defaultCombine(entriesForID[id],
                                                             kvp.second);
    if (!entriesForID[id])
      return failure();
  }

  return success();
}

DataLayoutSpecAttr
DataLayoutSpecAttr::combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
  // Only combine with attributes of the same kind.
  // TODO: reconsider this when the need arises.
  if (llvm::any_of(specs, [](DataLayoutSpecInterface spec) {
        return !spec.isa<DataLayoutSpecAttr>();
      }))
    return {};

  // Combine all specs in order, with `this` being the last one.
  DenseMap<TypeID, DataLayoutEntryList> entriesForType;
  DenseMap<StringAttr, DataLayoutEntryInterface> entriesForID;
  for (DataLayoutSpecInterface spec : specs)
    if (failed(combineOneSpec(spec, entriesForType, entriesForID)))
      return nullptr;
  if (failed(combineOneSpec(*this, entriesForType, entriesForID)))
    return nullptr;

  // Rebuild the linear list of entries.
  SmallVector<DataLayoutEntryInterface> entries;
  llvm::append_range(entries, llvm::make_second_range(entriesForID));
  for (const auto &kvp : entriesForType)
    llvm::append_range(entries, kvp.getSecond());

  return DataLayoutSpecAttr::get(getContext(), entries);
}

DataLayoutEntryListRef DataLayoutSpecAttr::getEntries() const {
  return getImpl()->entries;
}

/// Parses an attribute with syntax
///   attr ::= `#target.` `dl_spec` `<` attr-list? `>`
///   attr-list ::= attr
///               | attr `,` attr-list
DataLayoutSpecAttr DataLayoutSpecAttr::parse(AsmParser &parser) {
  if (failed(parser.parseLess()))
    return {};

  // Empty spec.
  if (succeeded(parser.parseOptionalGreater()))
    return get(parser.getContext(), {});

  SmallVector<DataLayoutEntryInterface> entries;
  if (parser.parseCommaSeparatedList(
          [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
      parser.parseGreater())
    return {};

  return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
                    parser.getContext(), entries);
}

void DataLayoutSpecAttr::print(AsmPrinter &os) const {
  os << DataLayoutSpecAttr::kAttrKeyword << "<";
  llvm::interleaveComma(getEntries(), os);
  os << ">";
}

//===----------------------------------------------------------------------===//
// DLTIDialect
//===----------------------------------------------------------------------===//

constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessLittle;

namespace {
class TargetDataLayoutInterface : public DataLayoutDialectInterface {
public:
  using DataLayoutDialectInterface::DataLayoutDialectInterface;

  LogicalResult verifyEntry(DataLayoutEntryInterface entry,
                            Location loc) const final {
    StringRef entryName = entry.getKey().get<StringAttr>().strref();
    if (entryName == DLTIDialect::kDataLayoutEndiannessKey) {
      auto value = entry.getValue().dyn_cast<StringAttr>();
      if (value &&
          (value.getValue() == DLTIDialect::kDataLayoutEndiannessBig ||
           value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle))
        return success();
      return emitError(loc) << "'" << entryName
                            << "' data layout entry is expected to be either '"
                            << DLTIDialect::kDataLayoutEndiannessBig << "' or '"
                            << DLTIDialect::kDataLayoutEndiannessLittle << "'";
    }
    return emitError(loc) << "unknown data layout entry name: " << entryName;
  }
};
} // namespace

void DLTIDialect::initialize() {
  addAttributes<DataLayoutEntryAttr, DataLayoutSpecAttr>();
  addInterfaces<TargetDataLayoutInterface>();
}

Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
                                      Type type) const {
  StringRef attrKind;
  if (parser.parseKeyword(&attrKind))
    return {};

  if (attrKind == DataLayoutEntryAttr::kAttrKeyword)
    return DataLayoutEntryAttr::parse(parser);
  if (attrKind == DataLayoutSpecAttr::kAttrKeyword)
    return DataLayoutSpecAttr::parse(parser);

  parser.emitError(parser.getNameLoc(), "unknown attrribute type: ")
      << attrKind;
  return {};
}

void DLTIDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
  llvm::TypeSwitch<Attribute>(attr)
      .Case<DataLayoutEntryAttr, DataLayoutSpecAttr>(
          [&](auto a) { a.print(os); })
      .Default([](Attribute) { llvm_unreachable("unknown attribute kind"); });
}

LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
                                                    NamedAttribute attr) {
  if (attr.getName() == DLTIDialect::kDataLayoutAttrName) {
    if (!attr.getValue().isa<DataLayoutSpecAttr>()) {
      return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName
                             << "' is expected to be a #dlti.dl_spec attribute";
    }
    if (isa<ModuleOp>(op))
      return detail::verifyDataLayoutOp(op);
    return success();
  }

  return op->emitError() << "attribute '" << attr.getName().getValue()
                         << "' not supported by dialect";
}