[maven][wsl] allow retrying sockets creation when communicating with WSL IDEA-340867

WSL mirrored network has a delay between socket opening in Linux and socket opening in Windows host.

GitOrigin-RevId: 1f5c1f0ad5329e989883ff732d6c480153085470
This commit is contained in:
Nikita.Skvortsov
2024-02-28 16:56:21 +01:00
committed by intellij-monorepo-bot
parent 456a327474
commit dc5b09b987
4 changed files with 85 additions and 40 deletions

View File

@@ -1,4 +1,4 @@
// Copyright 2000-2021 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.execution.rmi;
import com.intellij.execution.*;
@@ -31,6 +31,8 @@ import java.rmi.Remote;
import java.rmi.RemoteException;
import java.rmi.registry.LocateRegistry;
import java.rmi.registry.Registry;
import java.rmi.server.RMIClientSocketFactory;
import java.rmi.server.RMISocketFactory;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
@@ -329,7 +331,7 @@ public abstract class RemoteProcessSupport<Target, EntryPoint, Parameters> {
private EntryPoint acquire(final RunningInfo info) throws Exception {
EntryPoint result = RemoteUtil.executeWithClassLoader(() -> {
Registry registry = LocateRegistry.getRegistry(info.host, info.port);
Registry registry = LocateRegistry.getRegistry(info.host, info.port, getClientSocketFactory());
Remote remote = Objects.requireNonNull(registry.lookup(info.name));
if (myValueClass.isInstance(remote)) {
@@ -345,6 +347,16 @@ public abstract class RemoteProcessSupport<Target, EntryPoint, Parameters> {
return result;
}
/**
* Override this method to use custom client socket factory.
*
* Default implementation returns null and uses {@link RMISocketFactory#getSocketFactory()}
* @return client socket factory to be used by this remote process support.
*/
protected RMIClientSocketFactory getClientSocketFactory() {
return null;
}
private ProcessListener getProcessListener(@NotNull final Pair<Target, Parameters> key) {
return new ProcessListener() {
@Override

View File

@@ -0,0 +1,42 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.openapi.externalSystem.util.wsl
import java.net.ConnectException
/**
* Connects to a remote server with retrying mechanism.
*
* If the [action] causes [ConnectException] it will be retried after [step] milliseconds.
* If all retries fail after [timeoutMillis] , the first [ConnectException] will be rethrown
*
* @param timeoutMillis The maximum timeout in milliseconds to wait for a successful connection.
* @param step The interval in milliseconds between retries. Default value is 100 milliseconds.
* @param action The action to connect to the remote server. Must not return null in case of success
* @return The result of the action.
* @throws ConnectException if unable to connect to the server within the specified timeout.
* @throws java.lang.RuntimeException if an unexpected failure occurs while connecting to the remote server.
*/
fun <T> connectRetrying(timeoutMillis: Long, step: Long = 100, action: () -> T): T {
val start = System.currentTimeMillis()
var result: T?
var lastException: Exception? = null
do {
result = try {
action()
}
catch (e: ConnectException) {
lastException = e
Thread.sleep(step)
null
}
} while (result == null && (System.currentTimeMillis() - start < timeoutMillis))
if (result == null) {
if (lastException != null) {
throw lastException
} else {
throw RuntimeException("Unexpected failure while connecting to remote server")
}
}
return result
}

View File

@@ -12,6 +12,7 @@ import com.intellij.openapi.diagnostic.trace
import com.intellij.openapi.externalSystem.model.task.ExternalSystemTaskId
import com.intellij.openapi.externalSystem.model.task.ExternalSystemTaskNotificationListener
import com.intellij.openapi.externalSystem.service.remote.MultiLoaderObjectInputStream
import com.intellij.openapi.externalSystem.util.wsl.connectRetrying
import com.intellij.openapi.progress.EmptyProgressIndicator
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.Key
@@ -20,8 +21,6 @@ import com.intellij.util.PlatformUtils
import com.intellij.util.io.BaseOutputReader
import com.intellij.util.text.nullize
import org.gradle.initialization.BuildEventConsumer
import org.gradle.internal.remote.internal.ConnectCompletion
import org.gradle.internal.remote.internal.ConnectException
import org.gradle.internal.remote.internal.RemoteConnection
import org.gradle.internal.remote.internal.inet.SocketInetAddress
import org.gradle.internal.remote.internal.inet.TcpOutgoingConnector
@@ -223,41 +222,6 @@ internal class GradleServerRunner(private val connection: TargetProjectConnectio
}
}
/**
* Connects to a remote server with retrying mechanism.
*
* @param timeoutMillis The maximum timeout in milliseconds to wait for a successful connection.
* @param step The interval in milliseconds between retries. Default value is 100 milliseconds.
* @param action The function to execute for connecting to the remote server.
* @return The result of the connection.
* @throws ConnectException if unable to connect to the server within the specified timeout.
* @throws java.lang.RuntimeException if an unexpected failure occurs while connecting to the remote server.
*/
private fun connectRetrying(timeoutMillis: Long, step: Long = 100, action: () -> ConnectCompletion): ConnectCompletion {
val start = System.currentTimeMillis()
var result: ConnectCompletion?
var lastException: Exception? = null
do {
result = try {
action()
}
catch (e: ConnectException) {
lastException = e
Thread.sleep(step)
null
}
} while (result == null && (System.currentTimeMillis() - start < timeoutMillis))
if (result == null) {
if (lastException != null) {
throw lastException
} else {
throw RuntimeException("Unexpected failure while connecting to remote server")
}
}
return result
}
private fun deserializeIfNeeded(value: Any?): Any? {
val bytes = value as? ByteArray ?: return value
val deserialized = MultiLoaderObjectInputStream(ByteArrayInputStream(bytes), classpathInferer.getClassloaders()).use {

View File

@@ -1,13 +1,19 @@
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package org.jetbrains.idea.maven.server.wsl
import com.intellij.execution.Executor
import com.intellij.execution.configurations.RunProfileState
import com.intellij.execution.wsl.WSLDistribution
import com.intellij.openapi.externalSystem.util.wsl.connectRetrying
import com.intellij.openapi.project.Project
import com.intellij.openapi.projectRoots.Sdk
import org.jetbrains.idea.maven.server.AbstractMavenServerRemoteProcessSupport
import org.jetbrains.idea.maven.server.WslMavenDistribution
import java.io.IOException
import java.net.ServerSocket
import java.net.Socket
import java.rmi.server.RMIClientSocketFactory
import java.rmi.server.RMISocketFactory
internal class WslMavenServerRemoteProcessSupport(private val myWslDistribution: WSLDistribution,
jdk: Sdk,
@@ -17,13 +23,34 @@ internal class WslMavenServerRemoteProcessSupport(private val myWslDistribution:
debugPort: Int?) : AbstractMavenServerRemoteProcessSupport(jdk, vmOptions,
mavenDistribution,
project, debugPort) {
override fun getRunProfileState(target: Any, configuration: Any, executor: Executor): RunProfileState {
return WslMavenCmdState(myWslDistribution, myJdk, myOptions, myDistribution as WslMavenDistribution, myDebugPort, myProject, remoteHost)
}
override fun getRemoteHost(): String = myWslDistribution.wslIpAddress.hostAddress
override fun getClientSocketFactory(): RMIClientSocketFactory {
val delegate = RMISocketFactory.getSocketFactory() ?: RMISocketFactory.getDefaultSocketFactory()
return RetryingSocketFactory(delegate)
}
override fun type() = "WSL"
}
/**
* This factory will retry sockets creation.
*
* WSL mirrored network has a visible delay (hundreds of ms) in ports becoming available on host machine. So we have to retry a couple of times.
*/
class RetryingSocketFactory(val delegate: RMISocketFactory) : RMISocketFactory() {
@Throws(IOException::class)
override fun createSocket(host: String, port: Int): Socket {
return connectRetrying(3000) { delegate.createSocket(host, port) }
}
@Throws(IOException::class)
override fun createServerSocket(port: Int): ServerSocket {
return delegate.createServerSocket(port)
}
}