diff --git a/plugins/kotlin/uast/uast-kotlin-fir/src/org/jetbrains/uast/kotlin/FirKotlinUastAnalysisPlugin.kt b/plugins/kotlin/uast/uast-kotlin-fir/src/org/jetbrains/uast/kotlin/FirKotlinUastAnalysisPlugin.kt index 38eb0eebf7f4..5589f649f1c6 100644 --- a/plugins/kotlin/uast/uast-kotlin-fir/src/org/jetbrains/uast/kotlin/FirKotlinUastAnalysisPlugin.kt +++ b/plugins/kotlin/uast/uast-kotlin-fir/src/org/jetbrains/uast/kotlin/FirKotlinUastAnalysisPlugin.kt @@ -2,8 +2,13 @@ package org.jetbrains.uast.kotlin import com.intellij.lang.Language +import org.jetbrains.kotlin.analysis.api.KaSession +import org.jetbrains.kotlin.analysis.api.types.KaType +import org.jetbrains.kotlin.analysis.api.types.KaTypeNullability import org.jetbrains.kotlin.idea.KotlinLanguage +import org.jetbrains.kotlin.psi.KtElement import org.jetbrains.kotlin.psi.KtExpression +import org.jetbrains.kotlin.psi.KtTypeReference import org.jetbrains.uast.UExpression import org.jetbrains.uast.analysis.UExpressionFact import org.jetbrains.uast.analysis.UNullability @@ -14,19 +19,38 @@ class FirKotlinUastAnalysisPlugin : UastAnalysisPlugin { override val language: Language get() = KotlinLanguage.INSTANCE override fun UExpression.getExpressionFact(fact: UExpressionFact): T? { - val psiExpression = (sourcePsi as? KtExpression)?.unwrapBlockOrParenthesis() ?: return null + val ktElement = (sourcePsi as? KtElement) ?: return null @Suppress("UNCHECKED_CAST") return when (fact) { - UExpressionFact.UNullabilityFact -> { - analyzeForUast(psiExpression) { - when { - psiExpression.isDefinitelyNotNull -> UNullability.NOT_NULL - psiExpression.isDefinitelyNull -> UNullability.NULL - psiExpression.expressionType?.isMarkedNullable == true -> UNullability.NULLABLE - else -> UNullability.UNKNOWN - } - } - } + UExpressionFact.UNullabilityFact -> checkNullability(ktElement) } as T } + + private fun checkNullability(ktElement: KtElement): UNullability? { + return analyzeForUast(ktElement) { + when (ktElement) { + is KtExpression -> checkNullabilityForExpression(ktElement) + is KtTypeReference -> checkNullabilityForType(ktElement.type) + else -> null + } + } + } + + private fun KaSession.checkNullabilityForType(kaType: KaType): UNullability? { + return when (kaType.nullability) { + KaTypeNullability.NULLABLE -> UNullability.NULLABLE + KaTypeNullability.NON_NULLABLE -> UNullability.NOT_NULL + KaTypeNullability.UNKNOWN -> UNullability.UNKNOWN + } + } + + private fun KaSession.checkNullabilityForExpression(expression: KtExpression): UNullability? { + val unwrappedExpression = expression.unwrapBlockOrParenthesis() + + return when { + unwrappedExpression.isDefinitelyNotNull -> UNullability.NOT_NULL + unwrappedExpression.isDefinitelyNull -> UNullability.NULL + else -> unwrappedExpression.expressionType?.let { checkNullabilityForType(it) } + } + } } \ No newline at end of file diff --git a/plugins/kotlin/uast/uast-kotlin-fir/tests/test/org/jetbrains/fir/uast/test/FirUastAnalysisPluginTest.kt b/plugins/kotlin/uast/uast-kotlin-fir/tests/test/org/jetbrains/fir/uast/test/FirUastAnalysisPluginTest.kt index 296aef7a6b6e..01a657e70d11 100644 --- a/plugins/kotlin/uast/uast-kotlin-fir/tests/test/org/jetbrains/fir/uast/test/FirUastAnalysisPluginTest.kt +++ b/plugins/kotlin/uast/uast-kotlin-fir/tests/test/org/jetbrains/fir/uast/test/FirUastAnalysisPluginTest.kt @@ -7,6 +7,7 @@ import org.jetbrains.kotlin.idea.KotlinLanguage import org.jetbrains.kotlin.idea.base.plugin.KotlinPluginMode import org.jetbrains.kotlin.idea.test.KotlinLightCodeInsightFixtureTestCase import org.jetbrains.uast.UExpression +import org.jetbrains.uast.UField import org.jetbrains.uast.UastLanguagePlugin import org.jetbrains.uast.analysis.UExpressionFact import org.jetbrains.uast.analysis.UNullability @@ -14,7 +15,6 @@ import org.jetbrains.uast.kotlin.FirKotlinUastAnalysisPlugin import org.jetbrains.uast.test.common.kotlin.orFail import org.jetbrains.uast.toUElement import org.jetbrains.uast.visitor.AbstractUastVisitor -import kotlin.text.trimIndent class FirUastAnalysisPluginTest : KotlinLightCodeInsightFixtureTestCase() { @@ -164,12 +164,36 @@ class FirUastAnalysisPluginTest : KotlinLightCodeInsightFixtureTestCase() { } """.trimIndent()) + fun `test nullable properties with primitive types`() = doTest(""" + data class SomeClass(val a:/*NULLABLE*/String?, var b:/*NULLABLE*/Int? = null, val c:/*NULLABLE*/Int? = 1) + """.trimIndent()) + + fun `test non nullable properties with primitive types`() = doTest(""" + data class SomeClass(val a:/*NOT_NULL*/String, var b:/*NOT_NULL*/Int = 1) + """.trimIndent()) + + fun `test complex properties`() = doTest(""" + data class SomeClass( + val a:/*NOT_NULL*/String, + var b:/*NULLABLE*/Int? = null, + val c:/*NULLABLE*/D?, + val d:/*NOT_NULL*/D + ) + + class D + """.trimIndent()) + private fun doTest(@Language("kotlin") source: String) { val uastAnalysisPlugin = UastLanguagePlugin.byLanguage(KotlinLanguage.INSTANCE)?.analysisPlugin.orFail("Can not find analysis plugin for Kotlin") assertInstanceOf(uastAnalysisPlugin, FirKotlinUastAnalysisPlugin::class.java) val file = myFixture.configureByText("file.kt", source).toUElement().orFail("Cannot create UFile") var visitAny = false file.accept(object : AbstractUastVisitor() { + override fun visitField(node: UField): Boolean { + val typeReference = node.typeReference ?: return super.visitField(node) + return visitExpression(typeReference) + } + override fun visitExpression(node: UExpression): Boolean { val uNullability = node.comments.firstOrNull()?.text ?.removePrefix("/*") diff --git a/plugins/kotlin/uast/uast-kotlin-idea/src/org/jetbrains/uast/kotlin/analysis/KotlinUastAnalysisPlugin.kt b/plugins/kotlin/uast/uast-kotlin-idea/src/org/jetbrains/uast/kotlin/analysis/KotlinUastAnalysisPlugin.kt index d1ac338e79de..4a16c47518ff 100644 --- a/plugins/kotlin/uast/uast-kotlin-idea/src/org/jetbrains/uast/kotlin/analysis/KotlinUastAnalysisPlugin.kt +++ b/plugins/kotlin/uast/uast-kotlin-idea/src/org/jetbrains/uast/kotlin/analysis/KotlinUastAnalysisPlugin.kt @@ -7,14 +7,19 @@ import io.vavr.control.Option import org.jetbrains.kotlin.descriptors.VariableDescriptor import org.jetbrains.kotlin.idea.KotlinLanguage import org.jetbrains.kotlin.idea.caches.resolve.analyze +import org.jetbrains.kotlin.psi.KtElement import org.jetbrains.kotlin.psi.KtExpression import org.jetbrains.kotlin.psi.KtReferenceExpression +import org.jetbrains.kotlin.psi.KtTypeReference import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.calls.smartcasts.DataFlowValue import org.jetbrains.kotlin.resolve.calls.smartcasts.IdentifierInfo import org.jetbrains.kotlin.resolve.calls.smartcasts.Nullability import org.jetbrains.kotlin.types.FlexibleType +import org.jetbrains.kotlin.types.KotlinType import org.jetbrains.kotlin.types.SimpleType +import org.jetbrains.kotlin.types.typeUtil.TypeNullability +import org.jetbrains.kotlin.types.typeUtil.nullability import org.jetbrains.uast.UExpression import org.jetbrains.uast.analysis.UExpressionFact import org.jetbrains.uast.analysis.UNullability @@ -38,47 +43,69 @@ class KotlinUastAnalysisPlugin : UastAnalysisPlugin { } override fun UExpression.getExpressionFact(fact: UExpressionFact): T? { - val psiExpression = (sourcePsi as? KtExpression)?.unwrapBlockOrParenthesis() ?: return null + val ktElement = (sourcePsi as? KtElement) ?: return null @Suppress("UNCHECKED_CAST") return when (fact) { UExpressionFact.UNullabilityFact -> { - val typeInfo = psiExpression.analyze()[BindingContext.EXPRESSION_TYPE_INFO, psiExpression] - val dfaInfo = typeInfo?.dataFlowInfo - val variableDescriptor = (psiExpression as? KtReferenceExpression)?.getVariableDescriptor(psiExpression.analyze()) - - val variableNullability = dfaInfo?.completeNullabilityInfo?.find { (value, _) -> - isValueCorrespondsToDescriptor(value, variableDescriptor) - } - when { - typeInfo?.type is SimpleType && typeInfo.type?.isMarkedNullable == false -> UNullability.NOT_NULL - variableNullability is Option.Some<*> -> { - val (_, info) = variableNullability.get() - when (info) { - Nullability.NULL -> UNullability.NULL - Nullability.NOT_NULL -> UNullability.NOT_NULL - else -> UNullability.UNKNOWN - } - } - else -> { - val type = typeInfo?.type - when { - type is FlexibleType -> { - when { - !type.lowerBound.isMarkedNullable && type.upperBound.isMarkedNullable -> UNullability.UNKNOWN - type.upperBound.isMarkedNullable -> UNullability.NULLABLE - else -> UNullability.NOT_NULL - } - } - type?.isMarkedNullable == true -> UNullability.NULLABLE - else -> UNullability.UNKNOWN - } - } + when (ktElement) { + is KtExpression -> getNullabilityForExpression(ktElement) + is KtTypeReference -> getNullabilityForTypeReference(ktElement) + else -> null } } } as T } + private fun getNullabilityForExpression(psiExpression: KtExpression): UNullability? { + val ktExpression = psiExpression.unwrapBlockOrParenthesis() + + val typeInfo = ktExpression.analyze()[BindingContext.EXPRESSION_TYPE_INFO, ktExpression] + val dfaInfo = typeInfo?.dataFlowInfo + val variableDescriptor = (ktExpression as? KtReferenceExpression)?.getVariableDescriptor(ktExpression.analyze()) + + val variableNullability = dfaInfo?.completeNullabilityInfo?.find { (value, _) -> + isValueCorrespondsToDescriptor(value, variableDescriptor) + } + + val type = typeInfo?.type ?: return null + + return when { + type is SimpleType && type.isMarkedNullable == false -> UNullability.NOT_NULL + variableNullability is Option.Some<*> -> { + val (_, info) = variableNullability.get() + when (info) { + Nullability.NULL -> UNullability.NULL + Nullability.NOT_NULL -> UNullability.NOT_NULL + else -> UNullability.UNKNOWN + } + } + else -> getNullabilityForType(type) + } + } + + private fun getNullabilityForTypeReference(typeReference: KtTypeReference): UNullability? { + val type = typeReference.analyze()[BindingContext.TYPE, typeReference] ?: return null + + return getNullabilityForType(type) + } + + private fun getNullabilityForType(type: KotlinType): UNullability { + return when { + type is FlexibleType -> { + when { + !type.lowerBound.isMarkedNullable && type.upperBound.isMarkedNullable -> UNullability.UNKNOWN + type.upperBound.isMarkedNullable -> UNullability.NULLABLE + else -> UNullability.NOT_NULL + } + } + + type.nullability() == TypeNullability.NULLABLE -> UNullability.NULLABLE + type.nullability() == TypeNullability.NOT_NULL -> UNullability.NOT_NULL + else -> UNullability.UNKNOWN + } + } + private operator fun Tuple2.component1(): T1 = this._1 private operator fun Tuple2.component2(): T2 = this._2 } diff --git a/plugins/kotlin/uast/uast-kotlin-idea/tests/test/org/jetbrains/uast/test/kotlin/analysis/KotlinUastAnalysisPluginTest.kt b/plugins/kotlin/uast/uast-kotlin-idea/tests/test/org/jetbrains/uast/test/kotlin/analysis/KotlinUastAnalysisPluginTest.kt index c368275da296..c3ecff22a55c 100644 --- a/plugins/kotlin/uast/uast-kotlin-idea/tests/test/org/jetbrains/uast/test/kotlin/analysis/KotlinUastAnalysisPluginTest.kt +++ b/plugins/kotlin/uast/uast-kotlin-idea/tests/test/org/jetbrains/uast/test/kotlin/analysis/KotlinUastAnalysisPluginTest.kt @@ -7,6 +7,7 @@ import org.jetbrains.kotlin.idea.KotlinLanguage import org.jetbrains.kotlin.idea.base.plugin.KotlinPluginMode import org.jetbrains.kotlin.idea.test.KotlinLightCodeInsightFixtureTestCase import org.jetbrains.uast.UExpression +import org.jetbrains.uast.UField import org.jetbrains.uast.UastLanguagePlugin import org.jetbrains.uast.analysis.UExpressionFact import org.jetbrains.uast.analysis.UNullability @@ -164,6 +165,25 @@ class KotlinUastAnalysisPluginTest : KotlinLightCodeInsightFixtureTestCase() { } """.trimIndent()) + fun `test nullable properties with primitive types`() = doTest(""" + data class SomeClass(val a:/*NULLABLE*/String?, var b:/*NULLABLE*/Int? = null, val c:/*NULLABLE*/Int? = 1) + """.trimIndent()) + + fun `test non nullable properties with primitive types`() = doTest(""" + data class SomeClass(val a:/*NOT_NULL*/String, var b:/*NOT_NULL*/Int = 1) + """.trimIndent()) + + fun `test complex properties`() = doTest(""" + data class SomeClass( + val a:/*NOT_NULL*/String, + var b:/*NULLABLE*/Int? = null, + val c:/*NULLABLE*/D?, + val d:/*NOT_NULL*/D + ) + + class D + """.trimIndent()) + private fun doTest( @Language("kotlin") source: String ) { @@ -173,6 +193,11 @@ class KotlinUastAnalysisPluginTest : KotlinLightCodeInsightFixtureTestCase() { val file = myFixture.configureByText("file.kt", source).toUElement() ?: kFail("Cannot create UFile") var visitAny = false file.accept(object : AbstractUastVisitor() { + override fun visitField(node: UField): Boolean { + val typeReference = node.typeReference ?: return super.visitField(node) + return visitExpression(typeReference) + } + override fun visitExpression(node: UExpression): Boolean { val uNullability = node.comments.firstOrNull()?.text ?.removePrefix("/*") diff --git a/uast/uast-java-ide/src/org/jetbrains/uast/java/analysis/JavaUastAnalysisPlugin.kt b/uast/uast-java-ide/src/org/jetbrains/uast/java/analysis/JavaUastAnalysisPlugin.kt index 756a212dc599..252386b54485 100644 --- a/uast/uast-java-ide/src/org/jetbrains/uast/java/analysis/JavaUastAnalysisPlugin.kt +++ b/uast/uast-java-ide/src/org/jetbrains/uast/java/analysis/JavaUastAnalysisPlugin.kt @@ -1,30 +1,53 @@ // Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. package org.jetbrains.uast.java.analysis +import com.intellij.codeInsight.Nullability +import com.intellij.codeInsight.NullableNotNullManager import com.intellij.codeInspection.dataFlow.CommonDataflow import com.intellij.codeInspection.dataFlow.DfaNullability import com.intellij.lang.java.JavaLanguage +import com.intellij.psi.PsiElement import com.intellij.psi.PsiExpression +import com.intellij.psi.PsiModifierListOwner +import com.intellij.psi.PsiTypeElement +import com.intellij.psi.util.findParentOfType import org.jetbrains.uast.UExpression import org.jetbrains.uast.analysis.UExpressionFact import org.jetbrains.uast.analysis.UNullability import org.jetbrains.uast.analysis.UastAnalysisPlugin class JavaUastAnalysisPlugin : UastAnalysisPlugin { + override val language: JavaLanguage = JavaLanguage.INSTANCE + override fun UExpression.getExpressionFact(fact: UExpressionFact): T? { when (fact) { is UExpressionFact.UNullabilityFact -> { - val psiExpression = sourcePsi as? PsiExpression ?: return null - val dfType = CommonDataflow.getDfType(psiExpression) - val nullability = DfaNullability.fromDfType(dfType).toUNullability() - @Suppress("UNCHECKED_CAST") - return nullability as T + return getNullability(sourcePsi) as T } } } - override val language: JavaLanguage = JavaLanguage.INSTANCE + private fun getNullability(psiElement: PsiElement?): UNullability? = + when (psiElement) { + is PsiTypeElement -> getNullabilityForTypeReference(psiElement) + is PsiExpression -> getNullabilityForExpression(psiElement) + else -> null + } + + private fun getNullabilityForTypeReference(typeElement: PsiTypeElement): UNullability? { + val modifierListOwner = typeElement.findParentOfType() ?: return null + return when (NullableNotNullManager.getNullability(modifierListOwner)) { + Nullability.NOT_NULL -> UNullability.NOT_NULL + Nullability.NULLABLE -> UNullability.NULLABLE + Nullability.UNKNOWN -> UNullability.UNKNOWN + } + } + + private fun getNullabilityForExpression(expression: PsiExpression): UNullability? { + val dfType = CommonDataflow.getDfType(expression) + return DfaNullability.fromDfType(dfType).toUNullability() + } private fun DfaNullability.toUNullability() = when (this) { DfaNullability.NULL -> UNullability.NULL diff --git a/uast/uast-tests/test/org/jetbrains/uast/test/java/analysis/JavaUastAnalysisPluginTest.kt b/uast/uast-tests/test/org/jetbrains/uast/test/java/analysis/JavaUastAnalysisPluginTest.kt index ba1012b4b2c7..a83b1cc05c65 100644 --- a/uast/uast-tests/test/org/jetbrains/uast/test/java/analysis/JavaUastAnalysisPluginTest.kt +++ b/uast/uast-tests/test/org/jetbrains/uast/test/java/analysis/JavaUastAnalysisPluginTest.kt @@ -5,6 +5,7 @@ import com.intellij.lang.java.JavaLanguage import com.intellij.testFramework.fixtures.LightJavaCodeInsightFixtureTestCase import org.intellij.lang.annotations.Language import org.jetbrains.uast.UExpression +import org.jetbrains.uast.UField import org.jetbrains.uast.UastLanguagePlugin import org.jetbrains.uast.analysis.UExpressionFact import org.jetbrains.uast.analysis.UNullability @@ -144,7 +145,61 @@ internal class JavaUastAnalysisPluginTest : LightJavaCodeInsightFixtureTestCase( } } """.trimIndent()) + + @Test + fun `test nullable field with primitive types`() = doTest(""" + import org.jetbrains.annotations.Nullable; + class SomeClass { + private @Nullable /*NULLABLE*/ String a; + private final @Nullable /*NULLABLE*/ Integer b; + private final @Nullable /*NULLABLE*/ Integer c; + + SomeClass(@Nullable String a, @Nullable Integer b, @Nullable Integer c) { + this.a = a; + this.b = b; + this.c = c; + } + } + """.trimIndent()) + + @Test + fun `test non nullable fields with primitive types`() = doTest(""" + import org.jetbrains.annotations.NotNull; + + class SomeClass { + private final @NotNull /*NOT_NULL*/ String a; + private final @NotNull /*NOT_NULL*/ Integer b; + + SomeClass(@NotNull String a, @NotNull Integer b) { + this.a = a; + this.b = b; + } + } + """.trimIndent()) + + @Test + fun `test complex fields`() = doTest(""" + import org.jetbrains.annotations.NotNull; + import org.jetbrains.annotations.Nullable; + + class SomeClass { + private final @NotNull /*NOT_NULL*/ String a; + private final @Nullable /*NULLABLE*/ Integer b; + private final @Nullable /*NULLABLE*/ D c; + private final @NotNull /*NOT_NULL*/ D d; + + SomeClass(@NotNull String a, @Nullable Integer b, @Nullable D c, @NotNull D d) { + this.a = a; + this.b = b; + this.c = c; + this.d = d; + } + } + + class D {} + """.trimIndent()) + private fun doTest(@Language("java") source: String) { val uastAnalysisPlugin = UastLanguagePlugin.byLanguage(JavaLanguage.INSTANCE)?.analysisPlugin assertInstanceOf(uastAnalysisPlugin, JavaUastAnalysisPlugin::class.java) @@ -155,6 +210,12 @@ internal class JavaUastAnalysisPluginTest : LightJavaCodeInsightFixtureTestCase( var visitAny = false file.accept(object : AbstractUastVisitor() { + + override fun visitField(node: UField): Boolean { + val typeReference = node.typeReference ?: return super.visitField(node) + return visitExpression(typeReference) + } + override fun visitExpression(node: UExpression): Boolean { val uNullability = node.comments.firstOrNull()?.text ?.removePrefix("/*")