summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Fuller <nfuller@google.com>2014-10-24 15:58:48 +0100
committerNeil Fuller <nfuller@google.com>2014-10-24 17:19:59 +0100
commitd8f241c21b3e2e8f94648040b1b62d7b12491d4d (patch)
treeb3a5ffe5caaddb9ac91a15025a89c8e84d64bf2e
parenta09bacce76b6ece0de2ca71c397dd4b96d101904 (diff)
downloadmockwebserver-d8f241c21b3e2e8f94648040b1b62d7b12491d4d.tar.gz
Add throttling support to MockWebServer
This adds non-blocking throttling support to MockWebServer. Most of the changes are patched across from OkHttp's version (minus SPDY changes). The motivation is to make an upstream OkHttp change easier to apply, but having fewer differences with the OkHttp version should be beneficial. Bug: 18083851 Change-Id: I63367baa46897d02ca5d3fa86f3ab83712b8addf
-rw-r--r--src/main/java/com/google/mockwebserver/Dispatcher.java12
-rw-r--r--src/main/java/com/google/mockwebserver/MockResponse.java55
-rw-r--r--src/main/java/com/google/mockwebserver/MockWebServer.java188
-rw-r--r--src/main/java/com/google/mockwebserver/QueueDispatcher.java11
4 files changed, 154 insertions, 112 deletions
diff --git a/src/main/java/com/google/mockwebserver/Dispatcher.java b/src/main/java/com/google/mockwebserver/Dispatcher.java
index 0456025..48541a4 100644
--- a/src/main/java/com/google/mockwebserver/Dispatcher.java
+++ b/src/main/java/com/google/mockwebserver/Dispatcher.java
@@ -26,11 +26,13 @@ public abstract class Dispatcher {
public abstract MockResponse dispatch(RecordedRequest request) throws InterruptedException;
/**
- * Returns the socket policy of the next request. Default implementation
- * returns {@link SocketPolicy#KEEP_OPEN}. Mischievous implementations can
- * return other values to test HTTP edge cases.
+ * Returns an early guess of the next response, used for policy on how an
+ * incoming request should be received. The default implementation returns an
+ * empty response. Mischievous implementations can return other values to test
+ * HTTP edge cases, such as unhappy socket policies or throttled request
+ * bodies.
*/
- public SocketPolicy peekSocketPolicy() {
- return SocketPolicy.KEEP_OPEN;
+ public MockResponse peek() {
+ return new MockResponse().setSocketPolicy(SocketPolicy.KEEP_OPEN);
}
}
diff --git a/src/main/java/com/google/mockwebserver/MockResponse.java b/src/main/java/com/google/mockwebserver/MockResponse.java
index 7bca741..665d85a 100644
--- a/src/main/java/com/google/mockwebserver/MockResponse.java
+++ b/src/main/java/com/google/mockwebserver/MockResponse.java
@@ -16,7 +16,6 @@
package com.google.mockwebserver;
-import static com.google.mockwebserver.MockWebServer.ASCII;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -25,6 +24,9 @@ import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import static java.nio.charset.StandardCharsets.US_ASCII;
/**
* A scripted response to be replayed by the mock web server.
@@ -40,9 +42,14 @@ public final class MockResponse implements Cloneable {
/** The response body content, or null if {@code body} is set. */
private InputStream bodyStream;
- private int bytesPerSecond = Integer.MAX_VALUE;
+ private int throttleBytesPerPeriod = Integer.MAX_VALUE;
+ private long throttlePeriod = 1;
+ private TimeUnit throttleUnit = TimeUnit.SECONDS;
+
private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN;
+ private int bodyDelayTimeMs = 0;
+
/**
* Creates a new mock response with an empty body.
*/
@@ -185,13 +192,13 @@ public final class MockResponse implements Cloneable {
int pos = 0;
while (pos < body.length) {
int chunkSize = Math.min(body.length - pos, maxChunkSize);
- bytesOut.write(Integer.toHexString(chunkSize).getBytes(ASCII));
- bytesOut.write("\r\n".getBytes(ASCII));
+ bytesOut.write(Integer.toHexString(chunkSize).getBytes(US_ASCII));
+ bytesOut.write("\r\n".getBytes(US_ASCII));
bytesOut.write(body, pos, chunkSize);
- bytesOut.write("\r\n".getBytes(ASCII));
+ bytesOut.write("\r\n".getBytes(US_ASCII));
pos += chunkSize;
}
- bytesOut.write("0\r\n\r\n".getBytes(ASCII)); // last chunk + empty trailer + crlf
+ bytesOut.write("0\r\n\r\n".getBytes(US_ASCII)); // last chunk + empty trailer + crlf
this.body = bytesOut.toByteArray();
return this;
@@ -221,19 +228,43 @@ public final class MockResponse implements Cloneable {
return this;
}
- public int getBytesPerSecond() {
- return bytesPerSecond;
+ /**
+ * Throttles the response body writer to sleep for the given period after each
+ * series of {@code bytesPerPeriod} bytes are written. Use this to simulate
+ * network behavior.
+ */
+ public MockResponse throttleBody(int bytesPerPeriod, long period, TimeUnit unit) {
+ this.throttleBytesPerPeriod = bytesPerPeriod;
+ this.throttlePeriod = period;
+ this.throttleUnit = unit;
+ return this;
+ }
+
+ public int getThrottleBytesPerPeriod() {
+ return throttleBytesPerPeriod;
+ }
+
+ public long getThrottlePeriod() {
+ return throttlePeriod;
+ }
+
+ public TimeUnit getThrottleUnit() {
+ return throttleUnit;
}
/**
- * Set simulated network speed, in bytes per second. This applies to the
- * response body only; response headers are not throttled.
+ * Set the delayed time of the response body to {@code delay}. This applies to the
+ * response body only; response headers are not affected.
*/
- public MockResponse setBytesPerSecond(int bytesPerSecond) {
- this.bytesPerSecond = bytesPerSecond;
+ public MockResponse setBodyDelayTimeMs(int delay) {
+ bodyDelayTimeMs = delay;
return this;
}
+ public int getBodyDelayTimeMs() {
+ return bodyDelayTimeMs;
+ }
+
@Override public String toString() {
return "MockResponse{" + status + "}";
}
diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java
index afcacc5..13a4597 100644
--- a/src/main/java/com/google/mockwebserver/MockWebServer.java
+++ b/src/main/java/com/google/mockwebserver/MockWebServer.java
@@ -33,11 +33,14 @@ import java.net.Socket;
import java.net.SocketException;
import java.net.URL;
import java.net.UnknownHostException;
+import java.nio.charset.StandardCharsets;
+import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
+import java.util.Locale;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
@@ -58,12 +61,26 @@ import javax.net.ssl.X509TrustManager;
* replays them upon request in sequence.
*/
public final class MockWebServer {
+ private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() {
+ @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
+ throws CertificateException {
+ throw new CertificateException();
+ }
+
+ @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
+ throw new AssertionError();
+ }
- static final String ASCII = "US-ASCII";
+ @Override public X509Certificate[] getAcceptedIssuers() {
+ throw new AssertionError();
+ }
+ };
private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
+
private final BlockingQueue<RecordedRequest> requestQueue
= new LinkedBlockingQueue<RecordedRequest>();
+
/** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */
private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>();
private final AtomicInteger requestCount = new AtomicInteger();
@@ -78,7 +95,6 @@ public final class MockWebServer {
private int port = -1;
private int workerThreads = Integer.MAX_VALUE;
-
public int getPort() {
if (port == -1) {
throw new IllegalStateException("Cannot retrieve port before calling play()");
@@ -90,7 +106,7 @@ public final class MockWebServer {
try {
return InetAddress.getLocalHost().getHostName();
} catch (UnknownHostException e) {
- throw new AssertionError();
+ throw new AssertionError(e);
}
}
@@ -250,7 +266,7 @@ public final class MockWebServer {
} catch (SocketException e) {
return;
}
- final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
+ SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
if (socketPolicy == DISCONNECT_AT_START) {
dispatchBookkeepingRequest(0, socket);
socket.close();
@@ -288,16 +304,20 @@ public final class MockWebServer {
if (tunnelProxy) {
createTunnel();
}
- final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
+ SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
if (socketPolicy == FAIL_HANDSHAKE) {
dispatchBookkeepingRequest(sequenceNumber, raw);
- processHandshakeFailure(raw, sequenceNumber++);
+ processHandshakeFailure(raw);
return;
}
socket = sslSocketFactory.createSocket(
raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
- ((SSLSocket) socket).setUseClientMode(false);
+ SSLSocket sslSocket = (SSLSocket) socket;
+ sslSocket.setUseClientMode(false);
openClientSockets.put(socket, true);
+
+ sslSocket.startHandshake();
+
openClientSockets.remove(raw);
} else {
socket = raw;
@@ -325,13 +345,11 @@ public final class MockWebServer {
*/
private void createTunnel() throws IOException, InterruptedException {
while (true) {
- final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
+ SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
throw new IllegalStateException("Tunnel without any CONNECT!");
}
- if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) {
- return;
- }
+ if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return;
}
}
@@ -341,7 +359,7 @@ public final class MockWebServer {
*/
private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
throws IOException, InterruptedException {
- RecordedRequest request = readRequest(socket, in, sequenceNumber);
+ RecordedRequest request = readRequest(socket, in, out, sequenceNumber);
if (request == null) {
return false;
}
@@ -385,21 +403,9 @@ public final class MockWebServer {
}));
}
- private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception {
- X509TrustManager untrusted = new X509TrustManager() {
- @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
- throws CertificateException {
- throw new CertificateException();
- }
- @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
- throw new AssertionError();
- }
- @Override public X509Certificate[] getAcceptedIssuers() {
- throw new AssertionError();
- }
- };
+ private void processHandshakeFailure(Socket raw) throws Exception {
SSLContext context = SSLContext.getInstance("TLS");
- context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom());
+ context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom());
SSLSocketFactory sslSocketFactory = context.getSocketFactory();
SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
@@ -416,14 +422,11 @@ public final class MockWebServer {
RecordedRequest request = new RecordedRequest(null, null, null, -1, null, sequenceNumber,
socket);
dispatcher.dispatch(request);
- requestQueue.add(request);
}
- /**
- * @param sequenceNumber the index of this request on this connection.
- */
- private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber)
- throws IOException {
+ /** @param sequenceNumber the index of this request on this connection. */
+ private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out,
+ int sequenceNumber) throws IOException {
String request;
try {
request = readAsciiUntilCrlf(in);
@@ -435,27 +438,40 @@ public final class MockWebServer {
}
List<String> headers = new ArrayList<String>();
- int contentLength = -1;
+ long contentLength = -1;
boolean chunked = false;
+ boolean expectContinue = false;
String header;
while ((header = readAsciiUntilCrlf(in)).length() != 0) {
headers.add(header);
- String lowercaseHeader = header.toLowerCase();
+ String lowercaseHeader = header.toLowerCase(Locale.US);
if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
- contentLength = Integer.parseInt(header.substring(15).trim());
+ contentLength = Long.parseLong(header.substring(15).trim());
}
- if (lowercaseHeader.startsWith("transfer-encoding:") &&
- lowercaseHeader.substring(18).trim().equals("chunked")) {
+ if (lowercaseHeader.startsWith("transfer-encoding:")
+ && lowercaseHeader.substring(18).trim().equals("chunked")) {
chunked = true;
}
+ if (lowercaseHeader.startsWith("expect:")
+ && lowercaseHeader.substring(7).trim().equals("100-continue")) {
+ expectContinue = true;
+ }
+ }
+
+ if (expectContinue) {
+ out.write(("HTTP/1.1 100 Continue\r\n").getBytes(StandardCharsets.US_ASCII));
+ out.write(("Content-Length: 0\r\n").getBytes(StandardCharsets.US_ASCII));
+ out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
+ out.flush();
}
boolean hasBody = false;
TruncatingOutputStream requestBody = new TruncatingOutputStream();
List<Integer> chunkSizes = new ArrayList<Integer>();
+ MockResponse throttlePolicy = dispatcher.peek();
if (contentLength != -1) {
hasBody = true;
- transfer(contentLength, in, requestBody);
+ throttledTransfer(throttlePolicy, in, requestBody, contentLength);
} else if (chunked) {
hasBody = true;
while (true) {
@@ -465,79 +481,75 @@ public final class MockWebServer {
break;
}
chunkSizes.add(chunkSize);
- transfer(chunkSize, in, requestBody);
+ throttledTransfer(throttlePolicy, in, requestBody, chunkSize);
readEmptyLine(in);
}
}
- if (request.startsWith("OPTIONS ") || request.startsWith("GET ")
- || request.startsWith("HEAD ") || request.startsWith("DELETE ")
- || request.startsWith("TRACE ") || request.startsWith("CONNECT ")) {
+ if (request.startsWith("OPTIONS ")
+ || request.startsWith("GET ")
+ || request.startsWith("HEAD ")
+ || request.startsWith("TRACE ")
+ || request.startsWith("CONNECT ")) {
if (hasBody) {
throw new IllegalArgumentException("Request must not have a body: " + request);
}
- } else if (!request.startsWith("POST ") && !request.startsWith("PUT ")) {
+ } else if (!request.startsWith("POST ")
+ && !request.startsWith("PUT ")
+ && !request.startsWith("PATCH ")
+ && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous.
throw new UnsupportedOperationException("Unexpected method: " + request);
}
- return new RecordedRequest(request, headers, chunkSizes,
- requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket);
+ return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived,
+ requestBody.toByteArray(), sequenceNumber, socket);
}
private void writeResponse(OutputStream out, MockResponse response) throws IOException {
- out.write((response.getStatus() + "\r\n").getBytes(ASCII));
- for (String header : response.getHeaders()) {
- out.write((header + "\r\n").getBytes(ASCII));
+ out.write((response.getStatus() + "\r\n").getBytes(StandardCharsets.US_ASCII));
+ List<String> headers = response.getHeaders();
+ for (int i = 0, size = headers.size(); i < size; i++) {
+ String header = headers.get(i);
+ out.write((header + "\r\n").getBytes(StandardCharsets.US_ASCII));
}
- out.write(("\r\n").getBytes(ASCII));
+ out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
out.flush();
- final InputStream in = response.getBodyStream();
- if (in == null) {
- return;
- }
- final int bytesPerSecond = response.getBytesPerSecond();
-
- // Stream data in MTU-sized increments
- final byte[] buffer = new byte[1452];
- final long delayMs;
- if (bytesPerSecond == Integer.MAX_VALUE) {
- delayMs = 0;
- } else {
- delayMs = (1000 * buffer.length) / bytesPerSecond;
- }
-
- int read;
- long sinceDelay = 0;
- while ((read = in.read(buffer)) != -1) {
- out.write(buffer, 0, read);
- out.flush();
-
- sinceDelay += read;
- if (sinceDelay >= buffer.length && delayMs > 0) {
- sinceDelay %= buffer.length;
- try {
- Thread.sleep(delayMs);
- } catch (InterruptedException e) {
- throw new AssertionError();
- }
- }
- }
+ InputStream in = response.getBodyStream();
+ if (in == null) return;
+ throttledTransfer(response, in, out, Long.MAX_VALUE);
}
/**
* Transfer bytes from {@code in} to {@code out} until either {@code length}
- * bytes have been transferred or {@code in} is exhausted.
+ * bytes have been transferred or {@code in} is exhausted. The transfer is
+ * throttled according to {@code throttlePolicy}.
*/
- private void transfer(int length, InputStream in, OutputStream out) throws IOException {
+ private void throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out,
+ long limit) throws IOException {
byte[] buffer = new byte[1024];
- while (length > 0) {
- int count = in.read(buffer, 0, Math.min(buffer.length, length));
- if (count == -1) {
- return;
+ int bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod();
+ long delayMs = throttlePolicy.getThrottleUnit().toMillis(throttlePolicy.getThrottlePeriod());
+
+ while (true) {
+ for (int b = 0; b < bytesPerPeriod; ) {
+ int toRead = (int) Math.min(Math.min(buffer.length, limit), bytesPerPeriod - b);
+ int read = in.read(buffer, 0, toRead);
+ if (read == -1) return;
+
+ out.write(buffer, 0, read);
+ out.flush();
+ b += read;
+ limit -= read;
+
+ if (limit == 0) return;
+ }
+
+ try {
+ if (delayMs != 0) Thread.sleep(delayMs);
+ } catch (InterruptedException e) {
+ throw new AssertionError();
}
- out.write(buffer, 0, count);
- length -= count;
}
}
diff --git a/src/main/java/com/google/mockwebserver/QueueDispatcher.java b/src/main/java/com/google/mockwebserver/QueueDispatcher.java
index bc26694..a95089b 100644
--- a/src/main/java/com/google/mockwebserver/QueueDispatcher.java
+++ b/src/main/java/com/google/mockwebserver/QueueDispatcher.java
@@ -45,14 +45,11 @@ public class QueueDispatcher extends Dispatcher {
return responseQueue.take();
}
- @Override public SocketPolicy peekSocketPolicy() {
+ @Override public MockResponse peek() {
MockResponse peek = responseQueue.peek();
- if (peek == null) {
- return failFastResponse != null
- ? failFastResponse.getSocketPolicy()
- : SocketPolicy.KEEP_OPEN;
- }
- return peek.getSocketPolicy();
+ if (peek != null) return peek;
+ if (failFastResponse != null) return failFastResponse;
+ return super.peek();
}
public void enqueueResponse(MockResponse response) {