[uast] IDEA-324930 Support providing nullability facts for UTypeReferenceExpression

It helps to get information about nullability for fields


(cherry picked from commit b396cfac3bd2586cd6e1de2fda131dcd135b1ecb)

IJ-CR-149011

GitOrigin-RevId: d1fa105b06c2444def5f801a26c436980b27c61f
This commit is contained in:
Marat Dinmukhametov
2024-11-08 19:34:23 +02:00
committed by intellij-monorepo-bot
parent 7703245b39
commit 9421bfc3bb
6 changed files with 234 additions and 50 deletions

View File

@@ -2,8 +2,13 @@
package org.jetbrains.uast.kotlin
import com.intellij.lang.Language
import org.jetbrains.kotlin.analysis.api.KaSession
import org.jetbrains.kotlin.analysis.api.types.KaType
import org.jetbrains.kotlin.analysis.api.types.KaTypeNullability
import org.jetbrains.kotlin.idea.KotlinLanguage
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlin.psi.KtExpression
import org.jetbrains.kotlin.psi.KtTypeReference
import org.jetbrains.uast.UExpression
import org.jetbrains.uast.analysis.UExpressionFact
import org.jetbrains.uast.analysis.UNullability
@@ -14,19 +19,38 @@ class FirKotlinUastAnalysisPlugin : UastAnalysisPlugin {
override val language: Language get() = KotlinLanguage.INSTANCE
override fun <T : Any> UExpression.getExpressionFact(fact: UExpressionFact<T>): T? {
val psiExpression = (sourcePsi as? KtExpression)?.unwrapBlockOrParenthesis() ?: return null
val ktElement = (sourcePsi as? KtElement) ?: return null
@Suppress("UNCHECKED_CAST")
return when (fact) {
UExpressionFact.UNullabilityFact -> {
analyzeForUast(psiExpression) {
when {
psiExpression.isDefinitelyNotNull -> UNullability.NOT_NULL
psiExpression.isDefinitelyNull -> UNullability.NULL
psiExpression.expressionType?.isMarkedNullable == true -> UNullability.NULLABLE
else -> UNullability.UNKNOWN
}
}
}
UExpressionFact.UNullabilityFact -> checkNullability(ktElement)
} as T
}
private fun checkNullability(ktElement: KtElement): UNullability? {
return analyzeForUast(ktElement) {
when (ktElement) {
is KtExpression -> checkNullabilityForExpression(ktElement)
is KtTypeReference -> checkNullabilityForType(ktElement.type)
else -> null
}
}
}
private fun KaSession.checkNullabilityForType(kaType: KaType): UNullability? {
return when (kaType.nullability) {
KaTypeNullability.NULLABLE -> UNullability.NULLABLE
KaTypeNullability.NON_NULLABLE -> UNullability.NOT_NULL
KaTypeNullability.UNKNOWN -> UNullability.UNKNOWN
}
}
private fun KaSession.checkNullabilityForExpression(expression: KtExpression): UNullability? {
val unwrappedExpression = expression.unwrapBlockOrParenthesis()
return when {
unwrappedExpression.isDefinitelyNotNull -> UNullability.NOT_NULL
unwrappedExpression.isDefinitelyNull -> UNullability.NULL
else -> unwrappedExpression.expressionType?.let { checkNullabilityForType(it) }
}
}
}

View File

@@ -7,6 +7,7 @@ import org.jetbrains.kotlin.idea.KotlinLanguage
import org.jetbrains.kotlin.idea.base.plugin.KotlinPluginMode
import org.jetbrains.kotlin.idea.test.KotlinLightCodeInsightFixtureTestCase
import org.jetbrains.uast.UExpression
import org.jetbrains.uast.UField
import org.jetbrains.uast.UastLanguagePlugin
import org.jetbrains.uast.analysis.UExpressionFact
import org.jetbrains.uast.analysis.UNullability
@@ -14,7 +15,6 @@ import org.jetbrains.uast.kotlin.FirKotlinUastAnalysisPlugin
import org.jetbrains.uast.test.common.kotlin.orFail
import org.jetbrains.uast.toUElement
import org.jetbrains.uast.visitor.AbstractUastVisitor
import kotlin.text.trimIndent
class FirUastAnalysisPluginTest : KotlinLightCodeInsightFixtureTestCase() {
@@ -164,12 +164,36 @@ class FirUastAnalysisPluginTest : KotlinLightCodeInsightFixtureTestCase() {
}
""".trimIndent())
fun `test nullable properties with primitive types`() = doTest("""
data class SomeClass(val a:/*NULLABLE*/String?, var b:/*NULLABLE*/Int? = null, val c:/*NULLABLE*/Int? = 1)
""".trimIndent())
fun `test non nullable properties with primitive types`() = doTest("""
data class SomeClass(val a:/*NOT_NULL*/String, var b:/*NOT_NULL*/Int = 1)
""".trimIndent())
fun `test complex properties`() = doTest("""
data class SomeClass(
val a:/*NOT_NULL*/String,
var b:/*NULLABLE*/Int? = null,
val c:/*NULLABLE*/D?,
val d:/*NOT_NULL*/D
)
class D
""".trimIndent())
private fun doTest(@Language("kotlin") source: String) {
val uastAnalysisPlugin = UastLanguagePlugin.byLanguage(KotlinLanguage.INSTANCE)?.analysisPlugin.orFail("Can not find analysis plugin for Kotlin")
assertInstanceOf(uastAnalysisPlugin, FirKotlinUastAnalysisPlugin::class.java)
val file = myFixture.configureByText("file.kt", source).toUElement().orFail("Cannot create UFile")
var visitAny = false
file.accept(object : AbstractUastVisitor() {
override fun visitField(node: UField): Boolean {
val typeReference = node.typeReference ?: return super.visitField(node)
return visitExpression(typeReference)
}
override fun visitExpression(node: UExpression): Boolean {
val uNullability = node.comments.firstOrNull()?.text
?.removePrefix("/*")

View File

@@ -7,14 +7,19 @@ import io.vavr.control.Option
import org.jetbrains.kotlin.descriptors.VariableDescriptor
import org.jetbrains.kotlin.idea.KotlinLanguage
import org.jetbrains.kotlin.idea.caches.resolve.analyze
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlin.psi.KtExpression
import org.jetbrains.kotlin.psi.KtReferenceExpression
import org.jetbrains.kotlin.psi.KtTypeReference
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.calls.smartcasts.DataFlowValue
import org.jetbrains.kotlin.resolve.calls.smartcasts.IdentifierInfo
import org.jetbrains.kotlin.resolve.calls.smartcasts.Nullability
import org.jetbrains.kotlin.types.FlexibleType
import org.jetbrains.kotlin.types.KotlinType
import org.jetbrains.kotlin.types.SimpleType
import org.jetbrains.kotlin.types.typeUtil.TypeNullability
import org.jetbrains.kotlin.types.typeUtil.nullability
import org.jetbrains.uast.UExpression
import org.jetbrains.uast.analysis.UExpressionFact
import org.jetbrains.uast.analysis.UNullability
@@ -38,47 +43,69 @@ class KotlinUastAnalysisPlugin : UastAnalysisPlugin {
}
override fun <T : Any> UExpression.getExpressionFact(fact: UExpressionFact<T>): T? {
val psiExpression = (sourcePsi as? KtExpression)?.unwrapBlockOrParenthesis() ?: return null
val ktElement = (sourcePsi as? KtElement) ?: return null
@Suppress("UNCHECKED_CAST")
return when (fact) {
UExpressionFact.UNullabilityFact -> {
val typeInfo = psiExpression.analyze()[BindingContext.EXPRESSION_TYPE_INFO, psiExpression]
val dfaInfo = typeInfo?.dataFlowInfo
val variableDescriptor = (psiExpression as? KtReferenceExpression)?.getVariableDescriptor(psiExpression.analyze())
val variableNullability = dfaInfo?.completeNullabilityInfo?.find { (value, _) ->
isValueCorrespondsToDescriptor(value, variableDescriptor)
}
when {
typeInfo?.type is SimpleType && typeInfo.type?.isMarkedNullable == false -> UNullability.NOT_NULL
variableNullability is Option.Some<*> -> {
val (_, info) = variableNullability.get()
when (info) {
Nullability.NULL -> UNullability.NULL
Nullability.NOT_NULL -> UNullability.NOT_NULL
else -> UNullability.UNKNOWN
}
}
else -> {
val type = typeInfo?.type
when {
type is FlexibleType -> {
when {
!type.lowerBound.isMarkedNullable && type.upperBound.isMarkedNullable -> UNullability.UNKNOWN
type.upperBound.isMarkedNullable -> UNullability.NULLABLE
else -> UNullability.NOT_NULL
}
}
type?.isMarkedNullable == true -> UNullability.NULLABLE
else -> UNullability.UNKNOWN
}
}
when (ktElement) {
is KtExpression -> getNullabilityForExpression(ktElement)
is KtTypeReference -> getNullabilityForTypeReference(ktElement)
else -> null
}
}
} as T
}
private fun getNullabilityForExpression(psiExpression: KtExpression): UNullability? {
val ktExpression = psiExpression.unwrapBlockOrParenthesis()
val typeInfo = ktExpression.analyze()[BindingContext.EXPRESSION_TYPE_INFO, ktExpression]
val dfaInfo = typeInfo?.dataFlowInfo
val variableDescriptor = (ktExpression as? KtReferenceExpression)?.getVariableDescriptor(ktExpression.analyze())
val variableNullability = dfaInfo?.completeNullabilityInfo?.find { (value, _) ->
isValueCorrespondsToDescriptor(value, variableDescriptor)
}
val type = typeInfo?.type ?: return null
return when {
type is SimpleType && type.isMarkedNullable == false -> UNullability.NOT_NULL
variableNullability is Option.Some<*> -> {
val (_, info) = variableNullability.get()
when (info) {
Nullability.NULL -> UNullability.NULL
Nullability.NOT_NULL -> UNullability.NOT_NULL
else -> UNullability.UNKNOWN
}
}
else -> getNullabilityForType(type)
}
}
private fun getNullabilityForTypeReference(typeReference: KtTypeReference): UNullability? {
val type = typeReference.analyze()[BindingContext.TYPE, typeReference] ?: return null
return getNullabilityForType(type)
}
private fun getNullabilityForType(type: KotlinType): UNullability {
return when {
type is FlexibleType -> {
when {
!type.lowerBound.isMarkedNullable && type.upperBound.isMarkedNullable -> UNullability.UNKNOWN
type.upperBound.isMarkedNullable -> UNullability.NULLABLE
else -> UNullability.NOT_NULL
}
}
type.nullability() == TypeNullability.NULLABLE -> UNullability.NULLABLE
type.nullability() == TypeNullability.NOT_NULL -> UNullability.NOT_NULL
else -> UNullability.UNKNOWN
}
}
private operator fun <T1, T2> Tuple2<T1, T2>.component1(): T1 = this._1
private operator fun <T1, T2> Tuple2<T1, T2>.component2(): T2 = this._2
}

View File

@@ -7,6 +7,7 @@ import org.jetbrains.kotlin.idea.KotlinLanguage
import org.jetbrains.kotlin.idea.base.plugin.KotlinPluginMode
import org.jetbrains.kotlin.idea.test.KotlinLightCodeInsightFixtureTestCase
import org.jetbrains.uast.UExpression
import org.jetbrains.uast.UField
import org.jetbrains.uast.UastLanguagePlugin
import org.jetbrains.uast.analysis.UExpressionFact
import org.jetbrains.uast.analysis.UNullability
@@ -164,6 +165,25 @@ class KotlinUastAnalysisPluginTest : KotlinLightCodeInsightFixtureTestCase() {
}
""".trimIndent())
fun `test nullable properties with primitive types`() = doTest("""
data class SomeClass(val a:/*NULLABLE*/String?, var b:/*NULLABLE*/Int? = null, val c:/*NULLABLE*/Int? = 1)
""".trimIndent())
fun `test non nullable properties with primitive types`() = doTest("""
data class SomeClass(val a:/*NOT_NULL*/String, var b:/*NOT_NULL*/Int = 1)
""".trimIndent())
fun `test complex properties`() = doTest("""
data class SomeClass(
val a:/*NOT_NULL*/String,
var b:/*NULLABLE*/Int? = null,
val c:/*NULLABLE*/D?,
val d:/*NOT_NULL*/D
)
class D
""".trimIndent())
private fun doTest(
@Language("kotlin") source: String
) {
@@ -173,6 +193,11 @@ class KotlinUastAnalysisPluginTest : KotlinLightCodeInsightFixtureTestCase() {
val file = myFixture.configureByText("file.kt", source).toUElement() ?: kFail("Cannot create UFile")
var visitAny = false
file.accept(object : AbstractUastVisitor() {
override fun visitField(node: UField): Boolean {
val typeReference = node.typeReference ?: return super.visitField(node)
return visitExpression(typeReference)
}
override fun visitExpression(node: UExpression): Boolean {
val uNullability = node.comments.firstOrNull()?.text
?.removePrefix("/*")

View File

@@ -1,30 +1,53 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package org.jetbrains.uast.java.analysis
import com.intellij.codeInsight.Nullability
import com.intellij.codeInsight.NullableNotNullManager
import com.intellij.codeInspection.dataFlow.CommonDataflow
import com.intellij.codeInspection.dataFlow.DfaNullability
import com.intellij.lang.java.JavaLanguage
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiExpression
import com.intellij.psi.PsiModifierListOwner
import com.intellij.psi.PsiTypeElement
import com.intellij.psi.util.findParentOfType
import org.jetbrains.uast.UExpression
import org.jetbrains.uast.analysis.UExpressionFact
import org.jetbrains.uast.analysis.UNullability
import org.jetbrains.uast.analysis.UastAnalysisPlugin
class JavaUastAnalysisPlugin : UastAnalysisPlugin {
override val language: JavaLanguage = JavaLanguage.INSTANCE
override fun <T : Any> UExpression.getExpressionFact(fact: UExpressionFact<T>): T? {
when (fact) {
is UExpressionFact.UNullabilityFact -> {
val psiExpression = sourcePsi as? PsiExpression ?: return null
val dfType = CommonDataflow.getDfType(psiExpression)
val nullability = DfaNullability.fromDfType(dfType).toUNullability()
@Suppress("UNCHECKED_CAST")
return nullability as T
return getNullability(sourcePsi) as T
}
}
}
override val language: JavaLanguage = JavaLanguage.INSTANCE
private fun getNullability(psiElement: PsiElement?): UNullability? =
when (psiElement) {
is PsiTypeElement -> getNullabilityForTypeReference(psiElement)
is PsiExpression -> getNullabilityForExpression(psiElement)
else -> null
}
private fun getNullabilityForTypeReference(typeElement: PsiTypeElement): UNullability? {
val modifierListOwner = typeElement.findParentOfType<PsiModifierListOwner>() ?: return null
return when (NullableNotNullManager.getNullability(modifierListOwner)) {
Nullability.NOT_NULL -> UNullability.NOT_NULL
Nullability.NULLABLE -> UNullability.NULLABLE
Nullability.UNKNOWN -> UNullability.UNKNOWN
}
}
private fun getNullabilityForExpression(expression: PsiExpression): UNullability? {
val dfType = CommonDataflow.getDfType(expression)
return DfaNullability.fromDfType(dfType).toUNullability()
}
private fun DfaNullability.toUNullability() = when (this) {
DfaNullability.NULL -> UNullability.NULL

View File

@@ -5,6 +5,7 @@ import com.intellij.lang.java.JavaLanguage
import com.intellij.testFramework.fixtures.LightJavaCodeInsightFixtureTestCase
import org.intellij.lang.annotations.Language
import org.jetbrains.uast.UExpression
import org.jetbrains.uast.UField
import org.jetbrains.uast.UastLanguagePlugin
import org.jetbrains.uast.analysis.UExpressionFact
import org.jetbrains.uast.analysis.UNullability
@@ -144,7 +145,61 @@ internal class JavaUastAnalysisPluginTest : LightJavaCodeInsightFixtureTestCase(
}
}
""".trimIndent())
@Test
fun `test nullable field with primitive types`() = doTest("""
import org.jetbrains.annotations.Nullable;
class SomeClass {
private @Nullable /*NULLABLE*/ String a;
private final @Nullable /*NULLABLE*/ Integer b;
private final @Nullable /*NULLABLE*/ Integer c;
SomeClass(@Nullable String a, @Nullable Integer b, @Nullable Integer c) {
this.a = a;
this.b = b;
this.c = c;
}
}
""".trimIndent())
@Test
fun `test non nullable fields with primitive types`() = doTest("""
import org.jetbrains.annotations.NotNull;
class SomeClass {
private final @NotNull /*NOT_NULL*/ String a;
private final @NotNull /*NOT_NULL*/ Integer b;
SomeClass(@NotNull String a, @NotNull Integer b) {
this.a = a;
this.b = b;
}
}
""".trimIndent())
@Test
fun `test complex fields`() = doTest("""
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
class SomeClass {
private final @NotNull /*NOT_NULL*/ String a;
private final @Nullable /*NULLABLE*/ Integer b;
private final @Nullable /*NULLABLE*/ D c;
private final @NotNull /*NOT_NULL*/ D d;
SomeClass(@NotNull String a, @Nullable Integer b, @Nullable D c, @NotNull D d) {
this.a = a;
this.b = b;
this.c = c;
this.d = d;
}
}
class D {}
""".trimIndent())
private fun doTest(@Language("java") source: String) {
val uastAnalysisPlugin = UastLanguagePlugin.byLanguage(JavaLanguage.INSTANCE)?.analysisPlugin
assertInstanceOf(uastAnalysisPlugin, JavaUastAnalysisPlugin::class.java)
@@ -155,6 +210,12 @@ internal class JavaUastAnalysisPluginTest : LightJavaCodeInsightFixtureTestCase(
var visitAny = false
file.accept(object : AbstractUastVisitor() {
override fun visitField(node: UField): Boolean {
val typeReference = node.typeReference ?: return super.visitField(node)
return visitExpression(typeReference)
}
override fun visitExpression(node: UExpression): Boolean {
val uNullability = node.comments.firstOrNull()?.text
?.removePrefix("/*")