[platform] IJ-CR-152034: Add an ability to react to cancellation in IntelliJContextElement

(cherry picked from commit acf9427bd33648f34406796643025a30c70e9e5f)

GitOrigin-RevId: 4b68e73f865626ec21a793d4cc310226aa463fcc
This commit is contained in:
Konstantin Nisht
2024-12-23 11:52:32 +01:00
committed by intellij-monorepo-bot
parent d173a80534
commit 8722839bd2
8 changed files with 72 additions and 25 deletions

View File

@@ -122,6 +122,10 @@ internal class PlatformActivityTrackerService(private val scope: CoroutineScope)
removeObservedComputation(currentJob)
currentJob.complete()
}
override fun childCanceled(context: CoroutineContext) {
afterChildCompleted(context)
}
}
private fun enterConfiguration(kind: ActivityKey) : Any {

View File

@@ -341,11 +341,18 @@ class ThreadContextPropagationTest {
assertTrue(tracker.get())
}
class MyCancellableIjElement(val eventTracker: AtomicBoolean) : IntelliJContextElement, AbstractCoroutineContextElement(MyIjElement) {
companion object Key : CoroutineContext.Key<MyIjElement>
override fun produceChildElement(parentContext: CoroutineContext, isStructured: Boolean): IntelliJContextElement = this
override fun childCanceled(context: CoroutineContext) = eventTracker.set(true)
}
@Test
fun `cancellation of scheduled task triggers cleanup events`() = timeoutRunBlocking {
val service = AppExecutorUtil.createBoundedScheduledExecutorService("Test service", 1);
val tracker = AtomicBoolean(false)
withContext(MyIjElement(tracker)) {
withContext(MyCancellableIjElement(tracker)) {
val future = service.schedule(Callable<Unit> { Assertions.fail() }, 10, TimeUnit.SECONDS) // should never be executed
future.cancel(false)
}
@@ -353,10 +360,10 @@ class ThreadContextPropagationTest {
}
@Test
fun `cancellation of invokeLater triggers cleanup events`() = timeoutRunBlocking {
fun `expiration of invokeLater triggers cleanup events`() = timeoutRunBlocking {
val tracker = AtomicBoolean(false)
val expiration = AtomicBoolean(false)
withContext(Dispatchers.EDT + MyIjElement(tracker)) {
withContext(Dispatchers.EDT + MyCancellableIjElement(tracker)) {
@Suppress("ForbiddenInSuspectContextMethod")
application.invokeLater({ Assertions.fail() }, { expiration.get() })
expiration.set(true)

View File

@@ -2,6 +2,7 @@
- kotlin.coroutines.CoroutineContext$Element
- afterChildCompleted(kotlin.coroutines.CoroutineContext):V
- beforeChildStarted(kotlin.coroutines.CoroutineContext):V
- childCanceled(kotlin.coroutines.CoroutineContext):V
- produceChildElement(kotlin.coroutines.CoroutineContext,Z):com.intellij.concurrency.IntelliJContextElement
*f:com.intellij.concurrency.IntelliJContextElement$DefaultImpls
- s:fold(com.intellij.concurrency.IntelliJContextElement,java.lang.Object,kotlin.jvm.functions.Function2):java.lang.Object

View File

@@ -34,6 +34,7 @@ import kotlin.coroutines.CoroutineContext
* installThreadContext(initialContext + childElement) {
* // before the execution of a scheduled runnable,
* // the created element performs computations
* childElement.beforeChildStarted(currentThreadContext())
* try {
* runSomething()
* } finally {
@@ -46,6 +47,19 @@ import kotlin.coroutines.CoroutineContext
* }
* ```
*
* If `queueAsyncActivity` gets canceled for some reason, then [childCanceled] will be called, i.e:
* ```kotlin
* withContext(myIntelliJElement) {
* val initialContext = currentThreadContext()
* // the creation of a child context happens during the queueing
* val childElement = myIntelliJElement.produceChildElement(initialContext, ...)
* platformScheduler.queueAsyncActivity {
* // no `beforeChildStarted` is called here.
* childElement.childCanceled(currentThreadContext())
* }
* }
* ```
*
* ## Structured propagation
*
* Sometimes it is known that the parent process lives strictly longer than the child computation.
@@ -107,7 +121,8 @@ interface IntelliJContextElement : CoroutineContext.Element {
}
/**
* Called before the child computation is started
* Called before the child computation is started.
* The platform maintains an invariant that **only one** of [beforeChildStarted] and [childCanceled] will be called.
*
* @param context the context of the executing computation
*/
@@ -115,8 +130,15 @@ interface IntelliJContextElement : CoroutineContext.Element {
/**
* Called when the child computation ends its execution.
* [afterChildCompleted] will be called if there was a preceding [beforeChildStarted].
*
* @param context the context of the executing computation
*/
fun afterChildCompleted(context: CoroutineContext) {}
/**
* Called when the child computation was canceled without any attempt to execute it.
* The platform maintains an invariant that **only one** of [beforeChildStarted] and [childCanceled] will be called.
*/
fun childCanceled(context: CoroutineContext) {}
}

View File

@@ -3,9 +3,9 @@ package com.intellij.util.concurrency
import kotlinx.coroutines.InternalCoroutinesApi
import kotlinx.coroutines.Job
import java.util.concurrent.Callable
import java.util.concurrent.CancellationException
import java.util.concurrent.FutureTask
import java.util.concurrent.atomic.AtomicBoolean
/**
* A FutureTask, which cancels the given job when it's cancelled.
@@ -13,7 +13,9 @@ import java.util.concurrent.FutureTask
@OptIn(InternalCoroutinesApi::class)
internal class CancellationFutureTask<V>(
private val job: Job,
callable: Callable<V>,
callable: ContextCallable<V>,
val executionTracker: AtomicBoolean,
val context: ChildContext,
) : FutureTask<V>(callable) {
init {
@@ -27,8 +29,12 @@ internal class CancellationFutureTask<V>(
}
override fun cancel(mayInterruptIfRunning: Boolean): Boolean {
val isCurrentlyRunning = executionTracker.getAndSet(true)
val result = super.cancel(mayInterruptIfRunning)
job.cancel(null)
if (!isCurrentlyRunning) {
context.cancelAllIntelliJElements()
}
return result
}
}

View File

@@ -45,9 +45,7 @@ final class CancellationScheduledFutureTask<V> extends SchedulingWrapper.MySched
myJob.cancel(null);
}
if (!myExecutionTracker.getAndSet(true)) {
// todo: do we really need to trigger beforeChildStarted here?
//noinspection resource
myChildContext.applyContextActions(false).finish();
myChildContext.cancelAllIntelliJElements();
}
return result;
}

View File

@@ -4,11 +4,13 @@ package com.intellij.util.concurrency;
import com.intellij.concurrency.ThreadContext;
import com.intellij.openapi.application.AccessToken;
import kotlin.Unit;
import com.intellij.openapi.progress.ProcessCanceledException;
import kotlin.coroutines.Continuation;
import org.jetbrains.annotations.Async;
import org.jetbrains.annotations.NotNull;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
final class ContextCallable<V> implements Callable<V> {
@@ -19,6 +21,7 @@ final class ContextCallable<V> implements Callable<V> {
private final boolean myRoot;
private final @NotNull ChildContext myChildContext;
private final @NotNull Callable<? extends V> myCallable;
private final @NotNull AtomicBoolean myTracker;
static class RunResult<V, E extends Exception> {
Object result;
@@ -46,15 +49,23 @@ final class ContextCallable<V> implements Callable<V> {
}
@Async.Schedule
ContextCallable(boolean root, @NotNull ChildContext context, @NotNull Callable<? extends V> callable) {
ContextCallable(boolean root,
@NotNull ChildContext context,
@NotNull Callable<? extends V> callable,
@NotNull AtomicBoolean cancellationTracker) {
myRoot = root;
myChildContext = context;
myCallable = callable;
myTracker = cancellationTracker;
}
@Async.Execute
@Override
public V call() throws Exception {
if (myTracker.getAndSet(true)) {
// todo: add a cause of cancellation here as a suppressed runnable?
throw new ProcessCanceledException();
}
RunResult<V, Exception> result;
if (myRoot) {
result = myChildContext.runInChildContext(true, () -> {

View File

@@ -127,12 +127,18 @@ data class ChildContext internal constructor(
return object : AccessToken() {
override fun finish() {
installToken.finish()
for (elem in ijElements.reversed()) {
elem.afterChildCompleted(context)
ijElements.reversed().forEachGuaranteed {
it.afterChildCompleted(context)
}
}
}
}
fun cancelAllIntelliJElements() {
ijElements.forEachGuaranteed {
it.childCanceled(context)
}
}
}
@Internal
@@ -357,7 +363,7 @@ private fun <T> cleanupIfExpired(expiredCondition: Condition<in T>, childContext
if (expired) {
// Cancel to avoid a hanging child job which will prevent completion of the parent one.
childJob?.cancel(null)
childContext.applyContextActions(false).finish()
childContext.cancelAllIntelliJElements()
true
}
else {
@@ -373,11 +379,12 @@ internal fun <V> capturePropagationContext(c: Callable<V>): FutureTask<V> {
}
val callable = captureClientIdInCallable(c)
val childContext = createChildContext(c.toString())
val wrappedCallable = ContextCallable(false, childContext, callable)
val executionTracker = AtomicBoolean(false)
val wrappedCallable = ContextCallable(false, childContext, callable, executionTracker)
val cont = childContext.continuation
if (cont != null) {
val childJob = cont.context.job
return CancellationFutureTask(childJob, wrappedCallable)
return CancellationFutureTask(childJob, wrappedCallable, executionTracker, childContext)
}
else {
return FutureTask(wrappedCallable)
@@ -398,10 +405,7 @@ internal fun <V> capturePropagationContext(wrapper: SchedulingWrapper, c: Callab
val callable = captureClientIdInCallable(c)
val childContext = createChildContext("$c (scheduled: $ns)")
val cancellationTracker = AtomicBoolean(false)
val wrappedCallable = ContextCallable(false, childContext, Callable<V> {
cancellationTracker.takeOrThrowCancellationException()
callable.call()
})
val wrappedCallable = ContextCallable(false, childContext, callable, cancellationTracker)
val cont = childContext.continuation
return CancellationScheduledFutureTask(wrapper, childContext, cont?.context?.job, cancellationTracker, wrappedCallable, ns)
@@ -435,12 +439,6 @@ internal fun capturePropagationContext(
return CancellationScheduledFutureTask<Void>(wrapper, childContext, job, finalCapturedRunnable, ns, period)
}
private fun AtomicBoolean.takeOrThrowCancellationException() {
if (getAndSet(true)) {
throw ProcessCanceledException()
}
}
@ApiStatus.Internal
fun contextAwareCallable(r: Runnable): Callable<*> = ContextAwareCallable {
r.run()