diff --git a/java/java-impl/src/com/intellij/refactoring/extractMethod/ExtractMethodProcessor.java b/java/java-impl/src/com/intellij/refactoring/extractMethod/ExtractMethodProcessor.java index b24e9999d96a..99fc00ccf72f 100644 --- a/java/java-impl/src/com/intellij/refactoring/extractMethod/ExtractMethodProcessor.java +++ b/java/java-impl/src/com/intellij/refactoring/extractMethod/ExtractMethodProcessor.java @@ -27,6 +27,7 @@ import com.intellij.codeInsight.intention.impl.AddNullableNotNullAnnotationFix; import com.intellij.codeInsight.navigation.NavigationUtil; import com.intellij.codeInspection.dataFlow.*; import com.intellij.codeInspection.dataFlow.instructions.BranchingInstruction; +import com.intellij.codeInspection.dataFlow.instructions.CheckReturnValueInstruction; import com.intellij.codeInspection.dataFlow.instructions.Instruction; import com.intellij.ide.DataManager; import com.intellij.ide.util.PropertiesComponent; @@ -75,6 +76,7 @@ import com.intellij.util.IncorrectOperationException; import com.intellij.util.VisibilityUtil; import com.intellij.util.containers.ContainerUtil; import com.intellij.util.containers.MultiMap; +import one.util.streamex.StreamEx; import org.jetbrains.annotations.NonNls; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -245,7 +247,7 @@ public class ExtractMethodProcessor implements MatchProvider { } catch (ControlFlowWrapper.ExitStatementsNotSameException e) { myExitStatements = myControlFlowWrapper.getExitStatements(); - myNotNullConditionalCheck = areAllExitPointsAreNotNull(getExpectedReturnType()); + myNotNullConditionalCheck = areAllExitPointsNotNull(getExpectedReturnType()); if (!myNotNullConditionalCheck) { showMultipleExitPointsMessage(); return false; @@ -280,17 +282,9 @@ public class ExtractMethodProcessor implements MatchProvider { if (myGenerateConditionalExit && myOutputVariables.length == 1) { if (!(myOutputVariables[0].getType() instanceof PsiPrimitiveType)) { - myNullConditionalCheck = true; - for (PsiStatement exitStatement : myExitStatements) { - if (exitStatement instanceof PsiReturnStatement) { - final PsiExpression returnValue = ((PsiReturnStatement)exitStatement).getReturnValue(); - myNullConditionalCheck &= returnValue == null || isNullInferred(returnValue.getText(), true); - } - } - myNullConditionalCheck &= isNullInferred(myOutputVariables[0].getName(), false); + myNullConditionalCheck = isNullInferred(myOutputVariables[0].getName()) && getReturnsNullability(true); } - - myNotNullConditionalCheck = areAllExitPointsAreNotNull(returnStatementType); + myNotNullConditionalCheck = areAllExitPointsNotNull(returnStatementType); } if (!myHasReturnStatementOutput && checkOutputVariablesCount() && !myNullConditionalCheck && !myNotNullConditionalCheck) { @@ -385,16 +379,73 @@ public class ExtractMethodProcessor implements MatchProvider { return null; } - private boolean areAllExitPointsAreNotNull(PsiType returnStatementType) { + private boolean areAllExitPointsNotNull(PsiType returnStatementType) { if (insertNotNullCheckIfPossible() && myControlFlowWrapper.getOutputVariables(false).length == 0) { - boolean isNotNull = returnStatementType != null && !PsiType.VOID.equals(returnStatementType); - for (PsiStatement statement : myExitStatements) { - if (statement instanceof PsiReturnStatement) { - final PsiExpression returnValue = ((PsiReturnStatement)statement).getReturnValue(); - isNotNull &= returnValue != null && !isNullInferred(returnValue.getText(), true); + if (returnStatementType != null && !PsiType.VOID.equals(returnStatementType)) { + return getReturnsNullability(false); + } + } + return false; + } + + /** + * @param nullsExpected when true check that all returned values are null, when false check that all returned values can't be null + */ + private boolean getReturnsNullability(boolean nullsExpected) { + PsiElement body = null; + if (myCodeFragmentMember instanceof PsiMethod) { + body = ((PsiMethod)myCodeFragmentMember).getBody(); + } + else if (myCodeFragmentMember instanceof PsiLambdaExpression) { + body = ((PsiLambdaExpression)myCodeFragmentMember).getBody(); + } + if (body == null) return false; + + Set returnedExpressions = StreamEx.of(myExitStatements) + .select(PsiReturnStatement.class) + .map(PsiReturnStatement::getReturnValue) + .nonNull() + .toSet(); + + for (Iterator it = returnedExpressions.iterator(); it.hasNext(); ) { + PsiType type = it.next().getType(); + if (nullsExpected) { + if (type == PsiType.NULL) { + it.remove(); // don't need to check + } + else if (type instanceof PsiPrimitiveType) { + return false; } } - return isNotNull; + else { + if (type == PsiType.NULL) { + return false; + } + else if (type instanceof PsiPrimitiveType) { + it.remove(); // don't need to check + } + } + } + if (returnedExpressions.isEmpty()) return true; + + class ReturnChecker extends StandardInstructionVisitor { + boolean myResult = true; + + @Override + public DfaInstructionState[] visitCheckReturnValue(CheckReturnValueInstruction instruction, + DataFlowRunner runner, + DfaMemoryState memState) { + PsiElement aReturn = instruction.getReturn(); + if (aReturn instanceof PsiExpression && returnedExpressions.contains(aReturn)) { + myResult &= nullsExpected ? memState.isNull(memState.peek()) : memState.isNotNull(memState.peek()); + } + return super.visitCheckReturnValue(instruction, runner, memState); + } + } + final StandardDataFlowRunner dfaRunner = new StandardDataFlowRunner(); + final ReturnChecker returnChecker = new ReturnChecker(); + if (dfaRunner.analyzeMethod(body, returnChecker) == RunnerResult.OK) { + return returnChecker.myResult; } return false; } @@ -403,7 +454,7 @@ public class ExtractMethodProcessor implements MatchProvider { return true; } - private boolean isNullInferred(String exprText, boolean trueSet) { + private boolean isNullInferred(String exprText) { final PsiCodeBlock block = myElementFactory.createCodeBlockFromText("{}", myElements[0]); for (PsiElement element : myElements) { block.add(element); @@ -416,7 +467,7 @@ public class ExtractMethodProcessor implements MatchProvider { final RunnerResult rc = dfaRunner.analyzeMethod(block, visitor); if (rc == RunnerResult.OK) { final Pair, Set> expressions = dfaRunner.getConstConditionalExpressions(); - final Set set = trueSet ? expressions.getFirst() : expressions.getSecond(); + final Set set = expressions.getSecond(); for (Instruction instruction : set) { if (instruction instanceof BranchingInstruction) { if (((BranchingInstruction)instruction).getPsiAnchor().getText().equals(statementFromText.getCondition().getText())) { diff --git a/java/java-tests/testData/refactoring/extractMethod/ExitPoints10.java b/java/java-tests/testData/refactoring/extractMethod/ExitPoints10.java new file mode 100644 index 000000000000..93f0474872a2 --- /dev/null +++ b/java/java-tests/testData/refactoring/extractMethod/ExitPoints10.java @@ -0,0 +1,16 @@ +class C { + private int[] list; + + private Integer find(int id) { + + int n = 0; + for (int n1 : list) { + n = n1; + if (n == id) { + return n <= 0 ? null : n; + } + } + + throw new RuntimeException(); + } +} \ No newline at end of file diff --git a/java/java-tests/testData/refactoring/extractMethod/ExitPoints11.java b/java/java-tests/testData/refactoring/extractMethod/ExitPoints11.java new file mode 100644 index 000000000000..e2f3f809fa4b --- /dev/null +++ b/java/java-tests/testData/refactoring/extractMethod/ExitPoints11.java @@ -0,0 +1,12 @@ +class C { + private int[] list; + + private int find(int id) { + for (int n : list) { + if (n == id) { + return n <= 0 ? 0 : n; + } + } + throw new RuntimeException(); + } +} \ No newline at end of file diff --git a/java/java-tests/testData/refactoring/extractMethod/ExitPoints11_after.java b/java/java-tests/testData/refactoring/extractMethod/ExitPoints11_after.java new file mode 100644 index 000000000000..d4bcf6c334b1 --- /dev/null +++ b/java/java-tests/testData/refactoring/extractMethod/ExitPoints11_after.java @@ -0,0 +1,21 @@ +import org.jetbrains.annotations.Nullable; + +class C { + private int[] list; + + private int find(int id) { + Integer n = newMethod(id); + if (n != null) return n; + throw new RuntimeException(); + } + + @Nullable + private Integer newMethod(int id) { + for (int n : list) { + if (n == id) { + return n <= 0 ? 0 : n; + } + } + return null; + } +} \ No newline at end of file diff --git a/java/java-tests/testSrc/com/intellij/java/refactoring/ExtractMethodTest.java b/java/java-tests/testSrc/com/intellij/java/refactoring/ExtractMethodTest.java index 7663f495bc71..609292caf8c0 100644 --- a/java/java-tests/testSrc/com/intellij/java/refactoring/ExtractMethodTest.java +++ b/java/java-tests/testSrc/com/intellij/java/refactoring/ExtractMethodTest.java @@ -95,6 +95,14 @@ public class ExtractMethodTest extends LightCodeInsightTestCase { doTest(); } + public void testExitPoints10() throws Exception { + doExitPointsTest(false); + } + + public void testExitPoints11() throws Exception { + doTest(); + } + public void testNotNullCheckNameConflicts() throws Exception { doTest(); }