summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Fuller <nfuller@google.com>2014-10-27 11:24:09 +0000
committerAndroid Git Automerger <android-git-automerger@android.com>2014-10-27 11:24:09 +0000
commit6f5c8be5e4f247437f6e0f195b4e3712851e4bef (patch)
treeb3a5ffe5caaddb9ac91a15025a89c8e84d64bf2e
parenta09bacce76b6ece0de2ca71c397dd4b96d101904 (diff)
parentd5e25502a3ed333011753d5f2e1484072a7f5617 (diff)
downloadmockwebserver-6f5c8be5e4f247437f6e0f195b4e3712851e4bef.tar.gz
am d5e25502: Add throttling support to MockWebServer
* commit 'd5e25502a3ed333011753d5f2e1484072a7f5617': Add throttling support to MockWebServer
-rw-r--r--src/main/java/com/google/mockwebserver/Dispatcher.java12
-rw-r--r--src/main/java/com/google/mockwebserver/MockResponse.java55
-rw-r--r--src/main/java/com/google/mockwebserver/MockWebServer.java188
-rw-r--r--src/main/java/com/google/mockwebserver/QueueDispatcher.java11
4 files changed, 154 insertions, 112 deletions
diff --git a/src/main/java/com/google/mockwebserver/Dispatcher.java b/src/main/java/com/google/mockwebserver/Dispatcher.java
index 0456025..48541a4 100644
--- a/src/main/java/com/google/mockwebserver/Dispatcher.java
+++ b/src/main/java/com/google/mockwebserver/Dispatcher.java
@@ -26,11 +26,13 @@ public abstract class Dispatcher {
public abstract MockResponse dispatch(RecordedRequest request) throws InterruptedException;
/**
- * Returns the socket policy of the next request. Default implementation
- * returns {@link SocketPolicy#KEEP_OPEN}. Mischievous implementations can
- * return other values to test HTTP edge cases.
+ * Returns an early guess of the next response, used for policy on how an
+ * incoming request should be received. The default implementation returns an
+ * empty response. Mischievous implementations can return other values to test
+ * HTTP edge cases, such as unhappy socket policies or throttled request
+ * bodies.
*/
- public SocketPolicy peekSocketPolicy() {
- return SocketPolicy.KEEP_OPEN;
+ public MockResponse peek() {
+ return new MockResponse().setSocketPolicy(SocketPolicy.KEEP_OPEN);
}
}
diff --git a/src/main/java/com/google/mockwebserver/MockResponse.java b/src/main/java/com/google/mockwebserver/MockResponse.java
index 7bca741..665d85a 100644
--- a/src/main/java/com/google/mockwebserver/MockResponse.java
+++ b/src/main/java/com/google/mockwebserver/MockResponse.java
@@ -16,7 +16,6 @@
package com.google.mockwebserver;
-import static com.google.mockwebserver.MockWebServer.ASCII;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -25,6 +24,9 @@ import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import static java.nio.charset.StandardCharsets.US_ASCII;
/**
* A scripted response to be replayed by the mock web server.
@@ -40,9 +42,14 @@ public final class MockResponse implements Cloneable {
/** The response body content, or null if {@code body} is set. */
private InputStream bodyStream;
- private int bytesPerSecond = Integer.MAX_VALUE;
+ private int throttleBytesPerPeriod = Integer.MAX_VALUE;
+ private long throttlePeriod = 1;
+ private TimeUnit throttleUnit = TimeUnit.SECONDS;
+
private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN;
+ private int bodyDelayTimeMs = 0;
+
/**
* Creates a new mock response with an empty body.
*/
@@ -185,13 +192,13 @@ public final class MockResponse implements Cloneable {
int pos = 0;
while (pos < body.length) {
int chunkSize = Math.min(body.length - pos, maxChunkSize);
- bytesOut.write(Integer.toHexString(chunkSize).getBytes(ASCII));
- bytesOut.write("\r\n".getBytes(ASCII));
+ bytesOut.write(Integer.toHexString(chunkSize).getBytes(US_ASCII));
+ bytesOut.write("\r\n".getBytes(US_ASCII));
bytesOut.write(body, pos, chunkSize);
- bytesOut.write("\r\n".getBytes(ASCII));
+ bytesOut.write("\r\n".getBytes(US_ASCII));
pos += chunkSize;
}
- bytesOut.write("0\r\n\r\n".getBytes(ASCII)); // last chunk + empty trailer + crlf
+ bytesOut.write("0\r\n\r\n".getBytes(US_ASCII)); // last chunk + empty trailer + crlf
this.body = bytesOut.toByteArray();
return this;
@@ -221,19 +228,43 @@ public final class MockResponse implements Cloneable {
return this;
}
- public int getBytesPerSecond() {
- return bytesPerSecond;
+ /**
+ * Throttles the response body writer to sleep for the given period after each
+ * series of {@code bytesPerPeriod} bytes are written. Use this to simulate
+ * network behavior.
+ */
+ public MockResponse throttleBody(int bytesPerPeriod, long period, TimeUnit unit) {
+ this.throttleBytesPerPeriod = bytesPerPeriod;
+ this.throttlePeriod = period;
+ this.throttleUnit = unit;
+ return this;
+ }
+
+ public int getThrottleBytesPerPeriod() {
+ return throttleBytesPerPeriod;
+ }
+
+ public long getThrottlePeriod() {
+ return throttlePeriod;
+ }
+
+ public TimeUnit getThrottleUnit() {
+ return throttleUnit;
}
/**
- * Set simulated network speed, in bytes per second. This applies to the
- * response body only; response headers are not throttled.
+ * Set the delayed time of the response body to {@code delay}. This applies to the
+ * response body only; response headers are not affected.
*/
- public MockResponse setBytesPerSecond(int bytesPerSecond) {
- this.bytesPerSecond = bytesPerSecond;
+ public MockResponse setBodyDelayTimeMs(int delay) {
+ bodyDelayTimeMs = delay;
return this;
}
+ public int getBodyDelayTimeMs() {
+ return bodyDelayTimeMs;
+ }
+
@Override public String toString() {
return "MockResponse{" + status + "}";
}
diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java
index afcacc5..13a4597 100644
--- a/src/main/java/com/google/mockwebserver/MockWebServer.java
+++ b/src/main/java/com/google/mockwebserver/MockWebServer.java
@@ -33,11 +33,14 @@ import java.net.Socket;
import java.net.SocketException;
import java.net.URL;
import java.net.UnknownHostException;
+import java.nio.charset.StandardCharsets;
+import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
+import java.util.Locale;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
@@ -58,12 +61,26 @@ import javax.net.ssl.X509TrustManager;
* replays them upon request in sequence.
*/
public final class MockWebServer {
+ private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() {
+ @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
+ throws CertificateException {
+ throw new CertificateException();
+ }
+
+ @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
+ throw new AssertionError();
+ }
- static final String ASCII = "US-ASCII";
+ @Override public X509Certificate[] getAcceptedIssuers() {
+ throw new AssertionError();
+ }
+ };
private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
+
private final BlockingQueue<RecordedRequest> requestQueue
= new LinkedBlockingQueue<RecordedRequest>();
+
/** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */
private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>();
private final AtomicInteger requestCount = new AtomicInteger();
@@ -78,7 +95,6 @@ public final class MockWebServer {
private int port = -1;
private int workerThreads = Integer.MAX_VALUE;
-
public int getPort() {
if (port == -1) {
throw new IllegalStateException("Cannot retrieve port before calling play()");
@@ -90,7 +106,7 @@ public final class MockWebServer {
try {
return InetAddress.getLocalHost().getHostName();
} catch (UnknownHostException e) {
- throw new AssertionError();
+ throw new AssertionError(e);
}
}
@@ -250,7 +266,7 @@ public final class MockWebServer {
} catch (SocketException e) {
return;
}
- final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
+ SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
if (socketPolicy == DISCONNECT_AT_START) {
dispatchBookkeepingRequest(0, socket);
socket.close();
@@ -288,16 +304,20 @@ public final class MockWebServer {
if (tunnelProxy) {
createTunnel();
}
- final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
+ SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
if (socketPolicy == FAIL_HANDSHAKE) {
dispatchBookkeepingRequest(sequenceNumber, raw);
- processHandshakeFailure(raw, sequenceNumber++);
+ processHandshakeFailure(raw);
return;
}
socket = sslSocketFactory.createSocket(
raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
- ((SSLSocket) socket).setUseClientMode(false);
+ SSLSocket sslSocket = (SSLSocket) socket;
+ sslSocket.setUseClientMode(false);
openClientSockets.put(socket, true);
+
+ sslSocket.startHandshake();
+
openClientSockets.remove(raw);
} else {
socket = raw;
@@ -325,13 +345,11 @@ public final class MockWebServer {
*/
private void createTunnel() throws IOException, InterruptedException {
while (true) {
- final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
+ SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
throw new IllegalStateException("Tunnel without any CONNECT!");
}
- if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) {
- return;
- }
+ if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return;
}
}
@@ -341,7 +359,7 @@ public final class MockWebServer {
*/
private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
throws IOException, InterruptedException {
- RecordedRequest request = readRequest(socket, in, sequenceNumber);
+ RecordedRequest request = readRequest(socket, in, out, sequenceNumber);
if (request == null) {
return false;
}
@@ -385,21 +403,9 @@ public final class MockWebServer {
}));
}
- private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception {
- X509TrustManager untrusted = new X509TrustManager() {
- @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
- throws CertificateException {
- throw new CertificateException();
- }
- @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
- throw new AssertionError();
- }
- @Override public X509Certificate[] getAcceptedIssuers() {
- throw new AssertionError();
- }
- };
+ private void processHandshakeFailure(Socket raw) throws Exception {
SSLContext context = SSLContext.getInstance("TLS");
- context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom());
+ context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom());
SSLSocketFactory sslSocketFactory = context.getSocketFactory();
SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
@@ -416,14 +422,11 @@ public final class MockWebServer {
RecordedRequest request = new RecordedRequest(null, null, null, -1, null, sequenceNumber,
socket);
dispatcher.dispatch(request);
- requestQueue.add(request);
}
- /**
- * @param sequenceNumber the index of this request on this connection.
- */
- private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber)
- throws IOException {
+ /** @param sequenceNumber the index of this request on this connection. */
+ private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out,
+ int sequenceNumber) throws IOException {
String request;
try {
request = readAsciiUntilCrlf(in);
@@ -435,27 +438,40 @@ public final class MockWebServer {
}
List<String> headers = new ArrayList<String>();
- int contentLength = -1;
+ long contentLength = -1;
boolean chunked = false;
+ boolean expectContinue = false;
String header;
while ((header = readAsciiUntilCrlf(in)).length() != 0) {
headers.add(header);
- String lowercaseHeader = header.toLowerCase();
+ String lowercaseHeader = header.toLowerCase(Locale.US);
if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
- contentLength = Integer.parseInt(header.substring(15).trim());
+ contentLength = Long.parseLong(header.substring(15).trim());
}
- if (lowercaseHeader.startsWith("transfer-encoding:") &&
- lowercaseHeader.substring(18).trim().equals("chunked")) {
+ if (lowercaseHeader.startsWith("transfer-encoding:")
+ && lowercaseHeader.substring(18).trim().equals("chunked")) {
chunked = true;
}
+ if (lowercaseHeader.startsWith("expect:")
+ && lowercaseHeader.substring(7).trim().equals("100-continue")) {
+ expectContinue = true;
+ }
+ }
+
+ if (expectContinue) {
+ out.write(("HTTP/1.1 100 Continue\r\n").getBytes(StandardCharsets.US_ASCII));
+ out.write(("Content-Length: 0\r\n").getBytes(StandardCharsets.US_ASCII));
+ out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
+ out.flush();
}
boolean hasBody = false;
TruncatingOutputStream requestBody = new TruncatingOutputStream();
List<Integer> chunkSizes = new ArrayList<Integer>();
+ MockResponse throttlePolicy = dispatcher.peek();
if (contentLength != -1) {
hasBody = true;
- transfer(contentLength, in, requestBody);
+ throttledTransfer(throttlePolicy, in, requestBody, contentLength);
} else if (chunked) {
hasBody = true;
while (true) {
@@ -465,79 +481,75 @@ public final class MockWebServer {
break;
}
chunkSizes.add(chunkSize);
- transfer(chunkSize, in, requestBody);
+ throttledTransfer(throttlePolicy, in, requestBody, chunkSize);
readEmptyLine(in);
}
}
- if (request.startsWith("OPTIONS ") || request.startsWith("GET ")
- || request.startsWith("HEAD ") || request.startsWith("DELETE ")
- || request.startsWith("TRACE ") || request.startsWith("CONNECT ")) {
+ if (request.startsWith("OPTIONS ")
+ || request.startsWith("GET ")
+ || request.startsWith("HEAD ")
+ || request.startsWith("TRACE ")
+ || request.startsWith("CONNECT ")) {
if (hasBody) {
throw new IllegalArgumentException("Request must not have a body: " + request);
}
- } else if (!request.startsWith("POST ") && !request.startsWith("PUT ")) {
+ } else if (!request.startsWith("POST ")
+ && !request.startsWith("PUT ")
+ && !request.startsWith("PATCH ")
+ && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous.
throw new UnsupportedOperationException("Unexpected method: " + request);
}
- return new RecordedRequest(request, headers, chunkSizes,
- requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket);
+ return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived,
+ requestBody.toByteArray(), sequenceNumber, socket);
}
private void writeResponse(OutputStream out, MockResponse response) throws IOException {
- out.write((response.getStatus() + "\r\n").getBytes(ASCII));
- for (String header : response.getHeaders()) {
- out.write((header + "\r\n").getBytes(ASCII));
+ out.write((response.getStatus() + "\r\n").getBytes(StandardCharsets.US_ASCII));
+ List<String> headers = response.getHeaders();
+ for (int i = 0, size = headers.size(); i < size; i++) {
+ String header = headers.get(i);
+ out.write((header + "\r\n").getBytes(StandardCharsets.US_ASCII));
}
- out.write(("\r\n").getBytes(ASCII));
+ out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
out.flush();
- final InputStream in = response.getBodyStream();
- if (in == null) {
- return;
- }
- final int bytesPerSecond = response.getBytesPerSecond();
-
- // Stream data in MTU-sized increments
- final byte[] buffer = new byte[1452];
- final long delayMs;
- if (bytesPerSecond == Integer.MAX_VALUE) {
- delayMs = 0;
- } else {
- delayMs = (1000 * buffer.length) / bytesPerSecond;
- }
-
- int read;
- long sinceDelay = 0;
- while ((read = in.read(buffer)) != -1) {
- out.write(buffer, 0, read);
- out.flush();
-
- sinceDelay += read;
- if (sinceDelay >= buffer.length && delayMs > 0) {
- sinceDelay %= buffer.length;
- try {
- Thread.sleep(delayMs);
- } catch (InterruptedException e) {
- throw new AssertionError();
- }
- }
- }
+ InputStream in = response.getBodyStream();
+ if (in == null) return;
+ throttledTransfer(response, in, out, Long.MAX_VALUE);
}
/**
* Transfer bytes from {@code in} to {@code out} until either {@code length}
- * bytes have been transferred or {@code in} is exhausted.
+ * bytes have been transferred or {@code in} is exhausted. The transfer is
+ * throttled according to {@code throttlePolicy}.
*/
- private void transfer(int length, InputStream in, OutputStream out) throws IOException {
+ private void throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out,
+ long limit) throws IOException {
byte[] buffer = new byte[1024];
- while (length > 0) {
- int count = in.read(buffer, 0, Math.min(buffer.length, length));
- if (count == -1) {
- return;
+ int bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod();
+ long delayMs = throttlePolicy.getThrottleUnit().toMillis(throttlePolicy.getThrottlePeriod());
+
+ while (true) {
+ for (int b = 0; b < bytesPerPeriod; ) {
+ int toRead = (int) Math.min(Math.min(buffer.length, limit), bytesPerPeriod - b);
+ int read = in.read(buffer, 0, toRead);
+ if (read == -1) return;
+
+ out.write(buffer, 0, read);
+ out.flush();
+ b += read;
+ limit -= read;
+
+ if (limit == 0) return;
+ }
+
+ try {
+ if (delayMs != 0) Thread.sleep(delayMs);
+ } catch (InterruptedException e) {
+ throw new AssertionError();
}
- out.write(buffer, 0, count);
- length -= count;
}
}
diff --git a/src/main/java/com/google/mockwebserver/QueueDispatcher.java b/src/main/java/com/google/mockwebserver/QueueDispatcher.java
index bc26694..a95089b 100644
--- a/src/main/java/com/google/mockwebserver/QueueDispatcher.java
+++ b/src/main/java/com/google/mockwebserver/QueueDispatcher.java
@@ -45,14 +45,11 @@ public class QueueDispatcher extends Dispatcher {
return responseQueue.take();
}
- @Override public SocketPolicy peekSocketPolicy() {
+ @Override public MockResponse peek() {
MockResponse peek = responseQueue.peek();
- if (peek == null) {
- return failFastResponse != null
- ? failFastResponse.getSocketPolicy()
- : SocketPolicy.KEEP_OPEN;
- }
- return peek.getSocketPolicy();
+ if (peek != null) return peek;
+ if (failFastResponse != null) return failFastResponse;
+ return super.peek();
}
public void enqueueResponse(MockResponse response) {