[pycharm] move type guards from control flow to PyDefUse stage

GitOrigin-RevId: e66971e619978ad179bb49a15820a7482b27df7c
This commit is contained in:
Vladimir Koshelev
2024-07-15 17:15:11 +02:00
committed by intellij-monorepo-bot
parent 25b01bf1db
commit 82e8947e95
11 changed files with 211 additions and 133 deletions

View File

@@ -29,21 +29,20 @@ import com.intellij.psi.util.QualifiedName;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.ParamHelper;
import com.jetbrains.python.psi.impl.PyAugAssignmentStatementNavigator;
import com.jetbrains.python.psi.impl.PyEvaluator;
import com.jetbrains.python.psi.impl.PyImportStatementNavigator;
import com.jetbrains.python.psi.resolve.PyResolveUtil;
import com.jetbrains.python.psi.types.TypeEvalContext;
import kotlin.Triple;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
@@ -51,7 +50,6 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
private static final Set<String> EXCEPTION_SUPPRESSORS = ImmutableSet.of("suppress", "assertRaises", "assertRaisesRegex");
private final ControlFlowBuilder myBuilder = new ControlFlowBuilder();
private final Map<PyExpression, PyFunction> expressionToGuards = new HashMap<>();
public ControlFlow buildControlFlow(@NotNull final ScopeOwner owner) {
return myBuilder.build(this, owner);
@@ -65,26 +63,6 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
private void startConditionalNodeAndCheckGuards(@NotNull PsiElement element, @Nullable PyExpression condition, boolean result) {
myBuilder.startConditionalNode(element, condition, result);
addTypeGuardAssertions(condition, result);
}
private void addTypeGuardAssertions(@Nullable PyExpression condition, boolean result) {
final PyExpression actualExpression;
final boolean negation;
if (condition instanceof PyPrefixExpression prefixExpression && prefixExpression.getOperator() == PyTokenTypes.NOT_KEYWORD) {
actualExpression = prefixExpression.getOperand();
negation = true;
}
else {
actualExpression = condition;
negation = false;
}
var function = expressionToGuards.get(actualExpression);
if ((negation && !result || !negation && result) && function != null && actualExpression instanceof PyCallExpression callExpression) {
final var evaluator = new PyTypeAssertionEvaluator();
evaluator.handleTypeGuardCall(callExpression, function);
InstructionBuilder.addAssertInstructions(myBuilder, evaluator);
}
}
@Override
@@ -167,16 +145,7 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
@Override
public void visitPyCallExpression(final @NotNull PyCallExpression node) {
final PyExpression callee = node.getCallee();
final var callNodeType = getCalleeNodeType(callee);
if (callNodeType instanceof TypeGuardCallKind typeGuardCallKind && node.getArguments().length > 0) {
expressionToGuards.put(node, typeGuardCallKind.pyFunction);
super.visitPyCallExpression(node);
}
else {
super.visitPyCallExpression(node);
}
super.visitPyCallExpression(node);
var callInstruction = new CallInstruction(myBuilder, node);
myBuilder.addNodeAndCheckPending(callInstruction);
@@ -583,12 +552,10 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
final var outside = new ConditionalInstructionImpl(myBuilder, null, subExpression, !conditionResultToContinue);
myBuilder.addNode(outside);
addTypeGuardAssertions(subExpression, !conditionResultToContinue);
myBuilder.addPendingEdge(node, myBuilder.prevInstruction);
myBuilder.prevInstruction = branchingPoint;
final var toTheNext = new ConditionalInstructionImpl(myBuilder, null, subExpression, conditionResultToContinue);
addTypeGuardAssertions(subExpression, conditionResultToContinue);
myBuilder.addNode(toTheNext);
}
@@ -619,7 +586,6 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
final var elsePartInstruction = new ConditionalInstructionImpl(myBuilder, elsePart, mainPartResults.getFirst(), false);
myBuilder.prevInstruction = null;
myBuilder.addNode(elsePartInstruction);
addTypeGuardAssertions(mainPartResults.getFirst(), false);
if (!isStaticallyTrue) {
for (Pair<PsiElement, Instruction> pair : branchingPoints) {
@@ -1051,40 +1017,6 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
myBuilder.checkPending(instruction);
}
@Nullable
private static CallTypeKind getCalleeNodeType(@Nullable PyExpression callee) {
if (callee instanceof PyReferenceExpression expression) {
QualifiedName qName = expression.asQualifiedName();
if (qName == null) {
return null;
}
ScopeOwner scopeOwner = ScopeUtil.getScopeOwner(expression);
// Flow-insensitive context is required to prevent recursive control flow access during the resolve process
TypeEvalContext context = TypeEvalContext.codeInsightFallback(callee.getProject());
while (scopeOwner != null) {
final var result = StreamEx
.of(PyResolveUtil.resolveQualifiedNameInScope(qName, scopeOwner, context))
.select(PyFunction.class)
.map(function -> {
//if (PyTypingTypeProvider.isNoReturn(function, context)) {
// return NoReturnCallKind.INSTANCE;
//}
if (PyTypingTypeProvider.isTypeGuard(function, context)) {
return new TypeGuardCallKind(function);
}
return null;
}).findFirst( it -> it != null);
if (result.isPresent()) return result.get();
scopeOwner = ScopeUtil.getScopeOwner(scopeOwner);
}
}
return null;
}
private void abruptFlow(final PsiElement node) {
// Here we process pending instructions!!!
myBuilder.processPending((pendingScope, instruction) -> {
@@ -1106,9 +1038,5 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
return !PsiTreeUtil.instanceOf(instruction.getElement(),
PyStatementList.class);
}
private interface CallTypeKind { }
private record TypeGuardCallKind(@NotNull PyFunction pyFunction) implements CallTypeKind {}
}

View File

@@ -6,13 +6,12 @@ import com.intellij.psi.PsiElement;
import com.intellij.util.containers.Stack;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.codeInsight.functionTypeComments.psi.PyFunctionTypeAnnotation;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyBuiltinCache;
import com.jetbrains.python.psi.impl.PyEvaluator;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -48,29 +47,6 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
}
}
public void handleTypeGuardCall(@NotNull PyCallExpression call, @NotNull PyFunction function) {
if (call.getArguments().length == 0) return;
final var firstArgument = call.getArguments()[0];
final var annotation = function.getAnnotationValue();
if (annotation == null) return;
if (firstArgument instanceof PyReferenceExpression referenceExpression) {
pushAssertion(referenceExpression, myPositive, false, (context) -> {
var returnType = PyTypingTypeProvider.getReturnTypeAnnotation(function, context);
if (returnType instanceof PyStringLiteralExpression stringLiteralExpression) {
returnType = PyUtil.createExpressionFromFragment(stringLiteralExpression.getStringValue(),
function.getContainingFile());
}
if (returnType instanceof PySubscriptionExpression subscriptionExpression) {
var indexExpression = subscriptionExpression.getIndexExpression();
if (indexExpression != null) {
return Ref.deref(PyTypingTypeProvider.getType(indexExpression, context));
}
}
return null;
}, null);
}
}
@Override
public void visitPyCallExpression(@NotNull PyCallExpression node) {
if (node.isCalleeText(PyNames.ISINSTANCE, PyNames.ASSERT_IS_INSTANCE)) {
@@ -155,7 +131,8 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
}
@Nullable
private static Ref<PyType> createAssertionType(@Nullable PyType initial,
@ApiStatus.Internal
public static Ref<PyType> createAssertionType(@Nullable PyType initial,
@Nullable PyType suggested,
boolean positive,
boolean transformToDefinition,

View File

@@ -19,6 +19,7 @@ import com.jetbrains.python.PyCustomType;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.ast.PyAstFunction;
import com.jetbrains.python.ast.impl.PyPsiUtilsCore;
import com.jetbrains.python.ast.impl.PyUtilCore;
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
@@ -394,6 +395,20 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
.orElse(null);
}
if (callSite instanceof PyCallExpression callExpression && isTypeGuard(function, context.myContext)) {
var arguments = callSite.getArguments(function);
if (!arguments.isEmpty() && arguments.get(0) instanceof PyReferenceExpression refExpr) {
var qname = PyPsiUtilsCore.asQualifiedName(refExpr);
if (qname != null) {
var narrowedType = getTypeFromTypeGuardLikeType(function, context.myContext);
if (narrowedType != null) {
return Ref.create(PyNarrowedType.Companion.create(callSite, qname.toString(), narrowedType, callExpression, false));
}
}
}
return Ref.create(PyBuiltinCache.getInstance(function).getBoolType());
}
if (callSite instanceof PyCallExpression) {
final LanguageLevel level = "open".equals(functionQName)
? LanguageLevel.forElement(callSite)
@@ -1214,6 +1229,22 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
typeHintedWithName(function, context, TYPE_GUARD, TYPE_GUARD_EXT));
}
@Nullable
public static PyType getTypeFromTypeGuardLikeType(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
var returnType = getReturnTypeAnnotation(function, context);
if (returnType instanceof PyStringLiteralExpression stringLiteralExpression) {
returnType = PyUtil.createExpressionFromFragment(stringLiteralExpression.getStringValue(),
function.getContainingFile());
}
if (returnType instanceof PySubscriptionExpression subscriptionExpression) {
var indexExpression = subscriptionExpression.getIndexExpression();
if (indexExpression != null) {
return Ref.deref(getType(indexExpression, context));
}
}
return null;
}
private static boolean resolvesToQualifiedNames(@NotNull PyExpression expression, @NotNull TypeEvalContext context, String... names) {
final var qualifiedNames = resolveToQualifiedNames(expression, context);
return ContainerUtil.exists(names, qualifiedNames::contains);

View File

@@ -44,6 +44,12 @@ public class PyPrefixExpressionImpl extends PyElementImpl implements PyPrefixExp
@Override
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
if (getOperator() == PyTokenTypes.NOT_KEYWORD) {
final PyExpression operand = getOperand();
if (operand != null) {
final PyType operandType = context.getType(operand);
return (operandType instanceof PyNarrowedType) ? ((PyNarrowedType)operandType).negate()
: PyBuiltinCache.getInstance(this).getBoolType();
}
return PyBuiltinCache.getInstance(this).getBoolType();
}
final boolean isAwait = getOperator() == PyTokenTypes.AWAIT_KEYWORD;

View File

@@ -1,6 +1,7 @@
// Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.psi.impl;
import com.intellij.codeInsight.controlflow.ConditionalInstruction;
import com.intellij.codeInsight.controlflow.Instruction;
import com.intellij.diagnostic.PluginException;
import com.intellij.lang.ASTNode;
@@ -13,6 +14,7 @@ import com.intellij.psi.util.QualifiedName;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PythonRuntimeService;
import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator;
import com.jetbrains.python.codeInsight.controlflow.ReadWriteInstruction;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
@@ -415,7 +417,8 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
final String name = ((PyElement)target).getName();
if (scopeOwner != null &&
name != null &&
!ScopeUtil.getElementsOfAccessType(name, scopeOwner, ReadWriteInstruction.ACCESS.ASSERTTYPE).isEmpty()) {
(!ScopeUtil.getElementsOfAccessType(name, scopeOwner, ReadWriteInstruction.ACCESS.ASSERTTYPE).isEmpty()
|| target instanceof PyTargetExpression || target instanceof PyNamedParameter)) {
final PyType type = getTypeByControlFlow(name, context, anchor, scopeOwner);
if (type != null) {
return type;
@@ -480,8 +483,30 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
final List<Instruction> defs = PyDefUseUtil.getLatestDefs(scopeOwner, name, element, true, false, context);
// null means empty set of possible types, Ref(null) means Any
final @Nullable Ref<PyType> combinedType = StreamEx.of(defs)
.select(ReadWriteInstruction.class)
.map(instr -> instr.getType(context, anchor))
.map(instr -> {
if (instr instanceof ReadWriteInstruction readWriteInstruction) {
return readWriteInstruction.getType(context, anchor);
}
if (instr instanceof ConditionalInstruction conditionalInstruction) {
if (context.getType((PyTypedElement)conditionalInstruction.getCondition()) instanceof PyNarrowedType narrowedType) {
PyExpression[] arguments = narrowedType.getOriginal().getArguments();
if (arguments.length > 0) {
var firstArgument = arguments[0];
if (firstArgument instanceof PyReferenceExpression) {
return PyTypeAssertionEvaluator.createAssertionType(
context.getType(firstArgument),
narrowedType.getNarrowedType(),
conditionalInstruction.getResult() ^ narrowedType.getNegated(),
false,
context,
null);
}
}
}
}
return null;
})
.nonNull()
.collect(PyTypeUtil.toUnionFromRef());
return Ref.deref(combinedType);
}

View File

@@ -0,0 +1,31 @@
package com.jetbrains.python.psi.types
import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyElement
import com.jetbrains.python.psi.impl.PyBuiltinCache
/**
* Class is used for representing TypeGuard and TypeIs behavior
*/
class PyNarrowedType private constructor(
pyClass: PyClass,
val qname: String,
val narrowedType: PyType,
val original: PyCallExpression,
// used
val negated: Boolean)
: PyClassTypeImpl(pyClass, false) {
fun negate(): PyNarrowedType {
return PyNarrowedType(pyClass, qname, narrowedType, original, !negated)
}
companion object {
fun create(anchor: PyElement, name: String, narrowedType: PyType, original: PyCallExpression, negated: Boolean = false): PyNarrowedType? {
val pyClass = PyBuiltinCache.getInstance(anchor).getClass("bool")
if (pyClass == null) return null
return PyNarrowedType(pyClass, name, narrowedType, original, negated)
}
}
}

View File

@@ -15,6 +15,7 @@
*/
package com.jetbrains.python.refactoring;
import com.intellij.codeInsight.controlflow.ConditionalInstruction;
import com.intellij.codeInsight.controlflow.ControlFlow;
import com.intellij.codeInsight.controlflow.ControlFlowUtil;
import com.intellij.codeInsight.controlflow.Instruction;
@@ -29,6 +30,7 @@ import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyAugAssignmentStatementNavigator;
import com.jetbrains.python.psi.types.PyNarrowedType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -71,17 +73,41 @@ public final class PyDefUseUtil {
return new ArrayList<>(result);
}
private static Collection<Instruction> getLatestDefs(final String varName, final Instruction[] instructions, final int instr,
private static Collection<Instruction> getLatestDefs(final String varName, final Instruction[] instructions, final int startNum,
final boolean acceptTypeAssertions, final boolean acceptImplicitImports,
@NotNull final TypeEvalContext context) {
final Collection<Instruction> result = new LinkedHashSet<>();
ControlFlowUtil.iteratePrev(instr, instructions,
final HashMap<PyCallSiteExpression, ConditionalInstruction> pendingTypeGuard = new HashMap<>();
ControlFlowUtil.iteratePrev(startNum, instructions,
instruction -> {
if (instruction.num() < instructions[instr].num() && instruction instanceof CallInstruction callInstruction) {
if (callInstruction.isNoReturnCall(context)) return ControlFlowUtil.Operation.CONTINUE;
if (instruction instanceof CallInstruction callInstruction) {
var typeGuardInstruction = pendingTypeGuard.get(instruction.getElement());
if (acceptTypeAssertions && typeGuardInstruction != null) {
result.add(typeGuardInstruction);
return ControlFlowUtil.Operation.CONTINUE;
}
// not a back edge
if (instruction.num() < startNum && context.getOrigin() != null) {
// switch back to code analysis, since all other analyses are too aggressive
var newContext = TypeEvalContext.codeAnalysis(context.getOrigin().getProject(), context.getOrigin());
if (callInstruction.isNoReturnCall(newContext)) return ControlFlowUtil.Operation.CONTINUE;
}
}
final PsiElement element = instruction.getElement();
final PyImplicitImportNameDefiner implicit = PyUtil.as(element, PyImplicitImportNameDefiner.class);
if (acceptTypeAssertions
&& instruction instanceof ConditionalInstruction conditionalInstruction
&& instruction.num() < startNum) {
if (conditionalInstruction.getCondition() instanceof PyTypedElement typedElement && context.getOrigin() != null) {
// switch back to code analysis, since all other analyses are too aggressive
TypeEvalContext newContext = TypeEvalContext.codeAnalysis(context.getOrigin().getProject(), context.getOrigin());
if (newContext.getType(typedElement) instanceof PyNarrowedType narrowedType) {
if (narrowedType.getQname().equals(varName)) {
pendingTypeGuard.put(narrowedType.getOriginal(), conditionalInstruction);
}
}
}
}
if (instruction instanceof ReadWriteInstruction rwInstruction) {
final ReadWriteInstruction.ACCESS access = rwInstruction.getAccess();
if (access.isWriteAccess() || acceptTypeAssertions && access.isAssertTypeAccess()) {