[platform] IJ-CR-152034 Improve exception handling in context propagation

(cherry picked from commit 97892c3562f2df8a139aad1237505f4702078bab)

GitOrigin-RevId: b476ebb9ab69dc4a29b83fa5e8d924d0b1c0b09c
This commit is contained in:
Konstantin.Nisht
2024-12-23 11:53:16 +01:00
committed by intellij-monorepo-bot
parent 8722839bd2
commit 16ce6f60f8
2 changed files with 80 additions and 7 deletions

View File

@@ -371,4 +371,50 @@ class ThreadContextPropagationTest {
delay(100) delay(100)
Assertions.assertTrue(tracker.get()) Assertions.assertTrue(tracker.get())
} }
class MyFaultyIjElement1(val e: Throwable) : IntelliJContextElement, AbstractCoroutineContextElement(MyFaultyIjElement1) {
companion object Key : CoroutineContext.Key<MyFaultyIjElement1>
override fun produceChildElement(parentContext: CoroutineContext, isStructured: Boolean): IntelliJContextElement? = throw e
}
@Test
fun `faulty produceChildElement`() = timeoutRunBlocking {
val tracker = AtomicBoolean(false)
val ise = IllegalStateException("Boom")
withContext(MyCancellableIjElement(tracker) + MyFaultyIjElement1(ise)) {
val exception = Assertions.assertThrows(IllegalStateException::class.java) {
application.executeOnPooledThread { Assertions.fail() }
}
Assertions.assertEquals(ise, exception)
Assertions.assertTrue(tracker.get())
}
}
class MyFaultyIjElement2(val e: Throwable) : IntelliJContextElement, AbstractCoroutineContextElement(MyFaultyIjElement2) {
companion object Key : CoroutineContext.Key<MyFaultyIjElement2>
override fun produceChildElement(parentContext: CoroutineContext, isStructured: Boolean): IntelliJContextElement? = this
override fun beforeChildStarted(context: CoroutineContext) = throw e
}
@Test
fun `faulty beforeChildStarted`() = timeoutRunBlocking {
val tracker = AtomicBoolean(false)
val ise = Exception("Boom")
val rootJob = Job()
try {
withContext(MyIjElement(tracker) + MyFaultyIjElement2(ise) + rootJob) {
val exception = Assertions.assertThrows(RuntimeException::class.java) {
application.invokeAndWait { Assertions.fail() }
}
Assertions.assertEquals(ise, exception.cause)
Assertions.assertTrue(tracker.get())
}
}
catch (e: Throwable) {
Assertions.assertEquals(ise.message, e.message)
}
}
} }

View File

@@ -19,6 +19,7 @@ import com.intellij.openapi.util.Ref
import com.intellij.util.SmartList import com.intellij.util.SmartList
import com.intellij.util.SystemProperties import com.intellij.util.SystemProperties
import com.intellij.util.concurrency.SchedulingWrapper.MyScheduledFutureTask import com.intellij.util.concurrency.SchedulingWrapper.MyScheduledFutureTask
import com.intellij.util.containers.forEachGuaranteed
import kotlinx.coroutines.* import kotlinx.coroutines.*
import org.jetbrains.annotations.ApiStatus import org.jetbrains.annotations.ApiStatus
import org.jetbrains.annotations.ApiStatus.Internal import org.jetbrains.annotations.ApiStatus.Internal
@@ -115,15 +116,24 @@ data class ChildContext internal constructor(
@DelicateCoroutinesApi @DelicateCoroutinesApi
fun applyContextActions(installThreadContext: Boolean = true): AccessToken { fun applyContextActions(installThreadContext: Boolean = true): AccessToken {
val alreadyAppliedElements = mutableListOf<IntelliJContextElement>()
try {
for (elem in ijElements) {
elem.beforeChildStarted(context)
alreadyAppliedElements.add(elem)
}
}
catch (e: Throwable) {
cleanupList(e, alreadyAppliedElements.reversed()) {
it.afterChildCompleted(context)
}
}
val installToken = if (installThreadContext) { val installToken = if (installThreadContext) {
installThreadContext(context, replace = false) installThreadContext(context, replace = false)
} }
else { else {
AccessToken.EMPTY_ACCESS_TOKEN AccessToken.EMPTY_ACCESS_TOKEN
} }
for (elem in ijElements) {
elem.beforeChildStarted(context)
}
return object : AccessToken() { return object : AccessToken() {
override fun finish() { override fun finish() {
installToken.finish() installToken.finish()
@@ -190,10 +200,27 @@ private fun doCreateChildContext(debugName: @NonNls String, unconditionalCancell
private fun gatherAppliedChildContext(parentContext: CoroutineContext, isStructured: Boolean): Pair<CoroutineContext, List<IntelliJContextElement>> { private fun gatherAppliedChildContext(parentContext: CoroutineContext, isStructured: Boolean): Pair<CoroutineContext, List<IntelliJContextElement>> {
val ijElements = SmartList<IntelliJContextElement>() val ijElements = SmartList<IntelliJContextElement>()
try {
val newContext = parentContext.fold<CoroutineContext>(EmptyCoroutineContext) { old, elem -> val newContext = parentContext.fold<CoroutineContext>(EmptyCoroutineContext) { old, elem ->
old + produceChildContextElement(parentContext, elem, isStructured, ijElements) old + produceChildContextElement(parentContext, elem, isStructured, ijElements)
} }
return Pair(newContext, ijElements) return Pair(newContext, ijElements)
}
catch (e: Throwable) {
cleanupList(e, ijElements.reversed()) {
it.childCanceled(parentContext)
}
}
}
private fun <T> cleanupList(original: Throwable, list: List<T>, action: (T) -> Unit): Nothing {
try {
list.forEachGuaranteed(action)
}
catch (e: Throwable) {
original.addSuppressed(e)
}
throw original
} }
private fun produceChildContextElement(parentContext: CoroutineContext, element: CoroutineContext.Element, isStructured: Boolean, ijElements: MutableList<IntelliJContextElement>): CoroutineContext { private fun produceChildContextElement(parentContext: CoroutineContext, element: CoroutineContext.Element, isStructured: Boolean, ijElements: MutableList<IntelliJContextElement>): CoroutineContext {
@@ -239,7 +266,7 @@ internal fun captureRunnableThreadContext(command: Runnable): Runnable {
internal fun <V> captureCallableThreadContext(callable: Callable<V>): Callable<V> { internal fun <V> captureCallableThreadContext(callable: Callable<V>): Callable<V> {
val childContext = createChildContext(callable.toString()) val childContext = createChildContext(callable.toString())
var callable = captureClientIdInCallable(callable) var callable = captureClientIdInCallable(callable)
callable = ContextCallable(true, childContext, callable) callable = ContextCallable(true, childContext, callable, AtomicBoolean(false))
return callable return callable
} }