diff options
Diffstat (limited to 'adservices/tests/perf/src/android/adservices/test/scenario/adservices/utils/MockWebServerRule.java')
-rw-r--r-- | adservices/tests/perf/src/android/adservices/test/scenario/adservices/utils/MockWebServerRule.java | 289 |
1 files changed, 289 insertions, 0 deletions
diff --git a/adservices/tests/perf/src/android/adservices/test/scenario/adservices/utils/MockWebServerRule.java b/adservices/tests/perf/src/android/adservices/test/scenario/adservices/utils/MockWebServerRule.java new file mode 100644 index 0000000000..ecc447977f --- /dev/null +++ b/adservices/tests/perf/src/android/adservices/test/scenario/adservices/utils/MockWebServerRule.java @@ -0,0 +1,289 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.adservices.test.scenario.adservices.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import android.content.Context; +import android.net.Uri; + +import com.google.mockwebserver.Dispatcher; +import com.google.mockwebserver.MockResponse; +import com.google.mockwebserver.MockWebServer; +import com.google.mockwebserver.RecordedRequest; + +import org.junit.Assert; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +import java.io.IOException; +import java.io.InputStream; +import java.net.ServerSocket; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; + +/** Instances of this class are not thread safe. */ +public class MockWebServerRule implements TestRule { + private static final int UNINITIALIZED = -1; + private final InputStream mCertificateInputStream; + private final char[] mKeyStorePassword; + private int mPort = UNINITIALIZED; + private MockWebServer mMockWebServer; + + private MockWebServerRule(InputStream inputStream, String keyStorePassword) { + mCertificateInputStream = inputStream; + mKeyStorePassword = keyStorePassword == null ? null : keyStorePassword.toCharArray(); + } + + public static MockWebServerRule forHttp() { + return new MockWebServerRule(null, null); + } + + /** + * Builds an instance of the MockWebServerRule configured for HTTPS traffic. + * + * @param context The app context used to load the PKCS12 key store + * @param assetName The name of the key store under the app assets folder + * @param keyStorePassword The password of the keystore + */ + public static MockWebServerRule forHttps( + Context context, String assetName, String keyStorePassword) { + try { + return new MockWebServerRule(context.getAssets().open(assetName), keyStorePassword); + } catch (IOException ioException) { + throw new RuntimeException("Unable to initialize MockWebServerRule", ioException); + } + } + + /** + * Builds an instance of the MockWebServerRule configured for HTTPS traffic. + * + * @param certificateInputStream An input stream to load the content of a PKCS12 key store + * @param keyStorePassword The password of the keystore + */ + public static MockWebServerRule forHttps( + InputStream certificateInputStream, String keyStorePassword) { + return new MockWebServerRule(certificateInputStream, keyStorePassword); + } + + private boolean useHttps() { + return Objects.nonNull(mCertificateInputStream); + } + + public MockWebServer startMockWebServer(List<MockResponse> responses) throws Exception { + if (mPort == UNINITIALIZED) { + reserveServerListeningPort(); + } + + mMockWebServer = new MockWebServer(); + if (useHttps()) { + mMockWebServer.useHttps(getTestingSslSocketFactory(), false); + } + for (MockResponse response : responses) { + mMockWebServer.enqueue(response); + } + mMockWebServer.play(mPort); + return mMockWebServer; + } + + public MockWebServer startMockWebServer(Function<RecordedRequest, MockResponse> lambda) + throws Exception { + Dispatcher dispatcher = + new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + return lambda.apply(request); + } + }; + return startMockWebServer(dispatcher); + } + + public MockWebServer startMockWebServer(Dispatcher dispatcher) throws Exception { + if (mPort == UNINITIALIZED) { + reserveServerListeningPort(); + } + + mMockWebServer = new MockWebServer(); + if (useHttps()) { + mMockWebServer.useHttps(getTestingSslSocketFactory(), false); + } + mMockWebServer.setDispatcher(dispatcher); + + mMockWebServer.play(mPort); + return mMockWebServer; + } + + public MockWebServer createMockWebServer() throws Exception { + if (mPort == UNINITIALIZED) { + reserveServerListeningPort(); + } + + mMockWebServer = new MockWebServer(); + if (useHttps()) { + mMockWebServer.useHttps(getTestingSslSocketFactory(), false); + } + return mMockWebServer; + } + + public MockWebServer startCreatedMockWebServer(Dispatcher dispatcher) throws Exception { + if (mMockWebServer == null || mPort == UNINITIALIZED) { + throw new IllegalStateException( + "MockWebServer is not created or the port is not reserved."); + } + mMockWebServer.setDispatcher(dispatcher); + + mMockWebServer.play(mPort); + return mMockWebServer; + } + + /** + * @return the mock web server for this rull and {@code null} if it hasn't been started yet by + * calling {@link #startMockWebServer(List)}. + */ + public MockWebServer getMockWebServer() { + return mMockWebServer; + } + + /** @return the base address the mock web server will be listening to when started. */ + public String getServerBaseAddress() { + return String.format("%s://localhost:%d", useHttps() ? "https" : "http", mPort); + } + + /** + * This method is equivalent to {@link MockWebServer#getUrl(String)} but it can be used before + * you prepare and start the server if you need to prepare responses that will reference the + * same test server. + * + * @return an Uri to use to reach the given {@code @path} on the mock web server. + */ + public Uri uriForPath(String path) { + return Uri.parse( + String.format( + "%s%s%s", getServerBaseAddress(), path.startsWith("/") ? "" : "/", path)); + } + + private void reserveServerListeningPort() throws IOException { + ServerSocket serverSocket = new ServerSocket(38383); + serverSocket.setReuseAddress(true); + mPort = serverSocket.getLocalPort(); + serverSocket.close(); + } + + private SSLSocketFactory getTestingSslSocketFactory() + throws GeneralSecurityException, IOException { + final KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore keyStore = KeyStore.getInstance("PKCS12"); + keyStore.load(mCertificateInputStream, mKeyStorePassword); + keyManagerFactory.init(keyStore, mKeyStorePassword); + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(keyManagerFactory.getKeyManagers(), null, null); + return sslContext.getSocketFactory(); + } + + /** + * A utility that validates that the mock web server got the expected traffic. + * + * @param mockWebServer server instance used for making requests + * @param expectedRequestCount the number of requests expected to be received by the server + * @param expectedRequests the list of URLs that should have been requested, in case of repeat + * requests the size of expectedRequests list could be less than the expectedRequestCount + * @param requestMatcher A custom matcher that dictates if the request meets the criteria of + * being hit or not. This allows tests to do partial match of URLs in case of params or + * other sub path of URL. + */ + public void verifyMockServerRequests( + final MockWebServer mockWebServer, + final int expectedRequestCount, + final List<String> expectedRequests, + final RequestMatcher<String> requestMatcher) { + + assertEquals( + "Number of expected requests does not match actual request count", + expectedRequestCount, + mockWebServer.getRequestCount()); + + // For parallel executions requests should be checked agnostic of order + final Set<String> actualRequests = new HashSet<>(); + for (int i = 0; i < expectedRequestCount; i++) { + try { + actualRequests.add(mockWebServer.takeRequest().getPath()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + assertFalse( + String.format( + "Expected requests cannot be empty, actual requests <%s>", actualRequests), + expectedRequestCount != 0 && expectedRequests.isEmpty()); + + for (String request : expectedRequests) { + Assert.assertTrue( + String.format( + "Actual requests <%s> do not contain request <%s>", + actualRequests, request), + wasPathRequested(actualRequests, request, requestMatcher)); + } + } + + private boolean wasPathRequested( + final Set<String> actualRequests, + final String request, + final RequestMatcher<String> requestMatcher) { + for (String actualRequest : actualRequests) { + // Passing a custom comparator allows tests to do partial match of URLs in case of + // params or other sub path of URL + if (requestMatcher.wasRequestMade(actualRequest, request)) { + return true; + } + } + return false; + } + + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + reserveServerListeningPort(); + try { + base.evaluate(); + } finally { + if (mMockWebServer != null) { + mMockWebServer.shutdown(); + } + } + } + }; + } + + public interface RequestMatcher<T> { + boolean wasRequestMade(T actualRequest, T expectedRequest); + } +} |