diff --git a/platform/platform-tests/testSrc/com/intellij/util/concurrency/ThreadContextPropagationTest.kt b/platform/platform-tests/testSrc/com/intellij/util/concurrency/ThreadContextPropagationTest.kt index 8440fce56524..746715aecb86 100644 --- a/platform/platform-tests/testSrc/com/intellij/util/concurrency/ThreadContextPropagationTest.kt +++ b/platform/platform-tests/testSrc/com/intellij/util/concurrency/ThreadContextPropagationTest.kt @@ -371,4 +371,50 @@ class ThreadContextPropagationTest { delay(100) Assertions.assertTrue(tracker.get()) } + + + class MyFaultyIjElement1(val e: Throwable) : IntelliJContextElement, AbstractCoroutineContextElement(MyFaultyIjElement1) { + companion object Key : CoroutineContext.Key + + override fun produceChildElement(parentContext: CoroutineContext, isStructured: Boolean): IntelliJContextElement? = throw e + } + + @Test + fun `faulty produceChildElement`() = timeoutRunBlocking { + val tracker = AtomicBoolean(false) + val ise = IllegalStateException("Boom") + withContext(MyCancellableIjElement(tracker) + MyFaultyIjElement1(ise)) { + val exception = Assertions.assertThrows(IllegalStateException::class.java) { + application.executeOnPooledThread { Assertions.fail() } + } + Assertions.assertEquals(ise, exception) + Assertions.assertTrue(tracker.get()) + } + } + + class MyFaultyIjElement2(val e: Throwable) : IntelliJContextElement, AbstractCoroutineContextElement(MyFaultyIjElement2) { + companion object Key : CoroutineContext.Key + + override fun produceChildElement(parentContext: CoroutineContext, isStructured: Boolean): IntelliJContextElement? = this + override fun beforeChildStarted(context: CoroutineContext) = throw e + } + + @Test + fun `faulty beforeChildStarted`() = timeoutRunBlocking { + val tracker = AtomicBoolean(false) + val ise = Exception("Boom") + val rootJob = Job() + try { + withContext(MyIjElement(tracker) + MyFaultyIjElement2(ise) + rootJob) { + val exception = Assertions.assertThrows(RuntimeException::class.java) { + application.invokeAndWait { Assertions.fail() } + } + Assertions.assertEquals(ise, exception.cause) + Assertions.assertTrue(tracker.get()) + } + } + catch (e: Throwable) { + Assertions.assertEquals(ise.message, e.message) + } + } } diff --git a/platform/util/src/com/intellij/util/concurrency/propagation.kt b/platform/util/src/com/intellij/util/concurrency/propagation.kt index 6215b9fbcd17..0f86663d134b 100644 --- a/platform/util/src/com/intellij/util/concurrency/propagation.kt +++ b/platform/util/src/com/intellij/util/concurrency/propagation.kt @@ -19,6 +19,7 @@ import com.intellij.openapi.util.Ref import com.intellij.util.SmartList import com.intellij.util.SystemProperties import com.intellij.util.concurrency.SchedulingWrapper.MyScheduledFutureTask +import com.intellij.util.containers.forEachGuaranteed import kotlinx.coroutines.* import org.jetbrains.annotations.ApiStatus import org.jetbrains.annotations.ApiStatus.Internal @@ -115,15 +116,24 @@ data class ChildContext internal constructor( @DelicateCoroutinesApi fun applyContextActions(installThreadContext: Boolean = true): AccessToken { + val alreadyAppliedElements = mutableListOf() + try { + for (elem in ijElements) { + elem.beforeChildStarted(context) + alreadyAppliedElements.add(elem) + } + } + catch (e: Throwable) { + cleanupList(e, alreadyAppliedElements.reversed()) { + it.afterChildCompleted(context) + } + } val installToken = if (installThreadContext) { installThreadContext(context, replace = false) } else { AccessToken.EMPTY_ACCESS_TOKEN } - for (elem in ijElements) { - elem.beforeChildStarted(context) - } return object : AccessToken() { override fun finish() { installToken.finish() @@ -190,10 +200,27 @@ private fun doCreateChildContext(debugName: @NonNls String, unconditionalCancell private fun gatherAppliedChildContext(parentContext: CoroutineContext, isStructured: Boolean): Pair> { val ijElements = SmartList() - val newContext = parentContext.fold(EmptyCoroutineContext) { old, elem -> - old + produceChildContextElement(parentContext, elem, isStructured, ijElements) + try { + val newContext = parentContext.fold(EmptyCoroutineContext) { old, elem -> + old + produceChildContextElement(parentContext, elem, isStructured, ijElements) + } + return Pair(newContext, ijElements) } - return Pair(newContext, ijElements) + catch (e: Throwable) { + cleanupList(e, ijElements.reversed()) { + it.childCanceled(parentContext) + } + } +} + +private fun cleanupList(original: Throwable, list: List, action: (T) -> Unit): Nothing { + try { + list.forEachGuaranteed(action) + } + catch (e: Throwable) { + original.addSuppressed(e) + } + throw original } private fun produceChildContextElement(parentContext: CoroutineContext, element: CoroutineContext.Element, isStructured: Boolean, ijElements: MutableList): CoroutineContext { @@ -239,7 +266,7 @@ internal fun captureRunnableThreadContext(command: Runnable): Runnable { internal fun captureCallableThreadContext(callable: Callable): Callable { val childContext = createChildContext(callable.toString()) var callable = captureClientIdInCallable(callable) - callable = ContextCallable(true, childContext, callable) + callable = ContextCallable(true, childContext, callable, AtomicBoolean(false)) return callable }