summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-11-04 00:35:56 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-11-04 00:35:56 +0000
commit1f309d8bad45209ee43fd3a6600d10461eabf71e (patch)
tree3feb7e66dff1c99fbccbc8055f526ee38bd68259
parenta54582e7acc3eb5ef2bab7dee0c3a81009533a72 (diff)
parente74812d8a96c8b9a0aec84ad914783de24751315 (diff)
downloadnet-1f309d8bad45209ee43fd3a6600d10461eabf71e.tar.gz
Snap for 9254005 from e74812d8a96c8b9a0aec84ad914783de24751315 to mainline-ipsec-releaseaml_ips_331910010aml_ips_331312000aml_ips_331310000android13-mainline-ipsec-release
Change-Id: I335eafe4388651891bdc606939ffa9ab219049b4
-rw-r--r--common/Android.bp5
-rw-r--r--common/device/com/android/net/module/util/BpfDump.java20
-rw-r--r--common/framework/com/android/net/module/util/BitUtils.java110
-rw-r--r--common/framework/com/android/net/module/util/ByteUtils.java67
-rw-r--r--common/framework/com/android/net/module/util/DnsPacket.java412
-rw-r--r--common/framework/com/android/net/module/util/DnsPacketUtils.java59
-rw-r--r--common/framework/com/android/net/module/util/IpUtils.java10
-rw-r--r--common/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java48
-rw-r--r--common/framework/com/android/net/module/util/NetworkStackConstants.java8
-rw-r--r--common/native/bpf_headers/include/bpf/BpfUtils.h15
-rw-r--r--common/native/bpf_headers/include/bpf/bpf_helpers.h26
-rw-r--r--common/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt61
-rw-r--r--common/tests/unit/src/com/android/net/module/util/BpfDumpTest.java42
-rw-r--r--common/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt54
-rw-r--r--common/tests/unit/src/com/android/net/module/util/DnsPacketTest.java273
-rw-r--r--common/tests/unit/src/com/android/net/module/util/DnsPacketUtilsTest.java37
-rw-r--r--common/tests/unit/src/com/android/net/module/util/IpUtilsTest.java49
-rw-r--r--common/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt20
-rw-r--r--common/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt3
-rw-r--r--common/tests/unit/src/com/android/net/module/util/netlink/NetlinkSocketTest.java4
-rw-r--r--common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt477
-rw-r--r--common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java76
-rw-r--r--common/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt9
-rw-r--r--common/testutils/devicetests/com/android/testutils/ConnectUtil.kt3
-rw-r--r--common/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java3
-rw-r--r--common/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt4
-rw-r--r--common/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt284
27 files changed, 1985 insertions, 194 deletions
diff --git a/common/Android.bp b/common/Android.bp
index 4dfc752b..f4997a16 100644
--- a/common/Android.bp
+++ b/common/Android.bp
@@ -96,7 +96,10 @@ java_library {
visibility: [
"//packages/services/Iwlan:__subpackages__",
],
- libs: ["framework-annotations-lib"],
+ libs: [
+ "framework-annotations-lib",
+ "framework-connectivity.stubs.module_lib",
+ ],
}
filegroup {
diff --git a/common/device/com/android/net/module/util/BpfDump.java b/common/device/com/android/net/module/util/BpfDump.java
index 2534d616..7549e712 100644
--- a/common/device/com/android/net/module/util/BpfDump.java
+++ b/common/device/com/android/net/module/util/BpfDump.java
@@ -15,6 +15,8 @@
*/
package com.android.net.module.util;
+import static android.system.OsConstants.R_OK;
+
import android.system.ErrnoException;
import android.system.Os;
import android.util.Base64;
@@ -110,6 +112,24 @@ public class BpfDump {
}
}
+ /**
+ * Dump the BpfMap status
+ */
+ public static <K extends Struct, V extends Struct> void dumpMapStatus(IBpfMap<K, V> map,
+ PrintWriter pw, String mapName, String path) {
+ if (map != null) {
+ pw.println(mapName + ": OK");
+ return;
+ }
+ try {
+ Os.access(path, R_OK);
+ pw.println(mapName + ": NULL(map is pinned to " + path + ")");
+ } catch (ErrnoException e) {
+ pw.println(mapName + ": NULL(map is not pinned to " + path + ": "
+ + Os.strerror(e.errno) + ")");
+ }
+ }
+
// TODO: add a helper to dump bpf map content with the map name, the header line
// (ex: "BPF ingress map: iif nat64Prefix v6Addr -> v4Addr oif"), a lambda that
// knows how to dump each line, and the PrintWriter.
diff --git a/common/framework/com/android/net/module/util/BitUtils.java b/common/framework/com/android/net/module/util/BitUtils.java
new file mode 100644
index 00000000..2b32e860
--- /dev/null
+++ b/common/framework/com/android/net/module/util/BitUtils.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util;
+
+import android.annotation.NonNull;
+
+/**
+ * @hide
+ */
+public class BitUtils {
+ /**
+ * Unpacks long value into an array of bits.
+ */
+ public static int[] unpackBits(long val) {
+ int size = Long.bitCount(val);
+ int[] result = new int[size];
+ int index = 0;
+ int bitPos = 0;
+ while (val != 0) {
+ if ((val & 1) == 1) result[index++] = bitPos;
+ val = val >>> 1;
+ bitPos++;
+ }
+ return result;
+ }
+
+ /**
+ * Packs a list of ints in the same way as packBits()
+ *
+ * Each passed int is the rank of a bit that should be set in the returned long.
+ * Example : passing (1,3) will return in 0b00001010 and passing (5,6,0) will return 0b01100001
+ *
+ * @param bits bits to pack
+ * @return a long with the specified bits set.
+ */
+ public static long packBitList(int... bits) {
+ return packBits(bits);
+ }
+
+ /**
+ * Packs array of bits into a long value.
+ *
+ * Each passed int is the rank of a bit that should be set in the returned long.
+ * Example : passing [1,3] will return in 0b00001010 and passing [5,6,0] will return 0b01100001
+ *
+ * @param bits bits to pack
+ * @return a long with the specified bits set.
+ */
+ public static long packBits(int[] bits) {
+ long packed = 0;
+ for (int b : bits) {
+ packed |= (1L << b);
+ }
+ return packed;
+ }
+
+ /**
+ * An interface for a function that can retrieve a name associated with an int.
+ *
+ * This is useful for bitfields like network capabilities or network score policies.
+ */
+ @FunctionalInterface
+ public interface NameOf {
+ /** Retrieve the name associated with the passed value */
+ String nameOf(int value);
+ }
+
+ /**
+ * Given a bitmask and a name fetcher, append names of all set bits to the builder
+ *
+ * This method takes all bit sets in the passed bitmask, will figure out the name associated
+ * with the weight of each bit with the passed name fetcher, and append each name to the
+ * passed StringBuilder, separated by the passed separator.
+ *
+ * For example, if the bitmask is 0110, and the name fetcher return "BIT_1" to "BIT_4" for
+ * numbers from 1 to 4, and the separator is "&", this method appends "BIT_2&BIT3" to the
+ * StringBuilder.
+ */
+ public static void appendStringRepresentationOfBitMaskToStringBuilder(@NonNull StringBuilder sb,
+ long bitMask, @NonNull NameOf nameFetcher, @NonNull String separator) {
+ int bitPos = 0;
+ boolean firstElementAdded = false;
+ while (bitMask != 0) {
+ if ((bitMask & 1) != 0) {
+ if (firstElementAdded) {
+ sb.append(separator);
+ } else {
+ firstElementAdded = true;
+ }
+ sb.append(nameFetcher.nameOf(bitPos));
+ }
+ bitMask >>>= 1;
+ ++bitPos;
+ }
+ }
+}
diff --git a/common/framework/com/android/net/module/util/ByteUtils.java b/common/framework/com/android/net/module/util/ByteUtils.java
new file mode 100644
index 00000000..290ed465
--- /dev/null
+++ b/common/framework/com/android/net/module/util/ByteUtils.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util;
+
+import android.annotation.NonNull;
+
+/**
+ * Byte utility functions.
+ * @hide
+ */
+public class ByteUtils {
+ /**
+ * Returns the index of the first appearance of the value {@code target} in {@code array}.
+ *
+ * @param array an array of {@code byte} values, possibly empty
+ * @param target a primitive {@code byte} value
+ * @return the least index {@code i} for which {@code array[i] == target}, or {@code -1} if no
+ * such index exists.
+ */
+ public static int indexOf(@NonNull byte[] array, byte target) {
+ return indexOf(array, target, 0, array.length);
+ }
+
+ private static int indexOf(byte[] array, byte target, int start, int end) {
+ for (int i = start; i < end; i++) {
+ if (array[i] == target) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ /**
+ * Returns the values from each provided array combined into a single array. For example, {@code
+ * concat(new byte[] {a, b}, new byte[] {}, new byte[] {c}} returns the array {@code {a, b, c}}.
+ *
+ * @param arrays zero or more {@code byte} arrays
+ * @return a single array containing all the values from the source arrays, in order
+ */
+ public static byte[] concat(@NonNull byte[]... arrays) {
+ int length = 0;
+ for (byte[] array : arrays) {
+ length += array.length;
+ }
+ byte[] result = new byte[length];
+ int pos = 0;
+ for (byte[] array : arrays) {
+ System.arraycopy(array, 0, result, pos, array.length);
+ pos += array.length;
+ }
+ return result;
+ }
+}
diff --git a/common/framework/com/android/net/module/util/DnsPacket.java b/common/framework/com/android/net/module/util/DnsPacket.java
index 080781c1..702d114e 100644
--- a/common/framework/com/android/net/module/util/DnsPacket.java
+++ b/common/framework/com/android/net/module/util/DnsPacket.java
@@ -16,15 +16,34 @@
package com.android.net.module.util;
+import static android.net.DnsResolver.TYPE_A;
+import static android.net.DnsResolver.TYPE_AAAA;
+
+import static com.android.internal.annotations.VisibleForTesting.Visibility.PACKAGE;
+import static com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE;
+import static com.android.net.module.util.DnsPacketUtils.DnsRecordParser.domainNameToLabels;
+
+import android.annotation.IntDef;
import android.annotation.NonNull;
import android.annotation.Nullable;
+import android.text.TextUtils;
+import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.DnsPacketUtils.DnsRecordParser;
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.net.InetAddress;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* Defines basic data for DNS protocol based on RFC 1035.
@@ -34,6 +53,12 @@ import java.util.List;
*/
public abstract class DnsPacket {
/**
+ * Type of the canonical name for an alias. Refer to RFC 1035 section 3.2.2.
+ */
+ // TODO: Define the constant as a public constant in DnsResolver since it can never change.
+ private static final int TYPE_CNAME = 5;
+
+ /**
* Thrown when parsing packet failed.
*/
public static class ParseException extends RuntimeException {
@@ -50,13 +75,29 @@ public abstract class DnsPacket {
}
/**
- * DNS header for DNS protocol based on RFC 1035.
+ * DNS header for DNS protocol based on RFC 1035 section 4.1.1.
+ *
+ * 1 1 1 1 1 1
+ * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | ID |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | QDCOUNT |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | ANCOUNT |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | NSCOUNT |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | ARCOUNT |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
*/
- public class DnsHeader {
+ public static class DnsHeader {
private static final String TAG = "DnsHeader";
- public final int id;
- public final int flags;
- public final int rcode;
+ private static final int SIZE_IN_BYTES = 12;
+ private final int mId;
+ private final int mFlags;
private final int[] mRecordCount;
/* If this bit in the 'flags' field is set to 0, the DNS message corresponding to this
@@ -73,10 +114,11 @@ public abstract class DnsPacket {
* advanced to the end of the DNS header record.
* This is meant to chain with other methods reading a DNS response in sequence.
*/
- DnsHeader(@NonNull ByteBuffer buf) throws BufferUnderflowException {
- id = Short.toUnsignedInt(buf.getShort());
- flags = Short.toUnsignedInt(buf.getShort());
- rcode = flags & 0xF;
+ @VisibleForTesting
+ public DnsHeader(@NonNull ByteBuffer buf) throws BufferUnderflowException {
+ Objects.requireNonNull(buf);
+ mId = Short.toUnsignedInt(buf.getShort());
+ mFlags = Short.toUnsignedInt(buf.getShort());
mRecordCount = new int[NUM_SECTIONS];
for (int i = 0; i < NUM_SECTIONS; ++i) {
mRecordCount[i] = Short.toUnsignedInt(buf.getShort());
@@ -88,7 +130,23 @@ public abstract class DnsPacket {
* RFC 1035 Section 4.1.1.
*/
public boolean isResponse() {
- return (flags & (1 << FLAGS_SECTION_QR_BIT)) != 0;
+ return (mFlags & (1 << FLAGS_SECTION_QR_BIT)) != 0;
+ }
+
+ /**
+ * Create a new DnsHeader from specified parameters.
+ *
+ * This constructor only builds the question and answer sections. Authority
+ * and additional sections are not supported. Useful when synthesizing dns
+ * responses from query or reply packets.
+ */
+ @VisibleForTesting
+ public DnsHeader(int id, int flags, int qdcount, int ancount) {
+ this.mId = id;
+ this.mFlags = flags;
+ mRecordCount = new int[NUM_SECTIONS];
+ mRecordCount[QDSECTION] = qdcount;
+ mRecordCount[ANSECTION] = ancount;
}
/**
@@ -97,15 +155,103 @@ public abstract class DnsPacket {
public int getRecordCount(int type) {
return mRecordCount[type];
}
+
+ /**
+ * Get flags of this instance.
+ */
+ public int getFlags() {
+ return mFlags;
+ }
+
+ /**
+ * Get id of this instance.
+ */
+ public int getId() {
+ return mId;
+ }
+
+ @Override
+ public String toString() {
+ return "DnsHeader{" + "id=" + mId + ", flags=" + mFlags
+ + ", recordCounts=" + Arrays.toString(mRecordCount) + '}';
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o.getClass() != getClass()) return false;
+ final DnsHeader other = (DnsHeader) o;
+ return mId == other.mId
+ && mFlags == other.mFlags
+ && Arrays.equals(mRecordCount, other.mRecordCount);
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * mId + 37 * mFlags + Arrays.hashCode(mRecordCount);
+ }
+
+ /**
+ * Get DnsHeader as byte array.
+ */
+ @NonNull
+ public byte[] getBytes() {
+ // TODO: if this is called often, optimize the ByteBuffer out and write to the
+ // array directly.
+ final ByteBuffer buf = ByteBuffer.allocate(SIZE_IN_BYTES);
+ buf.putShort((short) mId);
+ buf.putShort((short) mFlags);
+ for (int i = 0; i < NUM_SECTIONS; ++i) {
+ buf.putShort((short) mRecordCount[i]);
+ }
+ return buf.array();
+ }
}
/**
* Superclass for DNS questions and DNS resource records.
*
- * DNS questions (No TTL/RDATA)
- * DNS resource records (With TTL/RDATA)
+ * DNS questions (No TTL/RDLENGTH/RDATA) based on RFC 1035 section 4.1.2.
+ * 1 1 1 1 1 1
+ * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | |
+ * / QNAME /
+ * / /
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | QTYPE |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | QCLASS |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ *
+ * DNS resource records (With TTL/RDLENGTH/RDATA) based on RFC 1035 section 4.1.3.
+ * 1 1 1 1 1 1
+ * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | |
+ * / /
+ * / NAME /
+ * | |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | TYPE |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | CLASS |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | TTL |
+ * | |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+ * | RDLENGTH |
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
+ * / RDATA /
+ * / /
+ * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
*/
- public class DnsRecord {
+ // TODO: Make DnsResourceRecord and DnsQuestion subclasses of DnsRecord, and construct
+ // corresponding object from factory methods.
+ public static class DnsRecord {
+ // Refer to RFC 1035 section 2.3.4 for MAXNAMESIZE.
+ // NAME_NORMAL and NAME_COMPRESSION are used for checking name compression,
+ // refer to rfc 1035 section 4.1.4.
private static final int MAXNAMESIZE = 255;
public static final int NAME_NORMAL = 0;
public static final int NAME_COMPRESSION = 0xC0;
@@ -117,22 +263,31 @@ public abstract class DnsPacket {
public final int nsClass;
public final long ttl;
private final byte[] mRdata;
+ /**
+ * Type of this DNS record.
+ */
+ @RecordType
+ public final int rType;
/**
* Create a new DnsRecord from a positioned ByteBuffer.
*
* Reads the passed ByteBuffer from its current position and decodes a DNS record.
* When this constructor returns, the reading position of the ByteBuffer has been
- * advanced to the end of the DNS header record.
+ * advanced to the end of the DNS resource record.
* This is meant to chain with other methods reading a DNS response in sequence.
*
+ * @param rType Type of the record.
* @param buf ByteBuffer input of record, must be in network byte order
* (which is the default).
*/
- DnsRecord(int recordType, @NonNull ByteBuffer buf)
+ @VisibleForTesting(visibility = PACKAGE)
+ public DnsRecord(@RecordType int rType, @NonNull ByteBuffer buf)
throws BufferUnderflowException, ParseException {
+ Objects.requireNonNull(buf);
+ this.rType = rType;
dName = DnsRecordParser.parseName(buf, 0 /* Parse depth */,
- /* isNameCompressionSupported= */ true);
+ true /* isNameCompressionSupported */);
if (dName.length() > MAXNAMESIZE) {
throw new ParseException(
"Parse name fail, name size is too long: " + dName.length());
@@ -140,7 +295,7 @@ public abstract class DnsPacket {
nsType = Short.toUnsignedInt(buf.getShort());
nsClass = Short.toUnsignedInt(buf.getShort());
- if (recordType != QDSECTION) {
+ if (rType != QDSECTION) {
ttl = Integer.toUnsignedLong(buf.getInt());
final int length = Short.toUnsignedInt(buf.getShort());
mRdata = new byte[length];
@@ -152,6 +307,95 @@ public abstract class DnsPacket {
}
/**
+ * Make an A or AAAA record based on the specified parameters.
+ *
+ * @param rType Type of the record, can be {@link #ANSECTION}, {@link #ARSECTION}
+ * or {@link #NSSECTION}.
+ * @param dName Domain name of the record.
+ * @param nsClass Class of the record. See RFC 1035 section 3.2.4.
+ * @param ttl time interval (in seconds) that the resource record may be
+ * cached before it should be discarded. Zero values are
+ * interpreted to mean that the RR can only be used for the
+ * transaction in progress, and should not be cached.
+ * @param address Instance of {@link InetAddress}
+ * @return A record if the {@code address} is an IPv4 address, or AAAA record if the
+ * {@code address} is an IPv6 address.
+ */
+ public static DnsRecord makeAOrAAAARecord(int rType, @NonNull String dName,
+ int nsClass, long ttl, @NonNull InetAddress address) throws IOException {
+ final int nsType = (address.getAddress().length == 4) ? TYPE_A : TYPE_AAAA;
+ return new DnsRecord(rType, dName, nsType, nsClass, ttl, address, null /* rDataStr */);
+ }
+
+ /**
+ * Make an CNAME record based on the specified parameters.
+ *
+ * @param rType Type of the record, can be {@link #ANSECTION}, {@link #ARSECTION}
+ * or {@link #NSSECTION}.
+ * @param dName Domain name of the record.
+ * @param nsClass Class of the record. See RFC 1035 section 3.2.4.
+ * @param ttl time interval (in seconds) that the resource record may be
+ * cached before it should be discarded. Zero values are
+ * interpreted to mean that the RR can only be used for the
+ * transaction in progress, and should not be cached.
+ * @param domainName Canonical name of the {@code dName}.
+ * @return A record if the {@code address} is an IPv4 address, or AAAA record if the
+ * {@code address} is an IPv6 address.
+ */
+ public static DnsRecord makeCNameRecord(int rType, @NonNull String dName, int nsClass,
+ long ttl, @NonNull String domainName) throws IOException {
+ return new DnsRecord(rType, dName, TYPE_CNAME, nsClass, ttl, null /* address */,
+ domainName);
+ }
+
+ /**
+ * Make a DNS question based on the specified parameters.
+ */
+ public static DnsRecord makeQuestion(@NonNull String dName, int nsType, int nsClass) {
+ return new DnsRecord(dName, nsType, nsClass);
+ }
+
+ private static String requireHostName(@NonNull String name) {
+ if (!DnsRecordParser.isHostName(name)) {
+ throw new IllegalArgumentException("Expected domain name but got " + name);
+ }
+ return name;
+ }
+
+ /**
+ * Create a new query DnsRecord from specified parameters, useful when synthesizing
+ * dns response.
+ */
+ private DnsRecord(@NonNull String dName, int nsType, int nsClass) {
+ this.rType = QDSECTION;
+ this.dName = requireHostName(dName);
+ this.nsType = nsType;
+ this.nsClass = nsClass;
+ mRdata = null;
+ this.ttl = 0;
+ }
+
+ /**
+ * Create a new CNAME/A/AAAA DnsRecord from specified parameters.
+ *
+ * @param address The address only used when synthesizing A or AAAA record.
+ * @param rDataStr The alias of the domain, only used when synthesizing CNAME record.
+ */
+ private DnsRecord(@RecordType int rType, @NonNull String dName, int nsType, int nsClass,
+ long ttl, @Nullable InetAddress address, @Nullable String rDataStr)
+ throws IOException {
+ this.rType = rType;
+ this.dName = requireHostName(dName);
+ this.nsType = nsType;
+ this.nsClass = nsClass;
+ if (rType < 0 || rType >= NUM_SECTIONS || rType == QDSECTION) {
+ throw new IllegalArgumentException("Unexpected record type: " + rType);
+ }
+ mRdata = nsType == TYPE_CNAME ? domainNameToLabels(rDataStr) : address.getAddress();
+ this.ttl = ttl;
+ }
+
+ /**
* Get a copy of rdata.
*/
@Nullable
@@ -159,13 +403,84 @@ public abstract class DnsPacket {
return (mRdata == null) ? null : mRdata.clone();
}
+ /**
+ * Get DnsRecord as byte array.
+ */
+ @NonNull
+ public byte[] getBytes() throws IOException {
+ final ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ final DataOutputStream dos = new DataOutputStream(baos);
+ dos.write(domainNameToLabels(dName));
+ dos.writeShort(nsType);
+ dos.writeShort(nsClass);
+ if (rType != QDSECTION) {
+ dos.writeInt((int) ttl);
+ if (mRdata == null) {
+ dos.writeShort(0);
+ } else {
+ dos.writeShort(mRdata.length);
+ dos.write(mRdata);
+ }
+ }
+ return baos.toByteArray();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o.getClass() != getClass()) return false;
+ final DnsRecord other = (DnsRecord) o;
+ return rType == other.rType
+ && nsType == other.nsType
+ && nsClass == other.nsClass
+ && ttl == other.ttl
+ && TextUtils.equals(dName, other.dName)
+ && Arrays.equals(mRdata, other.mRdata);
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * Objects.hash(dName)
+ + 37 * ((int) (ttl & 0xFFFFFFFF))
+ + 41 * ((int) (ttl >> 32))
+ + 43 * nsType
+ + 47 * nsClass
+ + 53 * rType
+ + Arrays.hashCode(mRdata);
+ }
+
+ @Override
+ public String toString() {
+ return "DnsRecord{"
+ + "rType=" + rType
+ + ", dName='" + dName + '\''
+ + ", nsType=" + nsType
+ + ", nsClass=" + nsClass
+ + ", ttl=" + ttl
+ + ", mRdata=" + Arrays.toString(mRdata)
+ + '}';
+ }
}
+ /**
+ * Header section types, refer to RFC 1035 section 4.1.1.
+ */
public static final int QDSECTION = 0;
public static final int ANSECTION = 1;
public static final int NSSECTION = 2;
public static final int ARSECTION = 3;
- private static final int NUM_SECTIONS = ARSECTION + 1;
+ @VisibleForTesting(visibility = PRIVATE)
+ static final int NUM_SECTIONS = ARSECTION + 1;
+
+ @Retention(RetentionPolicy.SOURCE)
+ @IntDef(value = {
+ QDSECTION,
+ ANSECTION,
+ NSSECTION,
+ ARSECTION,
+ })
+ public @interface RecordType {}
+
private static final String TAG = DnsPacket.class.getSimpleName();
@@ -189,9 +504,7 @@ public abstract class DnsPacket {
for (int i = 0; i < NUM_SECTIONS; ++i) {
final int count = mHeader.getRecordCount(i);
- if (count > 0) {
- mRecords[i] = new ArrayList(count);
- }
+ mRecords[i] = new ArrayList(count);
for (int j = 0; j < count; ++j) {
try {
mRecords[i].add(new DnsRecord(i, buffer));
@@ -201,4 +514,61 @@ public abstract class DnsPacket {
}
}
}
+
+ /**
+ * Create a new {@link #DnsPacket} from specified parameters.
+ *
+ * Note that authority records section and additional records section is not supported.
+ */
+ protected DnsPacket(@NonNull DnsHeader header, @NonNull List<DnsRecord> qd,
+ @NonNull List<DnsRecord> an) {
+ mHeader = Objects.requireNonNull(header);
+ mRecords = new List[NUM_SECTIONS];
+ mRecords[QDSECTION] = Collections.unmodifiableList(new ArrayList<>(qd));
+ mRecords[ANSECTION] = Collections.unmodifiableList(new ArrayList<>(an));
+ mRecords[NSSECTION] = new ArrayList<>();
+ mRecords[ARSECTION] = new ArrayList<>();
+ for (int i = 0; i < NUM_SECTIONS; i++) {
+ if (mHeader.mRecordCount[i] != mRecords[i].size()) {
+ throw new IllegalArgumentException("Record count mismatch: expected "
+ + mHeader.mRecordCount[i] + " but was " + mRecords[i]);
+ }
+ }
+ }
+
+ /**
+ * Get DnsPacket as byte array.
+ */
+ public @NonNull byte[] getBytes() throws IOException {
+ final ByteArrayOutputStream buf = new ByteArrayOutputStream();
+ buf.write(mHeader.getBytes());
+
+ for (int i = 0; i < NUM_SECTIONS; ++i) {
+ for (final DnsRecord record : mRecords[i]) {
+ buf.write(record.getBytes());
+ }
+ }
+ return buf.toByteArray();
+ }
+
+ @Override
+ public String toString() {
+ return "DnsPacket{" + "header=" + mHeader + ", records='" + Arrays.toString(mRecords) + '}';
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o.getClass() != getClass()) return false;
+ final DnsPacket other = (DnsPacket) o;
+ return Objects.equals(mHeader, other.mHeader)
+ && Arrays.deepEquals(mRecords, other.mRecords);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Objects.hash(mHeader);
+ result = 31 * result + Arrays.hashCode(mRecords);
+ return result;
+ }
}
diff --git a/common/framework/com/android/net/module/util/DnsPacketUtils.java b/common/framework/com/android/net/module/util/DnsPacketUtils.java
index 3448cadf..c47bfa07 100644
--- a/common/framework/com/android/net/module/util/DnsPacketUtils.java
+++ b/common/framework/com/android/net/module/util/DnsPacketUtils.java
@@ -20,10 +20,19 @@ import static com.android.net.module.util.DnsPacket.DnsRecord.NAME_COMPRESSION;
import static com.android.net.module.util.DnsPacket.DnsRecord.NAME_NORMAL;
import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.InetAddresses;
+import android.net.ParseException;
import android.text.TextUtils;
+import android.util.Patterns;
+import com.android.internal.annotations.VisibleForTesting;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
import java.text.DecimalFormat;
import java.text.FieldPosition;
@@ -38,6 +47,7 @@ public final class DnsPacketUtils {
*/
public static class DnsRecordParser {
private static final int MAXLABELSIZE = 63;
+ private static final int MAXNAMESIZE = 255;
private static final int MAXLABELCOUNT = 128;
private static final DecimalFormat sByteFormat = new DecimalFormat();
@@ -48,7 +58,8 @@ public final class DnsPacketUtils {
*
* <p>Follows the same conversion rules of the native code (ns_name.c in libc).
*/
- private static String labelToString(@NonNull byte[] label) {
+ @VisibleForTesting
+ static String labelToString(@NonNull byte[] label) {
final StringBuffer sb = new StringBuffer();
for (int i = 0; i < label.length; ++i) {
@@ -72,6 +83,52 @@ public final class DnsPacketUtils {
}
/**
+ * Converts domain name to labels according to RFC 1035.
+ *
+ * @param name Domain name as String that needs to be converted to labels.
+ * @return An encoded byte array that is constructed out of labels,
+ * and ends with zero-length label.
+ * @throws ParseException if failed to parse the given domain name or
+ * IOException if failed to output labels.
+ */
+ public static @NonNull byte[] domainNameToLabels(@NonNull String name) throws
+ IOException, ParseException {
+ if (name.length() > MAXNAMESIZE) {
+ throw new ParseException("Domain name exceeds max length: " + name.length());
+ }
+ if (!isHostName(name)) {
+ throw new ParseException("Failed to parse domain name: " + name);
+ }
+ final ByteArrayOutputStream buf = new ByteArrayOutputStream();
+ final String[] labels = name.split("\\.");
+ for (final String label : labels) {
+ if (label.length() > MAXLABELSIZE) {
+ throw new ParseException("label is too long: " + label);
+ }
+ buf.write(label.length());
+ // Encode as UTF-8 as suggested in RFC 6055 section 3.
+ buf.write(label.getBytes(StandardCharsets.UTF_8));
+ }
+ buf.write(0x00); // end with zero-length label
+ return buf.toByteArray();
+ }
+
+ /**
+ * Check whether the input is a valid hostname based on rfc 1035 section 3.3.
+ *
+ * @param hostName the target host name.
+ * @return true if the input is a valid hostname.
+ */
+ public static boolean isHostName(@Nullable String hostName) {
+ // TODO: Use {@code Patterns.HOST_NAME} if available.
+ // Patterns.DOMAIN_NAME accepts host names or IP addresses, so reject
+ // IP addresses.
+ return hostName != null
+ && Patterns.DOMAIN_NAME.matcher(hostName).matches()
+ && !InetAddresses.isNumericAddress(hostName);
+ }
+
+ /**
* Parses the domain / target name of a DNS record.
*
* As described in RFC 1035 Section 4.1.3, the NAME field of a DNS Resource Record always
diff --git a/common/framework/com/android/net/module/util/IpUtils.java b/common/framework/com/android/net/module/util/IpUtils.java
index dbc22854..569733ed 100644
--- a/common/framework/com/android/net/module/util/IpUtils.java
+++ b/common/framework/com/android/net/module/util/IpUtils.java
@@ -39,7 +39,7 @@ public class IpUtils {
/**
* Performs an IP checksum (used in IP header and across UDP
- * payload) on the specified portion of a ByteBuffer. The seed
+ * payload) or ICMP checksum on the specified portion of a ByteBuffer. The seed
* allows the checksum to commence with a specified value.
*/
private static int checksum(ByteBuffer buf, int seed, int start, int end) {
@@ -145,6 +145,14 @@ public class IpUtils {
}
/**
+ * Calculate the ICMP checksum for an ICMPv4 packet.
+ */
+ public static short icmpChecksum(ByteBuffer buf, int transportOffset, int transportLen) {
+ // ICMP checksum doesn't include pseudo-header. See RFC 792.
+ return (short) checksum(buf, 0, transportOffset, transportOffset + transportLen);
+ }
+
+ /**
* Calculate the ICMPv6 checksum for an ICMPv6 packet.
*/
public static short icmpv6Checksum(ByteBuffer buf, int ipOffset, int transportOffset,
diff --git a/common/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java b/common/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java
index 26c24f8d..54ce01ec 100644
--- a/common/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java
+++ b/common/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java
@@ -44,6 +44,9 @@ import static android.net.NetworkCapabilities.TRANSPORT_VPN;
import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
import static android.net.NetworkCapabilities.TRANSPORT_WIFI_AWARE;
+import static com.android.net.module.util.BitUtils.packBitList;
+import static com.android.net.module.util.BitUtils.unpackBits;
+
import android.annotation.NonNull;
import android.net.NetworkCapabilities;
@@ -181,49 +184,4 @@ public final class NetworkCapabilitiesUtils {
return false;
}
- /**
- * Unpacks long value into an array of bits.
- */
- public static int[] unpackBits(long val) {
- int size = Long.bitCount(val);
- int[] result = new int[size];
- int index = 0;
- int bitPos = 0;
- while (val != 0) {
- if ((val & 1) == 1) result[index++] = bitPos;
- val = val >>> 1;
- bitPos++;
- }
- return result;
- }
-
- /**
- * Packs a list of ints in the same way as packBits()
- *
- * Each passed int is the rank of a bit that should be set in the returned long.
- * Example : passing (1,3) will return in 0b00001010 and passing (5,6,0) will return 0b01100001
- *
- * @param bits bits to pack
- * @return a long with the specified bits set.
- */
- public static long packBitList(int... bits) {
- return packBits(bits);
- }
-
- /**
- * Packs array of bits into a long value.
- *
- * Each passed int is the rank of a bit that should be set in the returned long.
- * Example : passing [1,3] will return in 0b00001010 and passing [5,6,0] will return 0b01100001
- *
- * @param bits bits to pack
- * @return a long with the specified bits set.
- */
- public static long packBits(int[] bits) {
- long packed = 0;
- for (int b : bits) {
- packed |= (1L << b);
- }
- return packed;
- }
}
diff --git a/common/framework/com/android/net/module/util/NetworkStackConstants.java b/common/framework/com/android/net/module/util/NetworkStackConstants.java
index 1ad9c359..ebef1557 100644
--- a/common/framework/com/android/net/module/util/NetworkStackConstants.java
+++ b/common/framework/com/android/net/module/util/NetworkStackConstants.java
@@ -127,6 +127,14 @@ public final class NetworkStackConstants {
(Inet6Address) InetAddresses.parseNumericAddress("ff02::3");
/**
+ * ICMP constants.
+ *
+ * See also:
+ * - https://tools.ietf.org/html/rfc792
+ */
+ public static final int ICMP_CHECKSUM_OFFSET = 2;
+
+ /**
* ICMPv6 constants.
*
* See also:
diff --git a/common/native/bpf_headers/include/bpf/BpfUtils.h b/common/native/bpf_headers/include/bpf/BpfUtils.h
index 4429164e..157f210c 100644
--- a/common/native/bpf_headers/include/bpf/BpfUtils.h
+++ b/common/native/bpf_headers/include/bpf/BpfUtils.h
@@ -115,23 +115,16 @@ static inline bool isAtLeastKernelVersion(unsigned major, unsigned minor, unsign
return kernelVersion() >= KVER(major, minor, sub);
}
-#define SKIP_IF_BPF_SUPPORTED \
- do { \
- if (android::bpf::isAtLeastKernelVersion(4, 9, 0)) \
- GTEST_SKIP() << "Skip: bpf is supported."; \
- } while (0)
-
#define SKIP_IF_BPF_NOT_SUPPORTED \
do { \
if (!android::bpf::isAtLeastKernelVersion(4, 9, 0)) \
GTEST_SKIP() << "Skip: bpf is not supported."; \
} while (0)
-#define SKIP_IF_EXTENDED_BPF_NOT_SUPPORTED \
- do { \
- if (!android::bpf::isAtLeastKernelVersion(4, 14, 0)) \
- GTEST_SKIP() << "Skip: extended bpf feature not supported."; \
- } while (0)
+// Only used by tm-mainline-prod's system/netd/tests/bpf_base_test.cpp
+// but platform and platform tests aren't expected to build/work in tm-mainline-prod
+// so we can just trivialize this
+#define SKIP_IF_EXTENDED_BPF_NOT_SUPPORTED
#define SKIP_IF_XDP_NOT_SUPPORTED \
do { \
diff --git a/common/native/bpf_headers/include/bpf/bpf_helpers.h b/common/native/bpf_headers/include/bpf/bpf_helpers.h
index 35d9f31d..56bc7d01 100644
--- a/common/native/bpf_headers/include/bpf/bpf_helpers.h
+++ b/common/native/bpf_headers/include/bpf/bpf_helpers.h
@@ -138,6 +138,10 @@ static int (*bpf_map_update_elem_unsafe)(const struct bpf_map_def* map, const vo
static int (*bpf_map_delete_elem_unsafe)(const struct bpf_map_def* map,
const void* key) = (void*)BPF_FUNC_map_delete_elem;
+static int (*bpf_for_each_map_elem)(const struct bpf_map_def* map, void *callback_fn,
+ void *callback_ctx, unsigned long long flags) = (void*)
+ BPF_FUNC_for_each_map_elem;
+
#define BPF_ANNOTATE_KV_PAIR(name, type_key, type_val) \
struct ____btf_map_##name { \
type_key key; \
@@ -147,6 +151,23 @@ static int (*bpf_map_delete_elem_unsafe)(const struct bpf_map_def* map,
__attribute__ ((section(".maps." #name), used)) \
____btf_map_##name = { }
+/* There exist buggy kernels with pre-T OS, that due to
+ * kernel patch "[ALPS05162612] bpf: fix ubsan error"
+ * do not support userspace writes into non-zero index of bpf map arrays.
+ *
+ * We use this assert to prevent us from being able to define such a map.
+ */
+
+#ifdef THIS_BPF_PROGRAM_IS_FOR_TEST_PURPOSES_ONLY
+#define BPF_MAP_ASSERT_OK(type, entries, mode)
+#elif BPFLOADER_MIN_VER >= BPFLOADER_T_BETA3_VERSION
+#define BPF_MAP_ASSERT_OK(type, entries, mode)
+#else
+#define BPF_MAP_ASSERT_OK(type, entries, mode) \
+ _Static_assert(((type) != BPF_MAP_TYPE_ARRAY) || ((entries) <= 1) || !((mode) & 0222), \
+ "Writable arrays with more than 1 element not supported on pre-T devices.")
+#endif
+
/* type safe macro to declare a map and related accessor functions */
#define DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md, \
selinux, pindir, share) \
@@ -167,6 +188,7 @@ static int (*bpf_map_delete_elem_unsafe)(const struct bpf_map_def* map,
.pin_subdir = pindir, \
.shared = share, \
}; \
+ BPF_MAP_ASSERT_OK(BPF_MAP_TYPE_##TYPE, (num_entries), (md)); \
BPF_ANNOTATE_KV_PAIR(the_map, KeyType, ValueType); \
\
static inline __always_inline __unused ValueType* bpf_##the_map##_lookup_elem( \
@@ -205,6 +227,10 @@ static int (*bpf_map_delete_elem_unsafe)(const struct bpf_map_def* map,
DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
DEFAULT_BPF_MAP_UID, AID_ROOT, 0600)
+#define DEFINE_BPF_MAP_RO(the_map, TYPE, KeyType, ValueType, num_entries, gid) \
+ DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
+ DEFAULT_BPF_MAP_UID, gid, 0440)
+
#define DEFINE_BPF_MAP_GWO(the_map, TYPE, KeyType, ValueType, num_entries, gid) \
DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
DEFAULT_BPF_MAP_UID, gid, 0620)
diff --git a/common/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt b/common/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
new file mode 100644
index 00000000..02367162
--- /dev/null
+++ b/common/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util
+
+import com.android.net.module.util.BitUtils.appendStringRepresentationOfBitMaskToStringBuilder
+import com.android.net.module.util.BitUtils.packBits
+import com.android.net.module.util.BitUtils.unpackBits
+import org.junit.Test
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+
+class BitUtilsTests {
+ @Test
+ fun testBitPackingTestCase() {
+ runBitPackingTestCase(0, intArrayOf())
+ runBitPackingTestCase(1, intArrayOf(0))
+ runBitPackingTestCase(3, intArrayOf(0, 1))
+ runBitPackingTestCase(4, intArrayOf(2))
+ runBitPackingTestCase(63, intArrayOf(0, 1, 2, 3, 4, 5))
+ runBitPackingTestCase(Long.MAX_VALUE.inv(), intArrayOf(63))
+ runBitPackingTestCase(Long.MAX_VALUE.inv() + 1, intArrayOf(0, 63))
+ runBitPackingTestCase(Long.MAX_VALUE.inv() + 2, intArrayOf(1, 63))
+ }
+
+ fun runBitPackingTestCase(packedBits: Long, bits: IntArray) {
+ assertEquals(packedBits, packBits(bits))
+ assertTrue(bits contentEquals unpackBits(packedBits))
+ }
+
+ @Test
+ fun testAppendStringRepresentationOfBitMaskToStringBuilder() {
+ runTestAppendStringRepresentationOfBitMaskToStringBuilder("", 0)
+ runTestAppendStringRepresentationOfBitMaskToStringBuilder("BIT0", 0b1)
+ runTestAppendStringRepresentationOfBitMaskToStringBuilder("BIT1&BIT2&BIT4", 0b10110)
+ runTestAppendStringRepresentationOfBitMaskToStringBuilder(
+ "BIT0&BIT60&BIT61&BIT62&BIT63",
+ (0b11110000_00000000_00000000_00000000 shl 32) +
+ 0b00000000_00000000_00000000_00000001)
+ }
+
+ fun runTestAppendStringRepresentationOfBitMaskToStringBuilder(expected: String, bitMask: Long) {
+ StringBuilder().let {
+ appendStringRepresentationOfBitMaskToStringBuilder(it, bitMask, { i -> "BIT$i" }, "&")
+ assertEquals(expected, it.toString())
+ }
+ }
+}
diff --git a/common/tests/unit/src/com/android/net/module/util/BpfDumpTest.java b/common/tests/unit/src/com/android/net/module/util/BpfDumpTest.java
index 39329258..4e483322 100644
--- a/common/tests/unit/src/com/android/net/module/util/BpfDumpTest.java
+++ b/common/tests/unit/src/com/android/net/module/util/BpfDumpTest.java
@@ -16,10 +16,19 @@
package com.android.net.module.util;
+import static android.system.OsConstants.EPERM;
+import static android.system.OsConstants.R_OK;
+
+import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn;
+import static com.android.dx.mockito.inline.extended.ExtendedMockito.doThrow;
+import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
+import android.system.ErrnoException;
+import android.system.Os;
import android.util.Pair;
import androidx.test.filters.SmallTest;
@@ -29,6 +38,7 @@ import com.android.testutils.TestBpfMap;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.mockito.MockitoSession;
import java.io.PrintWriter;
import java.io.StringWriter;
@@ -118,4 +128,36 @@ public class BpfDumpTest {
assertTrue(dump.contains("key=123, val=456"));
assertTrue(dump.contains("key=789, val=123"));
}
+
+ private String getDumpMapStatus(final IBpfMap<Struct.U32, Struct.U32> map) {
+ final StringWriter sw = new StringWriter();
+ BpfDump.dumpMapStatus(map, new PrintWriter(sw), "mapName", "mapPath");
+ return sw.toString();
+ }
+
+ @Test
+ public void testGetMapStatus() {
+ final IBpfMap<Struct.U32, Struct.U32> map =
+ new TestBpfMap<>(Struct.U32.class, Struct.U32.class);
+ assertEquals("mapName: OK\n", getDumpMapStatus(map));
+ }
+
+ @Test
+ public void testGetMapStatusNull() {
+ final MockitoSession session = mockitoSession()
+ .spyStatic(Os.class)
+ .startMocking();
+ try {
+ // Os.access succeeds
+ doReturn(true).when(() -> Os.access("mapPath", R_OK));
+ assertEquals("mapName: NULL(map is pinned to mapPath)\n", getDumpMapStatus(null));
+
+ // Os.access throws EPERM
+ doThrow(new ErrnoException("", EPERM)).when(() -> Os.access("mapPath", R_OK));
+ assertEquals("mapName: NULL(map is not pinned to mapPath: Operation not permitted)\n",
+ getDumpMapStatus(null));
+ } finally {
+ session.finishMocking();
+ }
+ }
}
diff --git a/common/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt b/common/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt
new file mode 100644
index 00000000..e58adad3
--- /dev/null
+++ b/common/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util
+
+import com.android.net.module.util.ByteUtils.indexOf
+import com.android.net.module.util.ByteUtils.concat
+import org.junit.Test
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertNotSame
+
+class ByteUtilsTests {
+ private val EMPTY = byteArrayOf()
+ private val ARRAY1 = byteArrayOf(1)
+ private val ARRAY234 = byteArrayOf(2, 3, 4)
+
+ @Test
+ fun testIndexOf() {
+ assertEquals(-1, indexOf(EMPTY, 1))
+ assertEquals(-1, indexOf(ARRAY1, 2))
+ assertEquals(-1, indexOf(ARRAY234, 1))
+ assertEquals(0, indexOf(byteArrayOf(-1), -1))
+ assertEquals(0, indexOf(ARRAY234, 2))
+ assertEquals(1, indexOf(ARRAY234, 3))
+ assertEquals(2, indexOf(ARRAY234, 4))
+ assertEquals(1, indexOf(byteArrayOf(2, 3, 2, 3), 3))
+ }
+
+ @Test
+ fun testConcat() {
+ assertContentEquals(EMPTY, concat())
+ assertContentEquals(EMPTY, concat(EMPTY))
+ assertContentEquals(EMPTY, concat(EMPTY, EMPTY, EMPTY))
+ assertContentEquals(ARRAY1, concat(ARRAY1))
+ assertNotSame(ARRAY1, concat(ARRAY1))
+ assertContentEquals(ARRAY1, concat(EMPTY, ARRAY1, EMPTY))
+ assertContentEquals(byteArrayOf(1, 1, 1), concat(ARRAY1, ARRAY1, ARRAY1))
+ assertContentEquals(byteArrayOf(1, 2, 3, 4), concat(ARRAY1, ARRAY234))
+ }
+} \ No newline at end of file
diff --git a/common/tests/unit/src/com/android/net/module/util/DnsPacketTest.java b/common/tests/unit/src/com/android/net/module/util/DnsPacketTest.java
index 1bbba088..409c1ebb 100644
--- a/common/tests/unit/src/com/android/net/module/util/DnsPacketTest.java
+++ b/common/tests/unit/src/com/android/net/module/util/DnsPacketTest.java
@@ -16,26 +16,45 @@
package com.android.net.module.util;
+import static android.net.DnsResolver.CLASS_IN;
+import static android.net.DnsResolver.TYPE_A;
+import static android.net.DnsResolver.TYPE_AAAA;
+
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+
import androidx.test.filters.SmallTest;
import androidx.test.runner.AndroidJUnit4;
+import libcore.net.InetAddressUtils;
+
import org.junit.Test;
import org.junit.runner.RunWith;
+import java.io.IOException;
+import java.nio.BufferUnderflowException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@RunWith(AndroidJUnit4.class)
@SmallTest
public class DnsPacketTest {
+ private static final int TEST_DNS_PACKET_ID = 0x7722;
+ private static final int TEST_DNS_PACKET_FLAGS = 0x8180;
+
private void assertHeaderParses(DnsPacket.DnsHeader header, int id, int flag,
int qCount, int aCount, int nsCount, int arCount) {
- assertEquals(header.id, id);
- assertEquals(header.flags, flag);
+ assertEquals(header.getId(), id);
+ assertEquals(header.getFlags(), flag);
assertEquals(header.getRecordCount(DnsPacket.QDSECTION), qCount);
assertEquals(header.getRecordCount(DnsPacket.ANSECTION), aCount);
assertEquals(header.getRecordCount(DnsPacket.NSSECTION), nsCount);
@@ -51,11 +70,16 @@ public class DnsPacketTest {
assertTrue(Arrays.equals(record.getRR(), rr));
}
- class TestDnsPacket extends DnsPacket {
+ static class TestDnsPacket extends DnsPacket {
TestDnsPacket(byte[] data) throws DnsPacket.ParseException {
super(data);
}
+ TestDnsPacket(@NonNull DnsHeader header, @Nullable ArrayList<DnsRecord> qd,
+ @Nullable ArrayList<DnsRecord> an) {
+ super(header, qd, an);
+ }
+
public DnsHeader getHeader() {
return mHeader;
}
@@ -156,4 +180,247 @@ public class DnsPacketTest {
new byte[]{ 0x24, 0x04, 0x68, 0x00, 0x40, 0x05, 0x08, 0x0d,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x04 });
}
+
+ /** Verifies that the synthesized {@link DnsPacket.DnsHeader} can be parsed correctly. */
+ @Test
+ public void testDnsHeaderSynthesize() {
+ final DnsPacket.DnsHeader testHeader = new DnsPacket.DnsHeader(TEST_DNS_PACKET_ID,
+ TEST_DNS_PACKET_FLAGS, 3 /* qcount */, 5 /* ancount */);
+ final DnsPacket.DnsHeader actualHeader = new DnsPacket.DnsHeader(
+ ByteBuffer.wrap(testHeader.getBytes()));
+ assertEquals(testHeader, actualHeader);
+ }
+
+ /** Verifies that the synthesized {@link DnsPacket.DnsRecord} can be parsed correctly. */
+ @Test
+ public void testDnsRecordSynthesize() throws IOException {
+ assertDnsRecordRoundTrip(
+ DnsPacket.DnsRecord.makeAOrAAAARecord(DnsPacket.ANSECTION,
+ "test.com", CLASS_IN, 5 /* ttl */,
+ InetAddressUtils.parseNumericAddress("abcd::fedc")));
+ assertDnsRecordRoundTrip(DnsPacket.DnsRecord.makeQuestion("test.com", TYPE_AAAA, CLASS_IN));
+ assertDnsRecordRoundTrip(DnsPacket.DnsRecord.makeCNameRecord(DnsPacket.ANSECTION,
+ "test.com", CLASS_IN, 0 /* ttl */, "example.com"));
+ }
+
+ /**
+ * Verifies ttl/rData error handling when parsing
+ * {@link DnsPacket.DnsRecord} from bytes.
+ */
+ @Test
+ public void testDnsRecordTTLRDataErrorHandling() throws IOException {
+ // Verify the constructor ignore ttl/rData of questions even if they are supplied.
+ final byte[] qdWithTTLRData = new byte[]{
+ 0x03, 0x77, 0x77, 0x77, 0x06, 0x67, 0x6F, 0x6F, 0x67, 0x6c, 0x65,
+ 0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name */
+ 0x00, 0x00, /* Type */
+ 0x00, 0x01, /* Class */
+ 0x00, 0x00, 0x01, 0x2b, /* TTL */
+ 0x00, 0x04, /* Data length */
+ (byte) 0xac, (byte) 0xd9, (byte) 0xa1, (byte) 0x84 /* Address */};
+ final DnsPacket.DnsRecord questionsFromBytes =
+ new DnsPacket.DnsRecord(DnsPacket.QDSECTION, ByteBuffer.wrap(qdWithTTLRData));
+ assertEquals(0, questionsFromBytes.ttl);
+ assertNull(questionsFromBytes.getRR());
+
+ // Verify ANSECTION must have rData when constructing.
+ final byte[] anWithoutTTLRData = new byte[]{
+ 0x03, 0x77, 0x77, 0x77, 0x06, 0x67, 0x6F, 0x6F, 0x67, 0x6c, 0x65,
+ 0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name */
+ 0x00, 0x01, /* Type */
+ 0x00, 0x01, /* Class */};
+ assertThrows(BufferUnderflowException.class, () ->
+ new DnsPacket.DnsRecord(DnsPacket.ANSECTION, ByteBuffer.wrap(anWithoutTTLRData)));
+ }
+
+ private void assertDnsRecordRoundTrip(DnsPacket.DnsRecord before)
+ throws IOException {
+ final DnsPacket.DnsRecord after = new DnsPacket.DnsRecord(before.rType,
+ ByteBuffer.wrap(before.getBytes()));
+ assertEquals(after, before);
+ }
+
+ /** Verifies that the synthesized {@link DnsPacket} can be parsed correctly. */
+ @Test
+ public void testDnsPacketSynthesize() throws IOException {
+ // Ipv4 dns response packet generated by scapy:
+ // dns_r = scapy.DNS(
+ // id=0xbeef,
+ // qr=1,
+ // qd=scapy.DNSQR(qname="hello.example.com"),
+ // an=scapy.DNSRR(rrname="hello.example.com", type="CNAME", rdata='test.com') /
+ // scapy.DNSRR(rrname="hello.example.com", rdata='1.2.3.4'))
+ // scapy.hexdump(dns_r)
+ // dns_r.show2()
+ // Note that since the synthesizing does not support name compression yet, the domain
+ // name of the sample need to be uncompressed when generating.
+ final byte[] v4BlobUncompressed = new byte[]{
+ /* Header */
+ (byte) 0xbe, (byte) 0xef, /* Transaction ID */
+ (byte) 0x81, 0x00, /* Flags */
+ 0x00, 0x01, /* Questions */
+ 0x00, 0x02, /* Answer RRs */
+ 0x00, 0x00, /* Authority RRs */
+ 0x00, 0x00, /* Additional RRs */
+ /* Queries */
+ 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61,
+ 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name: hello.example.com */
+ 0x00, 0x01, /* Type */
+ 0x00, 0x01, /* Class */
+ /* Answers */
+ 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61,
+ 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name: hello.example.com */
+ 0x00, 0x05, /* Type */
+ 0x00, 0x01, /* Class */
+ 0x00, 0x00, 0x00, 0x00, /* TTL */
+ 0x00, 0x0A, /* Data length */
+ 0x04, 0x74, 0x65, 0x73, 0x74, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Alias: test.com */
+ 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61,
+ 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name: hello.example.com */
+ 0x00, 0x01, /* Type */
+ 0x00, 0x01, /* Class */
+ 0x00, 0x00, 0x00, 0x00, /* TTL */
+ 0x00, 0x04, /* Data length */
+ 0x01, 0x02, 0x03, 0x04, /* Address: 1.2.3.4 */
+ };
+
+ // Forge one via constructors.
+ final DnsPacket.DnsHeader testHeader = new DnsPacket.DnsHeader(0xbeef,
+ 0x8100, 1 /* qcount */, 2 /* ancount */);
+ final ArrayList<DnsPacket.DnsRecord> qlist = new ArrayList<>();
+ final ArrayList<DnsPacket.DnsRecord> alist = new ArrayList<>();
+ qlist.add(DnsPacket.DnsRecord.makeQuestion(
+ "hello.example.com", TYPE_A, CLASS_IN));
+ alist.add(DnsPacket.DnsRecord.makeCNameRecord(
+ DnsPacket.ANSECTION, "hello.example.com", CLASS_IN, 0 /* ttl */, "test.com"));
+ alist.add(DnsPacket.DnsRecord.makeAOrAAAARecord(
+ DnsPacket.ANSECTION, "hello.example.com", CLASS_IN, 0 /* ttl */,
+ InetAddressUtils.parseNumericAddress("1.2.3.4")));
+ final TestDnsPacket testPacket = new TestDnsPacket(testHeader, qlist, alist);
+
+ // Assert content equals in both ways.
+ assertTrue(Arrays.equals(v4BlobUncompressed, testPacket.getBytes()));
+ assertEquals(new TestDnsPacket(v4BlobUncompressed), testPacket);
+ }
+
+ @Test
+ public void testDnsPacketSynthesize_recordCountMismatch() throws IOException {
+ final DnsPacket.DnsHeader testHeader = new DnsPacket.DnsHeader(0xbeef,
+ 0x8100, 1 /* qcount */, 1 /* ancount */);
+ final ArrayList<DnsPacket.DnsRecord> qlist = new ArrayList<>();
+ final ArrayList<DnsPacket.DnsRecord> alist = new ArrayList<>();
+ qlist.add(DnsPacket.DnsRecord.makeQuestion(
+ "hello.example.com", TYPE_A, CLASS_IN));
+
+ // Assert throws if the supplied answer records fewer than the declared count.
+ assertThrows(IllegalArgumentException.class, () ->
+ new TestDnsPacket(testHeader, qlist, alist));
+
+ // Assert throws if the supplied answer records more than the declared count.
+ alist.add(DnsPacket.DnsRecord.makeCNameRecord(
+ DnsPacket.ANSECTION, "hello.example.com", CLASS_IN, 0 /* ttl */, "test.com"));
+ alist.add(DnsPacket.DnsRecord.makeAOrAAAARecord(
+ DnsPacket.ANSECTION, "hello.example.com", CLASS_IN, 0 /* ttl */,
+ InetAddressUtils.parseNumericAddress("1.2.3.4")));
+ assertThrows(IllegalArgumentException.class, () ->
+ new TestDnsPacket(testHeader, qlist, alist));
+
+ // Assert counts matched if the byte buffer still has data when parsing ended.
+ final byte[] blobTooMuchData = new byte[]{
+ /* Header */
+ (byte) 0xbe, (byte) 0xef, /* Transaction ID */
+ (byte) 0x81, 0x00, /* Flags */
+ 0x00, 0x00, /* Questions */
+ 0x00, 0x00, /* Answer RRs */
+ 0x00, 0x00, /* Authority RRs */
+ 0x00, 0x00, /* Additional RRs */
+ /* Queries */
+ 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61,
+ 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name */
+ 0x00, 0x01, /* Type */
+ 0x00, 0x01, /* Class */
+ };
+ final TestDnsPacket packetFromTooMuchData = new TestDnsPacket(blobTooMuchData);
+ for (int i = 0; i < DnsPacket.NUM_SECTIONS; i++) {
+ assertEquals(0, packetFromTooMuchData.getRecordList(i).size());
+ assertEquals(0, packetFromTooMuchData.getHeader().getRecordCount(i));
+ }
+
+ // Assert throws if the byte buffer ended when expecting more records.
+ final byte[] blobNotEnoughData = new byte[]{
+ /* Header */
+ (byte) 0xbe, (byte) 0xef, /* Transaction ID */
+ (byte) 0x81, 0x00, /* Flags */
+ 0x00, 0x01, /* Questions */
+ 0x00, 0x02, /* Answer RRs */
+ 0x00, 0x00, /* Authority RRs */
+ 0x00, 0x00, /* Additional RRs */
+ /* Queries */
+ 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61,
+ 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name */
+ 0x00, 0x01, /* Type */
+ 0x00, 0x01, /* Class */
+ /* Answers */
+ 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61,
+ 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name */
+ 0x00, 0x01, /* Type */
+ 0x00, 0x01, /* Class */
+ 0x00, 0x00, 0x00, 0x00, /* TTL */
+ 0x00, 0x04, /* Data length */
+ 0x01, 0x02, 0x03, 0x04, /* Address */
+ };
+ assertThrows(DnsPacket.ParseException.class, () -> new TestDnsPacket(blobNotEnoughData));
+ }
+
+ @Test
+ public void testEqualsAndHashCode() throws IOException {
+ // Verify DnsHeader equals and hashCode.
+ final DnsPacket.DnsHeader testHeader = new DnsPacket.DnsHeader(TEST_DNS_PACKET_ID,
+ TEST_DNS_PACKET_FLAGS, 1 /* qcount */, 1 /* ancount */);
+ final DnsPacket.DnsHeader emptyHeader = new DnsPacket.DnsHeader(TEST_DNS_PACKET_ID + 1,
+ TEST_DNS_PACKET_FLAGS + 0x08, 0 /* qcount */, 0 /* ancount */);
+ final DnsPacket.DnsHeader headerFromBytes =
+ new DnsPacket.DnsHeader(ByteBuffer.wrap(testHeader.getBytes()));
+ assertEquals(testHeader, headerFromBytes);
+ assertEquals(testHeader.hashCode(), headerFromBytes.hashCode());
+ assertNotEquals(testHeader, emptyHeader);
+ assertNotEquals(testHeader.hashCode(), emptyHeader.hashCode());
+ assertNotEquals(headerFromBytes, emptyHeader);
+ assertNotEquals(headerFromBytes.hashCode(), emptyHeader.hashCode());
+
+ // Verify DnsRecord equals and hashCode.
+ final DnsPacket.DnsRecord testQuestion = DnsPacket.DnsRecord.makeQuestion(
+ "test.com", TYPE_AAAA, CLASS_IN);
+ final DnsPacket.DnsRecord testAnswer = DnsPacket.DnsRecord.makeCNameRecord(
+ DnsPacket.ANSECTION, "test.com", CLASS_IN, 9, "www.test.com");
+ final DnsPacket.DnsRecord questionFromBytes = new DnsPacket.DnsRecord(DnsPacket.QDSECTION,
+ ByteBuffer.wrap(testQuestion.getBytes()));
+ assertEquals(testQuestion, questionFromBytes);
+ assertEquals(testQuestion.hashCode(), questionFromBytes.hashCode());
+ assertNotEquals(testQuestion, testAnswer);
+ assertNotEquals(testQuestion.hashCode(), testAnswer.hashCode());
+ assertNotEquals(questionFromBytes, testAnswer);
+ assertNotEquals(questionFromBytes.hashCode(), testAnswer.hashCode());
+
+ // Verify DnsPacket equals and hashCode.
+ final ArrayList<DnsPacket.DnsRecord> qlist = new ArrayList<>();
+ final ArrayList<DnsPacket.DnsRecord> alist = new ArrayList<>();
+ qlist.add(testQuestion);
+ alist.add(testAnswer);
+ final TestDnsPacket testPacket = new TestDnsPacket(testHeader, qlist, alist);
+ final TestDnsPacket emptyPacket = new TestDnsPacket(
+ emptyHeader, new ArrayList<>(), new ArrayList<>());
+ final TestDnsPacket packetFromBytes = new TestDnsPacket(testPacket.getBytes());
+ assertEquals(testPacket, packetFromBytes);
+ assertEquals(testPacket.hashCode(), packetFromBytes.hashCode());
+ assertNotEquals(testPacket, emptyPacket);
+ assertNotEquals(testPacket.hashCode(), emptyPacket.hashCode());
+ assertNotEquals(packetFromBytes, emptyPacket);
+ assertNotEquals(packetFromBytes.hashCode(), emptyPacket.hashCode());
+
+ // Verify DnsPacket with empty list.
+ final TestDnsPacket emptyPacketFromBytes = new TestDnsPacket(emptyPacket.getBytes());
+ assertEquals(emptyPacket, emptyPacketFromBytes);
+ assertEquals(emptyPacket.hashCode(), emptyPacketFromBytes.hashCode());
+ }
}
diff --git a/common/tests/unit/src/com/android/net/module/util/DnsPacketUtilsTest.java b/common/tests/unit/src/com/android/net/module/util/DnsPacketUtilsTest.java
index 48777acc..9e1ab824 100644
--- a/common/tests/unit/src/com/android/net/module/util/DnsPacketUtilsTest.java
+++ b/common/tests/unit/src/com/android/net/module/util/DnsPacketUtilsTest.java
@@ -18,12 +18,20 @@ package com.android.net.module.util;
import static com.android.net.module.util.DnsPacketUtils.DnsRecordParser;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
+import android.net.ParseException;
+import android.os.Build;
+
import androidx.test.filters.SmallTest;
import androidx.test.runner.AndroidJUnit4;
+import com.android.testutils.DevSdkIgnoreRule;
+
+import org.junit.Assert;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -32,6 +40,8 @@ import java.nio.ByteBuffer;
@RunWith(AndroidJUnit4.class)
@SmallTest
public class DnsPacketUtilsTest {
+ @Rule
+ public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
/**
* Verifies that the compressed NAME field in the answer section of the DNS message is parsed
@@ -116,4 +126,31 @@ public class DnsPacketUtilsTest {
notNameCompressedBuf, /* depth= */ 0, /* isNameCompressionSupported= */ false);
assertEquals(domainName, "www.google.com");
}
+
+ // Skip test on R- devices since ParseException only available on S+ devices.
+ @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
+ @Test
+ public void testDomainNameToLabels() throws Exception {
+ assertArrayEquals(
+ new byte[]{3, 'w', 'w', 'w', 6, 'g', 'o', 'o', 'g', 'l', 'e', 3, 'c', 'o', 'm', 0},
+ DnsRecordParser.domainNameToLabels("www.google.com"));
+ assertThrows(ParseException.class, () ->
+ DnsRecordParser.domainNameToLabels("aaa."));
+ assertThrows(ParseException.class, () ->
+ DnsRecordParser.domainNameToLabels("aaa"));
+ assertThrows(ParseException.class, () ->
+ DnsRecordParser.domainNameToLabels("."));
+ assertThrows(ParseException.class, () ->
+ DnsRecordParser.domainNameToLabels(""));
+ }
+
+ @Test
+ public void testIsHostName() {
+ Assert.assertTrue(DnsRecordParser.isHostName("www.google.com"));
+ Assert.assertFalse(DnsRecordParser.isHostName("com"));
+ Assert.assertFalse(DnsRecordParser.isHostName("1.2.3.4"));
+ Assert.assertFalse(DnsRecordParser.isHostName("1234::5678"));
+ Assert.assertFalse(DnsRecordParser.isHostName(null));
+ Assert.assertFalse(DnsRecordParser.isHostName(""));
+ }
}
diff --git a/common/tests/unit/src/com/android/net/module/util/IpUtilsTest.java b/common/tests/unit/src/com/android/net/module/util/IpUtilsTest.java
index 34683090..20555b35 100644
--- a/common/tests/unit/src/com/android/net/module/util/IpUtilsTest.java
+++ b/common/tests/unit/src/com/android/net/module/util/IpUtilsTest.java
@@ -16,6 +16,9 @@
package com.android.net.module.util;
+import static com.android.net.module.util.NetworkStackConstants.ICMP_CHECKSUM_OFFSET;
+import static com.android.net.module.util.NetworkStackConstants.IPV4_CHECKSUM_OFFSET;
+
import static org.junit.Assert.assertEquals;
import androidx.test.filters.SmallTest;
@@ -34,7 +37,6 @@ public class IpUtilsTest {
private static final int IPV6_HEADER_LENGTH = 40;
private static final int TCP_HEADER_LENGTH = 20;
private static final int UDP_HEADER_LENGTH = 8;
- private static final int IP_CHECKSUM_OFFSET = 10;
private static final int TCP_CHECKSUM_OFFSET = 16;
private static final int UDP_CHECKSUM_OFFSET = 6;
@@ -141,7 +143,7 @@ public class IpUtilsTest {
assertEquals((short) 0xffff, IpUtils.udpChecksum(packet, 0, IPV4_HEADER_LENGTH));
// Check that we can calculate the checksums from scratch.
- final int ipSumOffset = IP_CHECKSUM_OFFSET;
+ final int ipSumOffset = IPV4_CHECKSUM_OFFSET;
final int ipSum = getChecksum(packet, ipSumOffset);
assertEquals(0xf68b, ipSum);
@@ -163,4 +165,47 @@ public class IpUtilsTest {
assertEquals(0, IpUtils.ipChecksum(packet, 0));
assertEquals((short) 0xffff, IpUtils.udpChecksum(packet, 0, IPV4_HEADER_LENGTH));
}
+
+ @Test
+ public void testIpv4IcmpChecksum() throws Exception {
+ // packet = (scapy.IP(src="192.0.2.1", dst="192.0.2.2", tos=0x40) /
+ // scapy.ICMP(type=0x8, id=0x1234, seq=0x5678) /
+ // "hello, world")
+ ByteBuffer packet = ByteBuffer.wrap(new byte[] {
+ /* IPv4 */
+ (byte) 0x45, (byte) 0x40, (byte) 0x00, (byte) 0x28,
+ (byte) 0x00, (byte) 0x01, (byte) 0x00, (byte) 0x00,
+ (byte) 0x40, (byte) 0x01, (byte) 0xf6, (byte) 0x90,
+ (byte) 0xc0, (byte) 0x00, (byte) 0x02, (byte) 0x01,
+ (byte) 0xc0, (byte) 0x00, (byte) 0x02, (byte) 0x02,
+ /* ICMP */
+ (byte) 0x08, /* type: echo-request */
+ (byte) 0x00, /* code: 0 */
+ (byte) 0x4f, (byte) 0x07, /* chksum: 0x4f07 */
+ (byte) 0x12, (byte) 0x34, /* id: 0x1234 */
+ (byte) 0x56, (byte) 0x78, /* seq: 0x5678 */
+ (byte) 0x68, (byte) 0x65, (byte) 0x6c, (byte) 0x6c, /* data: hello, world */
+ (byte) 0x6f, (byte) 0x2c, (byte) 0x20, (byte) 0x77,
+ (byte) 0x6f, (byte) 0x72, (byte) 0x6c, (byte) 0x64
+ });
+
+ // Check that a valid packet has checksum 0.
+ int transportLen = packet.limit() - IPV4_HEADER_LENGTH;
+ assertEquals(0, IpUtils.icmpChecksum(packet, IPV4_HEADER_LENGTH, transportLen));
+
+ // Check that we can calculate the checksum from scratch.
+ int sumOffset = IPV4_HEADER_LENGTH + ICMP_CHECKSUM_OFFSET;
+ int sum = getUnsignedByte(packet, sumOffset) * 256 + getUnsignedByte(packet, sumOffset + 1);
+ assertEquals(0x4f07, sum);
+
+ packet.put(sumOffset, (byte) 0);
+ packet.put(sumOffset + 1, (byte) 0);
+ assertChecksumEquals(sum, IpUtils.icmpChecksum(packet, IPV4_HEADER_LENGTH, transportLen));
+
+ // Check that writing the checksum back into the packet results in a valid packet.
+ packet.putShort(
+ sumOffset,
+ IpUtils.icmpChecksum(packet, IPV4_HEADER_LENGTH, transportLen));
+ assertEquals(0, IpUtils.icmpChecksum(packet, IPV4_HEADER_LENGTH, transportLen));
+ }
}
diff --git a/common/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt b/common/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt
index 256ea1e8..958f45f4 100644
--- a/common/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt
+++ b/common/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt
@@ -22,7 +22,6 @@ import android.net.NetworkCapabilities.NET_CAPABILITY_BIP
import android.net.NetworkCapabilities.NET_CAPABILITY_CBS
import android.net.NetworkCapabilities.NET_CAPABILITY_EIMS
import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
-import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
import android.net.NetworkCapabilities.NET_CAPABILITY_OEM_PAID
import android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH
import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
@@ -38,8 +37,6 @@ import com.android.modules.utils.build.SdkLevel
import com.android.net.module.util.NetworkCapabilitiesUtils.RESTRICTED_CAPABILITIES
import com.android.net.module.util.NetworkCapabilitiesUtils.UNRESTRICTED_CAPABILITIES
import com.android.net.module.util.NetworkCapabilitiesUtils.getDisplayTransport
-import com.android.net.module.util.NetworkCapabilitiesUtils.packBits
-import com.android.net.module.util.NetworkCapabilitiesUtils.unpackBits
import org.junit.Test
import org.junit.runner.RunWith
import java.lang.IllegalArgumentException
@@ -75,23 +72,6 @@ class NetworkCapabilitiesUtilsTest {
}
}
- @Test
- fun testBitPackingTestCase() {
- runBitPackingTestCase(0, intArrayOf())
- runBitPackingTestCase(1, intArrayOf(0))
- runBitPackingTestCase(3, intArrayOf(0, 1))
- runBitPackingTestCase(4, intArrayOf(2))
- runBitPackingTestCase(63, intArrayOf(0, 1, 2, 3, 4, 5))
- runBitPackingTestCase(Long.MAX_VALUE.inv(), intArrayOf(63))
- runBitPackingTestCase(Long.MAX_VALUE.inv() + 1, intArrayOf(0, 63))
- runBitPackingTestCase(Long.MAX_VALUE.inv() + 2, intArrayOf(1, 63))
- }
-
- fun runBitPackingTestCase(packedBits: Long, bits: IntArray) {
- assertEquals(packedBits, packBits(bits))
- assertTrue(bits contentEquals unpackBits(packedBits))
- }
-
// NetworkCapabilities constructor and Builder are not available until R. Mark TargetApi to
// ignore the linter error since it's used in only unit test.
@Test @TargetApi(Build.VERSION_CODES.R)
diff --git a/common/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt b/common/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt
index 9fb4d8ca..8e320d0d 100644
--- a/common/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt
+++ b/common/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt
@@ -17,6 +17,7 @@
package com.android.net.module.util
import com.android.testutils.ConcurrentInterpreter
+import com.android.testutils.INTERPRET_TIME_UNIT
import com.android.testutils.InterpretException
import com.android.testutils.InterpretMatcher
import com.android.testutils.SyntaxException
@@ -420,7 +421,7 @@ private val interpretTable = listOf<InterpretMatcher<TrackRecord<Int>>>(
// the test code to not compile instead of throw, but it's vastly more complex and this will
// fail 100% at runtime any test that would not have compiled.
Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { i, t, r ->
- (if (r.strArg(1).isEmpty()) i.interpretTimeUnit else r.timeArg(1)).let { time ->
+ (if (r.strArg(1).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(1)).let { time ->
(t as ArrayTrackRecord<Int>.ReadHead).poll(time, makePredicate(r.strArg(2)))
}
},
diff --git a/common/tests/unit/src/com/android/net/module/util/netlink/NetlinkSocketTest.java b/common/tests/unit/src/com/android/net/module/util/netlink/NetlinkSocketTest.java
index 33ea6023..f4073b6d 100644
--- a/common/tests/unit/src/com/android/net/module/util/netlink/NetlinkSocketTest.java
+++ b/common/tests/unit/src/com/android/net/module/util/netlink/NetlinkSocketTest.java
@@ -236,7 +236,7 @@ public class NetlinkSocketTest {
@Field(order = 3, type = Type.U8)
public short scope;
- @Field(order = 4, type = Type.U32)
- public long index;
+ @Field(order = 4, type = Type.S32)
+ public int index;
}
}
diff --git a/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt b/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
new file mode 100644
index 00000000..f8f2da08
--- /dev/null
+++ b/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
@@ -0,0 +1,477 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.annotation.SuppressLint
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.Network
+import android.net.NetworkCapabilities
+import com.android.testutils.RecorderCallback.CallbackEntry
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.AVAILABLE
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.BLOCKED_STATUS
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LINK_PROPERTIES_CHANGED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOSING
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.NETWORK_CAPS_UPDATED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOST
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.RESUMED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.SUSPENDED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.UNAVAILABLE
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Assume.assumeTrue
+import kotlin.reflect.KClass
+import kotlin.test.assertEquals
+import kotlin.test.assertFails
+import kotlin.test.assertNull
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+const val SHORT_TIMEOUT_MS = 20L
+const val DEFAULT_LINGER_DELAY_MS = 30000
+const val NOT_METERED = NetworkCapabilities.NET_CAPABILITY_NOT_METERED
+const val WIFI = NetworkCapabilities.TRANSPORT_WIFI
+const val CELLULAR = NetworkCapabilities.TRANSPORT_CELLULAR
+const val TEST_INTERFACE_NAME = "testInterfaceName"
+
+@RunWith(JUnit4::class)
+@SuppressLint("NewApi") // Uses hidden APIs, which the linter would identify as missing APIs.
+class TestableNetworkCallbackTest {
+ private lateinit var mCallback: TestableNetworkCallback
+
+ private fun makeHasNetwork(netId: Int) = object : TestableNetworkCallback.HasNetwork {
+ override val network: Network = Network(netId)
+ }
+
+ @Before
+ fun setUp() {
+ mCallback = TestableNetworkCallback()
+ }
+
+ @Test
+ fun testLastAvailableNetwork() {
+ // Make sure there is no last available network at first, then the last available network
+ // is returned after onAvailable is called.
+ val net2097 = Network(2097)
+ assertNull(mCallback.lastAvailableNetwork)
+ mCallback.onAvailable(net2097)
+ assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+ // Make sure calling onCapsChanged/onLinkPropertiesChanged don't affect the last available
+ // network.
+ mCallback.onCapabilitiesChanged(net2097, NetworkCapabilities())
+ mCallback.onLinkPropertiesChanged(net2097, LinkProperties())
+ assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+ // Make sure onLost clears the last available network.
+ mCallback.onLost(net2097)
+ assertNull(mCallback.lastAvailableNetwork)
+
+ // Do the same but with a different network after onLost : make sure the last available
+ // network is the new one, not the original one.
+ val net2098 = Network(2098)
+ mCallback.onAvailable(net2098)
+ mCallback.onCapabilitiesChanged(net2098, NetworkCapabilities())
+ mCallback.onLinkPropertiesChanged(net2098, LinkProperties())
+ assertEquals(mCallback.lastAvailableNetwork, net2098)
+
+ // Make sure onAvailable changes the last available network even if onLost was not called.
+ val net2099 = Network(2099)
+ mCallback.onAvailable(net2099)
+ assertEquals(mCallback.lastAvailableNetwork, net2099)
+
+ // For legacy reasons, lastAvailableNetwork is null as soon as any is lost, not necessarily
+ // the last available one. Check that behavior.
+ mCallback.onLost(net2098)
+ assertNull(mCallback.lastAvailableNetwork)
+
+ // Make sure that losing the really last available one still results in null.
+ mCallback.onLost(net2099)
+ assertNull(mCallback.lastAvailableNetwork)
+
+ // Make sure multiple onAvailable in a row then onLost still results in null.
+ mCallback.onAvailable(net2097)
+ mCallback.onAvailable(net2098)
+ mCallback.onAvailable(net2099)
+ mCallback.onLost(net2097)
+ assertNull(mCallback.lastAvailableNetwork)
+ }
+
+ @Test
+ fun testAssertNoCallback() {
+ mCallback.assertNoCallback(SHORT_TIMEOUT_MS)
+ mCallback.onAvailable(Network(100))
+ assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) }
+ }
+
+ @Test
+ fun testAssertNoCallbackThat() {
+ val net = Network(101)
+ mCallback.assertNoCallbackThat { it is Available }
+ mCallback.onAvailable(net)
+ // Expect no blocked status change. Receive other callback does not fail the test.
+ mCallback.assertNoCallbackThat { it is BlockedStatus }
+ mCallback.onBlockedStatusChanged(net, true)
+ assertFails { mCallback.assertNoCallbackThat { it is BlockedStatus } }
+ mCallback.onBlockedStatusChanged(net, false)
+ mCallback.onCapabilitiesChanged(net, NetworkCapabilities())
+ assertFails { mCallback.assertNoCallbackThat { it is CapabilitiesChanged } }
+ }
+
+ @Test
+ fun testCapabilitiesWithAndWithout() {
+ val net = Network(101)
+ val matcher = makeHasNetwork(101)
+ val meteredNc = NetworkCapabilities()
+ val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED)
+ // Check that expecting caps (with or without) fails when no callback has been received.
+ assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+ assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+ // Add NOT_METERED and check that With succeeds and Without fails.
+ mCallback.onCapabilitiesChanged(net, unmeteredNc)
+ mCallback.expectCapabilitiesWith(NOT_METERED, matcher)
+ mCallback.onCapabilitiesChanged(net, unmeteredNc)
+ assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+ // Don't add NOT_METERED and check that With fails and Without succeeds.
+ mCallback.onCapabilitiesChanged(net, meteredNc)
+ assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+ mCallback.onCapabilitiesChanged(net, meteredNc)
+ mCallback.expectCapabilitiesWithout(NOT_METERED, matcher)
+ }
+
+ @Test
+ fun testExpectWithPredicate() {
+ val net = Network(193)
+ val netCaps = NetworkCapabilities().addTransportType(CELLULAR)
+ // Check that expecting callbackThat anything fails when no callback has been received.
+ assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { true } }
+
+ // Basic test for true and false
+ mCallback.onAvailable(net)
+ mCallback.expect<Available> { true }
+ mCallback.onAvailable(net)
+ assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { false } }
+
+ // Try a positive and a negative case
+ mCallback.onBlockedStatusChanged(net, true)
+ mCallback.expect<CallbackEntry> { cb -> cb is BlockedStatus && cb.blocked }
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { cb ->
+ cb is CapabilitiesChanged && cb.caps.hasTransport(WIFI)
+ } }
+ }
+
+ @Test
+ fun testCapabilitiesThat() {
+ val net = Network(101)
+ val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI)
+ // Check that expecting capabilitiesThat anything fails when no callback has been received.
+ assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+ // Basic test for true and false
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ mCallback.expectCapabilitiesThat(net) { true }
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+ // Try a positive and a negative case
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ mCallback.expectCapabilitiesThat(net) { caps ->
+ caps.hasCapability(NOT_METERED) &&
+ caps.hasTransport(WIFI) &&
+ !caps.hasTransport(CELLULAR)
+ }
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps ->
+ caps.hasTransport(CELLULAR)
+ } }
+
+ // Try a matching callback on the wrong network
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } }
+ }
+
+ @Test
+ fun testLinkPropertiesThat() {
+ val net = Network(112)
+ val linkAddress = LinkAddress("fe80::ace:d00d/64")
+ val mtu = 1984
+ val linkProps = LinkProperties().apply {
+ this.mtu = mtu
+ interfaceName = TEST_INTERFACE_NAME
+ addLinkAddress(linkAddress)
+ }
+
+ // Check that expecting linkPropsThat anything fails when no callback has been received.
+ assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+ // Basic test for true and false
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ mCallback.expectLinkPropertiesThat(net) { true }
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+ // Try a positive and negative case
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ mCallback.expectLinkPropertiesThat(net) { lp ->
+ lp.interfaceName == TEST_INTERFACE_NAME &&
+ lp.linkAddresses.contains(linkAddress) &&
+ lp.mtu == mtu
+ }
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { lp ->
+ lp.interfaceName != TEST_INTERFACE_NAME
+ } }
+
+ // Try a matching callback on the wrong network
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ assertFails { mCallback.expectLinkPropertiesThat(Network(114), SHORT_TIMEOUT_MS) { lp ->
+ lp.interfaceName == TEST_INTERFACE_NAME
+ } }
+ }
+
+ @Test
+ fun testExpect() {
+ val net = Network(103)
+ // Test expectCallback fails when nothing was sent.
+ assertFails { mCallback.expect<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+
+ // Test onAvailable is seen and can be expected
+ mCallback.onAvailable(net)
+ mCallback.expect<Available>(net, SHORT_TIMEOUT_MS)
+
+ // Test onAvailable won't return calls with a different network
+ mCallback.onAvailable(Network(106))
+ assertFails { mCallback.expect<Available>(net, SHORT_TIMEOUT_MS) }
+
+ // Test onAvailable won't return calls with a different callback
+ mCallback.onAvailable(net)
+ assertFails { mCallback.expect<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+ }
+
+ @Test
+ fun testAllExpectOverloads() {
+ // This test should never run, it only checks that all overloads exist and build
+ assumeTrue(false)
+ val hn = object : TestableNetworkCallback.HasNetwork { override val network = ANY_NETWORK }
+
+ // Method with all arguments (version that takes a Network)
+ mCallback.expect(AVAILABLE, ANY_NETWORK, 10, "error") { true }
+
+ // Java overloads omitting one argument. One line for omitting each argument, in positional
+ // order. Versions that take a Network.
+ mCallback.expect(AVAILABLE, 10, "error") { true }
+ mCallback.expect(AVAILABLE, ANY_NETWORK, "error") { true }
+ mCallback.expect(AVAILABLE, ANY_NETWORK, 10) { true }
+ mCallback.expect(AVAILABLE, ANY_NETWORK, 10, "error")
+
+ // Java overloads for omitting two arguments. One line for omitting each pair of arguments.
+ // Versions that take a Network.
+ mCallback.expect(AVAILABLE, "error") { true }
+ mCallback.expect(AVAILABLE, 10) { true }
+ mCallback.expect(AVAILABLE, 10, "error")
+ mCallback.expect(AVAILABLE, ANY_NETWORK) { true }
+ mCallback.expect(AVAILABLE, ANY_NETWORK, "error")
+ mCallback.expect(AVAILABLE, ANY_NETWORK, 10)
+
+ // Java overloads for omitting three arguments. One line for each remaining argument.
+ // Versions that take a Network.
+ mCallback.expect(AVAILABLE) { true }
+ mCallback.expect(AVAILABLE, "error")
+ mCallback.expect(AVAILABLE, 10)
+ mCallback.expect(AVAILABLE, ANY_NETWORK)
+
+ // Java overload for omitting all four arguments.
+ mCallback.expect(AVAILABLE)
+
+ // Same orders as above, but versions that take a HasNetwork. Except overloads that
+ // were already tested because they omitted the Network argument
+ mCallback.expect(AVAILABLE, hn, 10, "error") { true }
+ mCallback.expect(AVAILABLE, hn, "error") { true }
+ mCallback.expect(AVAILABLE, hn, 10) { true }
+ mCallback.expect(AVAILABLE, hn, 10, "error")
+
+ mCallback.expect(AVAILABLE, hn) { true }
+ mCallback.expect(AVAILABLE, hn, "error")
+ mCallback.expect(AVAILABLE, hn, 10)
+
+ mCallback.expect(AVAILABLE, hn)
+
+ // Same as above but for reified versions.
+ mCallback.expect<Available>(ANY_NETWORK, 10, "error") { true }
+ mCallback.expect<Available>(timeoutMs = 10, errorMsg = "error") { true }
+ mCallback.expect<Available>(network = ANY_NETWORK, errorMsg = "error") { true }
+ mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10) { true }
+ mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10, errorMsg = "error")
+
+ mCallback.expect<Available>(errorMsg = "error") { true }
+ mCallback.expect<Available>(timeoutMs = 10) { true }
+ mCallback.expect<Available>(timeoutMs = 10, errorMsg = "error")
+ mCallback.expect<Available>(network = ANY_NETWORK) { true }
+ mCallback.expect<Available>(network = ANY_NETWORK, errorMsg = "error")
+ mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10)
+
+ mCallback.expect<Available> { true }
+ mCallback.expect<Available>(errorMsg = "error")
+ mCallback.expect<Available>(timeoutMs = 10)
+ mCallback.expect<Available>(network = ANY_NETWORK)
+ mCallback.expect<Available>()
+
+ mCallback.expect<Available>(hn, 10, "error") { true }
+ mCallback.expect<Available>(network = hn, errorMsg = "error") { true }
+ mCallback.expect<Available>(network = hn, timeoutMs = 10) { true }
+ mCallback.expect<Available>(network = hn, timeoutMs = 10, errorMsg = "error")
+
+ mCallback.expect<Available>(network = hn) { true }
+ mCallback.expect<Available>(network = hn, errorMsg = "error")
+ mCallback.expect<Available>(network = hn, timeoutMs = 10)
+
+ mCallback.expect<Available>(network = hn)
+ }
+
+ @Test
+ fun testPoll() {
+ assertNull(mCallback.poll(SHORT_TIMEOUT_MS))
+ TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+ threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+ sleep; onAvailable(133) | poll(2) = Available(133) time 1..4
+ | poll(1) = null
+ onCapabilitiesChanged(108) | poll(1) = CapabilitiesChanged(108) time 0..3
+ onBlockedStatus(199) | poll(1) = BlockedStatus(199) time 0..3
+ """)
+ }
+
+ @Test
+ fun testPollOrThrow() {
+ assertFails { mCallback.pollOrThrow(SHORT_TIMEOUT_MS) }
+ TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+ threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+ sleep; onAvailable(133) | pollOrThrow(2) = Available(133) time 1..4
+ | pollOrThrow(1) fails
+ onCapabilitiesChanged(108) | pollOrThrow(1) = CapabilitiesChanged(108) time 0..3
+ onBlockedStatus(199) | pollOrThrow(1) = BlockedStatus(199) time 0..3
+ """)
+ }
+
+ @Test
+ fun testEventuallyExpect() {
+ // TODO: Current test does not verify the inline one. Also verify the behavior after
+ // aligning two eventuallyExpect()
+ val net1 = Network(100)
+ val net2 = Network(101)
+ mCallback.onAvailable(net1)
+ mCallback.onCapabilitiesChanged(net1, NetworkCapabilities())
+ mCallback.onLinkPropertiesChanged(net1, LinkProperties())
+ mCallback.eventuallyExpect(LINK_PROPERTIES_CHANGED) {
+ net1.equals(it.network)
+ }
+ // No further new callback. Expect no callback.
+ assertFails { mCallback.eventuallyExpect(LINK_PROPERTIES_CHANGED, SHORT_TIMEOUT_MS) }
+
+ // Verify no predicate set.
+ mCallback.onAvailable(net2)
+ mCallback.onLinkPropertiesChanged(net2, LinkProperties())
+ mCallback.onBlockedStatusChanged(net1, false)
+ mCallback.eventuallyExpect(BLOCKED_STATUS) { net1.equals(it.network) }
+ // Verify no callback received if the callback does not happen.
+ assertFails { mCallback.eventuallyExpect(LOSING, SHORT_TIMEOUT_MS) }
+ }
+
+ @Test
+ fun testEventuallyExpectOnMultiThreads() {
+ TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+ threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+ onAvailable(100) | eventually(CapabilitiesChanged(100), 1) fails
+ sleep ; onCapabilitiesChanged(100) | eventually(CapabilitiesChanged(100), 3)
+ onAvailable(101) ; onBlockedStatus(101) | eventually(BlockedStatus(100), 2) fails
+ onSuspended(100) ; sleep ; onLost(100) | eventually(Lost(100), 3)
+ """)
+ }
+}
+
+private object TNCInterpreter : ConcurrentInterpreter<TestableNetworkCallback>(interpretTable)
+
+val EntryList = CallbackEntry::class.sealedSubclasses.map { it.simpleName }.joinToString("|")
+private fun callbackEntryFromString(name: String): KClass<out CallbackEntry> {
+ return CallbackEntry::class.sealedSubclasses.first { it.simpleName == name }
+}
+
+@SuppressLint("NewApi") // Uses hidden APIs, which the linter would identify as missing APIs.
+private val interpretTable = listOf<InterpretMatcher<TestableNetworkCallback>>(
+ // Interpret "Available(xx)" as "call to onAvailable with netId xx", and likewise for
+ // all callback types. This is implemented above by enumerating the subclasses of
+ // CallbackEntry and reading their simpleName.
+ Regex("""(.*)\s+=\s+($EntryList)\((\d+)\)""") to { i, cb, t ->
+ val record = i.interpret(t.strArg(1), cb)
+ assertTrue(callbackEntryFromString(t.strArg(2)).isInstance(record))
+ // Strictly speaking testing for is CallbackEntry is useless as it's been tested above
+ // but the compiler can't figure things out from the isInstance call. It does understand
+ // from the assertTrue(is CallbackEntry) that this is true, which allows to access
+ // the 'network' member below.
+ assertTrue(record is CallbackEntry)
+ assertEquals(record.network.netId, t.intArg(3))
+ },
+ // Interpret "onAvailable(xx)" as calling "onAvailable" with a netId of xx, and likewise for
+ // all callback types. NetworkCapabilities and LinkProperties just get an empty object
+ // as their argument. Losing gets the default linger timer. Blocked gets false.
+ Regex("""on($EntryList)\((\d+)\)""") to { i, cb, t ->
+ val net = Network(t.intArg(2))
+ when (t.strArg(1)) {
+ "Available" -> cb.onAvailable(net)
+ // PreCheck not used in tests. Add it here if it becomes useful.
+ "CapabilitiesChanged" -> cb.onCapabilitiesChanged(net, NetworkCapabilities())
+ "LinkPropertiesChanged" -> cb.onLinkPropertiesChanged(net, LinkProperties())
+ "Suspended" -> cb.onNetworkSuspended(net)
+ "Resumed" -> cb.onNetworkResumed(net)
+ "Losing" -> cb.onLosing(net, DEFAULT_LINGER_DELAY_MS)
+ "Lost" -> cb.onLost(net)
+ "Unavailable" -> cb.onUnavailable()
+ "BlockedStatus" -> cb.onBlockedStatusChanged(net, false)
+ else -> fail("Unknown callback type")
+ }
+ },
+ Regex("""poll\((\d+)\)""") to { i, cb, t -> cb.poll(t.timeArg(1)) },
+ Regex("""pollOrThrow\((\d+)\)""") to { i, cb, t -> cb.pollOrThrow(t.timeArg(1)) },
+ // Interpret "eventually(Available(xx), timeout)" as calling eventuallyExpect that expects
+ // CallbackEntry.AVAILABLE with netId of xx within timeout*INTERPRET_TIME_UNIT timeout, and
+ // likewise for all callback types.
+ Regex("""eventually\(($EntryList)\((\d+)\),\s+(\d+)\)""") to { i, cb, t ->
+ val net = Network(t.intArg(2))
+ val timeout = t.timeArg(3)
+ when (t.strArg(1)) {
+ "Available" -> cb.eventuallyExpect(AVAILABLE, timeout) { net == it.network }
+ "Suspended" -> cb.eventuallyExpect(SUSPENDED, timeout) { net == it.network }
+ "Resumed" -> cb.eventuallyExpect(RESUMED, timeout) { net == it.network }
+ "Losing" -> cb.eventuallyExpect(LOSING, timeout) { net == it.network }
+ "Lost" -> cb.eventuallyExpect(LOST, timeout) { net == it.network }
+ "Unavailable" -> cb.eventuallyExpect(UNAVAILABLE, timeout) { net == it.network }
+ "BlockedStatus" -> cb.eventuallyExpect(BLOCKED_STATUS, timeout) { net == it.network }
+ "CapabilitiesChanged" ->
+ cb.eventuallyExpect(NETWORK_CAPS_UPDATED, timeout) { net == it.network }
+ "LinkPropertiesChanged" ->
+ cb.eventuallyExpect(LINK_PROPERTIES_CHANGED, timeout) { net == it.network }
+ else -> fail("Unknown callback type")
+ }
+ }
+)
diff --git a/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java b/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java
new file mode 100644
index 00000000..4570d0ae
--- /dev/null
+++ b/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils;
+
+import static com.android.testutils.RecorderCallback.CallbackEntry.AVAILABLE;
+import static com.android.testutils.TestableNetworkCallbackKt.anyNetwork;
+
+import static org.junit.Assume.assumeTrue;
+
+import org.junit.Test;
+
+public class TestableNetworkCallbackTestJava {
+ @Test
+ void testAllExpectOverloads() {
+ // This test should never run, it only checks that all overloads exist and build
+ assumeTrue(false);
+ final TestableNetworkCallback callback = new TestableNetworkCallback();
+ TestableNetworkCallback.HasNetwork hn = TestableNetworkCallbackKt::anyNetwork;
+
+ // Method with all arguments (version that takes a Network)
+ callback.expect(AVAILABLE, anyNetwork(), 10, "error", cb -> true);
+
+ // Overloads omitting one argument. One line for omitting each argument, in positional
+ // order. Versions that take a Network.
+ callback.expect(AVAILABLE, 10, "error", cb -> true);
+ callback.expect(AVAILABLE, anyNetwork(), "error", cb -> true);
+ callback.expect(AVAILABLE, anyNetwork(), 10, cb -> true);
+ callback.expect(AVAILABLE, anyNetwork(), 10, "error");
+
+ // Overloads for omitting two arguments. One line for omitting each pair of arguments.
+ // Versions that take a Network.
+ callback.expect(AVAILABLE, "error", cb -> true);
+ callback.expect(AVAILABLE, 10, cb -> true);
+ callback.expect(AVAILABLE, 10, "error");
+ callback.expect(AVAILABLE, anyNetwork(), cb -> true);
+ callback.expect(AVAILABLE, anyNetwork(), "error");
+ callback.expect(AVAILABLE, anyNetwork(), 10);
+
+ // Overloads for omitting three arguments. One line for each remaining argument.
+ // Versions that take a Network.
+ callback.expect(AVAILABLE, cb -> true);
+ callback.expect(AVAILABLE, "error");
+ callback.expect(AVAILABLE, 10);
+ callback.expect(AVAILABLE, anyNetwork());
+
+ // Java overload for omitting all four arguments.
+ callback.expect(AVAILABLE);
+
+ // Same orders as above, but versions that take a HasNetwork. Except overloads that
+ // were already tested because they omitted the Network argument
+ callback.expect(AVAILABLE, hn, 10, "error", cb -> true);
+ callback.expect(AVAILABLE, hn, "error", cb -> true);
+ callback.expect(AVAILABLE, hn, 10, cb -> true);
+ callback.expect(AVAILABLE, hn, 10, "error");
+
+ callback.expect(AVAILABLE, hn, cb -> true);
+ callback.expect(AVAILABLE, hn, "error");
+ callback.expect(AVAILABLE, hn, 10);
+
+ callback.expect(AVAILABLE, hn);
+ }
+}
diff --git a/common/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt b/common/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt
index cbdc0173..9e72f4b4 100644
--- a/common/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt
+++ b/common/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt
@@ -13,7 +13,7 @@ import kotlin.test.assertTrue
typealias InterpretMatcher<T> = Pair<Regex, (ConcurrentInterpreter<T>, T, MatchResult) -> Any?>
// The default unit of time for interpreted tests
-val INTERPRET_TIME_UNIT = 40L // ms
+const val INTERPRET_TIME_UNIT = 60L // ms
/**
* A small interpreter for testing parallel code.
@@ -40,10 +40,7 @@ val INTERPRET_TIME_UNIT = 40L // ms
* Some expressions already exist by default and can be used by all interpreters. Refer to
* getDefaultInstructions() below for a list and documentation.
*/
-open class ConcurrentInterpreter<T>(
- localInterpretTable: List<InterpretMatcher<T>>,
- val interpretTimeUnit: Long = INTERPRET_TIME_UNIT
-) {
+open class ConcurrentInterpreter<T>(localInterpretTable: List<InterpretMatcher<T>>) {
private val interpretTable: List<InterpretMatcher<T>> =
localInterpretTable + getDefaultInstructions()
// The last time the thread became blocked, with base System.currentTimeMillis(). This should
@@ -211,7 +208,7 @@ private fun <T> getDefaultInstructions() = listOf<InterpretMatcher<T>>(
},
// Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
Regex("""sleep(\((\d+)\))?""") to { i, t, r ->
- SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2))
+ SystemClock.sleep(if (r.strArg(2).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(2))
},
Regex("""(.*)\s*fails""") to { i, t, r ->
assertFails { i.interpret(r.strArg(1), t) }
diff --git a/common/testutils/devicetests/com/android/testutils/ConnectUtil.kt b/common/testutils/devicetests/com/android/testutils/ConnectUtil.kt
index fc951d86..7b5ad010 100644
--- a/common/testutils/devicetests/com/android/testutils/ConnectUtil.kt
+++ b/common/testutils/devicetests/com/android/testutils/ConnectUtil.kt
@@ -63,6 +63,7 @@ class ConnectUtil(private val context: Context) {
try {
val connInfo = wifiManager.connectionInfo
+ Log.d(TAG, "connInfo=" + connInfo)
if (connInfo == null || connInfo.networkId == -1) {
clearWifiBlocklist()
val pfd = getInstrumentation().uiAutomation.executeShellCommand("svc wifi enable")
@@ -75,7 +76,7 @@ class ConnectUtil(private val context: Context) {
timeoutMs = WIFI_CONNECT_TIMEOUT_MS)
assertNotNull(cb, "Could not connect to a wifi access point within " +
- "$WIFI_CONNECT_INTERVAL_MS ms. Check that the test device has a wifi network " +
+ "$WIFI_CONNECT_TIMEOUT_MS ms. Check that the test device has a wifi network " +
"configured, and that the test access point is functioning properly.")
return cb.network
} finally {
diff --git a/common/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java b/common/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java
index ea89edaa..ce55fdc1 100644
--- a/common/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java
+++ b/common/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java
@@ -20,6 +20,7 @@ import android.os.VintfRuntimeInfo;
import android.text.TextUtils;
import android.util.Pair;
+import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -128,7 +129,7 @@ public class DeviceInfoUtils {
final Pair<Integer, Integer> v1 = getMajorMinorVersion(s1);
final Pair<Integer, Integer> v2 = getMajorMinorVersion(s2);
- if (v1.first == v2.first) {
+ if (Objects.equals(v1.first, v2.first)) {
return Integer.compare(v1.second, v2.second);
} else {
return Integer.compare(v1.first, v2.first);
diff --git a/common/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt b/common/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt
index 5ae2439d..b743b6c9 100644
--- a/common/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt
+++ b/common/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt
@@ -111,12 +111,12 @@ class TestNetworkTracker internal constructor(
val tnm: TestNetworkManager,
val lp: LinkProperties?,
setupTimeoutMs: Long
-) {
+) : TestableNetworkCallback.HasNetwork {
private val cm = context.getSystemService(ConnectivityManager::class.java)
private val binder = Binder()
private val networkCallback: NetworkCallback
- val network: Network
+ override val network: Network
val testIface: TestNetworkInterface
init {
diff --git a/common/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/common/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index dffdbe8b..406a179c 100644
--- a/common/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/common/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -36,15 +36,12 @@ import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
import kotlin.reflect.KClass
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
-import kotlin.test.assertTrue
import kotlin.test.fail
object NULL_NETWORK : Network(-1)
object ANY_NETWORK : Network(-2)
fun anyNetwork() = ANY_NETWORK
-private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
-
open class RecorderCallback private constructor(
private val backingRecord: ArrayTrackRecord<CallbackEntry>
) : NetworkCallback() {
@@ -168,16 +165,22 @@ open class RecorderCallback private constructor(
}
}
-private const val DEFAULT_TIMEOUT = 200L // ms
+private const val DEFAULT_TIMEOUT = 30_000L // ms
+private const val DEFAULT_NO_CALLBACK_TIMEOUT = 200L // ms
open class TestableNetworkCallback private constructor(
src: TestableNetworkCallback?,
- val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
+ val defaultTimeoutMs: Long = DEFAULT_TIMEOUT,
+ val defaultNoCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT
) : RecorderCallback(src) {
@JvmOverloads
- constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
+ constructor(
+ timeoutMs: Long = DEFAULT_TIMEOUT,
+ noCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT
+ ): this(null, timeoutMs, noCallbackTimeoutMs)
- fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
+ fun createLinkedCopy() = TestableNetworkCallback(
+ this, defaultTimeoutMs, defaultNoCallbackTimeoutMs)
// The last available network, or null if any network was lost since the last call to
// onAvailable. TODO : fix this by fixing the tests that rely on this behavior
@@ -187,14 +190,189 @@ open class TestableNetworkCallback private constructor(
else -> null
}
- fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
- return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
- }
+ /**
+ * Get the next callback or null if timeout.
+ *
+ * With no argument, this method waits out the default timeout. To wait forever, pass
+ * Long.MAX_VALUE.
+ */
+ @JvmOverloads
+ fun poll(timeoutMs: Long = defaultTimeoutMs): CallbackEntry? = history.poll(timeoutMs)
+
+ /**
+ * Get the next callback or throw if timeout.
+ *
+ * With no argument, this method waits out the default timeout. To wait forever, pass
+ * Long.MAX_VALUE.
+ */
+ @JvmOverloads
+ fun pollOrThrow(
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String = "Did not receive callback after $timeoutMs"
+ ): CallbackEntry = poll(timeoutMs) ?: fail(errorMsg)
+
+ /*****
+ * expect family of methods.
+ * These methods fetch the next callback and assert it matches the conditions : type,
+ * passed predicate. If no callback is received within the timeout, these methods fail.
+ */
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: Network = ANY_NETWORK,
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String? = null,
+ test: (T) -> Boolean = { true }
+ ) = expect<CallbackEntry>(network, timeoutMs, errorMsg) {
+ test(it as? T ?: fail("Expected callback ${type.simpleName}, got $it"))
+ } as T
+
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: HasNetwork,
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String? = null,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, network.network, timeoutMs, errorMsg, test)
+
+ // Java needs an explicit overload to let it omit arguments in the middle, so define these
+ // here. Note that @JvmOverloads give us the versions without the last arguments too, so
+ // there is no need to explicitly define versions without the test predicate.
+ // Without |network|
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ timeoutMs: Long,
+ errorMsg: String?,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, ANY_NETWORK, timeoutMs, errorMsg, test)
+
+ // Without |timeout|, in Network and HasNetwork versions
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: Network,
+ errorMsg: String?,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, network, defaultTimeoutMs, errorMsg, test)
+
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: HasNetwork,
+ errorMsg: String?,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, network.network, defaultTimeoutMs, errorMsg, test)
+
+ // Without |errorMsg|, in Network and HasNetwork versions
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: Network,
+ timeoutMs: Long,
+ test: (T) -> Boolean
+ ) = expect(type, network, timeoutMs, null, test)
+
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: HasNetwork,
+ timeoutMs: Long,
+ test: (T) -> Boolean
+ ) = expect(type, network.network, timeoutMs, null, test)
+
+ // Without |network| or |timeout|
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ errorMsg: String?,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, ANY_NETWORK, defaultTimeoutMs, errorMsg, test)
+
+ // Without |network| or |errorMsg|
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ timeoutMs: Long,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, ANY_NETWORK, timeoutMs, null, test)
+
+ // Without |timeout| or |errorMsg|, in Network and HasNetwork versions
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: Network,
+ test: (T) -> Boolean
+ ) = expect(type, network, defaultTimeoutMs, null, test)
+
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: HasNetwork,
+ test: (T) -> Boolean
+ ) = expect(type, network.network, defaultTimeoutMs, null, test)
+
+ // Without |network| or |timeout| or |errorMsg|
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ test: (T) -> Boolean
+ ) = expect(type, ANY_NETWORK, defaultTimeoutMs, null, test)
+
+ // Kotlin reified versions. Don't call methods above, or the predicate would need to be noinline
+ inline fun <reified T : CallbackEntry> expect(
+ network: Network = ANY_NETWORK,
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String? = null,
+ test: (T) -> Boolean = { true }
+ ) = pollOrThrow(timeoutMs).also {
+ if (it !is T) fail("Expected callback ${T::class.simpleName}, got $it")
+ if (ANY_NETWORK !== network && it.network != network) {
+ fail("Expected network $network for callback : $it")
+ }
+ if (!test(it)) {
+ fail("${errorMsg ?: "Callback doesn't match predicate"} : $it")
+ }
+ } as T
+
+ inline fun <reified T : CallbackEntry> expect(
+ network: HasNetwork,
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String? = null,
+ test: (T) -> Boolean = { true }
+ ) = expect(network.network, timeoutMs, errorMsg, test)
+
+ // TODO : remove all expectCallback and expectCallbackThat methods after all callers have been
+ // migrated to expect().
+ inline fun <reified T : CallbackEntry> expectCallback(
+ network: Network = ANY_NETWORK,
+ timeoutMs: Long = defaultTimeoutMs
+ ): T = expect(network, timeoutMs)
+
+ @JvmOverloads
+ open fun <T : CallbackEntry> expectCallback(
+ type: KClass<T>,
+ n: Network?,
+ timeoutMs: Long = defaultTimeoutMs
+ ) = expect(type, n ?: ANY_NETWORK, timeoutMs)
+
+ @JvmOverloads
+ open fun <T : CallbackEntry> expectCallback(
+ type: KClass<T>,
+ n: HasNetwork?,
+ timeoutMs: Long = defaultTimeoutMs
+ ) = expect(type, n?.network ?: ANY_NETWORK, timeoutMs)
+
+ fun expectCallbackThat(
+ timeoutMs: Long = defaultTimeoutMs,
+ valid: (CallbackEntry) -> Boolean
+ ) = expect(timeoutMs = timeoutMs, test = valid)
// Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
// TODO : remove the necessity to overload this, remove the open qualifier, and give a
// default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary.
- open fun assertNoCallback() = assertNoCallback(defaultTimeoutMs)
+ open fun assertNoCallback() = assertNoCallback(defaultNoCallbackTimeoutMs)
fun assertNoCallback(timeoutMs: Long) {
val cb = history.poll(timeoutMs)
@@ -202,7 +380,7 @@ open class TestableNetworkCallback private constructor(
}
fun assertNoCallbackThat(
- timeoutMs: Long = defaultTimeoutMs,
+ timeoutMs: Long = defaultNoCallbackTimeoutMs,
valid: (CallbackEntry) -> Boolean
) {
val cb = history.poll(timeoutMs) { valid(it) }.let {
@@ -210,19 +388,6 @@ open class TestableNetworkCallback private constructor(
}
}
- // Expects a callback of the specified type on the specified network within the timeout.
- // If no callback arrives, or a different callback arrives, fail. Returns the callback.
- inline fun <reified T : CallbackEntry> expectCallback(
- network: Network = ANY_NETWORK,
- timeoutMs: Long = defaultTimeoutMs
- ): T = pollForNextCallback(timeoutMs).let {
- if (it !is T || (ANY_NETWORK !== network && it.network != network)) {
- fail("Unexpected callback : $it, expected ${T::class} with Network[$network]")
- } else {
- it
- }
- }
-
// Expects a callback of the specified type matching the predicate within the timeout.
// Any callback that doesn't match the predicate will be skipped. Fails only if
// no matching callback is received within the timeout.
@@ -249,30 +414,19 @@ open class TestableNetworkCallback private constructor(
crossinline predicate: (T) -> Boolean = { true }
) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
- fun expectCallbackThat(
- timeoutMs: Long = defaultTimeoutMs,
- valid: (CallbackEntry) -> Boolean
- ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
-
- fun expectCapabilitiesThat(
+ inline fun expectCapabilitiesThat(
net: Network,
tmt: Long = defaultTimeoutMs,
valid: (NetworkCapabilities) -> Boolean
- ): CapabilitiesChanged {
- return expectCallback<CapabilitiesChanged>(net, tmt).also {
- assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}")
- }
- }
+ ): CapabilitiesChanged =
+ expect(net, tmt, "Capabilities don't match expectations") { valid(it.caps) }
- fun expectLinkPropertiesThat(
+ inline fun expectLinkPropertiesThat(
net: Network,
tmt: Long = defaultTimeoutMs,
valid: (LinkProperties) -> Boolean
- ): LinkPropertiesChanged {
- return expectCallback<LinkPropertiesChanged>(net, tmt).also {
- assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}")
- }
- }
+ ): LinkPropertiesChanged =
+ expect(net, tmt, "LinkProperties don't match expectations") { valid(it.lp) }
// Expects onAvailable and the callbacks that follow it. These are:
// - onSuspended, iff the network was suspended when the callbacks fire.
@@ -313,16 +467,16 @@ open class TestableNetworkCallback private constructor(
validated: Boolean?,
tmt: Long
) {
- expectCallback<Available>(net, tmt)
+ expect<Available>(net, tmt)
if (suspended) {
- expectCallback<Suspended>(net, tmt)
+ expect<Suspended>(net, tmt)
}
expectCapabilitiesThat(net, tmt) {
validated == null || validated == it.hasCapability(
NET_CAPABILITY_VALIDATED
)
}
- expectCallback<LinkPropertiesChanged>(net, tmt)
+ expect<LinkPropertiesChanged>(net, tmt)
}
// Backward compatibility for existing Java code. Use named arguments instead and remove all
@@ -333,17 +487,15 @@ open class TestableNetworkCallback private constructor(
tmt: Long = defaultTimeoutMs
) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
- fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) {
- expectCallback<BlockedStatus>(net, tmt).also {
- assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
- }
- }
+ fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) =
+ expect<BlockedStatus>(net, tmt, "Unexpected blocked status") {
+ it.blocked == blocked
+ }
- fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) {
- expectCallback<BlockedStatusInt>(net, tmt).also {
- assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
- }
- }
+ fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) =
+ expect<BlockedStatusInt>(net, tmt, "Unexpected blocked status") {
+ it.blocked == blocked
+ }
// Expects the available callbacks (where the onCapabilitiesChanged must contain the
// VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
@@ -353,7 +505,7 @@ open class TestableNetworkCallback private constructor(
val mark = history.mark
expectAvailableCallbacks(net, tmt = tmt)
val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
- assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt))
+ assertEquals(firstCaps, expect<CapabilitiesChanged>(net, tmt))
}
// Expects the available callbacks where the onCapabilitiesChanged must not have validated,
@@ -381,26 +533,6 @@ open class TestableNetworkCallback private constructor(
val network: Network
}
- @JvmOverloads
- open fun <T : CallbackEntry> expectCallback(
- type: KClass<T>,
- n: Network?,
- timeoutMs: Long = defaultTimeoutMs
- ) = pollForNextCallback(timeoutMs).also {
- val network = n ?: NULL_NETWORK
- // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of
- // this writing this would be the only use of this library in the tests.
- assertTrue(type.java.isInstance(it) && (ANY_NETWORK === n || it.network == network),
- "Unexpected callback : $it, expected ${type.java} with Network[$network]")
- } as T
-
- @JvmOverloads
- open fun <T : CallbackEntry> expectCallback(
- type: KClass<T>,
- n: HasNetwork?,
- timeoutMs: Long = defaultTimeoutMs
- ) = expectCallback(type, n?.network, timeoutMs)
-
fun expectAvailableCallbacks(
n: HasNetwork,
suspended: Boolean,