From 0cc8bc79283a776a3ce60f2e4e476e2ec5b84b8d From: Pavel Yaskevich Date: Fri, 11 Feb 2022 12:55:37 -0800 Subject: [PATCH] [CSClosure] SE-0326: Type-checker statement conditions individually Instead of referencing whole statement condition, break it down to individual elements and solve them separately. Resolves: rdar://88720312 --- include/swift/AST/ASTNode.h | 4 +-- include/swift/AST/Stmt.h | 4 +-- include/swift/AST/TypeAlignments.h | 4 +++ include/swift/Sema/ConstraintLocator.h | 2 +- lib/AST/ASTNode.cpp | 26 ++++++--------- lib/Sema/CSClosure.cpp | 45 ++++++++++++++------------ lib/Sema/ConstraintSystem.cpp | 4 +-- 7 files changed, 44 insertions(+), 45 deletions(-) diff --git a/swift/include/swift/AST/ASTNode.h b/swift/include/swift/AST/ASTNode.h index ac208f1bf06d7..5724ac649b0a8 100644 --- a/swift/include/swift/AST/ASTNode.h +++ b/swift/include/swift/AST/ASTNode.h @@ -43,11 +43,9 @@ namespace swift { enum class PatternKind : uint8_t; enum class StmtKind; - using StmtCondition = llvm::MutableArrayRef; - struct ASTNode : public llvm::PointerUnion { + StmtConditionElement *, CaseLabelItem *> { // Inherit the constructors from PointerUnion. using PointerUnion::PointerUnion; diff --git a/swift/include/swift/AST/Stmt.h b/swift/include/swift/AST/Stmt.h index dab00742450f2..d4820894a169e 100644 --- a/swift/include/swift/AST/Stmt.h +++ b/swift/include/swift/AST/Stmt.h @@ -395,7 +395,7 @@ class alignas(8) PoundAvailableInfo final : /// the "x" binding, one for the "y" binding, one for the where clause, one for /// "z"'s binding. A simple "if" statement is represented as a single binding. /// -class StmtConditionElement { +class alignas(1 << PatternAlignInBits) StmtConditionElement { /// If this is a pattern binding, it may be the first one in a declaration, in /// which case this is the location of the var/let/case keyword. If this is /// the second pattern (e.g. for 'y' in "var x = ..., y = ...") then this @@ -818,7 +818,7 @@ class ForEachStmt : public LabeledStmt { }; /// A pattern and an optional guard expression used in a 'case' statement. -class CaseLabelItem { +class alignas(1 << PatternAlignInBits) CaseLabelItem { enum class Kind { /// A normal pattern Normal = 0, diff --git a/swift/include/swift/AST/TypeAlignments.h b/swift/include/swift/AST/TypeAlignments.h index 1c0f5e3b86846..911c29494ff19 100644 --- a/swift/include/swift/AST/TypeAlignments.h +++ b/swift/include/swift/AST/TypeAlignments.h @@ -61,6 +61,7 @@ namespace swift { class TypeRepr; class ValueDecl; class CaseLabelItem; + class StmtConditionElement; /// We frequently use three tag bits on all of these types. constexpr size_t AttrAlignInBits = 3; @@ -155,6 +156,9 @@ LLVM_DECLARE_TYPE_ALIGNMENT(swift::TypeRepr, swift::TypeReprAlignInBits) LLVM_DECLARE_TYPE_ALIGNMENT(swift::CaseLabelItem, swift::PatternAlignInBits) +LLVM_DECLARE_TYPE_ALIGNMENT(swift::StmtConditionElement, + swift::PatternAlignInBits) + static_assert(alignof(void*) >= 2, "pointer alignment is too small"); #endif diff --git a/swift/include/swift/Sema/ConstraintLocator.h b/swift/include/swift/Sema/ConstraintLocator.h index 74d4d87df9552..c1e19ab12ac4a 100644 --- a/swift/include/swift/Sema/ConstraintLocator.h +++ b/swift/include/swift/Sema/ConstraintLocator.h @@ -1036,7 +1036,7 @@ class LocatorPathElt::ClosureBodyElement final if (auto *repr = node.dyn_cast()) return repr; - if (auto *cond = node.dyn_cast()) + if (auto *cond = node.dyn_cast()) return cond; if (auto *caseItem = node.dyn_cast()) diff --git a/swift/lib/AST/ASTNode.cpp b/swift/lib/AST/ASTNode.cpp index bce57fb943dee..a9b0d3ae3342d 100644 --- a/swift/lib/AST/ASTNode.cpp +++ b/swift/lib/AST/ASTNode.cpp @@ -35,15 +35,8 @@ SourceRange ASTNode::getSourceRange() const { return P->getSourceRange(); if (const auto *T = this->dyn_cast()) return T->getSourceRange(); - if (const auto *C = this->dyn_cast()) { - if (C->empty()) - return SourceRange(); - - auto first = C->front(); - auto last = C->back(); - - return {first.getStartLoc(), last.getEndLoc()}; - } + if (const auto *C = this->dyn_cast()) + return C->getSourceRange(); if (const auto *I = this->dyn_cast()) { return I->getSourceRange(); } @@ -85,7 +78,7 @@ bool ASTNode::isImplicit() const { return P->isImplicit(); if (const auto *T = this->dyn_cast()) return false; - if (const auto *C = this->dyn_cast()) + if (const auto *C = this->dyn_cast()) return false; if (const auto *I = this->dyn_cast()) return false; @@ -103,10 +96,9 @@ void ASTNode::walk(ASTWalker &Walker) { P->walk(Walker); else if (auto *T = this->dyn_cast()) T->walk(Walker); - else if (auto *C = this->dyn_cast()) { - for (auto &elt : *C) - elt.walk(Walker); - } else if (auto *I = this->dyn_cast()) { + else if (auto *C = this->dyn_cast()) + C->walk(Walker); + else if (auto *I = this->dyn_cast()) { if (auto *P = I->getPattern()) P->walk(Walker); @@ -127,9 +119,9 @@ void ASTNode::dump(raw_ostream &OS, unsigned Indent) const { P->dump(OS, Indent); else if (auto T = dyn_cast()) T->print(OS); - else if (auto C = dyn_cast()) { - OS.indent(Indent) << "(statement conditions)"; - } else if (auto *I = dyn_cast()) { + else if (auto *C = dyn_cast()) + OS.indent(Indent) << "(statement condition)"; + else if (auto *I = dyn_cast()) { OS.indent(Indent) << "(case label item)"; } else llvm_unreachable("unsupported AST node"); diff --git a/swift/lib/Sema/CSClosure.cpp b/swift/lib/Sema/CSClosure.cpp index 6bcd97622ebc4..6f9e59a1513e9 100644 --- a/swift/lib/Sema/CSClosure.cpp +++ b/swift/lib/Sema/CSClosure.cpp @@ -531,6 +531,15 @@ class ClosureConstraintGenerator "Unsupported statement: Fallthrough"); } + void visitStmtCondition(LabeledConditionalStmt *S, + SmallVectorImpl &elements, + ConstraintLocator *locator) { + auto *condLocator = + cs.getConstraintLocator(locator, ConstraintLocator::Condition); + for (auto &condition : S->getCond()) + elements.push_back(makeElement(&condition, condLocator)); + } + void visitIfStmt(IfStmt *ifStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: If"); @@ -538,11 +547,7 @@ class ClosureConstraintGenerator SmallVector elements; // Condition - { - auto *condLoc = - cs.getConstraintLocator(locator, ConstraintLocator::Condition); - elements.push_back(makeElement(ifStmt->getCondPointer(), condLoc)); - } + visitStmtCondition(ifStmt, elements, locator); // Then Branch { @@ -565,24 +570,24 @@ private: if (!isSupportedMultiStatementClosure()) llvm_unreachable("Unsupported statement: Guard"); - createConjunction(cs, - {makeElement(guardStmt->getCondPointer(), - cs.getConstraintLocator( - locator, ConstraintLocator::Condition)), - makeElement(guardStmt->getBody(), locator)}, - locator); + SmallVector elements; + + visitStmtCondition(guardStmt, elements, locator); + elements.push_back(makeElement(guardStmt->getBody(), locator)); + + createConjunction(cs, elements, locator); } void visitWhileStmt(WhileStmt *whileStmt) { if (!isSupportedMultiStatementClosure()) llvm_unreachable("Unsupported statement: Guard"); - createConjunction(cs, - {makeElement(whileStmt->getCondPointer(), - cs.getConstraintLocator( - locator, ConstraintLocator::Condition)), - makeElement(whileStmt->getBody(), locator)}, - locator); + SmallVector elements; + + visitStmtCondition(whileStmt, elements, locator); + elements.push_back(makeElement(whileStmt->getBody(), locator)); + + createConjunction(cs, elements, locator); } void visitDoStmt(DoStmt *doStmt) { @@ -970,8 +975,8 @@ ConstraintSystem::simplifyClosureBodyElementConstraint( return SolutionKind::Solved; } else if (auto *stmt = element.dyn_cast()) { generator.visit(stmt); - } else if (auto *cond = element.dyn_cast()) { - if (generateConstraints(*cond, closure)) + } else if (auto *cond = element.dyn_cast()) { + if (generateConstraints({*cond}, closure)) return SolutionKind::Error; } else if (auto *pattern = element.dyn_cast()) { generator.visitPattern(pattern, context); @@ -1571,7 +1576,7 @@ void ConjunctionElement::findReferencedVariables( TypeVariableRefFinder refFinder(cs, locator->getAnchor(), typeVars); - if (element.is() || element.is() || + if (element.is() || element.is() || element.is() || element.isStmt(StmtKind::Return)) element.walk(refFinder); } diff --git a/swift/lib/Sema/ConstraintSystem.cpp b/swift/lib/Sema/ConstraintSystem.cpp index 5523828730804..22cbe9a94e280 100644 --- a/swift/lib/Sema/ConstraintSystem.cpp +++ b/swift/lib/Sema/ConstraintSystem.cpp @@ -6050,8 +6050,8 @@ SourceLoc constraints::getLoc(ASTNode anchor) { return S->getStartLoc(); } else if (auto *P = anchor.dyn_cast()) { return P->getLoc(); - } else if (auto *C = anchor.dyn_cast()) { - return C->front().getStartLoc(); + } else if (auto *C = anchor.dyn_cast()) { + return C->getStartLoc(); } else { auto *I = anchor.get(); return I->getStartLoc();