/*
* SPDX-FileCopyrightText: 2021-2023 The Refinery Authors
*
* SPDX-License-Identifier: EPL-2.0
*/
package tools.refinery.language.web;
import org.eclipse.jetty.ee10.servlet.ServletContextHandler;
import org.eclipse.jetty.ee10.servlet.ServletHolder;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.api.exceptions.UpgradeException;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.xtext.testing.GlobalRegistries;
import org.eclipse.xtext.testing.GlobalRegistries.GlobalStateMemento;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import tools.refinery.language.web.tests.WebSocketIntegrationTestClient;
import tools.refinery.language.web.xtext.servlet.XtextStatusCode;
import tools.refinery.language.web.xtext.servlet.XtextWebSocketServlet;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.URI;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;
import static org.junit.jupiter.api.Assertions.assertThrows;
class ProblemWebSocketServletIntegrationTest {
private static final String HOSTNAME = "127.0.0.1";
private static final String SERVLET_URI = "/xtext-service";
private GlobalStateMemento stateBeforeInjectorCreation;
private TestInfo testInfo;
private int serverPort;
private Server server;
private WebSocketClient client;
@BeforeEach
void beforeEach(TestInfo testInfo) throws Exception {
this.testInfo = testInfo;
// Find a free port for running the test. See e.g., https://stackoverflow.com/a/65937797
try (var serverSocket = new ServerSocket()) {
serverSocket.setReuseAddress(true);
serverSocket.bind(new InetSocketAddress(HOSTNAME, 0));
serverPort = serverSocket.getLocalPort();
}
stateBeforeInjectorCreation = GlobalRegistries.makeCopyOfGlobalState();
client = new WebSocketClient();
client.start();
}
@AfterEach
void afterEach() throws Exception {
client.stop();
client = null;
if (server != null) {
server.stop();
server = null;
}
stateBeforeInjectorCreation.restoreGlobalState();
stateBeforeInjectorCreation = null;
}
@Test
void updateTest() {
startServer(null);
var clientSocket = new UpdateTestClient();
var session = connect(clientSocket, null, XtextWebSocketServlet.XTEXT_SUBPROTOCOL_V1);
assertThat(session.getUpgradeResponse().getAcceptedSubProtocol(),
equalTo(XtextWebSocketServlet.XTEXT_SUBPROTOCOL_V1));
clientSocket.waitForTestResult();
assertThat(clientSocket.getCloseStatusCode(), equalTo(StatusCode.NORMAL));
var responses = clientSocket.getResponses();
assertThat(responses, hasSize(5));
assertThat(responses.get(0), equalTo("{\"id\":\"foo\",\"response\":{\"stateId\":\"-80000000\"}}"));
assertThat(responses.get(1), startsWith(
"{\"resource\":\"test.problem\",\"stateId\":\"-80000000\",\"service\":\"highlight\"," +
"\"push\":{\"regions\":["));
assertThat(responses.get(2), equalTo(
"{\"resource\":\"test.problem\",\"stateId\":\"-80000000\",\"service\":\"validate\"," +
"\"push\":{\"issues\":[]}}"));
assertThat(responses.get(3), equalTo("{\"id\":\"bar\",\"response\":{\"stateId\":\"-7fffffff\"}}"));
assertThat(responses.get(4), startsWith(
"{\"resource\":\"test.problem\",\"stateId\":\"-7fffffff\",\"service\":\"highlight\"," +
"\"push\":{\"regions\":["));
}
@WebSocket
public static class UpdateTestClient extends WebSocketIntegrationTestClient {
@Override
protected void arrange(Session session, int responsesReceived) {
switch (responsesReceived) {
case 0 -> session.sendText(
"{\"id\":\"foo\",\"request\":{\"resource\":\"test.problem\",\"serviceType\":\"update\"," +
"\"fullText\":\"class Person.\n\"}}",
Callback.NOOP
);
case 3 -> //noinspection TextBlockMigration
session.sendText(
"{\"id\":\"bar\",\"request\":{\"resource\":\"test.problem\",\"serviceType\":\"update\"," +
"\"requiredStateId\":\"-80000000\",\"deltaText\":\"indiv q.\nnode(q).\n\"," +
"\"deltaOffset\":\"0\",\"deltaReplaceLength\":\"0\"}}",
Callback.NOOP
);
case 5 -> session.close();
}
}
}
@Test
void badSubProtocolTest() {
startServer(null);
var clientSocket = new CloseImmediatelyTestClient();
var session = connect(clientSocket, null, "");
assertThat(session.getUpgradeResponse().getAcceptedSubProtocol(), equalTo(null));
clientSocket.waitForTestResult();
assertThat(clientSocket.getCloseStatusCode(), equalTo(StatusCode.NORMAL));
}
@WebSocket
public static class CloseImmediatelyTestClient extends WebSocketIntegrationTestClient {
@Override
protected void arrange(Session session, int responsesReceived) {
session.close();
}
}
@Test
void subProtocolNegotiationTest() {
startServer(null);
var clientSocket = new CloseImmediatelyTestClient();
try (var session = connect(clientSocket, null, "",
XtextWebSocketServlet.XTEXT_SUBPROTOCOL_V1)) {
assertThat(session.getUpgradeResponse().getAcceptedSubProtocol(),
equalTo(XtextWebSocketServlet.XTEXT_SUBPROTOCOL_V1));
clientSocket.waitForTestResult();
assertThat(clientSocket.getCloseStatusCode(), equalTo(StatusCode.NORMAL));
}
}
@Test
void invalidJsonTest() {
startServer(null);
var clientSocket = new InvalidJsonTestClient();
try (var ignored = connect(clientSocket, null, XtextWebSocketServlet.XTEXT_SUBPROTOCOL_V1)) {
clientSocket.waitForTestResult();
assertThat(clientSocket.getCloseStatusCode(), equalTo(XtextStatusCode.INVALID_JSON));
}
}
@WebSocket
public static class InvalidJsonTestClient extends WebSocketIntegrationTestClient {
@Override
protected void arrange(Session session, int responsesReceived) {
session.sendText("", Callback.NOOP);
}
}
@ParameterizedTest(name = "validOriginTest(\"{0}\")")
@ValueSource(strings = {"https://refinery.example", "https://refinery.example:443", "HTTPS://REFINERY.EXAMPLE"})
void validOriginTest(String origin) {
startServer("https://refinery.example,https://refinery.example:443");
var clientSocket = new CloseImmediatelyTestClient();
try (var ignored = connect(clientSocket, origin, XtextWebSocketServlet.XTEXT_SUBPROTOCOL_V1)) {
clientSocket.waitForTestResult();
assertThat(clientSocket.getCloseStatusCode(), equalTo(StatusCode.NORMAL));
}
}
@Test
void invalidOriginTest() {
startServer("https://refinery.example,https://refinery.example:443");
var clientSocket = new CloseImmediatelyTestClient();
// We have to put the close statement also into the lambda to ensure that the session is always closed.
@SuppressWarnings("squid:S5778")
var exception = assertThrows(CompletionException.class,
() -> {
var session = connect(clientSocket, "https://invalid.example",
XtextWebSocketServlet.XTEXT_SUBPROTOCOL_V1);
session.close();
});
var innerException = exception.getCause();
assertThat(innerException, instanceOf(UpgradeException.class));
assertThat(((UpgradeException) innerException).getResponseStatusCode(), equalTo(HttpStatus.FORBIDDEN_403));
}
private void startServer(String allowedOrigins) {
var testName = getClass().getSimpleName() + "-" + testInfo.getDisplayName();
var listenAddress = new InetSocketAddress(HOSTNAME, serverPort);
server = new Server(listenAddress);
((QueuedThreadPool) server.getThreadPool()).setName(testName);
var handler = new ServletContextHandler();
var holder = new ServletHolder(ProblemWebSocketServlet.class);
if (allowedOrigins != null) {
holder.setInitParameter(ProblemWebSocketServlet.ALLOWED_ORIGINS_INIT_PARAM, allowedOrigins);
}
handler.addServlet(holder, SERVLET_URI);
JettyWebSocketServletContainerInitializer.configure(handler, null);
server.setHandler(handler);
try {
server.start();
} catch (Exception e) {
throw new RuntimeException("Failed to start websocket server");
}
}
private Session connect(Object webSocketClient, String origin, String... subProtocols) {
var upgradeRequest = new ClientUpgradeRequest();
if (origin != null) {
upgradeRequest.setHeader(HttpHeader.ORIGIN.name(), origin);
}
upgradeRequest.setSubProtocols(subProtocols);
CompletableFuture sessionFuture;
try {
sessionFuture = client.connect(webSocketClient,
URI.create("ws://%s:%d%s".formatted(HOSTNAME, serverPort, SERVLET_URI)),
upgradeRequest);
} catch (IOException e) {
throw new AssertionError("Unexpected exception while connection to websocket", e);
}
return sessionFuture.join();
}
}