258 lines
10 KiB
Diff
258 lines
10 KiB
Diff
From 0cc8bc79283a776a3ce60f2e4e476e2ec5b84b8d
|
|
From: Pavel Yaskevich <pyaskevich@apple.com>
|
|
Date: Fri, 11 Feb 2022 12:55:37 -0800
|
|
Subject: [PATCH] [CSClosure] SE-0326: Type-checker statement conditions
|
|
individually
|
|
|
|
Instead of referencing whole statement condition, break it down to
|
|
individual elements and solve them separately.
|
|
|
|
Resolves: rdar://88720312
|
|
---
|
|
include/swift/AST/ASTNode.h | 4 +--
|
|
include/swift/AST/Stmt.h | 4 +--
|
|
include/swift/AST/TypeAlignments.h | 4 +++
|
|
include/swift/Sema/ConstraintLocator.h | 2 +-
|
|
lib/AST/ASTNode.cpp | 26 ++++++---------
|
|
lib/Sema/CSClosure.cpp | 45 ++++++++++++++------------
|
|
lib/Sema/ConstraintSystem.cpp | 4 +--
|
|
7 files changed, 44 insertions(+), 45 deletions(-)
|
|
|
|
diff --git a/swift/include/swift/AST/ASTNode.h b/swift/include/swift/AST/ASTNode.h
|
|
index ac208f1bf06d7..5724ac649b0a8 100644
|
|
--- a/swift/include/swift/AST/ASTNode.h
|
|
+++ b/swift/include/swift/AST/ASTNode.h
|
|
@@ -43,11 +43,9 @@ namespace swift {
|
|
enum class PatternKind : uint8_t;
|
|
enum class StmtKind;
|
|
|
|
- using StmtCondition = llvm::MutableArrayRef<StmtConditionElement>;
|
|
-
|
|
struct ASTNode
|
|
: public llvm::PointerUnion<Expr *, Stmt *, Decl *, Pattern *, TypeRepr *,
|
|
- StmtCondition *, CaseLabelItem *> {
|
|
+ StmtConditionElement *, CaseLabelItem *> {
|
|
// Inherit the constructors from PointerUnion.
|
|
using PointerUnion::PointerUnion;
|
|
|
|
diff --git a/swift/include/swift/AST/Stmt.h b/swift/include/swift/AST/Stmt.h
|
|
index dab00742450f2..d4820894a169e 100644
|
|
--- a/swift/include/swift/AST/Stmt.h
|
|
+++ b/swift/include/swift/AST/Stmt.h
|
|
@@ -395,7 +395,7 @@ class alignas(8) PoundAvailableInfo final :
|
|
/// the "x" binding, one for the "y" binding, one for the where clause, one for
|
|
/// "z"'s binding. A simple "if" statement is represented as a single binding.
|
|
///
|
|
-class StmtConditionElement {
|
|
+class alignas(1 << PatternAlignInBits) StmtConditionElement {
|
|
/// If this is a pattern binding, it may be the first one in a declaration, in
|
|
/// which case this is the location of the var/let/case keyword. If this is
|
|
/// the second pattern (e.g. for 'y' in "var x = ..., y = ...") then this
|
|
@@ -818,7 +818,7 @@ class ForEachStmt : public LabeledStmt {
|
|
};
|
|
|
|
/// A pattern and an optional guard expression used in a 'case' statement.
|
|
-class CaseLabelItem {
|
|
+class alignas(1 << PatternAlignInBits) CaseLabelItem {
|
|
enum class Kind {
|
|
/// A normal pattern
|
|
Normal = 0,
|
|
diff --git a/swift/include/swift/AST/TypeAlignments.h b/swift/include/swift/AST/TypeAlignments.h
|
|
index 1c0f5e3b86846..911c29494ff19 100644
|
|
--- a/swift/include/swift/AST/TypeAlignments.h
|
|
+++ b/swift/include/swift/AST/TypeAlignments.h
|
|
@@ -61,6 +61,7 @@ namespace swift {
|
|
class TypeRepr;
|
|
class ValueDecl;
|
|
class CaseLabelItem;
|
|
+ class StmtConditionElement;
|
|
|
|
/// We frequently use three tag bits on all of these types.
|
|
constexpr size_t AttrAlignInBits = 3;
|
|
@@ -155,6 +156,9 @@ LLVM_DECLARE_TYPE_ALIGNMENT(swift::TypeRepr, swift::TypeReprAlignInBits)
|
|
|
|
LLVM_DECLARE_TYPE_ALIGNMENT(swift::CaseLabelItem, swift::PatternAlignInBits)
|
|
|
|
+LLVM_DECLARE_TYPE_ALIGNMENT(swift::StmtConditionElement,
|
|
+ swift::PatternAlignInBits)
|
|
+
|
|
static_assert(alignof(void*) >= 2, "pointer alignment is too small");
|
|
|
|
#endif
|
|
diff --git a/swift/include/swift/Sema/ConstraintLocator.h b/swift/include/swift/Sema/ConstraintLocator.h
|
|
index 74d4d87df9552..c1e19ab12ac4a 100644
|
|
--- a/swift/include/swift/Sema/ConstraintLocator.h
|
|
+++ b/swift/include/swift/Sema/ConstraintLocator.h
|
|
@@ -1036,7 +1036,7 @@ class LocatorPathElt::ClosureBodyElement final
|
|
if (auto *repr = node.dyn_cast<TypeRepr *>())
|
|
return repr;
|
|
|
|
- if (auto *cond = node.dyn_cast<StmtCondition *>())
|
|
+ if (auto *cond = node.dyn_cast<StmtConditionElement *>())
|
|
return cond;
|
|
|
|
if (auto *caseItem = node.dyn_cast<CaseLabelItem *>())
|
|
diff --git a/swift/lib/AST/ASTNode.cpp b/swift/lib/AST/ASTNode.cpp
|
|
index bce57fb943dee..a9b0d3ae3342d 100644
|
|
--- a/swift/lib/AST/ASTNode.cpp
|
|
+++ b/swift/lib/AST/ASTNode.cpp
|
|
@@ -35,15 +35,8 @@ SourceRange ASTNode::getSourceRange() const {
|
|
return P->getSourceRange();
|
|
if (const auto *T = this->dyn_cast<TypeRepr *>())
|
|
return T->getSourceRange();
|
|
- if (const auto *C = this->dyn_cast<StmtCondition *>()) {
|
|
- if (C->empty())
|
|
- return SourceRange();
|
|
-
|
|
- auto first = C->front();
|
|
- auto last = C->back();
|
|
-
|
|
- return {first.getStartLoc(), last.getEndLoc()};
|
|
- }
|
|
+ if (const auto *C = this->dyn_cast<StmtConditionElement *>())
|
|
+ return C->getSourceRange();
|
|
if (const auto *I = this->dyn_cast<CaseLabelItem *>()) {
|
|
return I->getSourceRange();
|
|
}
|
|
@@ -85,7 +78,7 @@ bool ASTNode::isImplicit() const {
|
|
return P->isImplicit();
|
|
if (const auto *T = this->dyn_cast<TypeRepr*>())
|
|
return false;
|
|
- if (const auto *C = this->dyn_cast<StmtCondition *>())
|
|
+ if (const auto *C = this->dyn_cast<StmtConditionElement *>())
|
|
return false;
|
|
if (const auto *I = this->dyn_cast<CaseLabelItem *>())
|
|
return false;
|
|
@@ -103,10 +96,9 @@ void ASTNode::walk(ASTWalker &Walker) {
|
|
P->walk(Walker);
|
|
else if (auto *T = this->dyn_cast<TypeRepr*>())
|
|
T->walk(Walker);
|
|
- else if (auto *C = this->dyn_cast<StmtCondition *>()) {
|
|
- for (auto &elt : *C)
|
|
- elt.walk(Walker);
|
|
- } else if (auto *I = this->dyn_cast<CaseLabelItem *>()) {
|
|
+ else if (auto *C = this->dyn_cast<StmtConditionElement *>())
|
|
+ C->walk(Walker);
|
|
+ else if (auto *I = this->dyn_cast<CaseLabelItem *>()) {
|
|
if (auto *P = I->getPattern())
|
|
P->walk(Walker);
|
|
|
|
@@ -127,9 +119,9 @@ void ASTNode::dump(raw_ostream &OS, unsigned Indent) const {
|
|
P->dump(OS, Indent);
|
|
else if (auto T = dyn_cast<TypeRepr*>())
|
|
T->print(OS);
|
|
- else if (auto C = dyn_cast<StmtCondition *>()) {
|
|
- OS.indent(Indent) << "(statement conditions)";
|
|
- } else if (auto *I = dyn_cast<CaseLabelItem *>()) {
|
|
+ else if (auto *C = dyn_cast<StmtConditionElement *>())
|
|
+ OS.indent(Indent) << "(statement condition)";
|
|
+ else if (auto *I = dyn_cast<CaseLabelItem *>()) {
|
|
OS.indent(Indent) << "(case label item)";
|
|
} else
|
|
llvm_unreachable("unsupported AST node");
|
|
diff --git a/swift/lib/Sema/CSClosure.cpp b/swift/lib/Sema/CSClosure.cpp
|
|
index 6bcd97622ebc4..6f9e59a1513e9 100644
|
|
--- a/swift/lib/Sema/CSClosure.cpp
|
|
+++ b/swift/lib/Sema/CSClosure.cpp
|
|
@@ -531,6 +531,15 @@ class ClosureConstraintGenerator
|
|
"Unsupported statement: Fallthrough");
|
|
}
|
|
|
|
+ void visitStmtCondition(LabeledConditionalStmt *S,
|
|
+ SmallVectorImpl<ElementInfo> &elements,
|
|
+ ConstraintLocator *locator) {
|
|
+ auto *condLocator =
|
|
+ cs.getConstraintLocator(locator, ConstraintLocator::Condition);
|
|
+ for (auto &condition : S->getCond())
|
|
+ elements.push_back(makeElement(&condition, condLocator));
|
|
+ }
|
|
+
|
|
void visitIfStmt(IfStmt *ifStmt) {
|
|
assert(isSupportedMultiStatementClosure() &&
|
|
"Unsupported statement: If");
|
|
@@ -538,11 +547,7 @@ class ClosureConstraintGenerator
|
|
SmallVector<ElementInfo, 4> elements;
|
|
|
|
// Condition
|
|
- {
|
|
- auto *condLoc =
|
|
- cs.getConstraintLocator(locator, ConstraintLocator::Condition);
|
|
- elements.push_back(makeElement(ifStmt->getCondPointer(), condLoc));
|
|
- }
|
|
+ visitStmtCondition(ifStmt, elements, locator);
|
|
|
|
// Then Branch
|
|
{
|
|
@@ -565,24 +570,24 @@ private:
|
|
if (!isSupportedMultiStatementClosure())
|
|
llvm_unreachable("Unsupported statement: Guard");
|
|
|
|
- createConjunction(cs,
|
|
- {makeElement(guardStmt->getCondPointer(),
|
|
- cs.getConstraintLocator(
|
|
- locator, ConstraintLocator::Condition)),
|
|
- makeElement(guardStmt->getBody(), locator)},
|
|
- locator);
|
|
+ SmallVector<ElementInfo, 4> elements;
|
|
+
|
|
+ visitStmtCondition(guardStmt, elements, locator);
|
|
+ elements.push_back(makeElement(guardStmt->getBody(), locator));
|
|
+
|
|
+ createConjunction(cs, elements, locator);
|
|
}
|
|
|
|
void visitWhileStmt(WhileStmt *whileStmt) {
|
|
if (!isSupportedMultiStatementClosure())
|
|
llvm_unreachable("Unsupported statement: Guard");
|
|
|
|
- createConjunction(cs,
|
|
- {makeElement(whileStmt->getCondPointer(),
|
|
- cs.getConstraintLocator(
|
|
- locator, ConstraintLocator::Condition)),
|
|
- makeElement(whileStmt->getBody(), locator)},
|
|
- locator);
|
|
+ SmallVector<ElementInfo, 4> elements;
|
|
+
|
|
+ visitStmtCondition(whileStmt, elements, locator);
|
|
+ elements.push_back(makeElement(whileStmt->getBody(), locator));
|
|
+
|
|
+ createConjunction(cs, elements, locator);
|
|
}
|
|
|
|
void visitDoStmt(DoStmt *doStmt) {
|
|
@@ -970,8 +975,8 @@ ConstraintSystem::simplifyClosureBodyElementConstraint(
|
|
return SolutionKind::Solved;
|
|
} else if (auto *stmt = element.dyn_cast<Stmt *>()) {
|
|
generator.visit(stmt);
|
|
- } else if (auto *cond = element.dyn_cast<StmtCondition *>()) {
|
|
- if (generateConstraints(*cond, closure))
|
|
+ } else if (auto *cond = element.dyn_cast<StmtConditionElement *>()) {
|
|
+ if (generateConstraints({*cond}, closure))
|
|
return SolutionKind::Error;
|
|
} else if (auto *pattern = element.dyn_cast<Pattern *>()) {
|
|
generator.visitPattern(pattern, context);
|
|
@@ -1571,7 +1576,7 @@ void ConjunctionElement::findReferencedVariables(
|
|
|
|
TypeVariableRefFinder refFinder(cs, locator->getAnchor(), typeVars);
|
|
|
|
- if (element.is<Decl *>() || element.is<StmtCondition *>() ||
|
|
+ if (element.is<Decl *>() || element.is<StmtConditionElement *>() ||
|
|
element.is<Expr *>() || element.isStmt(StmtKind::Return))
|
|
element.walk(refFinder);
|
|
}
|
|
diff --git a/swift/lib/Sema/ConstraintSystem.cpp b/swift/lib/Sema/ConstraintSystem.cpp
|
|
index 5523828730804..22cbe9a94e280 100644
|
|
--- a/swift/lib/Sema/ConstraintSystem.cpp
|
|
+++ b/swift/lib/Sema/ConstraintSystem.cpp
|
|
@@ -6050,8 +6050,8 @@ SourceLoc constraints::getLoc(ASTNode anchor) {
|
|
return S->getStartLoc();
|
|
} else if (auto *P = anchor.dyn_cast<Pattern *>()) {
|
|
return P->getLoc();
|
|
- } else if (auto *C = anchor.dyn_cast<StmtCondition *>()) {
|
|
- return C->front().getStartLoc();
|
|
+ } else if (auto *C = anchor.dyn_cast<StmtConditionElement *>()) {
|
|
+ return C->getStartLoc();
|
|
} else {
|
|
auto *I = anchor.get<CaseLabelItem *>();
|
|
return I->getStartLoc();
|