diff app/src/main/java/ch/ethz/ssh2/transport/TransportManager.java @ 438:d29cce60f393

migrate from Eclipse to Android Studio
author Carl Byington <carl@five-ten-sg.com>
date Thu, 03 Dec 2015 11:23:55 -0800
parents src/ch/ethz/ssh2/transport/TransportManager.java@126af684034e
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/app/src/main/java/ch/ethz/ssh2/transport/TransportManager.java	Thu Dec 03 11:23:55 2015 -0800
@@ -0,0 +1,505 @@
+/*
+ * Copyright (c) 2006-2013 Christian Plattner. All rights reserved.
+ * Please refer to the LICENSE.txt for licensing details.
+ */
+
+package ch.ethz.ssh2.transport;
+
+import java.io.IOException;
+import java.io.InterruptedIOException;
+import java.net.Socket;
+import java.security.KeyPair;
+import java.util.ArrayList;
+import java.util.List;
+
+import ch.ethz.ssh2.ConnectionInfo;
+import ch.ethz.ssh2.ConnectionMonitor;
+import ch.ethz.ssh2.DHGexParameters;
+import ch.ethz.ssh2.PacketTypeException;
+import ch.ethz.ssh2.compression.Compressor;
+import ch.ethz.ssh2.crypto.CryptoWishList;
+import ch.ethz.ssh2.crypto.cipher.BlockCipher;
+import ch.ethz.ssh2.crypto.digest.MAC;
+import ch.ethz.ssh2.log.Logger;
+import ch.ethz.ssh2.packets.PacketDisconnect;
+import ch.ethz.ssh2.packets.Packets;
+import ch.ethz.ssh2.packets.TypesReader;
+
+/**
+ * Yes, the "standard" is a big mess. On one side, the say that arbitrary channel
+ * packets are allowed during kex exchange, on the other side we need to blindly
+ * ignore the next _packet_ if the KEX guess was wrong. Where do we know from that
+ * the next packet is not a channel data packet? Yes, we could check if it is in
+ * the KEX range. But the standard says nothing about this. The OpenSSH guys
+ * block local "normal" traffic during KEX. That's fine - however, they assume
+ * that the other side is doing the same. During re-key, if they receive traffic
+ * other than KEX, they become horribly irritated and kill the connection. Since
+ * we are very likely going to communicate with OpenSSH servers, we have to play
+ * the same game - even though we could do better.
+ *
+ * @author Christian Plattner
+ * @version $Id: TransportManager.java 161 2014-05-01 18:01:55Z dkocher@sudo.ch $
+ */
+public abstract class TransportManager {
+    protected static final Logger log = Logger.getLogger(TransportManager.class);
+
+    private static final class HandlerEntry {
+        MessageHandler mh;
+        int low;
+        int high;
+    }
+
+    /**
+     * Advertised maximum SSH packet size that the other side can send to us.
+     */
+    public static final int MAX_PACKET_SIZE = 64 * 1024;
+
+    private final List<AsynchronousEntry> asynchronousQueue
+        = new ArrayList<AsynchronousEntry>();
+
+    private Thread asynchronousThread = null;
+    private boolean asynchronousPending = false;
+
+    private Socket socket;
+
+    protected TransportManager(final Socket socket) {
+        this.socket = socket;
+    }
+
+    private static final class AsynchronousEntry {
+        public byte[] message;
+
+        public AsynchronousEntry(byte[] message) {
+            this.message = message;
+        }
+    }
+
+    private final class AsynchronousWorker implements Runnable {
+        public void run() {
+            while (true) {
+                final AsynchronousEntry item;
+
+                synchronized (asynchronousQueue) {
+                    if (asynchronousQueue.size() == 0) {
+                        // Only now we may reset the flag, since we are sure that all queued items
+                        // have been sent (there is a slight delay between de-queuing and sending,
+                        // this is why we need this flag! See code below. Sending takes place outside
+                        // of this lock, this is why a test for size()==0 (from another thread) does not ensure
+                        // that all messages have been sent.
+                        asynchronousPending = false;
+                        // Notify any senders that they can proceed, all async messages have been delivered
+                        asynchronousQueue.notifyAll();
+
+                        // After the queue is empty for about 2 seconds, stop this thread
+                        try {
+                            asynchronousQueue.wait(2000);
+                        }
+                        catch (InterruptedException ignore) {
+                            //
+                        }
+
+                        if (asynchronousQueue.size() == 0) {
+                            asynchronousThread = null;
+                            return;
+                        }
+                    }
+
+                    item = asynchronousQueue.remove(0);
+                }
+
+                try {
+                    sendMessageImmediate(item.message);
+                }
+                catch (IOException e) {
+                    // There is no point in handling it - it simply means that the connection has a problem and we should stop
+                    // sending asynchronously messages. We do not need to signal that we have exited (asynchronousThread = null):
+                    // further messages in the queue cannot be sent by this or any other thread.
+                    // Other threads will sooner or later (when receiving or sending the next message) get the
+                    // same IOException and get to the same conclusion.
+                    log.warning(e.getMessage());
+                    return;
+                }
+            }
+        }
+    }
+
+    private final Object connectionSemaphore = new Object();
+
+    private boolean flagKexOngoing;
+
+    private boolean   connectionClosed;
+    private Throwable reasonClosedCause;
+
+    private TransportConnection tc;
+    private KexManager km;
+
+    private final List<HandlerEntry> messageHandlers = new ArrayList<HandlerEntry>();
+
+    private List<ConnectionMonitor> connectionMonitors = new ArrayList<ConnectionMonitor>();
+    boolean monitorsWereInformed = false;
+
+    protected void init(TransportConnection tc, KexManager km) {
+        this.tc = tc;
+        this.km = km;
+    }
+
+    public int getPacketOverheadEstimate() {
+        return tc.getPacketOverheadEstimate();
+    }
+
+    public ConnectionInfo getConnectionInfo(int kexNumber) throws IOException {
+        return km.getOrWaitForConnectionInfo(kexNumber);
+    }
+
+    public Throwable getReasonClosedCause() {
+        synchronized (connectionSemaphore) {
+            return reasonClosedCause;
+        }
+    }
+
+    public byte[] getSessionIdentifier() {
+        return km.sessionId;
+    }
+
+    public void close(Throwable cause, boolean useDisconnectPacket) {
+        if (useDisconnectPacket == false) {
+            // OK, hard shutdown - do not acquire the semaphore,
+            // perhaps somebody is inside (and waits until
+            // the remote side is ready to accept new data).
+            try {
+                socket.close();
+            }
+            catch (IOException ignore) {
+            }
+
+            // OK, whoever tried to send data, should now agree that
+            // there is no point in further waiting =)
+            // It is safe now to acquire the semaphore.
+        }
+
+        synchronized (connectionSemaphore) {
+            if (!connectionClosed) {
+                if (useDisconnectPacket == true) {
+                    try {
+                        if (tc != null)
+                            tc.sendMessage(new PacketDisconnect(PacketDisconnect.Reason.SSH_DISCONNECT_BY_APPLICATION, "").getPayload());
+                    }
+                    catch (IOException ignore) {
+                    }
+
+                    try {
+                        socket.close();
+                    }
+                    catch (IOException ignore) {
+                    }
+                }
+
+                connectionClosed = true;
+                reasonClosedCause = cause;
+            }
+
+            connectionSemaphore.notifyAll();
+        }
+
+        // check if we need to inform the monitors
+        List<ConnectionMonitor> monitors = null;
+
+        synchronized (this) {
+            // Short term lock to protect "connectionMonitors"
+            // and "monitorsWereInformed"
+            // (they may be modified concurrently)
+            if (monitorsWereInformed == false) {
+                monitorsWereInformed = true;
+                monitors = new ArrayList<ConnectionMonitor>(connectionMonitors);
+            }
+        }
+
+        if (monitors != null) {
+            for (ConnectionMonitor cmon : monitors) {
+                try {
+                    cmon.connectionLost(reasonClosedCause);
+                }
+                catch (Exception ignore) {
+                }
+            }
+        }
+    }
+
+    protected void startReceiver() throws IOException {
+        final Thread receiveThread = new Thread(new Runnable() {
+            public void run() {
+                try {
+                    receiveLoop();
+                    // Can only exit with exception
+                }
+                catch (IOException e) {
+                    close(e, false);
+                    log.warning(e.getMessage());
+
+                    // Tell all handlers that it is time to say goodbye
+                    if (km != null) {
+                        km.handleFailure(e);
+                    }
+
+                    for (HandlerEntry he : messageHandlers) {
+                        he.mh.handleFailure(e);
+                    }
+                }
+
+                if (log.isDebugEnabled()) {
+                    log.debug("Receive thread: back from receiveLoop");
+                }
+            }
+        });
+        receiveThread.setName("Transport Manager");
+        receiveThread.setDaemon(true);
+        receiveThread.start();
+    }
+
+    public void registerMessageHandler(MessageHandler mh, int low, int high) {
+        HandlerEntry he = new HandlerEntry();
+        he.mh = mh;
+        he.low = low;
+        he.high = high;
+
+        synchronized (messageHandlers) {
+            messageHandlers.add(he);
+        }
+    }
+
+    public void removeMessageHandler(MessageHandler handler) {
+        synchronized (messageHandlers) {
+            for (int i = 0; i < messageHandlers.size(); i++) {
+                HandlerEntry he = messageHandlers.get(i);
+
+                if (he.mh == handler) {
+                    messageHandlers.remove(i);
+                    break;
+                }
+            }
+        }
+    }
+
+    public void sendKexMessage(byte[] msg) throws IOException {
+        synchronized (connectionSemaphore) {
+            if (connectionClosed) {
+                throw(IOException) new IOException("Sorry, this connection is closed.").initCause(reasonClosedCause);
+            }
+
+            flagKexOngoing = true;
+
+            try {
+                tc.sendMessage(msg);
+            }
+            catch (IOException e) {
+                close(e, false);
+                throw e;
+            }
+        }
+    }
+
+    public void kexFinished() throws IOException {
+        synchronized (connectionSemaphore) {
+            flagKexOngoing = false;
+            connectionSemaphore.notifyAll();
+        }
+    }
+
+    /**
+     * @param cwl   Crypto wishlist
+     * @param dhgex Diffie-hellman group exchange
+     * @param dsa   may be null if this is a client connection
+     * @param rsa   may be null if this is a client connection
+     * @throws IOException
+     */
+    public void forceKeyExchange(CryptoWishList cwl, DHGexParameters dhgex, KeyPair dsa, KeyPair rsa, KeyPair ec)
+    throws IOException {
+        synchronized (connectionSemaphore) {
+            if (connectionClosed) {
+                throw(IOException) new IOException("Sorry, this connection is closed.").initCause(reasonClosedCause);
+            }
+        }
+
+        km.initiateKEX(cwl, dhgex, dsa, rsa, ec);
+    }
+
+    public void changeRecvCipher(BlockCipher bc, MAC mac) {
+        tc.changeRecvCipher(bc, mac);
+    }
+
+    public void changeSendCipher(BlockCipher bc, MAC mac) {
+        tc.changeSendCipher(bc, mac);
+    }
+
+    public void changeRecvCompression(Compressor comp) {
+        tc.changeRecvCompression(comp);
+    }
+
+    public void changeSendCompression(Compressor comp) {
+        tc.changeSendCompression(comp);
+    }
+
+    public void sendAsynchronousMessage(byte[] msg) throws IOException {
+        synchronized (asynchronousQueue) {
+            asynchronousQueue.add(new AsynchronousEntry(msg));
+            asynchronousPending = true;
+
+            /* This limit should be flexible enough. We need this, otherwise the peer
+             * can flood us with global requests (and other stuff where we have to reply
+             * with an asynchronous message) and (if the server just sends data and does not
+             * read what we send) this will probably put us in a low memory situation
+             * (our send queue would grow and grow and...) */
+
+            if (asynchronousQueue.size() > 100) {
+                throw new IOException("The peer is not consuming our asynchronous replies.");
+            }
+
+            // Check if we have an asynchronous sending thread
+            if (asynchronousThread == null) {
+                asynchronousThread = new Thread(new AsynchronousWorker());
+                asynchronousThread.setDaemon(true);
+                asynchronousThread.start();
+                // The thread will stop after 2 seconds of inactivity (i.e., empty queue)
+            }
+
+            asynchronousQueue.notifyAll();
+        }
+    }
+
+    public void setConnectionMonitors(List<ConnectionMonitor> monitors) {
+        synchronized (this) {
+            connectionMonitors = new ArrayList<ConnectionMonitor>(monitors);
+        }
+    }
+
+    /**
+     * Send a message but ensure that all queued messages are being sent first.
+     *
+     * @param msg Message
+     * @throws IOException
+     */
+    public void sendMessage(byte[] msg) throws IOException {
+        synchronized (asynchronousQueue) {
+            while (asynchronousPending) {
+                try {
+                    asynchronousQueue.wait();
+                }
+                catch (InterruptedException e) {
+                    throw new InterruptedIOException(e.getMessage());
+                }
+            }
+        }
+
+        sendMessageImmediate(msg);
+    }
+
+    /**
+     * Send message, ignore queued async messages that have not been delivered yet.
+     * Will be called directly from the asynchronousThread thread.
+     *
+     * @param msg Message
+     * @throws IOException
+     */
+    public void sendMessageImmediate(byte[] msg) throws IOException {
+        synchronized (connectionSemaphore) {
+            while (true) {
+                if (connectionClosed) {
+                    throw(IOException) new IOException("Sorry, this connection is closed.").initCause(reasonClosedCause);
+                }
+
+                if (!flagKexOngoing) {
+                    break;
+                }
+
+                try {
+                    connectionSemaphore.wait();
+                }
+                catch (InterruptedException e) {
+                    throw new InterruptedIOException(e.getMessage());
+                }
+            }
+
+            try {
+                tc.sendMessage(msg);
+            }
+            catch (IOException e) {
+                close(e, false);
+                throw e;
+            }
+        }
+    }
+
+    private void receiveLoop() throws IOException {
+        while (true) {
+            final byte[] buffer = new byte[MAX_PACKET_SIZE];
+            final int length = tc.receiveMessage(buffer, 0, buffer.length);
+            final byte[] packet = new byte[length];
+            System.arraycopy(buffer, 0, packet, 0, length);
+            final int type = packet[0] & 0xff;
+            log.debug(String.format("transport manager receive loop type %d", type));
+
+            switch (type) {
+                case Packets.SSH_MSG_IGNORE:
+                    break;
+
+                case Packets.SSH_MSG_DEBUG: {
+                        TypesReader tr = new TypesReader(packet);
+                        tr.readByte();
+                        // always_display
+                        tr.readBoolean();
+                        String message = tr.readString();
+
+                        if (log.isDebugEnabled()) {
+                            log.debug(String.format("Debug message from remote: '%s'", message));
+                        }
+
+                        break;
+                    }
+
+                case Packets.SSH_MSG_UNIMPLEMENTED:
+                    throw new PacketTypeException(type);
+
+                case Packets.SSH_MSG_DISCONNECT: {
+                        final PacketDisconnect disconnect = new PacketDisconnect(packet);
+                        throw new DisconnectException(disconnect.getReason(), disconnect.getMessage());
+                    }
+
+                case Packets.SSH_MSG_KEXINIT:
+                case Packets.SSH_MSG_NEWKEYS:
+                case Packets.SSH_MSG_KEXDH_INIT:
+                case Packets.SSH_MSG_KEXDH_REPLY:
+                case Packets.SSH_MSG_KEX_DH_GEX_REQUEST:
+                case Packets.SSH_MSG_KEX_DH_GEX_INIT:
+                case Packets.SSH_MSG_KEX_DH_GEX_REPLY:
+                    // Is it a KEX Packet
+                    km.handleMessage(packet);
+                    break;
+
+                case Packets.SSH_MSG_USERAUTH_SUCCESS:
+                    tc.startCompression();
+
+                // Continue with message handlers
+                default:
+                    boolean handled = false;
+
+                    for (HandlerEntry handler : messageHandlers) {
+                        if ((handler.low <= type) && (type <= handler.high)) {
+                            handler.mh.handleMessage(packet);
+                            handled = true;
+                            break;
+                        }
+                    }
+
+                    if (!handled) {
+                        throw new PacketTypeException(type);
+                    }
+
+                    break;
+            }
+
+            if (log.isDebugEnabled()) {
+                log.debug(String.format("Handled packet %d", type));
+            }
+        }
+    }
+}