diff options
Diffstat (limited to 'java/src/com/google/polo/pairing/PairingSession.java')
-rw-r--r-- | java/src/com/google/polo/pairing/PairingSession.java | 753 |
1 files changed, 753 insertions, 0 deletions
diff --git a/java/src/com/google/polo/pairing/PairingSession.java b/java/src/com/google/polo/pairing/PairingSession.java new file mode 100644 index 0000000..8baccf4 --- /dev/null +++ b/java/src/com/google/polo/pairing/PairingSession.java @@ -0,0 +1,753 @@ +/* + * Copyright (C) 2009 Google Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.polo.pairing; + +import com.google.polo.encoding.HexadecimalEncoder; +import com.google.polo.encoding.SecretEncoder; +import com.google.polo.exception.BadSecretException; +import com.google.polo.exception.NoConfigurationException; +import com.google.polo.exception.PoloException; +import com.google.polo.exception.ProtocolErrorException; +import com.google.polo.pairing.PairingListener.LogLevel; +import com.google.polo.pairing.message.ConfigurationMessage; +import com.google.polo.pairing.message.EncodingOption; +import com.google.polo.pairing.message.OptionsMessage; +import com.google.polo.pairing.message.OptionsMessage.ProtocolRole; +import com.google.polo.pairing.message.PoloMessage; +import com.google.polo.pairing.message.PoloMessage.PoloMessageType; +import com.google.polo.pairing.message.SecretAckMessage; +import com.google.polo.pairing.message.SecretMessage; +import com.google.polo.wire.PoloWireInterface; + +import java.io.IOException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.cert.Certificate; +import java.util.Arrays; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + + +/** + * Implements the logic of and holds state for a single occurrence of the + * pairing protocol. + * <p> + * This abstract class implements the logic common to both client and server + * perspectives of the protocol. Notably, the 'pairing' phase of the + * protocol has the same logic regardless of client/server status + * ({link PairingSession#doPairingPhase()}). Other phases of the protocol are + * specific to client/server status; see {@link ServerPairingSession} and + * {@link ClientPairingSession}. + * <p> + * The protocol is initiated by called + * {@link PairingSession#doPair(PairingListener)} + * The listener implementation is responsible for showing the shared secret + * to the user + * ({@link PairingListener#onPerformOutputDeviceRole(PairingSession, byte[])}), + * or in accepting the user input + * ({@link PairingListener#onPerformInputDeviceRole(PairingSession)}), + * depending on the role negotiated during initialization. + * <p> + * When operating in the input role, the session will block execution after + * calling {@link PairingListener#onPerformInputDeviceRole(PairingSession)} to + * wait for the secret. The listener, or some activity resulting from it, must + * publish the input secret to the session via + * {@link PairingSession#setSecret(byte[])}. + */ +public abstract class PairingSession { + + protected enum ProtocolState { + STATE_UNINITIALIZED, + STATE_INITIALIZING, + STATE_CONFIGURING, + STATE_PAIRING, + STATE_SUCCESS, + STATE_FAILURE, + } + + /** + * Enable extra verbose debug logging. + */ + private static final boolean DEBUG_VERBOSE = false; + + /** + * Controls whether to verify the secret portion of the SecretAck message. + * <p> + * NOTE(mikey): One implementation does not send the secret back in + * the SecretAck. This should be fixed, but in the meantime it is not + * essential that we verify it, since *any* acknowledgment from the + * sender is enough to indicate protocol success. + */ + private static final boolean VERIFY_SECRET_ACK = false; + + /** + * Timeout, in milliseconds, for polling the secret queue for a response from + * the listener. This timeout is relevant only to periodically check the + * mAbort flag to terminate the protocol, which is set by calling teardown(). + */ + private static final int SECRET_POLL_TIMEOUT_MS = 500; + + /** + * Performs the initialization phase of the protocol. + * + * @throws PoloException if a protocol error occurred + * @throws IOException if an error occurred in input/output + */ + protected abstract void doInitializationPhase() + throws PoloException, IOException; + + /** + * Performs the configuration phase of the protocol. + * + * @throws PoloException if a protocol error occurred + * @throws IOException if an error occurred in input/output + */ + protected abstract void doConfigurationPhase() + throws PoloException, IOException; + + /** + * Internal representation of challenge-response. + */ + protected PoloChallengeResponse mChallenge; + + /** + * Implementation of the transport layer. + */ + private final PoloWireInterface mProtocol; + + /** + * Context for the pairing session. + */ + protected final PairingContext mPairingContext; + + /** + * Local endpoint's supported options. + * <p> + * If this session is acting as a server, this message will be sent to the + * client in the Initialization phase. If acting as a client, this member is + * used to store local options and compute the Configuration message (but + * is never transmitted directly). + */ + protected OptionsMessage mLocalOptions; + + /** + * Encoding scheme used for the session. + */ + protected SecretEncoder mEncoder; + + /** + * Name of the service being paired. + */ + protected String mServiceName; + + /** + * Name of the peer. + */ + protected String mPeerName; + + /** + * Configuration message for current session. + * <p> + * This is computed by the client and sent to the server. + */ + protected ConfigurationMessage mSessionConfig; + + /** + * Listener that will receive callbacks upon protocol events. + */ + protected PairingListener mListener; + + /** + * Internal state of the pairing session. + */ + protected ProtocolState mState; + + /** + * Threadsafe queue for receiving the messages sent by peer, user-given secret + * from the listener, or exceptions caught by async threads. + */ + protected BlockingQueue<QueueMessage> mMessageQueue; + + /** + * Flag set when the session should be aborted. + */ + protected boolean mAbort; + + /** + * Reader thread. + */ + private final Thread mThread; + + /** + * Constructor. + * + * @param protocol the wire interface to operate against + * @param pairingContext a PairingContext for the session + */ + public PairingSession(PoloWireInterface protocol, + PairingContext pairingContext) { + mProtocol = protocol; + mPairingContext = pairingContext; + mState = ProtocolState.STATE_UNINITIALIZED; + mMessageQueue = new LinkedBlockingQueue<QueueMessage>(); + + Certificate clientCert = mPairingContext.getClientCertificate(); + Certificate serverCert = mPairingContext.getServerCertificate(); + + mChallenge = new PoloChallengeResponse(clientCert, serverCert, + new PoloChallengeResponse.DebugLogger() { + public void debug(String message) { + logDebug(message); + } + public void verbose(String message) { + if (DEBUG_VERBOSE) { + logDebug(message); + } + } + }); + + mLocalOptions = new OptionsMessage(); + + if (mPairingContext.isServer()) { + mLocalOptions.setProtocolRolePreference(ProtocolRole.DISPLAY_DEVICE); + } else { + mLocalOptions.setProtocolRolePreference(ProtocolRole.INPUT_DEVICE); + } + + mThread = new Thread(new Runnable() { + public void run() { + logDebug("Starting reader"); + try { + while (!mAbort) { + try { + PoloMessage message = mProtocol.getNextMessage(); + logDebug("Received: " + message.getClass()); + mMessageQueue.put(new QueueMessage(message)); + } catch (PoloException exception) { + logDebug("Exception while getting message: " + exception); + mMessageQueue.put(new QueueMessage(exception)); + break; + } catch (IOException exception) { + logDebug("Exception while getting message: " + exception); + mMessageQueue.put(new QueueMessage(new PoloException(exception))); + break; + } + } + } catch (InterruptedException ie) { + logDebug("Interrupted: " + ie); + } finally { + logDebug("Reader is done"); + } + } + }); + mThread.start(); + } + + public void teardown() { + try { + // Send any error. + mProtocol.sendErrorMessage(new Exception()); + mPairingContext.getPeerInputStream().close(); + mPairingContext.getPeerOutputStream().close(); + } catch (IOException e) { + // oh well. + } + + // Unblock the blocking wait on the secret queue. + mAbort = true; + mThread.interrupt(); + } + + protected void log(LogLevel level, String message) { + if (mListener != null) { + mListener.onLogMessage(level, message); + } + } + + /** + * Logs a debug message to the active listener. + */ + public void logDebug(String message) { + log(LogLevel.LOG_DEBUG, message); + } + + /** + * Logs an informational message to the active listener. + */ + public void logInfo(String message) { + log(LogLevel.LOG_INFO, message); + } + + /** + * Logs an error message to the active listener. + */ + public void logError(String message) { + log(LogLevel.LOG_ERROR, message); + } + + /** + * Adds an encoding to the supported input role encodings. This method can + * only be called before the session has started. + * <p> + * If no input encodings have been added, then this endpoint cannot act as + * the input device protocol role. + * + * @param encoding the {@link EncodingOption} to add + */ + public void addInputEncoding(EncodingOption encoding) { + if (mState != ProtocolState.STATE_UNINITIALIZED) { + throw new IllegalStateException("Cannot add encodings once session " + + "has been started."); + } + // Legal values of GAMMALEN must be: + // - an even number of bytes + // - at least 2 bytes + if ((encoding.getSymbolLength() < 2) || + ((encoding.getSymbolLength() % 2) != 0)) { + throw new IllegalArgumentException("Bad symbol length: " + + encoding.getSymbolLength()); + } + mLocalOptions.addInputEncoding(encoding); + } + + /** + * Adds an encoding to the supported output role encodings. This method can + * only be called before the session has started. + * <p> + * If no output encodings have been added, then this endpoint cannot act as + * the output device protocol role. + * + * @param encoding the {@link EncodingOption} to add + */ + public void addOutputEncoding(EncodingOption encoding) { + if (mState != ProtocolState.STATE_UNINITIALIZED) { + throw new IllegalStateException("Cannot add encodings once session " + + "has been started."); + } + mLocalOptions.addOutputEncoding(encoding); + } + + /** + * Changes the internal state. + * + * @param newState the new state + */ + private void setState(ProtocolState newState) { + logInfo("New state: " + newState); + mState = newState; + } + + /** + * Runs the pairing protocol. + * <p> + * Supported input and output encodings must be specified + * first, using + * {@link PairingSession#addInputEncoding(EncodingOption)} and + * {@link PairingSession#addOutputEncoding(EncodingOption)}, + * respectively. + * + * @param listener the {@link PairingListener} for the session + * @return {@code true} if pairing was successful + */ + public boolean doPair(PairingListener listener) { + mListener = listener; + mListener.onSessionCreated(this); + + if (mPairingContext.isServer()) { + logDebug("Protocol started (SERVER mode)"); + } else { + logDebug("Protocol started (CLIENT mode)"); + } + + logDebug("Local options: " + mLocalOptions.toString()); + + Certificate clientCert = mPairingContext.getClientCertificate(); + if (DEBUG_VERBOSE) { + logDebug("Client certificate:"); + logDebug(clientCert.toString()); + } + + Certificate serverCert = mPairingContext.getServerCertificate(); + + if (DEBUG_VERBOSE) { + logDebug("Server certificate:"); + logDebug(serverCert.toString()); + } + + boolean success = false; + + try { + setState(ProtocolState.STATE_INITIALIZING); + doInitializationPhase(); + + setState(ProtocolState.STATE_CONFIGURING); + doConfigurationPhase(); + + setState(ProtocolState.STATE_PAIRING); + doPairingPhase(); + + success = true; + } catch (ProtocolErrorException e) { + logDebug("Remote protocol failure: " + e); + } catch (PoloException e) { + try { + logDebug("Local protocol failure, attempting to send error: " + e); + mProtocol.sendErrorMessage(e); + } catch (IOException e1) { + logDebug("Error message send failed"); + } + } catch (IOException e) { + logDebug("IOException: " + e); + } + + if (success) { + setState(ProtocolState.STATE_SUCCESS); + } else { + setState(ProtocolState.STATE_FAILURE); + } + + mListener.onSessionEnded(this); + return success; + } + + /** + * Returns {@code true} if the session is in a terminal state (success or + * failure). + */ + public boolean hasCompleted() { + switch (mState) { + case STATE_SUCCESS: + case STATE_FAILURE: + return true; + default: + return false; + } + } + + public boolean hasSucceeded() { + return mState == ProtocolState.STATE_SUCCESS; + } + + public String getServiceName() { + return mServiceName; + } + + /** + * Sets the secret, as received from a user. This method is only meaningful + * when the endpoint is acting as the input device role. + * + * @param secret the secret, as a byte sequence + * @return {@code true} if the secret was captured + */ + public boolean setSecret(byte[] secret) { + if (!isInputDevice()) { + throw new IllegalStateException("Secret can only be set for " + + "input role session."); + } else if (mState != ProtocolState.STATE_PAIRING) { + throw new IllegalStateException("Secret can only be set while " + + "in pairing state."); + } + return mMessageQueue.offer(new QueueMessage(secret)); + } + + /** + * Executes the pairing phase of the protocol. + * + * @throws PoloException if a protocol error occurred + * @throws IOException if an error in the input/output occurred + */ + protected void doPairingPhase() throws PoloException, IOException { + if (isInputDevice()) { + new Thread(new Runnable() { + public void run() { + logDebug("Calling listener for user input..."); + try { + mListener.onPerformInputDeviceRole(PairingSession.this); + } catch (PoloException exception) { + logDebug("Sending exception: " + exception); + mMessageQueue.offer(new QueueMessage(exception)); + } finally { + logDebug("Listener finished."); + } + } + }).start(); + + logDebug("Waiting for secret from Listener or ..."); + QueueMessage message = waitForMessage(); + if (message == null || !message.hasSecret()) { + throw new PoloException( + "Illegal state - no secret available: " + message); + } + byte[] userGamma = message.mSecret; + if (userGamma == null) { + throw new PoloException("Invalid secret."); + } + + boolean match = mChallenge.checkGamma(userGamma); + if (match != true) { + throw new BadSecretException("Secret failed local check."); + } + + byte[] userNonce = mChallenge.extractNonce(userGamma); + byte[] genAlpha = mChallenge.getAlpha(userNonce); + + logDebug("Sending Secret reply..."); + SecretMessage secretMessage = new SecretMessage(genAlpha); + mProtocol.sendMessage(secretMessage); + + logDebug("Waiting for SecretAck..."); + SecretAckMessage secretAck = + (SecretAckMessage) getNextMessage(PoloMessageType.SECRET_ACK); + + if (VERIFY_SECRET_ACK) { + byte[] inbandAlpha = secretAck.getSecret(); + if (!Arrays.equals(inbandAlpha, genAlpha)) { + throw new BadSecretException("Inband secret did not match. " + + "Expected [" + PoloUtil.bytesToHexString(genAlpha) + + "], got [" + PoloUtil.bytesToHexString(inbandAlpha) + "]"); + } + } + } else { + int symbolLength = mSessionConfig.getEncoding().getSymbolLength(); + int nonceLength = symbolLength / 2; + int bytesNeeded = nonceLength / mEncoder.symbolsPerByte(); + + byte[] nonce = new byte[bytesNeeded]; + SecureRandom random; + try { + random = SecureRandom.getInstance("SHA1PRNG"); + } catch (NoSuchAlgorithmException e) { + throw new PoloException(e); + } + random.nextBytes(nonce); + + // Display gamma + logDebug("Calling listener to display output..."); + byte[] gamma = mChallenge.getGamma(nonce); + mListener.onPerformOutputDeviceRole(this, gamma); + + logDebug("Waiting for Secret..."); + SecretMessage secretMessage = + (SecretMessage) getNextMessage(PoloMessageType.SECRET); + + byte[] localAlpha = mChallenge.getAlpha(nonce); + byte[] inbandAlpha = secretMessage.getSecret(); + boolean matched = Arrays.equals(localAlpha, inbandAlpha); + + if (!matched) { + throw new BadSecretException("Inband secret did not match. " + + "Expected [" + PoloUtil.bytesToHexString(localAlpha) + + "], got [" + PoloUtil.bytesToHexString(inbandAlpha) + "]"); + } + + logDebug("Sending SecretAck..."); + byte[] genAlpha = mChallenge.getAlpha(nonce); + SecretAckMessage secretAck = new SecretAckMessage(inbandAlpha); + mProtocol.sendMessage(secretAck); + } + } + + public SecretEncoder getEncoder() { + return mEncoder; + } + + /** + * Sets the current session's configuration from a + * {@link ConfigurationMessage}. + * + * @param message the session's config + * @throws PoloException if the config was not valid for some reason + */ + protected void setConfiguration(ConfigurationMessage message) + throws PoloException { + if (message == null || message.getEncoding() == null) { + throw new NoConfigurationException("No configuration is possible."); + } + if (message.getEncoding().getSymbolLength() % 2 != 0) { + throw new PoloException("Symbol length must be even."); + } + if (message.getEncoding().getSymbolLength() < 2) { + throw new PoloException("Symbol length must be >= 2 symbols."); + } + switch (message.getEncoding().getType()) { + case ENCODING_HEXADECIMAL: + mEncoder = new HexadecimalEncoder(); + break; + default: + throw new PoloException("Unsupported encoding type."); + } + mSessionConfig = message; + } + + /** + * Returns the role of this endpoint in the current session. + */ + protected ProtocolRole getLocalRole() { + assert (mSessionConfig != null); + if (!mPairingContext.isServer()) { + return mSessionConfig.getClientRole(); + } else { + return (mSessionConfig.getClientRole() == ProtocolRole.DISPLAY_DEVICE) ? + ProtocolRole.INPUT_DEVICE : ProtocolRole.DISPLAY_DEVICE; + } + } + + /** + * Returns {@code true} if this endpoint will act as the input device. + */ + protected boolean isInputDevice() { + return (getLocalRole() == ProtocolRole.INPUT_DEVICE); + } + + /** + * Returns {@code true} if peer's name is set. + */ + public boolean hasPeerName() { + return mPeerName != null; + } + + /** + * Returns peer's name if set, {@code null} otherwise. + */ + public String getPeerName() { + return mPeerName; + } + + protected PoloMessage getNextMessage(PoloMessageType type) + throws PoloException { + QueueMessage message = waitForMessage(); + if (message != null && message.hasPoloMessage()) { + if (!type.equals(message.mPoloMessage.getType())) { + throw new PoloException( + "Unexpected message type: " + message.mPoloMessage.getType()); + } + return message.mPoloMessage; + } + throw new PoloException("Invalid state - expected polo message"); + } + + /** + * Returns next queued message. The method blocks until the secret or the + * polo message is available. + * + * @return the queued message, or null on error + * @throws PoloException if exception was queued + */ + private QueueMessage waitForMessage() throws PoloException { + while (!mAbort) { + try { + QueueMessage message = mMessageQueue.poll(SECRET_POLL_TIMEOUT_MS, + TimeUnit.MILLISECONDS); + + if (message != null) { + if (message.hasPoloException()) { + throw new PoloException(message.mPoloException); + } + return message; + } + } catch (InterruptedException e) { + break; + } + } + + // Aborted or interrupted. + return null; + } + + /** + * Sends message to the peer. + * + * @param message the message + * @throws PoloException if a protocol error occurred + * @throws IOException if an error in the input/output occurred + */ + protected void sendMessage(PoloMessage message) + throws IOException, PoloException { + mProtocol.sendMessage(message); + } + + /** + * Queued message, that can carry information about secret, next read message, + * or exception caught by reader or input threads. + */ + private static final class QueueMessage { + final PoloMessage mPoloMessage; + final PoloException mPoloException; + final byte[] mSecret; + + private QueueMessage( + PoloMessage message, byte[] secret, PoloException exception) { + int nonNullCount = 0; + if (message != null) { + ++nonNullCount; + } + mPoloMessage = message; + if (exception != null) { + assert(nonNullCount == 0); + ++nonNullCount; + } + mPoloException = exception; + if (secret != null) { + assert(nonNullCount == 0); + ++nonNullCount; + } + mSecret = secret; + assert(nonNullCount == 1); + } + + public QueueMessage(PoloMessage message) { + this(message, null, null); + } + + public QueueMessage(byte[] secret) { + this(null, secret, null); + } + + public QueueMessage(PoloException exception) { + this(null, null, exception); + } + + public boolean hasPoloMessage() { + return mPoloMessage != null; + } + + public boolean hasPoloException() { + return mPoloException != null; + } + + public boolean hasSecret() { + return mSecret != null; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder("QueueMessage("); + if (hasPoloMessage()) { + builder.append("poloMessage = " + mPoloMessage); + } + if (hasPoloException()) { + builder.append("poloException = " + mPoloException); + } + if (hasSecret()) { + builder.append("secret = " + Arrays.toString(mSecret)); + } + return builder.append(")").toString(); + } + } + +} |