aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Klausler <pklausler@nvidia.com>2022-08-03 10:25:43 -0700
committerPeter Klausler <pklausler@nvidia.com>2022-08-07 09:12:46 -0700
commit3f10091c04e1478e0e5b2deb7dd782ebca0d529c (patch)
tree265215f47597131ffdcaee0242b9853ab8d56067
parent7602e285f69c4e3af60629100c151067c27b9eca (diff)
[flang] Allow pure function references in expandable scalar
F18 disallows function references and coarray references from appearing in scalar expressions that are to be expanded into arrays to conform with other operands or actual arguments in an elemental expression. This is too strong, as pure procedures can be safely used. Differential Revision: https://reviews.llvm.org/D131096
-rw-r--r--flang/include/flang/Evaluate/call.h1
-rw-r--r--flang/include/flang/Evaluate/tools.h18
-rw-r--r--flang/lib/Evaluate/call.cpp14
-rw-r--r--flang/lib/Semantics/expression.cpp2
4 files changed, 29 insertions, 6 deletions
diff --git a/flang/include/flang/Evaluate/call.h b/flang/include/flang/Evaluate/call.h
index 7866bab63299..3a083ab574ce 100644
--- a/flang/include/flang/Evaluate/call.h
+++ b/flang/include/flang/Evaluate/call.h
@@ -199,6 +199,7 @@ struct ProcedureDesignator {
std::optional<DynamicType> GetType() const;
int Rank() const;
bool IsElemental() const;
+ bool IsPure() const;
std::optional<Expr<SubscriptInteger>> LEN() const;
llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index fe8645b5b2ab..7d521612e42a 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1007,17 +1007,25 @@ std::optional<std::string> FindImpureCall(
// Predicate: is a scalar expression suitable for naive scalar expansion
// in the flattening of an array expression?
// TODO: capture such scalar expansions in temporaries, flatten everything
-struct UnexpandabilityFindingVisitor
+class UnexpandabilityFindingVisitor
: public AnyTraverse<UnexpandabilityFindingVisitor> {
+public:
using Base = AnyTraverse<UnexpandabilityFindingVisitor>;
using Base::operator();
- UnexpandabilityFindingVisitor() : Base{*this} {}
- template <typename T> bool operator()(const FunctionRef<T> &) { return true; }
+ explicit UnexpandabilityFindingVisitor(bool admitPureCall)
+ : Base{*this}, admitPureCall_{admitPureCall} {}
+ template <typename T> bool operator()(const FunctionRef<T> &procRef) {
+ return !admitPureCall_ || !procRef.proc().IsPure();
+ }
bool operator()(const CoarrayRef &) { return true; }
+
+private:
+ bool admitPureCall_{false};
};
-template <typename T> bool IsExpandableScalar(const Expr<T> &expr) {
- return !UnexpandabilityFindingVisitor{}(expr);
+template <typename T>
+bool IsExpandableScalar(const Expr<T> &expr, bool admitPureCall = false) {
+ return !UnexpandabilityFindingVisitor{admitPureCall}(expr);
}
// Common handling for procedure pointer compatibility of left- and right-hand
diff --git a/flang/lib/Evaluate/call.cpp b/flang/lib/Evaluate/call.cpp
index 6b008cfbd0b1..2ff4c317969e 100644
--- a/flang/lib/Evaluate/call.cpp
+++ b/flang/lib/Evaluate/call.cpp
@@ -145,6 +145,20 @@ bool ProcedureDesignator::IsElemental() const {
return false;
}
+bool ProcedureDesignator::IsPure() const {
+ if (const Symbol * interface{GetInterfaceSymbol()}) {
+ return IsPureProcedure(*interface);
+ } else if (const Symbol * symbol{GetSymbol()}) {
+ return IsPureProcedure(*symbol);
+ } else if (const auto *intrinsic{std::get_if<SpecificIntrinsic>(&u)}) {
+ return intrinsic->characteristics.value().attrs.test(
+ characteristics::Procedure::Attr::Pure);
+ } else {
+ DIE("ProcedureDesignator::IsPure(): no case");
+ }
+ return false;
+}
+
const SpecificIntrinsic *ProcedureDesignator::GetSpecificIntrinsic() const {
return std::get_if<SpecificIntrinsic>(&u);
}
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index d4492d2f3e99..01bfea713e6a 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -1833,7 +1833,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(
"component", "value")};
if (checked && *checked && GetRank(*componentShape) > 0 &&
GetRank(*valueShape) == 0 &&
- !IsExpandableScalar(*converted)) {
+ !IsExpandableScalar(*converted, true /*admit PURE call*/)) {
AttachDeclaration(
Say(expr.source,
"Scalar value cannot be expanded to shape of array component '%s'"_err_en_US,