//
//  ========================================================================
//  Copyright (c) 1995-2019 Mort Bay Consulting Pty. Ltd.
//  ------------------------------------------------------------------------
//  All rights reserved. This program and the accompanying materials
//  are made available under the terms of the Eclipse Public License v1.0
//  and Apache License v2.0 which accompanies this distribution.
//
//      The Eclipse Public License is available at
//      http://www.eclipse.org/legal/epl-v10.html
//
//      The Apache License v2.0 is available at
//      http://www.opensource.org/licenses/apache2.0.php
//
//  You may elect to redistribute this code under either of these licenses.
//  ========================================================================
//

package org.eclipse.jetty.websocket.tests.client;


import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.io.EofException;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.server.handler.DefaultHandler;
import org.eclipse.jetty.server.handler.HandlerList;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.CloseException;
import org.eclipse.jetty.websocket.api.MessageTooLargeException;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.WebSocketFrameListener;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.api.util.WSURI;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.common.CloseInfo;
import org.eclipse.jetty.websocket.common.OpCode;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.eclipse.jetty.websocket.tests.CloseTrackingEndpoint;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.time.Duration.ofSeconds;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class ClientCloseTest
{
    private Server server;
    private WebSocketClient client;
    private BlockingArrayQueue<ServerEndpoint> serverEndpoints = new BlockingArrayQueue<>();

    private Session confirmConnection(CloseTrackingEndpoint clientSocket, Future<Session> clientFuture) throws Exception
    {
        // Wait for client connect on via future
        Session session = clientFuture.get(30, SECONDS);

        try
        {
            // Send message from client to server
            final String echoMsg = "echo-test";
            Future<Void> testFut = clientSocket.getRemote().sendStringByFuture(echoMsg);

            // Wait for send future
            testFut.get(5, SECONDS);

            // Verify received message
            String recvMsg = clientSocket.messageQueue.poll(5, SECONDS);
            assertThat("Received message", recvMsg, is(echoMsg));

            // Verify that there are no errors
            assertThat("Error events", clientSocket.error.get(), nullValue());
        }
        finally
        {
            clientSocket.clearQueues();
        }

        return session;
    }

    @BeforeEach
    public void startClient() throws Exception
    {
        client = new WebSocketClient();
        client.setMaxTextMessageBufferSize(1024);
        client.getPolicy().setMaxTextMessageSize(1024);
        client.start();
    }

    @BeforeEach
    public void startServer() throws Exception
    {
        server = new Server();

        ServerConnector connector = new ServerConnector(server);
        connector.setPort(0);
        server.addConnector(connector);

        ServletContextHandler context = new ServletContextHandler();
        context.setContextPath("/");
        ServletHolder holder = new ServletHolder(new WebSocketServlet()
        {
            @Override
            public void configure(WebSocketServletFactory factory)
            {
                factory.getPolicy().setIdleTimeout(10000);
                factory.getPolicy().setMaxTextMessageSize(1024 * 1024 * 2);
                factory.setCreator((req,resp)->
                {
                    ServerEndpoint endpoint = new ServerEndpoint();
                    serverEndpoints.offer(endpoint);
                    return endpoint;
                });
            }
        });
        context.addServlet(holder, "/ws");

        HandlerList handlers = new HandlerList();
        handlers.addHandler(context);
        handlers.addHandler(new DefaultHandler());
        server.setHandler(handlers);

        server.start();
    }

    @AfterEach
    public void stopClient() throws Exception
    {
        client.stop();
    }

    @AfterEach
    public void stopServer() throws Exception
    {
        server.stop();
    }

    @Test
    public void testHalfClose() throws Exception
    {
        // Set client timeout
        final int timeout = 5000;
        client.setMaxIdleTimeout(timeout);

        ClientOpenSessionTracker clientSessionTracker = new ClientOpenSessionTracker(1);
        clientSessionTracker.addTo(client);

        // Client connects
        URI wsUri = WSURI.toWebsocket(server.getURI().resolve("/ws"));
        CloseTrackingEndpoint clientSocket = new CloseTrackingEndpoint();
        Future<Session> clientConnectFuture = client.connect(clientSocket, wsUri);

        try (Session session = confirmConnection(clientSocket, clientConnectFuture))
        {
            // client confirms connection via echo

            // client sends close frame (code 1000, normal)
            final String origCloseReason = "send-more-frames";
            clientSocket.getSession().close(StatusCode.NORMAL, origCloseReason);

            // Verify received messages
            String recvMsg = clientSocket.messageQueue.poll(5, SECONDS);
            assertThat("Received message 1", recvMsg, is("Hello"));
            recvMsg = clientSocket.messageQueue.poll(5, SECONDS);
            assertThat("Received message 2", recvMsg, is("World"));

            // Verify that there are no errors
            assertThat("Error events", clientSocket.error.get(), nullValue());

            // client close event on ws-endpoint
            clientSocket.assertReceivedCloseEvent(timeout, is(StatusCode.NORMAL), containsString(""));
        }

        clientSessionTracker.assertClosedProperly(client);
    }

    @Test
    public void testMessageTooLargeException() throws Exception
    {
        // Set client timeout
        final int timeout = 3000;
        client.setMaxIdleTimeout(timeout);

        ClientOpenSessionTracker clientSessionTracker = new ClientOpenSessionTracker(1);
        clientSessionTracker.addTo(client);

        // Client connects
        URI wsUri = WSURI.toWebsocket(server.getURI().resolve("/ws"));
        CloseTrackingEndpoint clientSocket = new CloseTrackingEndpoint();
        Future<Session> clientConnectFuture = client.connect(clientSocket, wsUri);

        try (Session session = confirmConnection(clientSocket, clientConnectFuture))
        {
            // client confirms connection via echo

            session.getRemote().sendString("too-large-message");

            clientSocket.assertReceivedCloseEvent(timeout, is(StatusCode.MESSAGE_TOO_LARGE), containsString("exceeds maximum size"));

            // client should have noticed the error
            assertThat("OnError Latch", clientSocket.errorLatch.await(2, SECONDS), is(true));
            assertThat("OnError", clientSocket.error.get(), instanceOf(MessageTooLargeException.class));
        }

        // client triggers close event on client ws-endpoint
        clientSessionTracker.assertClosedProperly(client);
    }

    @Test
    public void testRemoteDisconnect() throws Exception
    {
        // Set client timeout
        final int clientTimeout = 1000;
        client.setMaxIdleTimeout(clientTimeout);

        ClientOpenSessionTracker clientSessionTracker = new ClientOpenSessionTracker(1);
        clientSessionTracker.addTo(client);

        // Client connects
        URI wsUri = WSURI.toWebsocket(server.getURI().resolve("/ws"));
        CloseTrackingEndpoint clientSocket = new CloseTrackingEndpoint();
        Future<Session> clientConnectFuture = client.connect(clientSocket, wsUri);

        try (Session ignored = confirmConnection(clientSocket, clientConnectFuture))
        {
            // client confirms connection via echo

            // client sends close frame (triggering server connection abort)
            final String origCloseReason = "abort";
            clientSocket.getSession().close(StatusCode.NORMAL, origCloseReason);

            // client reads -1 (EOF)
            // client triggers close event on client ws-endpoint
            // assert - close code==1006 (abnormal) or code==1001 (shutdown)
            clientSocket.assertReceivedCloseEvent(clientTimeout * 2, anyOf(is(StatusCode.SHUTDOWN), is(StatusCode.ABNORMAL)));
        }

        clientSessionTracker.assertClosedProperly(client);
    }

    @Test
    public void testServerNoCloseHandshake() throws Exception
    {
        // Set client timeout
        final int clientTimeout = 1000;
        client.setMaxIdleTimeout(clientTimeout);

        ClientOpenSessionTracker clientSessionTracker = new ClientOpenSessionTracker(1);
        clientSessionTracker.addTo(client);

        // Client connects
        URI wsUri = WSURI.toWebsocket(server.getURI().resolve("/ws"));
        CloseTrackingEndpoint clientSocket = new CloseTrackingEndpoint();
        Future<Session> clientConnectFuture = client.connect(clientSocket, wsUri);

        try (Session ignored = confirmConnection(clientSocket, clientConnectFuture))
        {
            // client confirms connection via echo

            // client sends close frame
            final String origCloseReason = "sleep|5000";
            clientSocket.getSession().close(StatusCode.NORMAL, origCloseReason);

            // client close should occur
            clientSocket.assertReceivedCloseEvent(clientTimeout * 2,
                    is(StatusCode.SHUTDOWN),
                    containsString("timeout"));

            // client idle timeout triggers close event on client ws-endpoint
            assertThat("OnError Latch", clientSocket.errorLatch.await(2, SECONDS), is(true));
            assertThat("OnError", clientSocket.error.get(), instanceOf(CloseException.class));
            assertThat("OnError.cause", clientSocket.error.get().getCause(), instanceOf(TimeoutException.class));
        }

        clientSessionTracker.assertClosedProperly(client);
    }

    @Test
    public void testStopLifecycle() throws Exception
    {
        // Set client timeout
        final int timeout = 1000;
        client.setMaxIdleTimeout(timeout);

        int sessionCount = 3;
        ClientOpenSessionTracker clientSessionTracker = new ClientOpenSessionTracker(sessionCount);
        clientSessionTracker.addTo(client);

        URI wsUri = WSURI.toWebsocket(server.getURI().resolve("/ws"));
        List<CloseTrackingEndpoint> clientSockets = new ArrayList<>();

        // Open Multiple Clients
        for (int i = 0; i < sessionCount; i++)
        {
            // Client Request Upgrade
            CloseTrackingEndpoint clientSocket = new CloseTrackingEndpoint();
            clientSockets.add(clientSocket);
            Future<Session> clientConnectFuture = client.connect(clientSocket, wsUri);

            // client confirms connection via echo
            confirmConnection(clientSocket, clientConnectFuture);
        }

        assertTimeoutPreemptively(ofSeconds(5), () -> {
            // client lifecycle stop (the meat of this test)
            client.stop();
        });

        // clients disconnect
        for (int i = 0; i < sessionCount; i++)
        {
            clientSockets.get(i).assertReceivedCloseEvent(timeout, is(StatusCode.ABNORMAL), containsString("Disconnected"));
        }

        // ensure all Sessions are gone. connections are gone. etc. (client and server)
        // ensure ConnectionListener onClose is called 3 times
        clientSessionTracker.assertClosedProperly(client);
    }

    @Test
    public void testWriteException() throws Exception
    {
        // Set client timeout
        final int timeout = 2000;
        client.setMaxIdleTimeout(timeout);

        ClientOpenSessionTracker clientSessionTracker = new ClientOpenSessionTracker(1);
        clientSessionTracker.addTo(client);

        // Client connects
        URI wsUri = WSURI.toWebsocket(server.getURI().resolve("/ws"));
        CloseTrackingEndpoint clientSocket = new CloseTrackingEndpoint();
        Future<Session> clientConnectFuture = client.connect(clientSocket, wsUri);

        // client confirms connection via echo
        confirmConnection(clientSocket, clientConnectFuture);

        try
        {
            // Block on the server so that the server does not detect a read failure
            clientSocket.getSession().getRemote().sendString("block");

            // setup client endpoint for write failure (test only)
            EndPoint endp = clientSocket.getEndPoint();
            endp.shutdownOutput();

            // client enqueue close frame
            // should result in a client write failure
            final String origCloseReason = "Normal Close from Client";
            clientSocket.getSession().close(StatusCode.NORMAL, origCloseReason);

            assertThat("OnError Latch", clientSocket.errorLatch.await(2, SECONDS), is(true));
            assertThat("OnError", clientSocket.error.get(), instanceOf(EofException.class));

            // client triggers close event on client ws-endpoint
            // assert - close code==1006 (abnormal)
            clientSocket.assertReceivedCloseEvent(timeout, is(StatusCode.ABNORMAL), null);
            clientSessionTracker.assertClosedProperly(client);

            assertThat(serverEndpoints.size(), is(1));
        }
        finally
        {
            for (ServerEndpoint endpoint : serverEndpoints)
                endpoint.block.countDown();
        }
    }

    public static class ServerEndpoint implements WebSocketFrameListener, WebSocketListener
    {
        private static final Logger LOG = Log.getLogger(ServerEndpoint.class);
        private Session session;
        CountDownLatch block = new CountDownLatch(1);

        @Override
        public void onWebSocketBinary(byte[] payload, int offset, int len)
        {
        }

        @Override
        public void onWebSocketText(String message)
        {
            try
            {
                if (message.equals("too-large-message"))
                {
                    // send extra large message
                    byte[] buf = new byte[1024 * 1024];
                    Arrays.fill(buf, (byte) 'x');
                    String bigmsg = new String(buf, UTF_8);
                    session.getRemote().sendString(bigmsg);
                }
                else if (message.equals("block"))
                {
                    LOG.debug("blocking");
                    assertTrue(block.await(5, TimeUnit.MINUTES));
                    LOG.debug("unblocked");
                }
                else
                {
                    // simple echo
                    session.getRemote().sendString(message);
                }
            }
            catch (Throwable t)
            {
                LOG.debug(t);
                throw new RuntimeException(t);
            }
        }

        @Override
        public void onWebSocketClose(int statusCode, String reason)
        {
        }

        @Override
        public void onWebSocketConnect(Session session)
        {
            this.session = session;
        }

        @Override
        public void onWebSocketError(Throwable cause)
        {
            if (LOG.isDebugEnabled())
                LOG.debug("onWebSocketError(): ", cause);
        }

        @Override
        public void onWebSocketFrame(Frame frame)
        {
            if (frame.getOpCode() == OpCode.CLOSE)
            {
                CloseInfo closeInfo = new CloseInfo(frame);
                String reason = closeInfo.getReason();

                if (reason.equals("send-more-frames"))
                {
                    try
                    {
                        session.getRemote().sendString("Hello");
                        session.getRemote().sendString("World");
                    }
                    catch (Throwable ignore)
                    {
                        LOG.debug("OOPS", ignore);
                    }
                }
                else if (reason.equals("abort"))
                {
                    try
                    {
                        SECONDS.sleep(1);
                        LOG.info("Server aborting session abruptly");
                        session.disconnect();
                    }
                    catch (Throwable ignore)
                    {
                        LOG.ignore(ignore);
                    }
                }
                else if (reason.startsWith("sleep|"))
                {
                    int idx = reason.indexOf('|');
                    int timeMs = Integer.parseInt(reason.substring(idx + 1));
                    try
                    {
                        LOG.info("Server Sleeping for {} ms", timeMs);
                        TimeUnit.MILLISECONDS.sleep(timeMs);
                    }
                    catch (InterruptedException ignore)
                    {
                        LOG.ignore(ignore);
                    }
                }
            }
        }
    }
}