aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgysit <gysit@google.com>2022-02-14 10:55:23 +0000
committergysit <gysit@google.com>2022-02-14 11:19:53 +0000
commit01e04867e81c2f16968c1941f559704e25ca9fe4 (patch)
treea645e47554e8bcbc7b0b3b9ea9115dd99a3d42f9
parent03380c70ed548224e2fd8280bc92f69d356d258c (diff)
[mlir][OpDSL] Consistently use the term op_def (NFC).
... and remove unused type aliases. Depends On D119003 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D119125
-rw-r--r--mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py4
-rw-r--r--mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py3
-rw-r--r--mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py3
-rw-r--r--mlir/python/mlir/dialects/linalg/opdsl/lang/config.py10
-rw-r--r--mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py24
-rw-r--r--mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py1
6 files changed, 20 insertions, 25 deletions
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
index bacc0c302c5e..5a695d621677 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
@@ -73,10 +73,10 @@ def main(args):
# TODO: This class layering is awkward.
if isinstance(value, DefinedOpCallable):
try:
- linalg_config = LinalgOpConfig.from_linalg_op_def(value.model)
+ linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
except Exception as e:
raise ValueError(
- f"Could not create LinalgOpConfig from {value.model}") from e
+ f"Could not create LinalgOpConfig from {value.op_def}") from e
configs.extend(linalg_config)
# Print.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
index 9c1bb3342032..038f06834542 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
@@ -64,9 +64,6 @@ __all__ = [
"SymbolDef",
]
-# Type aliases.
-SymbolPosMap = Dict[str, int]
-
class AffineBuildState:
"""Internal state for the AffineExprDef._create impls.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index f6f3e01443b8..ea25d85aa742 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -17,9 +17,6 @@ from .scalar_expr import *
from .types import *
from .yaml_helper import *
-# Type aliases.
-AffineDimList = Dict[str, _ir.AffineExpr]
-
class TensorExpression:
"""An expression that can appear on the RHS of a comprehension."""
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index fec41decbb39..59a10998e102 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -421,18 +421,18 @@ class LinalgOpConfig(YAMLObject):
@staticmethod
def from_linalg_op_def(
- tc_op_def: LinalgOpDef,
+ op_def: LinalgOpDef,
context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
"""Expands a LinalgOpDef into corresponding Linalg configured ops."""
# TODO: Many LinalgOpDef patterns need to expand to multiple generics.
assert len(
- tc_op_def.comprehensions) == 1, "Only one comprehension supported"
+ op_def.comprehensions) == 1, "Only one comprehension supported"
return [
LinalgOpConfig(
- tc_op_def.metadata,
+ op_def.metadata,
structured_op=LinalgStructuredOpConfig(
- tc_op_def.comprehensions[0], tc_op_def.domain,
- tc_op_def.registered_operands.values(), context)),
+ op_def.comprehensions[0], op_def.domain,
+ op_def.registered_operands.values(), context)),
]
def __repr__(self):
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 459b1206af45..22ed934905cf 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -22,12 +22,12 @@ StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
@contextmanager
-def bind_op_def(model: LinalgOpDef):
+def bind_op_def(op_def: LinalgOpDef):
if hasattr(_CONTEXT, "current_op_def"):
raise ValueError("Cannot recursively define an operation")
- _CONTEXT.current_op_def = model
+ _CONTEXT.current_op_def = op_def
try:
- yield model
+ yield op_def
finally:
del _CONTEXT.current_op_def
@@ -53,9 +53,9 @@ def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
class DefinedOpCallable:
"""Callable that wraps any defined op function."""
- def __init__(self, op_name: str, model: LinalgOpDef):
+ def __init__(self, op_name: str, op_def: LinalgOpDef):
self.op_name = op_name
- self.model = model
+ self.op_def = op_def
def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
outs: StructuredOpOuts, **kwargs):
@@ -73,7 +73,7 @@ class DefinedOpCallable:
f" of type bool but got {type(emit_generic)}")
op_configs = LinalgOpConfig.from_linalg_op_def(
- self.model, context=ir.Context.current)
+ self.op_def, context=ir.Context.current)
if len(op_configs) != 1:
# TODO: Support composite ops.
@@ -97,7 +97,7 @@ class DefinedOpCallable:
return emit_named_structured_op(
op_config.structured_op,
self.op_name,
- self.model.metadata.cpp_class_name,
+ self.op_def.metadata.cpp_class_name,
*in_values,
outs=out_values,
**kwargs)
@@ -121,7 +121,7 @@ def linalg_structured_op(dsl_func=None,
# Camel case it.
op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
- tc_model = LinalgOpDef(
+ op_def = LinalgOpDef(
name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func))
# Extract arguments and TensorDefs from the signature.
@@ -130,7 +130,7 @@ def linalg_structured_op(dsl_func=None,
for param_name, param in sig.parameters.items():
param_default = param.default
if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)):
- tc_model.add_operand(param_name, param_default.operand_def)
+ op_def.add_operand(param_name, param_default.operand_def)
else:
raise ValueError(
f"@linalg_structured_op function parameters must be defaulted as "
@@ -138,13 +138,13 @@ def linalg_structured_op(dsl_func=None,
f"Found {param_name}: {param_default}")
dsl_func_args.append(param_default)
- # Invoke the DSL func to finish populating the model.
- with bind_op_def(tc_model):
+ # Invoke the DSL func to finish populating the op definition.
+ with bind_op_def(op_def):
dsl_func(*dsl_func_args)
# TODO: The returned callable should be an IR emitter but that is not
# upstreamed yet.
- return DefinedOpCallable(op_name, tc_model)
+ return DefinedOpCallable(op_name, op_def)
def implements(*interfaces: OpInterfaceDef):
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 643bcaa5c2f0..e4695f0c92a2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -23,6 +23,7 @@ __all__ = [
"ValueList",
]
+# Type aliases.
ValueList = Union[Sequence[Value], OpResultList]