Mercurial > 510Connectbot
diff src/ch/ethz/ssh2/transport/TransportManager.java @ 273:91a31873c42a ganymed
start conversion from trilead to ganymed
author | Carl Byington <carl@five-ten-sg.com> |
---|---|
date | Fri, 18 Jul 2014 11:21:46 -0700 |
parents | |
children | d7e088fa2123 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/ch/ethz/ssh2/transport/TransportManager.java Fri Jul 18 11:21:46 2014 -0700 @@ -0,0 +1,469 @@ +/* + * Copyright (c) 2006-2013 Christian Plattner. All rights reserved. + * Please refer to the LICENSE.txt for licensing details. + */ + +package ch.ethz.ssh2.transport; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.Socket; +import java.util.ArrayList; +import java.util.List; + +import ch.ethz.ssh2.ConnectionInfo; +import ch.ethz.ssh2.ConnectionMonitor; +import ch.ethz.ssh2.DHGexParameters; +import ch.ethz.ssh2.PacketTypeException; +import ch.ethz.ssh2.compression.Compressor; +import ch.ethz.ssh2.crypto.CryptoWishList; +import ch.ethz.ssh2.crypto.cipher.BlockCipher; +import ch.ethz.ssh2.crypto.digest.MAC; +import ch.ethz.ssh2.log.Logger; +import ch.ethz.ssh2.packets.PacketDisconnect; +import ch.ethz.ssh2.packets.Packets; +import ch.ethz.ssh2.packets.TypesReader; +import ch.ethz.ssh2.signature.DSAPrivateKey; +import ch.ethz.ssh2.signature.RSAPrivateKey; + +/** + * Yes, the "standard" is a big mess. On one side, the say that arbitrary channel + * packets are allowed during kex exchange, on the other side we need to blindly + * ignore the next _packet_ if the KEX guess was wrong. Where do we know from that + * the next packet is not a channel data packet? Yes, we could check if it is in + * the KEX range. But the standard says nothing about this. The OpenSSH guys + * block local "normal" traffic during KEX. That's fine - however, they assume + * that the other side is doing the same. During re-key, if they receive traffic + * other than KEX, they become horribly irritated and kill the connection. Since + * we are very likely going to communicate with OpenSSH servers, we have to play + * the same game - even though we could do better. + * + * @author Christian Plattner + * @version $Id: TransportManager.java 161 2014-05-01 18:01:55Z dkocher@sudo.ch $ + */ +public abstract class TransportManager { + private static final Logger log = Logger.getLogger(TransportManager.class); + + private static final class HandlerEntry { + MessageHandler mh; + int low; + int high; + } + + /** + * Advertised maximum SSH packet size that the other side can send to us. + */ + public static final int MAX_PACKET_SIZE = 64 * 1024; + + private final List<AsynchronousEntry> asynchronousQueue + = new ArrayList<AsynchronousEntry>(); + + private Thread asynchronousThread = null; + private boolean asynchronousPending = false; + + private Socket socket; + + protected TransportManager(final Socket socket) { + this.socket = socket; + } + + private static final class AsynchronousEntry { + public byte[] message; + + public AsynchronousEntry(byte[] message) { + this.message = message; + } + } + + private final class AsynchronousWorker implements Runnable { + @Override + public void run() { + while(true) { + final AsynchronousEntry item; + synchronized(asynchronousQueue) { + if(asynchronousQueue.size() == 0) { + // Only now we may reset the flag, since we are sure that all queued items + // have been sent (there is a slight delay between de-queuing and sending, + // this is why we need this flag! See code below. Sending takes place outside + // of this lock, this is why a test for size()==0 (from another thread) does not ensure + // that all messages have been sent. + + asynchronousPending = false; + + // Notify any senders that they can proceed, all async messages have been delivered + + asynchronousQueue.notifyAll(); + + // After the queue is empty for about 2 seconds, stop this thread + try { + asynchronousQueue.wait(2000); + } + catch(InterruptedException ignore) { + // + } + if(asynchronousQueue.size() == 0) { + asynchronousThread = null; + return; + } + } + item = asynchronousQueue.remove(0); + } + try { + sendMessageImmediate(item.message); + } + catch(IOException e) { + // There is no point in handling it - it simply means that the connection has a problem and we should stop + // sending asynchronously messages. We do not need to signal that we have exited (asynchronousThread = null): + // further messages in the queue cannot be sent by this or any other thread. + // Other threads will sooner or later (when receiving or sending the next message) get the + // same IOException and get to the same conclusion. + log.warning(e.getMessage()); + return; + } + } + } + } + + private final Object connectionSemaphore = new Object(); + + private boolean flagKexOngoing; + + private boolean connectionClosed; + private IOException reasonClosedCause; + + private TransportConnection tc; + private KexManager km; + + private final List<HandlerEntry> messageHandlers + = new ArrayList<HandlerEntry>(); + + private List<ConnectionMonitor> connectionMonitors + = new ArrayList<ConnectionMonitor>(); + + protected void init(TransportConnection tc, KexManager km) { + this.tc = tc; + this.km = km; + } + + public int getPacketOverheadEstimate() { + return tc.getPacketOverheadEstimate(); + } + + public ConnectionInfo getConnectionInfo(int kexNumber) throws IOException { + return km.getOrWaitForConnectionInfo(kexNumber); + } + + public IOException getReasonClosedCause() { + synchronized(connectionSemaphore) { + return reasonClosedCause; + } + } + + public byte[] getSessionIdentifier() { + return km.sessionId; + } + + public void close() { + // It is safe now to acquire the semaphore. + synchronized(connectionSemaphore) { + if(!connectionClosed) { + try { + tc.sendMessage(new PacketDisconnect( + PacketDisconnect.Reason.SSH_DISCONNECT_BY_APPLICATION, "").getPayload()); + } + catch(IOException ignore) { + // + } + try { + socket.close(); + } + catch(IOException ignore) { + // + } + connectionClosed = true; + synchronized(this) { + for(ConnectionMonitor cmon : connectionMonitors) { + cmon.connectionLost(reasonClosedCause); + } + } + } + connectionSemaphore.notifyAll(); + } + } + + public void close(IOException cause) { + // Do not acquire the semaphore, perhaps somebody is inside (and waits until + // the remote side is ready to accept new data + try { + socket.close(); + } + catch(IOException ignore) { + } + // It is safe now to acquire the semaphore. + synchronized(connectionSemaphore) { + connectionClosed = true; + reasonClosedCause = cause; + connectionSemaphore.notifyAll(); + } + synchronized(this) { + for(ConnectionMonitor cmon : connectionMonitors) { + cmon.connectionLost(reasonClosedCause); + } + } + } + + protected void startReceiver() throws IOException { + final Thread receiveThread = new Thread(new Runnable() { + public void run() { + try { + receiveLoop(); + // Can only exit with exception + } + catch(IOException e) { + close(e); + log.warning(e.getMessage()); + // Tell all handlers that it is time to say goodbye + if(km != null) { + km.handleFailure(e); + } + for(HandlerEntry he : messageHandlers) { + he.mh.handleFailure(e); + } + } + if(log.isDebugEnabled()) { + log.debug("Receive thread: back from receiveLoop"); + } + } + }); + receiveThread.setName("Transport Manager"); + receiveThread.setDaemon(true); + receiveThread.start(); + } + + public void registerMessageHandler(MessageHandler mh, int low, int high) { + HandlerEntry he = new HandlerEntry(); + he.mh = mh; + he.low = low; + he.high = high; + + synchronized(messageHandlers) { + messageHandlers.add(he); + } + } + + public void removeMessageHandler(MessageHandler handler) { + synchronized(messageHandlers) { + for(int i = 0; i < messageHandlers.size(); i++) { + HandlerEntry he = messageHandlers.get(i); + if(he.mh == handler) { + messageHandlers.remove(i); + break; + } + } + } + } + + public void sendKexMessage(byte[] msg) throws IOException { + synchronized(connectionSemaphore) { + if(connectionClosed) { + throw reasonClosedCause; + } + flagKexOngoing = true; + try { + tc.sendMessage(msg); + } + catch(IOException e) { + close(e); + throw e; + } + } + } + + public void kexFinished() throws IOException { + synchronized(connectionSemaphore) { + flagKexOngoing = false; + connectionSemaphore.notifyAll(); + } + } + + /** + * @param cwl Crypto wishlist + * @param dhgex Diffie-hellman group exchange + * @param dsa may be null if this is a client connection + * @param rsa may be null if this is a client connection + * @throws IOException + */ + public void forceKeyExchange(CryptoWishList cwl, DHGexParameters dhgex, DSAPrivateKey dsa, RSAPrivateKey rsa) + throws IOException { + synchronized(connectionSemaphore) { + if(connectionClosed) { + // Inform the caller that there is no point in triggering a new kex + throw reasonClosedCause; + } + } + km.initiateKEX(cwl, dhgex, dsa, rsa); + } + + public void changeRecvCipher(BlockCipher bc, MAC mac) { + tc.changeRecvCipher(bc, mac); + } + + public void changeSendCipher(BlockCipher bc, MAC mac) { + tc.changeSendCipher(bc, mac); + } + + public void changeRecvCompression(Compressor comp) { + tc.changeRecvCompression(comp); + } + + public void changeSendCompression(Compressor comp) { + tc.changeSendCompression(comp); + } + + public void sendAsynchronousMessage(byte[] msg) throws IOException { + synchronized(asynchronousQueue) { + asynchronousQueue.add(new AsynchronousEntry(msg)); + asynchronousPending = true; + + /* This limit should be flexible enough. We need this, otherwise the peer + * can flood us with global requests (and other stuff where we have to reply + * with an asynchronous message) and (if the server just sends data and does not + * read what we send) this will probably put us in a low memory situation + * (our send queue would grow and grow and...) */ + + if(asynchronousQueue.size() > 100) { + throw new IOException("The peer is not consuming our asynchronous replies."); + } + + // Check if we have an asynchronous sending thread + if(asynchronousThread == null) { + asynchronousThread = new Thread(new AsynchronousWorker()); + asynchronousThread.setDaemon(true); + asynchronousThread.start(); + // The thread will stop after 2 seconds of inactivity (i.e., empty queue) + } + asynchronousQueue.notifyAll(); + } + } + + public void setConnectionMonitors(List<ConnectionMonitor> monitors) { + synchronized(this) { + connectionMonitors = new ArrayList<ConnectionMonitor>(); + connectionMonitors.addAll(monitors); + } + } + + /** + * Send a message but ensure that all queued messages are being sent first. + * + * @param msg Message + * @throws IOException + */ + public void sendMessage(byte[] msg) throws IOException { + synchronized(asynchronousQueue) { + while(asynchronousPending) { + try { + asynchronousQueue.wait(); + } + catch(InterruptedException e) { + throw new InterruptedIOException(e.getMessage()); + } + } + } + sendMessageImmediate(msg); + } + + /** + * Send message, ignore queued async messages that have not been delivered yet. + * Will be called directly from the asynchronousThread thread. + * + * @param msg Message + * @throws IOException + */ + public void sendMessageImmediate(byte[] msg) throws IOException { + synchronized(connectionSemaphore) { + while(true) { + if(connectionClosed) { + throw reasonClosedCause; + } + if(!flagKexOngoing) { + break; + } + try { + connectionSemaphore.wait(); + } + catch(InterruptedException e) { + throw new InterruptedIOException(e.getMessage()); + } + } + + try { + tc.sendMessage(msg); + } + catch(IOException e) { + close(e); + throw e; + } + } + } + + private void receiveLoop() throws IOException { + while(true) { + final byte[] buffer = new byte[MAX_PACKET_SIZE]; + final int length = tc.receiveMessage(buffer, 0, buffer.length); + final byte[] packet = new byte[length]; + System.arraycopy(buffer, 0, packet, 0, length); + final int type = packet[0] & 0xff; + switch(type) { + case Packets.SSH_MSG_IGNORE: + break; + case Packets.SSH_MSG_DEBUG: { + TypesReader tr = new TypesReader(packet); + tr.readByte(); + // always_display + tr.readBoolean(); + String message = tr.readString(); + if(log.isDebugEnabled()) { + log.debug(String.format("Debug message from remote: '%s'", message)); + } + break; + } + case Packets.SSH_MSG_UNIMPLEMENTED: + throw new PacketTypeException(type); + case Packets.SSH_MSG_DISCONNECT: { + final PacketDisconnect disconnect = new PacketDisconnect(packet); + throw new DisconnectException(disconnect.getReason(), disconnect.getMessage()); + } + case Packets.SSH_MSG_KEXINIT: + case Packets.SSH_MSG_NEWKEYS: + case Packets.SSH_MSG_KEXDH_INIT: + case Packets.SSH_MSG_KEXDH_REPLY: + case Packets.SSH_MSG_KEX_DH_GEX_REQUEST: + case Packets.SSH_MSG_KEX_DH_GEX_INIT: + case Packets.SSH_MSG_KEX_DH_GEX_REPLY: + // Is it a KEX Packet + km.handleMessage(packet); + break; + case Packets.SSH_MSG_USERAUTH_SUCCESS: + tc.startCompression(); + // Continue with message handlers + default: + boolean handled = false; + for(HandlerEntry handler : messageHandlers) { + if((handler.low <= type) && (type <= handler.high)) { + handler.mh.handleMessage(packet); + handled = true; + break; + } + } + if(!handled) { + throw new PacketTypeException(type); + } + break; + } + if(log.isDebugEnabled()) { + log.debug(String.format("Handled packet %d", type)); + } + } + } +}