aboutsummaryrefslogtreecommitdiff
path: root/mlir/utils/generate-test-checks.py
diff options
context:
space:
mode:
authorTim Shen <timshen@google.com>2020-06-15 19:41:03 -0700
committerTim Shen <timshen@google.com>2020-06-16 11:15:46 -0700
commit25b3806788aed8633fa32afbe842b0dd48552938 (patch)
tree786ab3a225b0273ba855ea5d921f3284d78c7456 /mlir/utils/generate-test-checks.py
parent3f0c9c1634237834af6b74e9319cb15f6ab89d11 (diff)
[MLIR] Rework generate-test-checks.py to attach CHECK lines to the source (test) file.
Summary: This patch adds --source flag to indicate the source file. Then it tries to find insert points in the source file and insert corresponding checks at those places. Example output from Tensorflow XLA: // ----- // CHECK-LABEL: func @main.3( // CHECK-SAME: %[[VAL_0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 : index}, // CHECK-SAME: %[[VAL_1:.*]]: memref<16xi8> {xla_lhlo.alloc = 0 : index, xla_lhlo.liveout = true}) { // CHECK: %[[VAL_2:.*]] = constant 0 : index // CHECK: %[[VAL_3:.*]] = constant 0 : index // CHECK: %[[VAL_4:.*]] = std.view %[[VAL_1]]{{\[}}%[[VAL_3]]][] : memref<16xi8> to memref<2x2xf32> // CHECK: "xla_lhlo.tanh"(%[[VAL_0]], %[[VAL_4]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () // CHECK: return // CHECK: } func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { %res = "xla_hlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> } Differential Revision: https://reviews.llvm.org/D81903
Diffstat (limited to 'mlir/utils/generate-test-checks.py')
-rwxr-xr-xmlir/utils/generate-test-checks.py106
1 files changed, 85 insertions, 21 deletions
diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index 5fac81bffea9..e08f64e16475 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -56,6 +56,12 @@ class SSAVariableNamer:
def pop_name_scope(self):
self.scopes.pop()
+ def num_scopes(self):
+ return len(self.scopes)
+
+ def clear_counter(self):
+ self.name_counter = 0
+
# Process a line of input that has been split at each SSA identifier '%'.
def process_line(line_chunks, variable_namer):
@@ -87,6 +93,22 @@ def process_line(line_chunks, variable_namer):
return output_line + '\n'
+def process_source_lines(source_lines, note, args):
+ source_split_re = re.compile(args.source_delim_regex)
+
+ source_segments = [[]]
+ for line in source_lines:
+ if line == note:
+ continue
+ if line.find(args.check_prefix) != -1:
+ continue
+ if source_split_re.search(line):
+ source_segments.append([])
+
+ source_segments[-1].append(line + '\n')
+ return source_segments
+
+
# Pre-process a line of input to remove any character sequences that will be
# problematic with FileCheck.
def preprocess_line(line):
@@ -112,25 +134,51 @@ def main():
'--output',
nargs='?',
type=argparse.FileType('w'),
- default=sys.stdout)
+ default=None)
parser.add_argument(
'input',
nargs='?',
type=argparse.FileType('r'),
default=sys.stdin)
+ parser.add_argument(
+ '--source', type=str,
+ help='Print each CHECK chunk before each delimeter line in the source'
+ 'file, respectively. The delimeter lines are identified by '
+ '--source_delim_regex.')
+ parser.add_argument('--source_delim_regex', type=str, default='func @')
+ parser.add_argument(
+ '--starts_from_scope', type=int, default=1,
+ help='Omit the top specified level of content. For example, by default '
+ 'it omits "module {"')
+ parser.add_argument('-i', '--inplace', action='store_true', default=False)
+
args = parser.parse_args()
# Open the given input file.
input_lines = [l.rstrip() for l in args.input]
args.input.close()
- output_lines = []
-
# Generate a note used for the generated check file.
script_name = os.path.basename(__file__)
autogenerated_note = (ADVERT + 'utils/' + script_name)
- output_lines.append(autogenerated_note + '\n')
+ source_segments = None
+ if args.source:
+ source_segments = process_source_lines(
+ [l.rstrip() for l in open(args.source, 'r')],
+ autogenerated_note,
+ args
+ )
+
+ if args.inplace:
+ assert args.output is None
+ output = open(args.source, 'w')
+ elif args.output is None:
+ output = sys.stdout
+ else:
+ output = args.output
+
+ output_segments = [[]]
# A map containing data used for naming SSA value names.
variable_namer = SSAVariableNamer()
for input_line in input_lines:
@@ -144,17 +192,25 @@ def main():
if is_block:
input_line = input_line.rsplit('//', 1)[0].rstrip()
- # Top-level operations are heuristically the operations at nesting level 1.
- is_toplevel_op = (not is_block and input_line.startswith(' ') and
- input_line[2] != ' ' and input_line[2] != '}')
+ cur_level = variable_namer.num_scopes()
# If the line starts with a '}', pop the last name scope.
if lstripped_input_line[0] == '}':
variable_namer.pop_name_scope()
+ cur_level = variable_namer.num_scopes()
# If the line ends with a '{', push a new name scope.
if input_line[-1] == '{':
variable_namer.push_name_scope()
+ if cur_level == args.starts_from_scope:
+ output_segments.append([])
+
+ # Omit lines at the near top level e.g. "module {".
+ if cur_level < args.starts_from_scope:
+ continue
+
+ if len(output_segments[-1]) == 0:
+ variable_namer.clear_counter()
# Preprocess the input to remove any sequences that may be problematic with
# FileCheck.
@@ -164,7 +220,7 @@ def main():
ssa_split = input_line.split('%')
# If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
- if not is_toplevel_op or not ssa_split[0]:
+ if len(output_segments[-1]) != 0 or not ssa_split[0]:
output_line = '// ' + args.check_prefix + ': '
# Pad to align with the 'LABEL' statements.
output_line += (' ' * len('-LABEL'))
@@ -176,32 +232,40 @@ def main():
output_line += process_line(ssa_split[1:], variable_namer)
else:
- # Append a newline to the output to separate the logical blocks.
- output_lines.append('\n')
- output_line = '// ' + args.check_prefix + '-LABEL: '
-
# Output the first line chunk that does not contain an SSA name for the
# label.
- output_line += ssa_split[0] + '\n'
+ output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
- # Process the rest of the input line on a separate check line.
- if len(ssa_split) > 1:
+ # Process the rest of the input line on separate check lines.
+ for argument in ssa_split[1:]:
output_line += '// ' + args.check_prefix + '-SAME: '
# Pad to align with the original position in the line.
output_line += ' ' * len(ssa_split[0])
# Process the rest of the line.
- output_line += process_line(ssa_split[1:], variable_namer)
+ output_line += process_line([argument], variable_namer)
# Append the output line.
- output_lines.append(output_line)
+ output_segments[-1].append(output_line)
+
+ output.write(autogenerated_note + '\n')
# Write the output.
- for output_line in output_lines:
- args.output.write(output_line)
- args.output.write('\n')
- args.output.close()
+ if source_segments:
+ assert len(output_segments) == len(source_segments)
+ for check_segment, source_segment in zip(output_segments, source_segments):
+ for line in check_segment:
+ output.write(line)
+ for line in source_segment:
+ output.write(line)
+ else:
+ for segment in output_segments:
+ output.write('\n')
+ for output_line in segment:
+ output.write(output_line)
+ output.write('\n')
+ output.close()
if __name__ == '__main__':