inference incorporate optimization: avoid eq bounds propagation (IDEA-149952; IDEA-144822)

This commit is contained in:
Anna Kozlova
2016-01-19 20:38:41 +01:00
parent 5a0c674952
commit a3c0c965fb
5 changed files with 298 additions and 54 deletions

View File

@@ -218,54 +218,24 @@ public class InferenceIncorporationPhase {
if (inferenceVariable.getInstantiation() != PsiType.NULL) continue;
Map<InferenceBound, Set<PsiType>> boundsMap = myCurrentBounds.remove(inferenceVariable);
if (boundsMap == null) continue;
final Collection<PsiType> eqBounds = boundsMap.get(InferenceBound.EQ);
final List<PsiType> upperBounds = inferenceVariable.getBounds(InferenceBound.UPPER);
final List<PsiType> lowerBounds = inferenceVariable.getBounds(InferenceBound.LOWER);
needFurtherIncorporation |= crossVariables(inferenceVariable, upperBounds, lowerBounds, InferenceBound.LOWER);
needFurtherIncorporation |= crossVariables(inferenceVariable, lowerBounds, upperBounds, InferenceBound.UPPER);
if (eqBounds != null) {
needFurtherIncorporation |= eqCrossVariables(inferenceVariable, eqBounds);
final Set<PsiType> upperBounds = boundsMap.get(InferenceBound.UPPER);
final Set<PsiType> lowerBounds = boundsMap.get(InferenceBound.LOWER);
if (upperBounds != null) {
needFurtherIncorporation |= crossVariables(inferenceVariable, upperBounds, lowerBounds, InferenceBound.LOWER);
}
if (lowerBounds != null) {
needFurtherIncorporation |= crossVariables(inferenceVariable, lowerBounds, upperBounds, InferenceBound.UPPER);
}
}
return !needFurtherIncorporation;
}
/**
* a = b imply every bound of a matches a bound of b and vice versa
*/
private boolean eqCrossVariables(InferenceVariable inferenceVariable, Collection<PsiType> eqBounds) {
boolean needFurtherIncorporation = false;
for (PsiType eqBound : eqBounds) {
final InferenceVariable inferenceVar = mySession.getInferenceVariable(eqBound);
if (inferenceVar != null) {
for (InferenceBound inferenceBound : InferenceBound.values()) {
final List<PsiType> oldVarBounds = inferenceVar.getReadOnlyBounds(inferenceBound);
final List<PsiType> oldVariableBounds = inferenceVariable.getReadOnlyBounds(inferenceBound);
for (PsiType bound : oldVariableBounds) {
if (mySession.getInferenceVariable(bound) != inferenceVar) {
needFurtherIncorporation |= inferenceVar.addBound(bound, inferenceBound, this);
}
}
for (PsiType bound : oldVarBounds) {
if (mySession.getInferenceVariable(bound) != inferenceVariable) {
needFurtherIncorporation |= inferenceVariable.addBound(bound, inferenceBound, this);
}
}
}
}
}
return needFurtherIncorporation;
}
/**
* a < b & S <: a & b <: T imply S <: b & a <: T
*/
private boolean crossVariables(InferenceVariable inferenceVariable,
List<PsiType> upperBounds,
List<PsiType> lowerBounds,
Collection<PsiType> upperBounds,
Collection<PsiType> lowerBounds,
InferenceBound inferenceBound) {
final InferenceBound oppositeBound = inferenceBound == InferenceBound.LOWER
@@ -276,8 +246,10 @@ public class InferenceIncorporationPhase {
final InferenceVariable inferenceVar = mySession.getInferenceVariable(upperBound);
if (inferenceVar != null && inferenceVariable != inferenceVar) {
for (PsiType lowerBound : lowerBounds) {
result |= inferenceVar.addBound(lowerBound, inferenceBound, this);
if (lowerBounds != null) {
for (PsiType lowerBound : lowerBounds) {
result |= inferenceVar.addBound(lowerBound, inferenceBound, this);
}
}
for (PsiType varUpperBound : inferenceVar.getBounds(oppositeBound)) {

View File

@@ -0,0 +1,167 @@
import java.util.*;
import java.util.stream.*;
interface Graph<T> extends Iterable<T> {
double distance(T from, T to);
T startVertex();
default Stream<T> stream() {
return StreamSupport.stream(spliterator(), false);
}
}
class GeometricGraph<T extends Comparable<T>> implements Graph<GeometricGraph.Point<T>> {
private Point<T>[] vertices;
@SafeVarargs
public GeometricGraph(Point<T>... vertices) {
this.vertices = vertices;
}
@Override public double distance(Point<T> from, Point<T> to) {
double dx = to.x - from.x;
double dy = to.y - from.y;
return Math.sqrt(dx * dx + dy * dy);
}
@Override public Point<T> startVertex() {
return vertices[0];
}
@Override public Iterator<Point<T>> iterator() {
return Arrays.asList(vertices).iterator();
}
public static class Point<T extends Comparable<T>> implements Comparable<Point<T>> {
final T label;
final double x;
final double y;
public Point(T label, double x, double y) {
this.label = label;
this.x = x;
this.y = y;
}
@Override public String toString() {
return String.valueOf(label);
}
@Override public int compareTo(Point<T> o) {
return label.compareTo(o.label);
}
}
}
abstract class IndexedGraph<T> implements Graph<T> {
final double[][] distance;
public IndexedGraph(double[][] distance) {
this.distance = distance;
}
@Override public T startVertex() {
return getVertex(0);
}
@Override public double distance(T from, T to) {
return distance[getIndex(from)][getIndex(to)];
}
@Override public Iterator<T> iterator() {
return IntStream.range(0, getVertexCount()).mapToObj(this::getVertex).iterator();
}
private int getVertexCount() {
return distance.length;
}
protected abstract T getVertex(int index);
protected abstract int getIndex(T vertex);
}
class CharacterGraph extends IndexedGraph<Character> {
private final char base;
public CharacterGraph(double[][] distance, char base) {
super(distance);
this.base = base;
}
@Override protected Character getVertex(int index) {
return (char)(base + index);
}
@Override protected int getIndex(Character vertex) {
return vertex - base;
}
}
interface Crash<T> {
double INF = Double.POSITIVE_INFINITY;
void run(Graph<T> graph);
// From video, solution: 0->1->3->2->0 costing 21
CharacterGraph VIDEO = new CharacterGraph(new double[][] {
{0, 1, 15, 6},
{2, 0, 7, 3},
{9, 6, 0, 12},
{10, 4, 8, 0},
}, '0');
// From comments, solution: A->C->E->D->B->A costing 21
CharacterGraph COMMENT = new CharacterGraph(new double[][] {
{0, 3, 3, 1, INF},
{3, 0, 8, 5, INF},
{3, 8, 0, 1, 6},
{1, 5, 1, 0, 4},
{INF, INF, 6, 4, 0},
}, 'A');
// From http://lcm.csa.iisc.ernet.in/dsa/node186.html, solution a->c->d->e->f->b->a 48.39
GeometricGraph<Character> CARTESIAN = new GeometricGraph<>(
new GeometricGraph.Point<>('a', 0, 0),
new GeometricGraph.Point<>('b', 4, 3),
new GeometricGraph.Point<>('c', 1, 7),
new GeometricGraph.Point<>('d', 15, 7),
new GeometricGraph.Point<>('e', 15, 4),
new GeometricGraph.Point<>('f', 18, 0)
);
// http://www.geomidpoint.com/random/ whole world 100 points
GeometricGraph<Integer> GEO = new GeometricGraph<>(
new GeometricGraph.Point<>(1, 41.75887603, 45.54442576),
new GeometricGraph.Point<>(2, 25.95582633, 53.31372621),
new GeometricGraph.Point<>(3, 27.10968149, -148.3088281),
new GeometricGraph.Point<>(4, 47.89312627, 63.62800849),
new GeometricGraph.Point<>(5, -39.63985521, 22.72245952),
new GeometricGraph.Point<>(6, 10.75270177, 172.5620158),
new GeometricGraph.Point<>(7, -5.22786075, 174.0175703),
new GeometricGraph.Point<>(8, 6.09021552, -174.842083),
new GeometricGraph.Point<>(9, -41.93929433, -151.4679823),
new GeometricGraph.Point<>(10, 23.76929542, 52.02191021),
new GeometricGraph.Point<>(11, -27.07564288, -65.97458804),
new GeometricGraph.Point<>(12, -44.69115169, 3.9545051),
new GeometricGraph.Point<>(13, -5.43915001, -67.03701528),
new GeometricGraph.Point<>(14, -46.80168575, 167.7479893),
new GeometricGraph.Point<>(15, 3.37026877, -112.5740888),
new GeometricGraph.Point<>(16, 72.28180933, -29.27517743),
new GeometricGraph.Point<>(17, -42.08042944, -45.20059984),
new GeometricGraph.Point<>(18, -5.94878325, 65.81227912),
new GeometricGraph.Point<>(19, 0.82655482, -137.5756048),
new GeometricGraph.Point<>(20, -1.89649258, -34.85895025),
new GeometricGraph.Point<>(21, -8.41830692, 72.91705955),
new GeometricGraph.Point<>(22, -67.12475398, -30.98024614),
new GeometricGraph.Point<>(23, 3.27537627, -101.5926056),
new GeometricGraph.Point<>(24, 28.12320278, 171.0993409),
new GeometricGraph.Point<>(25, 43.81836686, 153.6367713),
new GeometricGraph.Point<>(26, -30.26453996, 125.4817181),
new GeometricGraph.Point<>(27, 30.42399561, 140.6854059),
new GeometricGraph.Point<>(28, 51.15497569, -118.603574),
new GeometricGraph.Point<>(29, -26.11317488, 165.3413163),
new GeometricGraph.Point<>(30, 17.3884151, 109.0310505),
new GeometricGraph.Point<>(31, -53.28586665, 113.3310133),
new GeometricGraph.Point<>(32, -36.91984178, 17.53340885),
new GeometricGraph.Point<>(33, -49.45998685, 111.9311892),
new GeometricGraph.Point<>(34, -63.1554812, 79.70629564),
new GeometricGraph.Point<>(35, 28.82084009, -9.14338737),
new GeometricGraph.Point<>(36, 37.52058234, -0.32285569),
new GeometricGraph.Point<>(37, 23.58437569, -138.7499972),
new GeometricGraph.Point<>(38, -28.07522086, -175.3760246),
new GeometricGraph.Point<>(39, -63.57013678, -100.3303656),
new GeometricGraph.Point<>(40, 16.2360492, -7.04890614),
new GeometricGraph.Point<>(41, 32.50586034, -93.26947618),
new GeometricGraph.Point<>(42, 0.37760791, 114.3663184),
new GeometricGraph.Point<>(43, -54.95460861, 173.9221499),
new GeometricGraph.Point<>(44, -62.88777314, 11.02357861),
new GeometricGraph.Point<>(45, -0.39552891, -10.24023055),
new GeometricGraph.Point<>(46, -32.82228853, 2.49278472),
new GeometricGraph.Point<>(47, -21.93177958, 104.425205),
new GeometricGraph.Point<>(48, 40.66726414, 1.38813168),
new GeometricGraph.Point<>(49, -10.17461981, -147.9987545),
new GeometricGraph.Point<>(50, -14.10034262, 115.3193397),
new GeometricGraph.Point<>(51, -60.18635059, -77.7990411)
);
}

View File

@@ -0,0 +1,71 @@
class GeometricGraph<T extends Comparable<T>> {
public static <K extends Comparable<K>> GeometricGraph<K> createGraph(Point<K>... points) {
return null;
}
}
class Point<T1 extends Comparable<T1>> implements Comparable<Point<T1>> {
@Override public int compareTo(Point<T1> o) {
return 0;
}
static <M extends Comparable<M>> Point<M> create(M m) {
return null;
}
}
class Graph {
GeometricGraph<Integer> GEO = GeometricGraph.createGraph(
Point.create(381),
Point.create(49),
Point.create(73),
Point.create(16),
Point.create(21),
Point.create(381),
Point.create(49),
Point.create(381),
Point.create(49),
Point.create(73),
Point.create(16),
Point.create(21),
Point.create(381),
Point.create(49),
Point.create(381),
Point.create(49),
Point.create(73),
Point.create(16),
Point.create(21),
Point.create(381),
Point.create(49),
Point.create(381),
Point.create(49),
Point.create(73),
Point.create(16),
Point.create(21),
Point.create(381),
Point.create(49),
Point.create(381),
Point.create(49),
Point.create(73),
Point.create(16),
Point.create(21),
Point.create(381),
Point.create(49),
Point.create(381),
Point.create(49),
Point.create(73),
Point.create(16),
Point.create(21),
Point.create(381),
Point.create(49),
Point.create(381),
Point.create(49),
Point.create(73),
Point.create(16),
Point.create(21),
Point.create(381),
Point.create(49),
Point.create(381)
);
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2000-2016 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intellij.codeInsight.daemon.lambda;
import com.intellij.codeInsight.daemon.LightDaemonAnalyzerTestCase;
import com.intellij.openapi.projectRoots.JavaSdkVersion;
import com.intellij.openapi.projectRoots.Sdk;
import com.intellij.testFramework.IdeaTestUtil;
import com.intellij.testFramework.PlatformTestUtil;
import org.jetbrains.annotations.NonNls;
public class InferencePerformanceTest extends LightDaemonAnalyzerTestCase {
@NonNls static final String BASE_PATH = "/codeInsight/daemonCodeAnalyzer/lambda/performance";
public void testPolyMethodCallArgumentPassedToVarargs() throws Exception {
PlatformTestUtil.startPerformanceTest("50 poly method calls passed to Arrays.asList", 10000, this::doTest).useLegacyScaling().assertTiming();
}
public void testDiamondConstructorCallPassedToVarargs() throws Exception {
PlatformTestUtil.startPerformanceTest("50 diamond constructor calls passed to Arrays.asList", 10000, this::doTest).useLegacyScaling().assertTiming();
}
private void doTest() {
IdeaTestUtil.setTestVersion(JavaSdkVersion.JDK_1_8, getModule(), getTestRootDisposable());
doTest(BASE_PATH + "/" + getTestName(false) + ".java", false, false);
}
@Override
protected Sdk getProjectJDK() {
return IdeaTestUtil.getMockJdk18();
}
}

View File

@@ -21,7 +21,6 @@ import com.intellij.openapi.projectRoots.JavaSdkVersion;
import com.intellij.openapi.projectRoots.Sdk;
import com.intellij.testFramework.IdeaTestUtil;
import com.intellij.testFramework.PlatformTestUtil;
import com.intellij.util.ThrowableRunnable;
import org.jetbrains.annotations.NonNls;
public class OverloadResolutionTest extends LightDaemonAnalyzerTestCase {
@@ -98,21 +97,11 @@ public class OverloadResolutionTest extends LightDaemonAnalyzerTestCase {
}
public void testManyOverloadsWithVarargs() throws Exception {
PlatformTestUtil.startPerformanceTest("Overload resolution with 14 overloads", 20000, new ThrowableRunnable() {
@Override
public void run() throws Throwable {
doTest(false);
}
}).useLegacyScaling().assertTiming();
PlatformTestUtil.startPerformanceTest("Overload resolution with 14 overloads", 10000, () -> doTest(false)).useLegacyScaling().assertTiming();
}
public void testConstructorOverloadsWithDiamonds() throws Exception {
PlatformTestUtil.startPerformanceTest("Overload resolution with chain constructor calls with diamonds", 10000, new ThrowableRunnable() {
@Override
public void run() throws Throwable {
doTest(false);
}
}).useLegacyScaling().assertTiming();
PlatformTestUtil.startPerformanceTest("Overload resolution with chain constructor calls with diamonds", 5000, () -> doTest(false)).useLegacyScaling().assertTiming();
}
public void testMultipleOverloadsWithNestedGeneric() throws Exception {