diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-11-04 00:35:56 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-11-04 00:35:56 +0000 |
commit | 1f309d8bad45209ee43fd3a6600d10461eabf71e (patch) | |
tree | 3feb7e66dff1c99fbccbc8055f526ee38bd68259 | |
parent | a54582e7acc3eb5ef2bab7dee0c3a81009533a72 (diff) | |
parent | e74812d8a96c8b9a0aec84ad914783de24751315 (diff) | |
download | net-1f309d8bad45209ee43fd3a6600d10461eabf71e.tar.gz |
Snap for 9254005 from e74812d8a96c8b9a0aec84ad914783de24751315 to mainline-ipsec-releaseaml_ips_331910010aml_ips_331312000aml_ips_331310000android13-mainline-ipsec-release
Change-Id: I335eafe4388651891bdc606939ffa9ab219049b4
27 files changed, 1985 insertions, 194 deletions
diff --git a/common/Android.bp b/common/Android.bp index 4dfc752b..f4997a16 100644 --- a/common/Android.bp +++ b/common/Android.bp @@ -96,7 +96,10 @@ java_library { visibility: [ "//packages/services/Iwlan:__subpackages__", ], - libs: ["framework-annotations-lib"], + libs: [ + "framework-annotations-lib", + "framework-connectivity.stubs.module_lib", + ], } filegroup { diff --git a/common/device/com/android/net/module/util/BpfDump.java b/common/device/com/android/net/module/util/BpfDump.java index 2534d616..7549e712 100644 --- a/common/device/com/android/net/module/util/BpfDump.java +++ b/common/device/com/android/net/module/util/BpfDump.java @@ -15,6 +15,8 @@ */ package com.android.net.module.util; +import static android.system.OsConstants.R_OK; + import android.system.ErrnoException; import android.system.Os; import android.util.Base64; @@ -110,6 +112,24 @@ public class BpfDump { } } + /** + * Dump the BpfMap status + */ + public static <K extends Struct, V extends Struct> void dumpMapStatus(IBpfMap<K, V> map, + PrintWriter pw, String mapName, String path) { + if (map != null) { + pw.println(mapName + ": OK"); + return; + } + try { + Os.access(path, R_OK); + pw.println(mapName + ": NULL(map is pinned to " + path + ")"); + } catch (ErrnoException e) { + pw.println(mapName + ": NULL(map is not pinned to " + path + ": " + + Os.strerror(e.errno) + ")"); + } + } + // TODO: add a helper to dump bpf map content with the map name, the header line // (ex: "BPF ingress map: iif nat64Prefix v6Addr -> v4Addr oif"), a lambda that // knows how to dump each line, and the PrintWriter. diff --git a/common/framework/com/android/net/module/util/BitUtils.java b/common/framework/com/android/net/module/util/BitUtils.java new file mode 100644 index 00000000..2b32e860 --- /dev/null +++ b/common/framework/com/android/net/module/util/BitUtils.java @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.net.module.util; + +import android.annotation.NonNull; + +/** + * @hide + */ +public class BitUtils { + /** + * Unpacks long value into an array of bits. + */ + public static int[] unpackBits(long val) { + int size = Long.bitCount(val); + int[] result = new int[size]; + int index = 0; + int bitPos = 0; + while (val != 0) { + if ((val & 1) == 1) result[index++] = bitPos; + val = val >>> 1; + bitPos++; + } + return result; + } + + /** + * Packs a list of ints in the same way as packBits() + * + * Each passed int is the rank of a bit that should be set in the returned long. + * Example : passing (1,3) will return in 0b00001010 and passing (5,6,0) will return 0b01100001 + * + * @param bits bits to pack + * @return a long with the specified bits set. + */ + public static long packBitList(int... bits) { + return packBits(bits); + } + + /** + * Packs array of bits into a long value. + * + * Each passed int is the rank of a bit that should be set in the returned long. + * Example : passing [1,3] will return in 0b00001010 and passing [5,6,0] will return 0b01100001 + * + * @param bits bits to pack + * @return a long with the specified bits set. + */ + public static long packBits(int[] bits) { + long packed = 0; + for (int b : bits) { + packed |= (1L << b); + } + return packed; + } + + /** + * An interface for a function that can retrieve a name associated with an int. + * + * This is useful for bitfields like network capabilities or network score policies. + */ + @FunctionalInterface + public interface NameOf { + /** Retrieve the name associated with the passed value */ + String nameOf(int value); + } + + /** + * Given a bitmask and a name fetcher, append names of all set bits to the builder + * + * This method takes all bit sets in the passed bitmask, will figure out the name associated + * with the weight of each bit with the passed name fetcher, and append each name to the + * passed StringBuilder, separated by the passed separator. + * + * For example, if the bitmask is 0110, and the name fetcher return "BIT_1" to "BIT_4" for + * numbers from 1 to 4, and the separator is "&", this method appends "BIT_2&BIT3" to the + * StringBuilder. + */ + public static void appendStringRepresentationOfBitMaskToStringBuilder(@NonNull StringBuilder sb, + long bitMask, @NonNull NameOf nameFetcher, @NonNull String separator) { + int bitPos = 0; + boolean firstElementAdded = false; + while (bitMask != 0) { + if ((bitMask & 1) != 0) { + if (firstElementAdded) { + sb.append(separator); + } else { + firstElementAdded = true; + } + sb.append(nameFetcher.nameOf(bitPos)); + } + bitMask >>>= 1; + ++bitPos; + } + } +} diff --git a/common/framework/com/android/net/module/util/ByteUtils.java b/common/framework/com/android/net/module/util/ByteUtils.java new file mode 100644 index 00000000..290ed465 --- /dev/null +++ b/common/framework/com/android/net/module/util/ByteUtils.java @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.net.module.util; + +import android.annotation.NonNull; + +/** + * Byte utility functions. + * @hide + */ +public class ByteUtils { + /** + * Returns the index of the first appearance of the value {@code target} in {@code array}. + * + * @param array an array of {@code byte} values, possibly empty + * @param target a primitive {@code byte} value + * @return the least index {@code i} for which {@code array[i] == target}, or {@code -1} if no + * such index exists. + */ + public static int indexOf(@NonNull byte[] array, byte target) { + return indexOf(array, target, 0, array.length); + } + + private static int indexOf(byte[] array, byte target, int start, int end) { + for (int i = start; i < end; i++) { + if (array[i] == target) { + return i; + } + } + return -1; + } + + /** + * Returns the values from each provided array combined into a single array. For example, {@code + * concat(new byte[] {a, b}, new byte[] {}, new byte[] {c}} returns the array {@code {a, b, c}}. + * + * @param arrays zero or more {@code byte} arrays + * @return a single array containing all the values from the source arrays, in order + */ + public static byte[] concat(@NonNull byte[]... arrays) { + int length = 0; + for (byte[] array : arrays) { + length += array.length; + } + byte[] result = new byte[length]; + int pos = 0; + for (byte[] array : arrays) { + System.arraycopy(array, 0, result, pos, array.length); + pos += array.length; + } + return result; + } +} diff --git a/common/framework/com/android/net/module/util/DnsPacket.java b/common/framework/com/android/net/module/util/DnsPacket.java index 080781c1..702d114e 100644 --- a/common/framework/com/android/net/module/util/DnsPacket.java +++ b/common/framework/com/android/net/module/util/DnsPacket.java @@ -16,15 +16,34 @@ package com.android.net.module.util; +import static android.net.DnsResolver.TYPE_A; +import static android.net.DnsResolver.TYPE_AAAA; + +import static com.android.internal.annotations.VisibleForTesting.Visibility.PACKAGE; +import static com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE; +import static com.android.net.module.util.DnsPacketUtils.DnsRecordParser.domainNameToLabels; + +import android.annotation.IntDef; import android.annotation.NonNull; import android.annotation.Nullable; +import android.text.TextUtils; +import com.android.internal.annotations.VisibleForTesting; import com.android.net.module.util.DnsPacketUtils.DnsRecordParser; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.net.InetAddress; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Objects; /** * Defines basic data for DNS protocol based on RFC 1035. @@ -34,6 +53,12 @@ import java.util.List; */ public abstract class DnsPacket { /** + * Type of the canonical name for an alias. Refer to RFC 1035 section 3.2.2. + */ + // TODO: Define the constant as a public constant in DnsResolver since it can never change. + private static final int TYPE_CNAME = 5; + + /** * Thrown when parsing packet failed. */ public static class ParseException extends RuntimeException { @@ -50,13 +75,29 @@ public abstract class DnsPacket { } /** - * DNS header for DNS protocol based on RFC 1035. + * DNS header for DNS protocol based on RFC 1035 section 4.1.1. + * + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | ID | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * |QR| Opcode |AA|TC|RD|RA| Z | RCODE | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | QDCOUNT | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | ANCOUNT | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | NSCOUNT | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | ARCOUNT | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ */ - public class DnsHeader { + public static class DnsHeader { private static final String TAG = "DnsHeader"; - public final int id; - public final int flags; - public final int rcode; + private static final int SIZE_IN_BYTES = 12; + private final int mId; + private final int mFlags; private final int[] mRecordCount; /* If this bit in the 'flags' field is set to 0, the DNS message corresponding to this @@ -73,10 +114,11 @@ public abstract class DnsPacket { * advanced to the end of the DNS header record. * This is meant to chain with other methods reading a DNS response in sequence. */ - DnsHeader(@NonNull ByteBuffer buf) throws BufferUnderflowException { - id = Short.toUnsignedInt(buf.getShort()); - flags = Short.toUnsignedInt(buf.getShort()); - rcode = flags & 0xF; + @VisibleForTesting + public DnsHeader(@NonNull ByteBuffer buf) throws BufferUnderflowException { + Objects.requireNonNull(buf); + mId = Short.toUnsignedInt(buf.getShort()); + mFlags = Short.toUnsignedInt(buf.getShort()); mRecordCount = new int[NUM_SECTIONS]; for (int i = 0; i < NUM_SECTIONS; ++i) { mRecordCount[i] = Short.toUnsignedInt(buf.getShort()); @@ -88,7 +130,23 @@ public abstract class DnsPacket { * RFC 1035 Section 4.1.1. */ public boolean isResponse() { - return (flags & (1 << FLAGS_SECTION_QR_BIT)) != 0; + return (mFlags & (1 << FLAGS_SECTION_QR_BIT)) != 0; + } + + /** + * Create a new DnsHeader from specified parameters. + * + * This constructor only builds the question and answer sections. Authority + * and additional sections are not supported. Useful when synthesizing dns + * responses from query or reply packets. + */ + @VisibleForTesting + public DnsHeader(int id, int flags, int qdcount, int ancount) { + this.mId = id; + this.mFlags = flags; + mRecordCount = new int[NUM_SECTIONS]; + mRecordCount[QDSECTION] = qdcount; + mRecordCount[ANSECTION] = ancount; } /** @@ -97,15 +155,103 @@ public abstract class DnsPacket { public int getRecordCount(int type) { return mRecordCount[type]; } + + /** + * Get flags of this instance. + */ + public int getFlags() { + return mFlags; + } + + /** + * Get id of this instance. + */ + public int getId() { + return mId; + } + + @Override + public String toString() { + return "DnsHeader{" + "id=" + mId + ", flags=" + mFlags + + ", recordCounts=" + Arrays.toString(mRecordCount) + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o.getClass() != getClass()) return false; + final DnsHeader other = (DnsHeader) o; + return mId == other.mId + && mFlags == other.mFlags + && Arrays.equals(mRecordCount, other.mRecordCount); + } + + @Override + public int hashCode() { + return 31 * mId + 37 * mFlags + Arrays.hashCode(mRecordCount); + } + + /** + * Get DnsHeader as byte array. + */ + @NonNull + public byte[] getBytes() { + // TODO: if this is called often, optimize the ByteBuffer out and write to the + // array directly. + final ByteBuffer buf = ByteBuffer.allocate(SIZE_IN_BYTES); + buf.putShort((short) mId); + buf.putShort((short) mFlags); + for (int i = 0; i < NUM_SECTIONS; ++i) { + buf.putShort((short) mRecordCount[i]); + } + return buf.array(); + } } /** * Superclass for DNS questions and DNS resource records. * - * DNS questions (No TTL/RDATA) - * DNS resource records (With TTL/RDATA) + * DNS questions (No TTL/RDLENGTH/RDATA) based on RFC 1035 section 4.1.2. + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | | + * / QNAME / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | QTYPE | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | QCLASS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * + * DNS resource records (With TTL/RDLENGTH/RDATA) based on RFC 1035 section 4.1.3. + * 1 1 1 1 1 1 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | | + * / / + * / NAME / + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | TYPE | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | CLASS | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | TTL | + * | | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + * | RDLENGTH | + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| + * / RDATA / + * / / + * +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ */ - public class DnsRecord { + // TODO: Make DnsResourceRecord and DnsQuestion subclasses of DnsRecord, and construct + // corresponding object from factory methods. + public static class DnsRecord { + // Refer to RFC 1035 section 2.3.4 for MAXNAMESIZE. + // NAME_NORMAL and NAME_COMPRESSION are used for checking name compression, + // refer to rfc 1035 section 4.1.4. private static final int MAXNAMESIZE = 255; public static final int NAME_NORMAL = 0; public static final int NAME_COMPRESSION = 0xC0; @@ -117,22 +263,31 @@ public abstract class DnsPacket { public final int nsClass; public final long ttl; private final byte[] mRdata; + /** + * Type of this DNS record. + */ + @RecordType + public final int rType; /** * Create a new DnsRecord from a positioned ByteBuffer. * * Reads the passed ByteBuffer from its current position and decodes a DNS record. * When this constructor returns, the reading position of the ByteBuffer has been - * advanced to the end of the DNS header record. + * advanced to the end of the DNS resource record. * This is meant to chain with other methods reading a DNS response in sequence. * + * @param rType Type of the record. * @param buf ByteBuffer input of record, must be in network byte order * (which is the default). */ - DnsRecord(int recordType, @NonNull ByteBuffer buf) + @VisibleForTesting(visibility = PACKAGE) + public DnsRecord(@RecordType int rType, @NonNull ByteBuffer buf) throws BufferUnderflowException, ParseException { + Objects.requireNonNull(buf); + this.rType = rType; dName = DnsRecordParser.parseName(buf, 0 /* Parse depth */, - /* isNameCompressionSupported= */ true); + true /* isNameCompressionSupported */); if (dName.length() > MAXNAMESIZE) { throw new ParseException( "Parse name fail, name size is too long: " + dName.length()); @@ -140,7 +295,7 @@ public abstract class DnsPacket { nsType = Short.toUnsignedInt(buf.getShort()); nsClass = Short.toUnsignedInt(buf.getShort()); - if (recordType != QDSECTION) { + if (rType != QDSECTION) { ttl = Integer.toUnsignedLong(buf.getInt()); final int length = Short.toUnsignedInt(buf.getShort()); mRdata = new byte[length]; @@ -152,6 +307,95 @@ public abstract class DnsPacket { } /** + * Make an A or AAAA record based on the specified parameters. + * + * @param rType Type of the record, can be {@link #ANSECTION}, {@link #ARSECTION} + * or {@link #NSSECTION}. + * @param dName Domain name of the record. + * @param nsClass Class of the record. See RFC 1035 section 3.2.4. + * @param ttl time interval (in seconds) that the resource record may be + * cached before it should be discarded. Zero values are + * interpreted to mean that the RR can only be used for the + * transaction in progress, and should not be cached. + * @param address Instance of {@link InetAddress} + * @return A record if the {@code address} is an IPv4 address, or AAAA record if the + * {@code address} is an IPv6 address. + */ + public static DnsRecord makeAOrAAAARecord(int rType, @NonNull String dName, + int nsClass, long ttl, @NonNull InetAddress address) throws IOException { + final int nsType = (address.getAddress().length == 4) ? TYPE_A : TYPE_AAAA; + return new DnsRecord(rType, dName, nsType, nsClass, ttl, address, null /* rDataStr */); + } + + /** + * Make an CNAME record based on the specified parameters. + * + * @param rType Type of the record, can be {@link #ANSECTION}, {@link #ARSECTION} + * or {@link #NSSECTION}. + * @param dName Domain name of the record. + * @param nsClass Class of the record. See RFC 1035 section 3.2.4. + * @param ttl time interval (in seconds) that the resource record may be + * cached before it should be discarded. Zero values are + * interpreted to mean that the RR can only be used for the + * transaction in progress, and should not be cached. + * @param domainName Canonical name of the {@code dName}. + * @return A record if the {@code address} is an IPv4 address, or AAAA record if the + * {@code address} is an IPv6 address. + */ + public static DnsRecord makeCNameRecord(int rType, @NonNull String dName, int nsClass, + long ttl, @NonNull String domainName) throws IOException { + return new DnsRecord(rType, dName, TYPE_CNAME, nsClass, ttl, null /* address */, + domainName); + } + + /** + * Make a DNS question based on the specified parameters. + */ + public static DnsRecord makeQuestion(@NonNull String dName, int nsType, int nsClass) { + return new DnsRecord(dName, nsType, nsClass); + } + + private static String requireHostName(@NonNull String name) { + if (!DnsRecordParser.isHostName(name)) { + throw new IllegalArgumentException("Expected domain name but got " + name); + } + return name; + } + + /** + * Create a new query DnsRecord from specified parameters, useful when synthesizing + * dns response. + */ + private DnsRecord(@NonNull String dName, int nsType, int nsClass) { + this.rType = QDSECTION; + this.dName = requireHostName(dName); + this.nsType = nsType; + this.nsClass = nsClass; + mRdata = null; + this.ttl = 0; + } + + /** + * Create a new CNAME/A/AAAA DnsRecord from specified parameters. + * + * @param address The address only used when synthesizing A or AAAA record. + * @param rDataStr The alias of the domain, only used when synthesizing CNAME record. + */ + private DnsRecord(@RecordType int rType, @NonNull String dName, int nsType, int nsClass, + long ttl, @Nullable InetAddress address, @Nullable String rDataStr) + throws IOException { + this.rType = rType; + this.dName = requireHostName(dName); + this.nsType = nsType; + this.nsClass = nsClass; + if (rType < 0 || rType >= NUM_SECTIONS || rType == QDSECTION) { + throw new IllegalArgumentException("Unexpected record type: " + rType); + } + mRdata = nsType == TYPE_CNAME ? domainNameToLabels(rDataStr) : address.getAddress(); + this.ttl = ttl; + } + + /** * Get a copy of rdata. */ @Nullable @@ -159,13 +403,84 @@ public abstract class DnsPacket { return (mRdata == null) ? null : mRdata.clone(); } + /** + * Get DnsRecord as byte array. + */ + @NonNull + public byte[] getBytes() throws IOException { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream dos = new DataOutputStream(baos); + dos.write(domainNameToLabels(dName)); + dos.writeShort(nsType); + dos.writeShort(nsClass); + if (rType != QDSECTION) { + dos.writeInt((int) ttl); + if (mRdata == null) { + dos.writeShort(0); + } else { + dos.writeShort(mRdata.length); + dos.write(mRdata); + } + } + return baos.toByteArray(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o.getClass() != getClass()) return false; + final DnsRecord other = (DnsRecord) o; + return rType == other.rType + && nsType == other.nsType + && nsClass == other.nsClass + && ttl == other.ttl + && TextUtils.equals(dName, other.dName) + && Arrays.equals(mRdata, other.mRdata); + } + + @Override + public int hashCode() { + return 31 * Objects.hash(dName) + + 37 * ((int) (ttl & 0xFFFFFFFF)) + + 41 * ((int) (ttl >> 32)) + + 43 * nsType + + 47 * nsClass + + 53 * rType + + Arrays.hashCode(mRdata); + } + + @Override + public String toString() { + return "DnsRecord{" + + "rType=" + rType + + ", dName='" + dName + '\'' + + ", nsType=" + nsType + + ", nsClass=" + nsClass + + ", ttl=" + ttl + + ", mRdata=" + Arrays.toString(mRdata) + + '}'; + } } + /** + * Header section types, refer to RFC 1035 section 4.1.1. + */ public static final int QDSECTION = 0; public static final int ANSECTION = 1; public static final int NSSECTION = 2; public static final int ARSECTION = 3; - private static final int NUM_SECTIONS = ARSECTION + 1; + @VisibleForTesting(visibility = PRIVATE) + static final int NUM_SECTIONS = ARSECTION + 1; + + @Retention(RetentionPolicy.SOURCE) + @IntDef(value = { + QDSECTION, + ANSECTION, + NSSECTION, + ARSECTION, + }) + public @interface RecordType {} + private static final String TAG = DnsPacket.class.getSimpleName(); @@ -189,9 +504,7 @@ public abstract class DnsPacket { for (int i = 0; i < NUM_SECTIONS; ++i) { final int count = mHeader.getRecordCount(i); - if (count > 0) { - mRecords[i] = new ArrayList(count); - } + mRecords[i] = new ArrayList(count); for (int j = 0; j < count; ++j) { try { mRecords[i].add(new DnsRecord(i, buffer)); @@ -201,4 +514,61 @@ public abstract class DnsPacket { } } } + + /** + * Create a new {@link #DnsPacket} from specified parameters. + * + * Note that authority records section and additional records section is not supported. + */ + protected DnsPacket(@NonNull DnsHeader header, @NonNull List<DnsRecord> qd, + @NonNull List<DnsRecord> an) { + mHeader = Objects.requireNonNull(header); + mRecords = new List[NUM_SECTIONS]; + mRecords[QDSECTION] = Collections.unmodifiableList(new ArrayList<>(qd)); + mRecords[ANSECTION] = Collections.unmodifiableList(new ArrayList<>(an)); + mRecords[NSSECTION] = new ArrayList<>(); + mRecords[ARSECTION] = new ArrayList<>(); + for (int i = 0; i < NUM_SECTIONS; i++) { + if (mHeader.mRecordCount[i] != mRecords[i].size()) { + throw new IllegalArgumentException("Record count mismatch: expected " + + mHeader.mRecordCount[i] + " but was " + mRecords[i]); + } + } + } + + /** + * Get DnsPacket as byte array. + */ + public @NonNull byte[] getBytes() throws IOException { + final ByteArrayOutputStream buf = new ByteArrayOutputStream(); + buf.write(mHeader.getBytes()); + + for (int i = 0; i < NUM_SECTIONS; ++i) { + for (final DnsRecord record : mRecords[i]) { + buf.write(record.getBytes()); + } + } + return buf.toByteArray(); + } + + @Override + public String toString() { + return "DnsPacket{" + "header=" + mHeader + ", records='" + Arrays.toString(mRecords) + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o.getClass() != getClass()) return false; + final DnsPacket other = (DnsPacket) o; + return Objects.equals(mHeader, other.mHeader) + && Arrays.deepEquals(mRecords, other.mRecords); + } + + @Override + public int hashCode() { + int result = Objects.hash(mHeader); + result = 31 * result + Arrays.hashCode(mRecords); + return result; + } } diff --git a/common/framework/com/android/net/module/util/DnsPacketUtils.java b/common/framework/com/android/net/module/util/DnsPacketUtils.java index 3448cadf..c47bfa07 100644 --- a/common/framework/com/android/net/module/util/DnsPacketUtils.java +++ b/common/framework/com/android/net/module/util/DnsPacketUtils.java @@ -20,10 +20,19 @@ import static com.android.net.module.util.DnsPacket.DnsRecord.NAME_COMPRESSION; import static com.android.net.module.util.DnsPacket.DnsRecord.NAME_NORMAL; import android.annotation.NonNull; +import android.annotation.Nullable; +import android.net.InetAddresses; +import android.net.ParseException; import android.text.TextUtils; +import android.util.Patterns; +import com.android.internal.annotations.VisibleForTesting; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.text.DecimalFormat; import java.text.FieldPosition; @@ -38,6 +47,7 @@ public final class DnsPacketUtils { */ public static class DnsRecordParser { private static final int MAXLABELSIZE = 63; + private static final int MAXNAMESIZE = 255; private static final int MAXLABELCOUNT = 128; private static final DecimalFormat sByteFormat = new DecimalFormat(); @@ -48,7 +58,8 @@ public final class DnsPacketUtils { * * <p>Follows the same conversion rules of the native code (ns_name.c in libc). */ - private static String labelToString(@NonNull byte[] label) { + @VisibleForTesting + static String labelToString(@NonNull byte[] label) { final StringBuffer sb = new StringBuffer(); for (int i = 0; i < label.length; ++i) { @@ -72,6 +83,52 @@ public final class DnsPacketUtils { } /** + * Converts domain name to labels according to RFC 1035. + * + * @param name Domain name as String that needs to be converted to labels. + * @return An encoded byte array that is constructed out of labels, + * and ends with zero-length label. + * @throws ParseException if failed to parse the given domain name or + * IOException if failed to output labels. + */ + public static @NonNull byte[] domainNameToLabels(@NonNull String name) throws + IOException, ParseException { + if (name.length() > MAXNAMESIZE) { + throw new ParseException("Domain name exceeds max length: " + name.length()); + } + if (!isHostName(name)) { + throw new ParseException("Failed to parse domain name: " + name); + } + final ByteArrayOutputStream buf = new ByteArrayOutputStream(); + final String[] labels = name.split("\\."); + for (final String label : labels) { + if (label.length() > MAXLABELSIZE) { + throw new ParseException("label is too long: " + label); + } + buf.write(label.length()); + // Encode as UTF-8 as suggested in RFC 6055 section 3. + buf.write(label.getBytes(StandardCharsets.UTF_8)); + } + buf.write(0x00); // end with zero-length label + return buf.toByteArray(); + } + + /** + * Check whether the input is a valid hostname based on rfc 1035 section 3.3. + * + * @param hostName the target host name. + * @return true if the input is a valid hostname. + */ + public static boolean isHostName(@Nullable String hostName) { + // TODO: Use {@code Patterns.HOST_NAME} if available. + // Patterns.DOMAIN_NAME accepts host names or IP addresses, so reject + // IP addresses. + return hostName != null + && Patterns.DOMAIN_NAME.matcher(hostName).matches() + && !InetAddresses.isNumericAddress(hostName); + } + + /** * Parses the domain / target name of a DNS record. * * As described in RFC 1035 Section 4.1.3, the NAME field of a DNS Resource Record always diff --git a/common/framework/com/android/net/module/util/IpUtils.java b/common/framework/com/android/net/module/util/IpUtils.java index dbc22854..569733ed 100644 --- a/common/framework/com/android/net/module/util/IpUtils.java +++ b/common/framework/com/android/net/module/util/IpUtils.java @@ -39,7 +39,7 @@ public class IpUtils { /** * Performs an IP checksum (used in IP header and across UDP - * payload) on the specified portion of a ByteBuffer. The seed + * payload) or ICMP checksum on the specified portion of a ByteBuffer. The seed * allows the checksum to commence with a specified value. */ private static int checksum(ByteBuffer buf, int seed, int start, int end) { @@ -145,6 +145,14 @@ public class IpUtils { } /** + * Calculate the ICMP checksum for an ICMPv4 packet. + */ + public static short icmpChecksum(ByteBuffer buf, int transportOffset, int transportLen) { + // ICMP checksum doesn't include pseudo-header. See RFC 792. + return (short) checksum(buf, 0, transportOffset, transportOffset + transportLen); + } + + /** * Calculate the ICMPv6 checksum for an ICMPv6 packet. */ public static short icmpv6Checksum(ByteBuffer buf, int ipOffset, int transportOffset, diff --git a/common/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java b/common/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java index 26c24f8d..54ce01ec 100644 --- a/common/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java +++ b/common/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java @@ -44,6 +44,9 @@ import static android.net.NetworkCapabilities.TRANSPORT_VPN; import static android.net.NetworkCapabilities.TRANSPORT_WIFI; import static android.net.NetworkCapabilities.TRANSPORT_WIFI_AWARE; +import static com.android.net.module.util.BitUtils.packBitList; +import static com.android.net.module.util.BitUtils.unpackBits; + import android.annotation.NonNull; import android.net.NetworkCapabilities; @@ -181,49 +184,4 @@ public final class NetworkCapabilitiesUtils { return false; } - /** - * Unpacks long value into an array of bits. - */ - public static int[] unpackBits(long val) { - int size = Long.bitCount(val); - int[] result = new int[size]; - int index = 0; - int bitPos = 0; - while (val != 0) { - if ((val & 1) == 1) result[index++] = bitPos; - val = val >>> 1; - bitPos++; - } - return result; - } - - /** - * Packs a list of ints in the same way as packBits() - * - * Each passed int is the rank of a bit that should be set in the returned long. - * Example : passing (1,3) will return in 0b00001010 and passing (5,6,0) will return 0b01100001 - * - * @param bits bits to pack - * @return a long with the specified bits set. - */ - public static long packBitList(int... bits) { - return packBits(bits); - } - - /** - * Packs array of bits into a long value. - * - * Each passed int is the rank of a bit that should be set in the returned long. - * Example : passing [1,3] will return in 0b00001010 and passing [5,6,0] will return 0b01100001 - * - * @param bits bits to pack - * @return a long with the specified bits set. - */ - public static long packBits(int[] bits) { - long packed = 0; - for (int b : bits) { - packed |= (1L << b); - } - return packed; - } } diff --git a/common/framework/com/android/net/module/util/NetworkStackConstants.java b/common/framework/com/android/net/module/util/NetworkStackConstants.java index 1ad9c359..ebef1557 100644 --- a/common/framework/com/android/net/module/util/NetworkStackConstants.java +++ b/common/framework/com/android/net/module/util/NetworkStackConstants.java @@ -127,6 +127,14 @@ public final class NetworkStackConstants { (Inet6Address) InetAddresses.parseNumericAddress("ff02::3"); /** + * ICMP constants. + * + * See also: + * - https://tools.ietf.org/html/rfc792 + */ + public static final int ICMP_CHECKSUM_OFFSET = 2; + + /** * ICMPv6 constants. * * See also: diff --git a/common/native/bpf_headers/include/bpf/BpfUtils.h b/common/native/bpf_headers/include/bpf/BpfUtils.h index 4429164e..157f210c 100644 --- a/common/native/bpf_headers/include/bpf/BpfUtils.h +++ b/common/native/bpf_headers/include/bpf/BpfUtils.h @@ -115,23 +115,16 @@ static inline bool isAtLeastKernelVersion(unsigned major, unsigned minor, unsign return kernelVersion() >= KVER(major, minor, sub); } -#define SKIP_IF_BPF_SUPPORTED \ - do { \ - if (android::bpf::isAtLeastKernelVersion(4, 9, 0)) \ - GTEST_SKIP() << "Skip: bpf is supported."; \ - } while (0) - #define SKIP_IF_BPF_NOT_SUPPORTED \ do { \ if (!android::bpf::isAtLeastKernelVersion(4, 9, 0)) \ GTEST_SKIP() << "Skip: bpf is not supported."; \ } while (0) -#define SKIP_IF_EXTENDED_BPF_NOT_SUPPORTED \ - do { \ - if (!android::bpf::isAtLeastKernelVersion(4, 14, 0)) \ - GTEST_SKIP() << "Skip: extended bpf feature not supported."; \ - } while (0) +// Only used by tm-mainline-prod's system/netd/tests/bpf_base_test.cpp +// but platform and platform tests aren't expected to build/work in tm-mainline-prod +// so we can just trivialize this +#define SKIP_IF_EXTENDED_BPF_NOT_SUPPORTED #define SKIP_IF_XDP_NOT_SUPPORTED \ do { \ diff --git a/common/native/bpf_headers/include/bpf/bpf_helpers.h b/common/native/bpf_headers/include/bpf/bpf_helpers.h index 35d9f31d..56bc7d01 100644 --- a/common/native/bpf_headers/include/bpf/bpf_helpers.h +++ b/common/native/bpf_headers/include/bpf/bpf_helpers.h @@ -138,6 +138,10 @@ static int (*bpf_map_update_elem_unsafe)(const struct bpf_map_def* map, const vo static int (*bpf_map_delete_elem_unsafe)(const struct bpf_map_def* map, const void* key) = (void*)BPF_FUNC_map_delete_elem; +static int (*bpf_for_each_map_elem)(const struct bpf_map_def* map, void *callback_fn, + void *callback_ctx, unsigned long long flags) = (void*) + BPF_FUNC_for_each_map_elem; + #define BPF_ANNOTATE_KV_PAIR(name, type_key, type_val) \ struct ____btf_map_##name { \ type_key key; \ @@ -147,6 +151,23 @@ static int (*bpf_map_delete_elem_unsafe)(const struct bpf_map_def* map, __attribute__ ((section(".maps." #name), used)) \ ____btf_map_##name = { } +/* There exist buggy kernels with pre-T OS, that due to + * kernel patch "[ALPS05162612] bpf: fix ubsan error" + * do not support userspace writes into non-zero index of bpf map arrays. + * + * We use this assert to prevent us from being able to define such a map. + */ + +#ifdef THIS_BPF_PROGRAM_IS_FOR_TEST_PURPOSES_ONLY +#define BPF_MAP_ASSERT_OK(type, entries, mode) +#elif BPFLOADER_MIN_VER >= BPFLOADER_T_BETA3_VERSION +#define BPF_MAP_ASSERT_OK(type, entries, mode) +#else +#define BPF_MAP_ASSERT_OK(type, entries, mode) \ + _Static_assert(((type) != BPF_MAP_TYPE_ARRAY) || ((entries) <= 1) || !((mode) & 0222), \ + "Writable arrays with more than 1 element not supported on pre-T devices.") +#endif + /* type safe macro to declare a map and related accessor functions */ #define DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md, \ selinux, pindir, share) \ @@ -167,6 +188,7 @@ static int (*bpf_map_delete_elem_unsafe)(const struct bpf_map_def* map, .pin_subdir = pindir, \ .shared = share, \ }; \ + BPF_MAP_ASSERT_OK(BPF_MAP_TYPE_##TYPE, (num_entries), (md)); \ BPF_ANNOTATE_KV_PAIR(the_map, KeyType, ValueType); \ \ static inline __always_inline __unused ValueType* bpf_##the_map##_lookup_elem( \ @@ -205,6 +227,10 @@ static int (*bpf_map_delete_elem_unsafe)(const struct bpf_map_def* map, DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \ DEFAULT_BPF_MAP_UID, AID_ROOT, 0600) +#define DEFINE_BPF_MAP_RO(the_map, TYPE, KeyType, ValueType, num_entries, gid) \ + DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \ + DEFAULT_BPF_MAP_UID, gid, 0440) + #define DEFINE_BPF_MAP_GWO(the_map, TYPE, KeyType, ValueType, num_entries, gid) \ DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \ DEFAULT_BPF_MAP_UID, gid, 0620) diff --git a/common/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt b/common/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt new file mode 100644 index 00000000..02367162 --- /dev/null +++ b/common/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.net.module.util + +import com.android.net.module.util.BitUtils.appendStringRepresentationOfBitMaskToStringBuilder +import com.android.net.module.util.BitUtils.packBits +import com.android.net.module.util.BitUtils.unpackBits +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class BitUtilsTests { + @Test + fun testBitPackingTestCase() { + runBitPackingTestCase(0, intArrayOf()) + runBitPackingTestCase(1, intArrayOf(0)) + runBitPackingTestCase(3, intArrayOf(0, 1)) + runBitPackingTestCase(4, intArrayOf(2)) + runBitPackingTestCase(63, intArrayOf(0, 1, 2, 3, 4, 5)) + runBitPackingTestCase(Long.MAX_VALUE.inv(), intArrayOf(63)) + runBitPackingTestCase(Long.MAX_VALUE.inv() + 1, intArrayOf(0, 63)) + runBitPackingTestCase(Long.MAX_VALUE.inv() + 2, intArrayOf(1, 63)) + } + + fun runBitPackingTestCase(packedBits: Long, bits: IntArray) { + assertEquals(packedBits, packBits(bits)) + assertTrue(bits contentEquals unpackBits(packedBits)) + } + + @Test + fun testAppendStringRepresentationOfBitMaskToStringBuilder() { + runTestAppendStringRepresentationOfBitMaskToStringBuilder("", 0) + runTestAppendStringRepresentationOfBitMaskToStringBuilder("BIT0", 0b1) + runTestAppendStringRepresentationOfBitMaskToStringBuilder("BIT1&BIT2&BIT4", 0b10110) + runTestAppendStringRepresentationOfBitMaskToStringBuilder( + "BIT0&BIT60&BIT61&BIT62&BIT63", + (0b11110000_00000000_00000000_00000000 shl 32) + + 0b00000000_00000000_00000000_00000001) + } + + fun runTestAppendStringRepresentationOfBitMaskToStringBuilder(expected: String, bitMask: Long) { + StringBuilder().let { + appendStringRepresentationOfBitMaskToStringBuilder(it, bitMask, { i -> "BIT$i" }, "&") + assertEquals(expected, it.toString()) + } + } +} diff --git a/common/tests/unit/src/com/android/net/module/util/BpfDumpTest.java b/common/tests/unit/src/com/android/net/module/util/BpfDumpTest.java index 39329258..4e483322 100644 --- a/common/tests/unit/src/com/android/net/module/util/BpfDumpTest.java +++ b/common/tests/unit/src/com/android/net/module/util/BpfDumpTest.java @@ -16,10 +16,19 @@ package com.android.net.module.util; +import static android.system.OsConstants.EPERM; +import static android.system.OsConstants.R_OK; + +import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn; +import static com.android.dx.mockito.inline.extended.ExtendedMockito.doThrow; +import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import android.system.ErrnoException; +import android.system.Os; import android.util.Pair; import androidx.test.filters.SmallTest; @@ -29,6 +38,7 @@ import com.android.testutils.TestBpfMap; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.MockitoSession; import java.io.PrintWriter; import java.io.StringWriter; @@ -118,4 +128,36 @@ public class BpfDumpTest { assertTrue(dump.contains("key=123, val=456")); assertTrue(dump.contains("key=789, val=123")); } + + private String getDumpMapStatus(final IBpfMap<Struct.U32, Struct.U32> map) { + final StringWriter sw = new StringWriter(); + BpfDump.dumpMapStatus(map, new PrintWriter(sw), "mapName", "mapPath"); + return sw.toString(); + } + + @Test + public void testGetMapStatus() { + final IBpfMap<Struct.U32, Struct.U32> map = + new TestBpfMap<>(Struct.U32.class, Struct.U32.class); + assertEquals("mapName: OK\n", getDumpMapStatus(map)); + } + + @Test + public void testGetMapStatusNull() { + final MockitoSession session = mockitoSession() + .spyStatic(Os.class) + .startMocking(); + try { + // Os.access succeeds + doReturn(true).when(() -> Os.access("mapPath", R_OK)); + assertEquals("mapName: NULL(map is pinned to mapPath)\n", getDumpMapStatus(null)); + + // Os.access throws EPERM + doThrow(new ErrnoException("", EPERM)).when(() -> Os.access("mapPath", R_OK)); + assertEquals("mapName: NULL(map is not pinned to mapPath: Operation not permitted)\n", + getDumpMapStatus(null)); + } finally { + session.finishMocking(); + } + } } diff --git a/common/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt b/common/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt new file mode 100644 index 00000000..e58adad3 --- /dev/null +++ b/common/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.net.module.util + +import com.android.net.module.util.ByteUtils.indexOf +import com.android.net.module.util.ByteUtils.concat +import org.junit.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertNotSame + +class ByteUtilsTests { + private val EMPTY = byteArrayOf() + private val ARRAY1 = byteArrayOf(1) + private val ARRAY234 = byteArrayOf(2, 3, 4) + + @Test + fun testIndexOf() { + assertEquals(-1, indexOf(EMPTY, 1)) + assertEquals(-1, indexOf(ARRAY1, 2)) + assertEquals(-1, indexOf(ARRAY234, 1)) + assertEquals(0, indexOf(byteArrayOf(-1), -1)) + assertEquals(0, indexOf(ARRAY234, 2)) + assertEquals(1, indexOf(ARRAY234, 3)) + assertEquals(2, indexOf(ARRAY234, 4)) + assertEquals(1, indexOf(byteArrayOf(2, 3, 2, 3), 3)) + } + + @Test + fun testConcat() { + assertContentEquals(EMPTY, concat()) + assertContentEquals(EMPTY, concat(EMPTY)) + assertContentEquals(EMPTY, concat(EMPTY, EMPTY, EMPTY)) + assertContentEquals(ARRAY1, concat(ARRAY1)) + assertNotSame(ARRAY1, concat(ARRAY1)) + assertContentEquals(ARRAY1, concat(EMPTY, ARRAY1, EMPTY)) + assertContentEquals(byteArrayOf(1, 1, 1), concat(ARRAY1, ARRAY1, ARRAY1)) + assertContentEquals(byteArrayOf(1, 2, 3, 4), concat(ARRAY1, ARRAY234)) + } +}
\ No newline at end of file diff --git a/common/tests/unit/src/com/android/net/module/util/DnsPacketTest.java b/common/tests/unit/src/com/android/net/module/util/DnsPacketTest.java index 1bbba088..409c1ebb 100644 --- a/common/tests/unit/src/com/android/net/module/util/DnsPacketTest.java +++ b/common/tests/unit/src/com/android/net/module/util/DnsPacketTest.java @@ -16,26 +16,45 @@ package com.android.net.module.util; +import static android.net.DnsResolver.CLASS_IN; +import static android.net.DnsResolver.TYPE_A; +import static android.net.DnsResolver.TYPE_AAAA; + import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import android.annotation.NonNull; +import android.annotation.Nullable; + import androidx.test.filters.SmallTest; import androidx.test.runner.AndroidJUnit4; +import libcore.net.InetAddressUtils; + import org.junit.Test; import org.junit.runner.RunWith; +import java.io.IOException; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @RunWith(AndroidJUnit4.class) @SmallTest public class DnsPacketTest { + private static final int TEST_DNS_PACKET_ID = 0x7722; + private static final int TEST_DNS_PACKET_FLAGS = 0x8180; + private void assertHeaderParses(DnsPacket.DnsHeader header, int id, int flag, int qCount, int aCount, int nsCount, int arCount) { - assertEquals(header.id, id); - assertEquals(header.flags, flag); + assertEquals(header.getId(), id); + assertEquals(header.getFlags(), flag); assertEquals(header.getRecordCount(DnsPacket.QDSECTION), qCount); assertEquals(header.getRecordCount(DnsPacket.ANSECTION), aCount); assertEquals(header.getRecordCount(DnsPacket.NSSECTION), nsCount); @@ -51,11 +70,16 @@ public class DnsPacketTest { assertTrue(Arrays.equals(record.getRR(), rr)); } - class TestDnsPacket extends DnsPacket { + static class TestDnsPacket extends DnsPacket { TestDnsPacket(byte[] data) throws DnsPacket.ParseException { super(data); } + TestDnsPacket(@NonNull DnsHeader header, @Nullable ArrayList<DnsRecord> qd, + @Nullable ArrayList<DnsRecord> an) { + super(header, qd, an); + } + public DnsHeader getHeader() { return mHeader; } @@ -156,4 +180,247 @@ public class DnsPacketTest { new byte[]{ 0x24, 0x04, 0x68, 0x00, 0x40, 0x05, 0x08, 0x0d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x04 }); } + + /** Verifies that the synthesized {@link DnsPacket.DnsHeader} can be parsed correctly. */ + @Test + public void testDnsHeaderSynthesize() { + final DnsPacket.DnsHeader testHeader = new DnsPacket.DnsHeader(TEST_DNS_PACKET_ID, + TEST_DNS_PACKET_FLAGS, 3 /* qcount */, 5 /* ancount */); + final DnsPacket.DnsHeader actualHeader = new DnsPacket.DnsHeader( + ByteBuffer.wrap(testHeader.getBytes())); + assertEquals(testHeader, actualHeader); + } + + /** Verifies that the synthesized {@link DnsPacket.DnsRecord} can be parsed correctly. */ + @Test + public void testDnsRecordSynthesize() throws IOException { + assertDnsRecordRoundTrip( + DnsPacket.DnsRecord.makeAOrAAAARecord(DnsPacket.ANSECTION, + "test.com", CLASS_IN, 5 /* ttl */, + InetAddressUtils.parseNumericAddress("abcd::fedc"))); + assertDnsRecordRoundTrip(DnsPacket.DnsRecord.makeQuestion("test.com", TYPE_AAAA, CLASS_IN)); + assertDnsRecordRoundTrip(DnsPacket.DnsRecord.makeCNameRecord(DnsPacket.ANSECTION, + "test.com", CLASS_IN, 0 /* ttl */, "example.com")); + } + + /** + * Verifies ttl/rData error handling when parsing + * {@link DnsPacket.DnsRecord} from bytes. + */ + @Test + public void testDnsRecordTTLRDataErrorHandling() throws IOException { + // Verify the constructor ignore ttl/rData of questions even if they are supplied. + final byte[] qdWithTTLRData = new byte[]{ + 0x03, 0x77, 0x77, 0x77, 0x06, 0x67, 0x6F, 0x6F, 0x67, 0x6c, 0x65, + 0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name */ + 0x00, 0x00, /* Type */ + 0x00, 0x01, /* Class */ + 0x00, 0x00, 0x01, 0x2b, /* TTL */ + 0x00, 0x04, /* Data length */ + (byte) 0xac, (byte) 0xd9, (byte) 0xa1, (byte) 0x84 /* Address */}; + final DnsPacket.DnsRecord questionsFromBytes = + new DnsPacket.DnsRecord(DnsPacket.QDSECTION, ByteBuffer.wrap(qdWithTTLRData)); + assertEquals(0, questionsFromBytes.ttl); + assertNull(questionsFromBytes.getRR()); + + // Verify ANSECTION must have rData when constructing. + final byte[] anWithoutTTLRData = new byte[]{ + 0x03, 0x77, 0x77, 0x77, 0x06, 0x67, 0x6F, 0x6F, 0x67, 0x6c, 0x65, + 0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name */ + 0x00, 0x01, /* Type */ + 0x00, 0x01, /* Class */}; + assertThrows(BufferUnderflowException.class, () -> + new DnsPacket.DnsRecord(DnsPacket.ANSECTION, ByteBuffer.wrap(anWithoutTTLRData))); + } + + private void assertDnsRecordRoundTrip(DnsPacket.DnsRecord before) + throws IOException { + final DnsPacket.DnsRecord after = new DnsPacket.DnsRecord(before.rType, + ByteBuffer.wrap(before.getBytes())); + assertEquals(after, before); + } + + /** Verifies that the synthesized {@link DnsPacket} can be parsed correctly. */ + @Test + public void testDnsPacketSynthesize() throws IOException { + // Ipv4 dns response packet generated by scapy: + // dns_r = scapy.DNS( + // id=0xbeef, + // qr=1, + // qd=scapy.DNSQR(qname="hello.example.com"), + // an=scapy.DNSRR(rrname="hello.example.com", type="CNAME", rdata='test.com') / + // scapy.DNSRR(rrname="hello.example.com", rdata='1.2.3.4')) + // scapy.hexdump(dns_r) + // dns_r.show2() + // Note that since the synthesizing does not support name compression yet, the domain + // name of the sample need to be uncompressed when generating. + final byte[] v4BlobUncompressed = new byte[]{ + /* Header */ + (byte) 0xbe, (byte) 0xef, /* Transaction ID */ + (byte) 0x81, 0x00, /* Flags */ + 0x00, 0x01, /* Questions */ + 0x00, 0x02, /* Answer RRs */ + 0x00, 0x00, /* Authority RRs */ + 0x00, 0x00, /* Additional RRs */ + /* Queries */ + 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61, + 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name: hello.example.com */ + 0x00, 0x01, /* Type */ + 0x00, 0x01, /* Class */ + /* Answers */ + 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61, + 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name: hello.example.com */ + 0x00, 0x05, /* Type */ + 0x00, 0x01, /* Class */ + 0x00, 0x00, 0x00, 0x00, /* TTL */ + 0x00, 0x0A, /* Data length */ + 0x04, 0x74, 0x65, 0x73, 0x74, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Alias: test.com */ + 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61, + 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name: hello.example.com */ + 0x00, 0x01, /* Type */ + 0x00, 0x01, /* Class */ + 0x00, 0x00, 0x00, 0x00, /* TTL */ + 0x00, 0x04, /* Data length */ + 0x01, 0x02, 0x03, 0x04, /* Address: 1.2.3.4 */ + }; + + // Forge one via constructors. + final DnsPacket.DnsHeader testHeader = new DnsPacket.DnsHeader(0xbeef, + 0x8100, 1 /* qcount */, 2 /* ancount */); + final ArrayList<DnsPacket.DnsRecord> qlist = new ArrayList<>(); + final ArrayList<DnsPacket.DnsRecord> alist = new ArrayList<>(); + qlist.add(DnsPacket.DnsRecord.makeQuestion( + "hello.example.com", TYPE_A, CLASS_IN)); + alist.add(DnsPacket.DnsRecord.makeCNameRecord( + DnsPacket.ANSECTION, "hello.example.com", CLASS_IN, 0 /* ttl */, "test.com")); + alist.add(DnsPacket.DnsRecord.makeAOrAAAARecord( + DnsPacket.ANSECTION, "hello.example.com", CLASS_IN, 0 /* ttl */, + InetAddressUtils.parseNumericAddress("1.2.3.4"))); + final TestDnsPacket testPacket = new TestDnsPacket(testHeader, qlist, alist); + + // Assert content equals in both ways. + assertTrue(Arrays.equals(v4BlobUncompressed, testPacket.getBytes())); + assertEquals(new TestDnsPacket(v4BlobUncompressed), testPacket); + } + + @Test + public void testDnsPacketSynthesize_recordCountMismatch() throws IOException { + final DnsPacket.DnsHeader testHeader = new DnsPacket.DnsHeader(0xbeef, + 0x8100, 1 /* qcount */, 1 /* ancount */); + final ArrayList<DnsPacket.DnsRecord> qlist = new ArrayList<>(); + final ArrayList<DnsPacket.DnsRecord> alist = new ArrayList<>(); + qlist.add(DnsPacket.DnsRecord.makeQuestion( + "hello.example.com", TYPE_A, CLASS_IN)); + + // Assert throws if the supplied answer records fewer than the declared count. + assertThrows(IllegalArgumentException.class, () -> + new TestDnsPacket(testHeader, qlist, alist)); + + // Assert throws if the supplied answer records more than the declared count. + alist.add(DnsPacket.DnsRecord.makeCNameRecord( + DnsPacket.ANSECTION, "hello.example.com", CLASS_IN, 0 /* ttl */, "test.com")); + alist.add(DnsPacket.DnsRecord.makeAOrAAAARecord( + DnsPacket.ANSECTION, "hello.example.com", CLASS_IN, 0 /* ttl */, + InetAddressUtils.parseNumericAddress("1.2.3.4"))); + assertThrows(IllegalArgumentException.class, () -> + new TestDnsPacket(testHeader, qlist, alist)); + + // Assert counts matched if the byte buffer still has data when parsing ended. + final byte[] blobTooMuchData = new byte[]{ + /* Header */ + (byte) 0xbe, (byte) 0xef, /* Transaction ID */ + (byte) 0x81, 0x00, /* Flags */ + 0x00, 0x00, /* Questions */ + 0x00, 0x00, /* Answer RRs */ + 0x00, 0x00, /* Authority RRs */ + 0x00, 0x00, /* Additional RRs */ + /* Queries */ + 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61, + 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name */ + 0x00, 0x01, /* Type */ + 0x00, 0x01, /* Class */ + }; + final TestDnsPacket packetFromTooMuchData = new TestDnsPacket(blobTooMuchData); + for (int i = 0; i < DnsPacket.NUM_SECTIONS; i++) { + assertEquals(0, packetFromTooMuchData.getRecordList(i).size()); + assertEquals(0, packetFromTooMuchData.getHeader().getRecordCount(i)); + } + + // Assert throws if the byte buffer ended when expecting more records. + final byte[] blobNotEnoughData = new byte[]{ + /* Header */ + (byte) 0xbe, (byte) 0xef, /* Transaction ID */ + (byte) 0x81, 0x00, /* Flags */ + 0x00, 0x01, /* Questions */ + 0x00, 0x02, /* Answer RRs */ + 0x00, 0x00, /* Authority RRs */ + 0x00, 0x00, /* Additional RRs */ + /* Queries */ + 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61, + 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name */ + 0x00, 0x01, /* Type */ + 0x00, 0x01, /* Class */ + /* Answers */ + 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x07, 0x65, 0x78, 0x61, + 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, /* Name */ + 0x00, 0x01, /* Type */ + 0x00, 0x01, /* Class */ + 0x00, 0x00, 0x00, 0x00, /* TTL */ + 0x00, 0x04, /* Data length */ + 0x01, 0x02, 0x03, 0x04, /* Address */ + }; + assertThrows(DnsPacket.ParseException.class, () -> new TestDnsPacket(blobNotEnoughData)); + } + + @Test + public void testEqualsAndHashCode() throws IOException { + // Verify DnsHeader equals and hashCode. + final DnsPacket.DnsHeader testHeader = new DnsPacket.DnsHeader(TEST_DNS_PACKET_ID, + TEST_DNS_PACKET_FLAGS, 1 /* qcount */, 1 /* ancount */); + final DnsPacket.DnsHeader emptyHeader = new DnsPacket.DnsHeader(TEST_DNS_PACKET_ID + 1, + TEST_DNS_PACKET_FLAGS + 0x08, 0 /* qcount */, 0 /* ancount */); + final DnsPacket.DnsHeader headerFromBytes = + new DnsPacket.DnsHeader(ByteBuffer.wrap(testHeader.getBytes())); + assertEquals(testHeader, headerFromBytes); + assertEquals(testHeader.hashCode(), headerFromBytes.hashCode()); + assertNotEquals(testHeader, emptyHeader); + assertNotEquals(testHeader.hashCode(), emptyHeader.hashCode()); + assertNotEquals(headerFromBytes, emptyHeader); + assertNotEquals(headerFromBytes.hashCode(), emptyHeader.hashCode()); + + // Verify DnsRecord equals and hashCode. + final DnsPacket.DnsRecord testQuestion = DnsPacket.DnsRecord.makeQuestion( + "test.com", TYPE_AAAA, CLASS_IN); + final DnsPacket.DnsRecord testAnswer = DnsPacket.DnsRecord.makeCNameRecord( + DnsPacket.ANSECTION, "test.com", CLASS_IN, 9, "www.test.com"); + final DnsPacket.DnsRecord questionFromBytes = new DnsPacket.DnsRecord(DnsPacket.QDSECTION, + ByteBuffer.wrap(testQuestion.getBytes())); + assertEquals(testQuestion, questionFromBytes); + assertEquals(testQuestion.hashCode(), questionFromBytes.hashCode()); + assertNotEquals(testQuestion, testAnswer); + assertNotEquals(testQuestion.hashCode(), testAnswer.hashCode()); + assertNotEquals(questionFromBytes, testAnswer); + assertNotEquals(questionFromBytes.hashCode(), testAnswer.hashCode()); + + // Verify DnsPacket equals and hashCode. + final ArrayList<DnsPacket.DnsRecord> qlist = new ArrayList<>(); + final ArrayList<DnsPacket.DnsRecord> alist = new ArrayList<>(); + qlist.add(testQuestion); + alist.add(testAnswer); + final TestDnsPacket testPacket = new TestDnsPacket(testHeader, qlist, alist); + final TestDnsPacket emptyPacket = new TestDnsPacket( + emptyHeader, new ArrayList<>(), new ArrayList<>()); + final TestDnsPacket packetFromBytes = new TestDnsPacket(testPacket.getBytes()); + assertEquals(testPacket, packetFromBytes); + assertEquals(testPacket.hashCode(), packetFromBytes.hashCode()); + assertNotEquals(testPacket, emptyPacket); + assertNotEquals(testPacket.hashCode(), emptyPacket.hashCode()); + assertNotEquals(packetFromBytes, emptyPacket); + assertNotEquals(packetFromBytes.hashCode(), emptyPacket.hashCode()); + + // Verify DnsPacket with empty list. + final TestDnsPacket emptyPacketFromBytes = new TestDnsPacket(emptyPacket.getBytes()); + assertEquals(emptyPacket, emptyPacketFromBytes); + assertEquals(emptyPacket.hashCode(), emptyPacketFromBytes.hashCode()); + } } diff --git a/common/tests/unit/src/com/android/net/module/util/DnsPacketUtilsTest.java b/common/tests/unit/src/com/android/net/module/util/DnsPacketUtilsTest.java index 48777acc..9e1ab824 100644 --- a/common/tests/unit/src/com/android/net/module/util/DnsPacketUtilsTest.java +++ b/common/tests/unit/src/com/android/net/module/util/DnsPacketUtilsTest.java @@ -18,12 +18,20 @@ package com.android.net.module.util; import static com.android.net.module.util.DnsPacketUtils.DnsRecordParser; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import android.net.ParseException; +import android.os.Build; + import androidx.test.filters.SmallTest; import androidx.test.runner.AndroidJUnit4; +import com.android.testutils.DevSdkIgnoreRule; + +import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -32,6 +40,8 @@ import java.nio.ByteBuffer; @RunWith(AndroidJUnit4.class) @SmallTest public class DnsPacketUtilsTest { + @Rule + public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule(); /** * Verifies that the compressed NAME field in the answer section of the DNS message is parsed @@ -116,4 +126,31 @@ public class DnsPacketUtilsTest { notNameCompressedBuf, /* depth= */ 0, /* isNameCompressionSupported= */ false); assertEquals(domainName, "www.google.com"); } + + // Skip test on R- devices since ParseException only available on S+ devices. + @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R) + @Test + public void testDomainNameToLabels() throws Exception { + assertArrayEquals( + new byte[]{3, 'w', 'w', 'w', 6, 'g', 'o', 'o', 'g', 'l', 'e', 3, 'c', 'o', 'm', 0}, + DnsRecordParser.domainNameToLabels("www.google.com")); + assertThrows(ParseException.class, () -> + DnsRecordParser.domainNameToLabels("aaa.")); + assertThrows(ParseException.class, () -> + DnsRecordParser.domainNameToLabels("aaa")); + assertThrows(ParseException.class, () -> + DnsRecordParser.domainNameToLabels(".")); + assertThrows(ParseException.class, () -> + DnsRecordParser.domainNameToLabels("")); + } + + @Test + public void testIsHostName() { + Assert.assertTrue(DnsRecordParser.isHostName("www.google.com")); + Assert.assertFalse(DnsRecordParser.isHostName("com")); + Assert.assertFalse(DnsRecordParser.isHostName("1.2.3.4")); + Assert.assertFalse(DnsRecordParser.isHostName("1234::5678")); + Assert.assertFalse(DnsRecordParser.isHostName(null)); + Assert.assertFalse(DnsRecordParser.isHostName("")); + } } diff --git a/common/tests/unit/src/com/android/net/module/util/IpUtilsTest.java b/common/tests/unit/src/com/android/net/module/util/IpUtilsTest.java index 34683090..20555b35 100644 --- a/common/tests/unit/src/com/android/net/module/util/IpUtilsTest.java +++ b/common/tests/unit/src/com/android/net/module/util/IpUtilsTest.java @@ -16,6 +16,9 @@ package com.android.net.module.util; +import static com.android.net.module.util.NetworkStackConstants.ICMP_CHECKSUM_OFFSET; +import static com.android.net.module.util.NetworkStackConstants.IPV4_CHECKSUM_OFFSET; + import static org.junit.Assert.assertEquals; import androidx.test.filters.SmallTest; @@ -34,7 +37,6 @@ public class IpUtilsTest { private static final int IPV6_HEADER_LENGTH = 40; private static final int TCP_HEADER_LENGTH = 20; private static final int UDP_HEADER_LENGTH = 8; - private static final int IP_CHECKSUM_OFFSET = 10; private static final int TCP_CHECKSUM_OFFSET = 16; private static final int UDP_CHECKSUM_OFFSET = 6; @@ -141,7 +143,7 @@ public class IpUtilsTest { assertEquals((short) 0xffff, IpUtils.udpChecksum(packet, 0, IPV4_HEADER_LENGTH)); // Check that we can calculate the checksums from scratch. - final int ipSumOffset = IP_CHECKSUM_OFFSET; + final int ipSumOffset = IPV4_CHECKSUM_OFFSET; final int ipSum = getChecksum(packet, ipSumOffset); assertEquals(0xf68b, ipSum); @@ -163,4 +165,47 @@ public class IpUtilsTest { assertEquals(0, IpUtils.ipChecksum(packet, 0)); assertEquals((short) 0xffff, IpUtils.udpChecksum(packet, 0, IPV4_HEADER_LENGTH)); } + + @Test + public void testIpv4IcmpChecksum() throws Exception { + // packet = (scapy.IP(src="192.0.2.1", dst="192.0.2.2", tos=0x40) / + // scapy.ICMP(type=0x8, id=0x1234, seq=0x5678) / + // "hello, world") + ByteBuffer packet = ByteBuffer.wrap(new byte[] { + /* IPv4 */ + (byte) 0x45, (byte) 0x40, (byte) 0x00, (byte) 0x28, + (byte) 0x00, (byte) 0x01, (byte) 0x00, (byte) 0x00, + (byte) 0x40, (byte) 0x01, (byte) 0xf6, (byte) 0x90, + (byte) 0xc0, (byte) 0x00, (byte) 0x02, (byte) 0x01, + (byte) 0xc0, (byte) 0x00, (byte) 0x02, (byte) 0x02, + /* ICMP */ + (byte) 0x08, /* type: echo-request */ + (byte) 0x00, /* code: 0 */ + (byte) 0x4f, (byte) 0x07, /* chksum: 0x4f07 */ + (byte) 0x12, (byte) 0x34, /* id: 0x1234 */ + (byte) 0x56, (byte) 0x78, /* seq: 0x5678 */ + (byte) 0x68, (byte) 0x65, (byte) 0x6c, (byte) 0x6c, /* data: hello, world */ + (byte) 0x6f, (byte) 0x2c, (byte) 0x20, (byte) 0x77, + (byte) 0x6f, (byte) 0x72, (byte) 0x6c, (byte) 0x64 + }); + + // Check that a valid packet has checksum 0. + int transportLen = packet.limit() - IPV4_HEADER_LENGTH; + assertEquals(0, IpUtils.icmpChecksum(packet, IPV4_HEADER_LENGTH, transportLen)); + + // Check that we can calculate the checksum from scratch. + int sumOffset = IPV4_HEADER_LENGTH + ICMP_CHECKSUM_OFFSET; + int sum = getUnsignedByte(packet, sumOffset) * 256 + getUnsignedByte(packet, sumOffset + 1); + assertEquals(0x4f07, sum); + + packet.put(sumOffset, (byte) 0); + packet.put(sumOffset + 1, (byte) 0); + assertChecksumEquals(sum, IpUtils.icmpChecksum(packet, IPV4_HEADER_LENGTH, transportLen)); + + // Check that writing the checksum back into the packet results in a valid packet. + packet.putShort( + sumOffset, + IpUtils.icmpChecksum(packet, IPV4_HEADER_LENGTH, transportLen)); + assertEquals(0, IpUtils.icmpChecksum(packet, IPV4_HEADER_LENGTH, transportLen)); + } } diff --git a/common/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt b/common/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt index 256ea1e8..958f45f4 100644 --- a/common/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt +++ b/common/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt @@ -22,7 +22,6 @@ import android.net.NetworkCapabilities.NET_CAPABILITY_BIP import android.net.NetworkCapabilities.NET_CAPABILITY_CBS import android.net.NetworkCapabilities.NET_CAPABILITY_EIMS import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET -import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET import android.net.NetworkCapabilities.NET_CAPABILITY_OEM_PAID import android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH import android.net.NetworkCapabilities.TRANSPORT_CELLULAR @@ -38,8 +37,6 @@ import com.android.modules.utils.build.SdkLevel import com.android.net.module.util.NetworkCapabilitiesUtils.RESTRICTED_CAPABILITIES import com.android.net.module.util.NetworkCapabilitiesUtils.UNRESTRICTED_CAPABILITIES import com.android.net.module.util.NetworkCapabilitiesUtils.getDisplayTransport -import com.android.net.module.util.NetworkCapabilitiesUtils.packBits -import com.android.net.module.util.NetworkCapabilitiesUtils.unpackBits import org.junit.Test import org.junit.runner.RunWith import java.lang.IllegalArgumentException @@ -75,23 +72,6 @@ class NetworkCapabilitiesUtilsTest { } } - @Test - fun testBitPackingTestCase() { - runBitPackingTestCase(0, intArrayOf()) - runBitPackingTestCase(1, intArrayOf(0)) - runBitPackingTestCase(3, intArrayOf(0, 1)) - runBitPackingTestCase(4, intArrayOf(2)) - runBitPackingTestCase(63, intArrayOf(0, 1, 2, 3, 4, 5)) - runBitPackingTestCase(Long.MAX_VALUE.inv(), intArrayOf(63)) - runBitPackingTestCase(Long.MAX_VALUE.inv() + 1, intArrayOf(0, 63)) - runBitPackingTestCase(Long.MAX_VALUE.inv() + 2, intArrayOf(1, 63)) - } - - fun runBitPackingTestCase(packedBits: Long, bits: IntArray) { - assertEquals(packedBits, packBits(bits)) - assertTrue(bits contentEquals unpackBits(packedBits)) - } - // NetworkCapabilities constructor and Builder are not available until R. Mark TargetApi to // ignore the linter error since it's used in only unit test. @Test @TargetApi(Build.VERSION_CODES.R) diff --git a/common/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt b/common/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt index 9fb4d8ca..8e320d0d 100644 --- a/common/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt +++ b/common/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt @@ -17,6 +17,7 @@ package com.android.net.module.util import com.android.testutils.ConcurrentInterpreter +import com.android.testutils.INTERPRET_TIME_UNIT import com.android.testutils.InterpretException import com.android.testutils.InterpretMatcher import com.android.testutils.SyntaxException @@ -420,7 +421,7 @@ private val interpretTable = listOf<InterpretMatcher<TrackRecord<Int>>>( // the test code to not compile instead of throw, but it's vastly more complex and this will // fail 100% at runtime any test that would not have compiled. Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { i, t, r -> - (if (r.strArg(1).isEmpty()) i.interpretTimeUnit else r.timeArg(1)).let { time -> + (if (r.strArg(1).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(1)).let { time -> (t as ArrayTrackRecord<Int>.ReadHead).poll(time, makePredicate(r.strArg(2))) } }, diff --git a/common/tests/unit/src/com/android/net/module/util/netlink/NetlinkSocketTest.java b/common/tests/unit/src/com/android/net/module/util/netlink/NetlinkSocketTest.java index 33ea6023..f4073b6d 100644 --- a/common/tests/unit/src/com/android/net/module/util/netlink/NetlinkSocketTest.java +++ b/common/tests/unit/src/com/android/net/module/util/netlink/NetlinkSocketTest.java @@ -236,7 +236,7 @@ public class NetlinkSocketTest { @Field(order = 3, type = Type.U8) public short scope; - @Field(order = 4, type = Type.U32) - public long index; + @Field(order = 4, type = Type.S32) + public int index; } } diff --git a/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt b/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt new file mode 100644 index 00000000..f8f2da08 --- /dev/null +++ b/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt @@ -0,0 +1,477 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.testutils + +import android.annotation.SuppressLint +import android.net.LinkAddress +import android.net.LinkProperties +import android.net.Network +import android.net.NetworkCapabilities +import com.android.testutils.RecorderCallback.CallbackEntry +import com.android.testutils.RecorderCallback.CallbackEntry.Companion.AVAILABLE +import com.android.testutils.RecorderCallback.CallbackEntry.Companion.BLOCKED_STATUS +import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LINK_PROPERTIES_CHANGED +import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOSING +import com.android.testutils.RecorderCallback.CallbackEntry.Companion.NETWORK_CAPS_UPDATED +import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOST +import com.android.testutils.RecorderCallback.CallbackEntry.Companion.RESUMED +import com.android.testutils.RecorderCallback.CallbackEntry.Companion.SUSPENDED +import com.android.testutils.RecorderCallback.CallbackEntry.Companion.UNAVAILABLE +import com.android.testutils.RecorderCallback.CallbackEntry.Available +import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus +import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Assume.assumeTrue +import kotlin.reflect.KClass +import kotlin.test.assertEquals +import kotlin.test.assertFails +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.test.fail + +const val SHORT_TIMEOUT_MS = 20L +const val DEFAULT_LINGER_DELAY_MS = 30000 +const val NOT_METERED = NetworkCapabilities.NET_CAPABILITY_NOT_METERED +const val WIFI = NetworkCapabilities.TRANSPORT_WIFI +const val CELLULAR = NetworkCapabilities.TRANSPORT_CELLULAR +const val TEST_INTERFACE_NAME = "testInterfaceName" + +@RunWith(JUnit4::class) +@SuppressLint("NewApi") // Uses hidden APIs, which the linter would identify as missing APIs. +class TestableNetworkCallbackTest { + private lateinit var mCallback: TestableNetworkCallback + + private fun makeHasNetwork(netId: Int) = object : TestableNetworkCallback.HasNetwork { + override val network: Network = Network(netId) + } + + @Before + fun setUp() { + mCallback = TestableNetworkCallback() + } + + @Test + fun testLastAvailableNetwork() { + // Make sure there is no last available network at first, then the last available network + // is returned after onAvailable is called. + val net2097 = Network(2097) + assertNull(mCallback.lastAvailableNetwork) + mCallback.onAvailable(net2097) + assertEquals(mCallback.lastAvailableNetwork, net2097) + + // Make sure calling onCapsChanged/onLinkPropertiesChanged don't affect the last available + // network. + mCallback.onCapabilitiesChanged(net2097, NetworkCapabilities()) + mCallback.onLinkPropertiesChanged(net2097, LinkProperties()) + assertEquals(mCallback.lastAvailableNetwork, net2097) + + // Make sure onLost clears the last available network. + mCallback.onLost(net2097) + assertNull(mCallback.lastAvailableNetwork) + + // Do the same but with a different network after onLost : make sure the last available + // network is the new one, not the original one. + val net2098 = Network(2098) + mCallback.onAvailable(net2098) + mCallback.onCapabilitiesChanged(net2098, NetworkCapabilities()) + mCallback.onLinkPropertiesChanged(net2098, LinkProperties()) + assertEquals(mCallback.lastAvailableNetwork, net2098) + + // Make sure onAvailable changes the last available network even if onLost was not called. + val net2099 = Network(2099) + mCallback.onAvailable(net2099) + assertEquals(mCallback.lastAvailableNetwork, net2099) + + // For legacy reasons, lastAvailableNetwork is null as soon as any is lost, not necessarily + // the last available one. Check that behavior. + mCallback.onLost(net2098) + assertNull(mCallback.lastAvailableNetwork) + + // Make sure that losing the really last available one still results in null. + mCallback.onLost(net2099) + assertNull(mCallback.lastAvailableNetwork) + + // Make sure multiple onAvailable in a row then onLost still results in null. + mCallback.onAvailable(net2097) + mCallback.onAvailable(net2098) + mCallback.onAvailable(net2099) + mCallback.onLost(net2097) + assertNull(mCallback.lastAvailableNetwork) + } + + @Test + fun testAssertNoCallback() { + mCallback.assertNoCallback(SHORT_TIMEOUT_MS) + mCallback.onAvailable(Network(100)) + assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) } + } + + @Test + fun testAssertNoCallbackThat() { + val net = Network(101) + mCallback.assertNoCallbackThat { it is Available } + mCallback.onAvailable(net) + // Expect no blocked status change. Receive other callback does not fail the test. + mCallback.assertNoCallbackThat { it is BlockedStatus } + mCallback.onBlockedStatusChanged(net, true) + assertFails { mCallback.assertNoCallbackThat { it is BlockedStatus } } + mCallback.onBlockedStatusChanged(net, false) + mCallback.onCapabilitiesChanged(net, NetworkCapabilities()) + assertFails { mCallback.assertNoCallbackThat { it is CapabilitiesChanged } } + } + + @Test + fun testCapabilitiesWithAndWithout() { + val net = Network(101) + val matcher = makeHasNetwork(101) + val meteredNc = NetworkCapabilities() + val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED) + // Check that expecting caps (with or without) fails when no callback has been received. + assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } + assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } + + // Add NOT_METERED and check that With succeeds and Without fails. + mCallback.onCapabilitiesChanged(net, unmeteredNc) + mCallback.expectCapabilitiesWith(NOT_METERED, matcher) + mCallback.onCapabilitiesChanged(net, unmeteredNc) + assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } + + // Don't add NOT_METERED and check that With fails and Without succeeds. + mCallback.onCapabilitiesChanged(net, meteredNc) + assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } + mCallback.onCapabilitiesChanged(net, meteredNc) + mCallback.expectCapabilitiesWithout(NOT_METERED, matcher) + } + + @Test + fun testExpectWithPredicate() { + val net = Network(193) + val netCaps = NetworkCapabilities().addTransportType(CELLULAR) + // Check that expecting callbackThat anything fails when no callback has been received. + assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { true } } + + // Basic test for true and false + mCallback.onAvailable(net) + mCallback.expect<Available> { true } + mCallback.onAvailable(net) + assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { false } } + + // Try a positive and a negative case + mCallback.onBlockedStatusChanged(net, true) + mCallback.expect<CallbackEntry> { cb -> cb is BlockedStatus && cb.blocked } + mCallback.onCapabilitiesChanged(net, netCaps) + assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { cb -> + cb is CapabilitiesChanged && cb.caps.hasTransport(WIFI) + } } + } + + @Test + fun testCapabilitiesThat() { + val net = Network(101) + val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI) + // Check that expecting capabilitiesThat anything fails when no callback has been received. + assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } } + + // Basic test for true and false + mCallback.onCapabilitiesChanged(net, netCaps) + mCallback.expectCapabilitiesThat(net) { true } + mCallback.onCapabilitiesChanged(net, netCaps) + assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } } + + // Try a positive and a negative case + mCallback.onCapabilitiesChanged(net, netCaps) + mCallback.expectCapabilitiesThat(net) { caps -> + caps.hasCapability(NOT_METERED) && + caps.hasTransport(WIFI) && + !caps.hasTransport(CELLULAR) + } + mCallback.onCapabilitiesChanged(net, netCaps) + assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps -> + caps.hasTransport(CELLULAR) + } } + + // Try a matching callback on the wrong network + mCallback.onCapabilitiesChanged(net, netCaps) + assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } } + } + + @Test + fun testLinkPropertiesThat() { + val net = Network(112) + val linkAddress = LinkAddress("fe80::ace:d00d/64") + val mtu = 1984 + val linkProps = LinkProperties().apply { + this.mtu = mtu + interfaceName = TEST_INTERFACE_NAME + addLinkAddress(linkAddress) + } + + // Check that expecting linkPropsThat anything fails when no callback has been received. + assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { true } } + + // Basic test for true and false + mCallback.onLinkPropertiesChanged(net, linkProps) + mCallback.expectLinkPropertiesThat(net) { true } + mCallback.onLinkPropertiesChanged(net, linkProps) + assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { false } } + + // Try a positive and negative case + mCallback.onLinkPropertiesChanged(net, linkProps) + mCallback.expectLinkPropertiesThat(net) { lp -> + lp.interfaceName == TEST_INTERFACE_NAME && + lp.linkAddresses.contains(linkAddress) && + lp.mtu == mtu + } + mCallback.onLinkPropertiesChanged(net, linkProps) + assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { lp -> + lp.interfaceName != TEST_INTERFACE_NAME + } } + + // Try a matching callback on the wrong network + mCallback.onLinkPropertiesChanged(net, linkProps) + assertFails { mCallback.expectLinkPropertiesThat(Network(114), SHORT_TIMEOUT_MS) { lp -> + lp.interfaceName == TEST_INTERFACE_NAME + } } + } + + @Test + fun testExpect() { + val net = Network(103) + // Test expectCallback fails when nothing was sent. + assertFails { mCallback.expect<BlockedStatus>(net, SHORT_TIMEOUT_MS) } + + // Test onAvailable is seen and can be expected + mCallback.onAvailable(net) + mCallback.expect<Available>(net, SHORT_TIMEOUT_MS) + + // Test onAvailable won't return calls with a different network + mCallback.onAvailable(Network(106)) + assertFails { mCallback.expect<Available>(net, SHORT_TIMEOUT_MS) } + + // Test onAvailable won't return calls with a different callback + mCallback.onAvailable(net) + assertFails { mCallback.expect<BlockedStatus>(net, SHORT_TIMEOUT_MS) } + } + + @Test + fun testAllExpectOverloads() { + // This test should never run, it only checks that all overloads exist and build + assumeTrue(false) + val hn = object : TestableNetworkCallback.HasNetwork { override val network = ANY_NETWORK } + + // Method with all arguments (version that takes a Network) + mCallback.expect(AVAILABLE, ANY_NETWORK, 10, "error") { true } + + // Java overloads omitting one argument. One line for omitting each argument, in positional + // order. Versions that take a Network. + mCallback.expect(AVAILABLE, 10, "error") { true } + mCallback.expect(AVAILABLE, ANY_NETWORK, "error") { true } + mCallback.expect(AVAILABLE, ANY_NETWORK, 10) { true } + mCallback.expect(AVAILABLE, ANY_NETWORK, 10, "error") + + // Java overloads for omitting two arguments. One line for omitting each pair of arguments. + // Versions that take a Network. + mCallback.expect(AVAILABLE, "error") { true } + mCallback.expect(AVAILABLE, 10) { true } + mCallback.expect(AVAILABLE, 10, "error") + mCallback.expect(AVAILABLE, ANY_NETWORK) { true } + mCallback.expect(AVAILABLE, ANY_NETWORK, "error") + mCallback.expect(AVAILABLE, ANY_NETWORK, 10) + + // Java overloads for omitting three arguments. One line for each remaining argument. + // Versions that take a Network. + mCallback.expect(AVAILABLE) { true } + mCallback.expect(AVAILABLE, "error") + mCallback.expect(AVAILABLE, 10) + mCallback.expect(AVAILABLE, ANY_NETWORK) + + // Java overload for omitting all four arguments. + mCallback.expect(AVAILABLE) + + // Same orders as above, but versions that take a HasNetwork. Except overloads that + // were already tested because they omitted the Network argument + mCallback.expect(AVAILABLE, hn, 10, "error") { true } + mCallback.expect(AVAILABLE, hn, "error") { true } + mCallback.expect(AVAILABLE, hn, 10) { true } + mCallback.expect(AVAILABLE, hn, 10, "error") + + mCallback.expect(AVAILABLE, hn) { true } + mCallback.expect(AVAILABLE, hn, "error") + mCallback.expect(AVAILABLE, hn, 10) + + mCallback.expect(AVAILABLE, hn) + + // Same as above but for reified versions. + mCallback.expect<Available>(ANY_NETWORK, 10, "error") { true } + mCallback.expect<Available>(timeoutMs = 10, errorMsg = "error") { true } + mCallback.expect<Available>(network = ANY_NETWORK, errorMsg = "error") { true } + mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10) { true } + mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10, errorMsg = "error") + + mCallback.expect<Available>(errorMsg = "error") { true } + mCallback.expect<Available>(timeoutMs = 10) { true } + mCallback.expect<Available>(timeoutMs = 10, errorMsg = "error") + mCallback.expect<Available>(network = ANY_NETWORK) { true } + mCallback.expect<Available>(network = ANY_NETWORK, errorMsg = "error") + mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10) + + mCallback.expect<Available> { true } + mCallback.expect<Available>(errorMsg = "error") + mCallback.expect<Available>(timeoutMs = 10) + mCallback.expect<Available>(network = ANY_NETWORK) + mCallback.expect<Available>() + + mCallback.expect<Available>(hn, 10, "error") { true } + mCallback.expect<Available>(network = hn, errorMsg = "error") { true } + mCallback.expect<Available>(network = hn, timeoutMs = 10) { true } + mCallback.expect<Available>(network = hn, timeoutMs = 10, errorMsg = "error") + + mCallback.expect<Available>(network = hn) { true } + mCallback.expect<Available>(network = hn, errorMsg = "error") + mCallback.expect<Available>(network = hn, timeoutMs = 10) + + mCallback.expect<Available>(network = hn) + } + + @Test + fun testPoll() { + assertNull(mCallback.poll(SHORT_TIMEOUT_MS)) + TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1, + threadTransform = { cb -> cb.createLinkedCopy() }, spec = """ + sleep; onAvailable(133) | poll(2) = Available(133) time 1..4 + | poll(1) = null + onCapabilitiesChanged(108) | poll(1) = CapabilitiesChanged(108) time 0..3 + onBlockedStatus(199) | poll(1) = BlockedStatus(199) time 0..3 + """) + } + + @Test + fun testPollOrThrow() { + assertFails { mCallback.pollOrThrow(SHORT_TIMEOUT_MS) } + TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1, + threadTransform = { cb -> cb.createLinkedCopy() }, spec = """ + sleep; onAvailable(133) | pollOrThrow(2) = Available(133) time 1..4 + | pollOrThrow(1) fails + onCapabilitiesChanged(108) | pollOrThrow(1) = CapabilitiesChanged(108) time 0..3 + onBlockedStatus(199) | pollOrThrow(1) = BlockedStatus(199) time 0..3 + """) + } + + @Test + fun testEventuallyExpect() { + // TODO: Current test does not verify the inline one. Also verify the behavior after + // aligning two eventuallyExpect() + val net1 = Network(100) + val net2 = Network(101) + mCallback.onAvailable(net1) + mCallback.onCapabilitiesChanged(net1, NetworkCapabilities()) + mCallback.onLinkPropertiesChanged(net1, LinkProperties()) + mCallback.eventuallyExpect(LINK_PROPERTIES_CHANGED) { + net1.equals(it.network) + } + // No further new callback. Expect no callback. + assertFails { mCallback.eventuallyExpect(LINK_PROPERTIES_CHANGED, SHORT_TIMEOUT_MS) } + + // Verify no predicate set. + mCallback.onAvailable(net2) + mCallback.onLinkPropertiesChanged(net2, LinkProperties()) + mCallback.onBlockedStatusChanged(net1, false) + mCallback.eventuallyExpect(BLOCKED_STATUS) { net1.equals(it.network) } + // Verify no callback received if the callback does not happen. + assertFails { mCallback.eventuallyExpect(LOSING, SHORT_TIMEOUT_MS) } + } + + @Test + fun testEventuallyExpectOnMultiThreads() { + TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1, + threadTransform = { cb -> cb.createLinkedCopy() }, spec = """ + onAvailable(100) | eventually(CapabilitiesChanged(100), 1) fails + sleep ; onCapabilitiesChanged(100) | eventually(CapabilitiesChanged(100), 3) + onAvailable(101) ; onBlockedStatus(101) | eventually(BlockedStatus(100), 2) fails + onSuspended(100) ; sleep ; onLost(100) | eventually(Lost(100), 3) + """) + } +} + +private object TNCInterpreter : ConcurrentInterpreter<TestableNetworkCallback>(interpretTable) + +val EntryList = CallbackEntry::class.sealedSubclasses.map { it.simpleName }.joinToString("|") +private fun callbackEntryFromString(name: String): KClass<out CallbackEntry> { + return CallbackEntry::class.sealedSubclasses.first { it.simpleName == name } +} + +@SuppressLint("NewApi") // Uses hidden APIs, which the linter would identify as missing APIs. +private val interpretTable = listOf<InterpretMatcher<TestableNetworkCallback>>( + // Interpret "Available(xx)" as "call to onAvailable with netId xx", and likewise for + // all callback types. This is implemented above by enumerating the subclasses of + // CallbackEntry and reading their simpleName. + Regex("""(.*)\s+=\s+($EntryList)\((\d+)\)""") to { i, cb, t -> + val record = i.interpret(t.strArg(1), cb) + assertTrue(callbackEntryFromString(t.strArg(2)).isInstance(record)) + // Strictly speaking testing for is CallbackEntry is useless as it's been tested above + // but the compiler can't figure things out from the isInstance call. It does understand + // from the assertTrue(is CallbackEntry) that this is true, which allows to access + // the 'network' member below. + assertTrue(record is CallbackEntry) + assertEquals(record.network.netId, t.intArg(3)) + }, + // Interpret "onAvailable(xx)" as calling "onAvailable" with a netId of xx, and likewise for + // all callback types. NetworkCapabilities and LinkProperties just get an empty object + // as their argument. Losing gets the default linger timer. Blocked gets false. + Regex("""on($EntryList)\((\d+)\)""") to { i, cb, t -> + val net = Network(t.intArg(2)) + when (t.strArg(1)) { + "Available" -> cb.onAvailable(net) + // PreCheck not used in tests. Add it here if it becomes useful. + "CapabilitiesChanged" -> cb.onCapabilitiesChanged(net, NetworkCapabilities()) + "LinkPropertiesChanged" -> cb.onLinkPropertiesChanged(net, LinkProperties()) + "Suspended" -> cb.onNetworkSuspended(net) + "Resumed" -> cb.onNetworkResumed(net) + "Losing" -> cb.onLosing(net, DEFAULT_LINGER_DELAY_MS) + "Lost" -> cb.onLost(net) + "Unavailable" -> cb.onUnavailable() + "BlockedStatus" -> cb.onBlockedStatusChanged(net, false) + else -> fail("Unknown callback type") + } + }, + Regex("""poll\((\d+)\)""") to { i, cb, t -> cb.poll(t.timeArg(1)) }, + Regex("""pollOrThrow\((\d+)\)""") to { i, cb, t -> cb.pollOrThrow(t.timeArg(1)) }, + // Interpret "eventually(Available(xx), timeout)" as calling eventuallyExpect that expects + // CallbackEntry.AVAILABLE with netId of xx within timeout*INTERPRET_TIME_UNIT timeout, and + // likewise for all callback types. + Regex("""eventually\(($EntryList)\((\d+)\),\s+(\d+)\)""") to { i, cb, t -> + val net = Network(t.intArg(2)) + val timeout = t.timeArg(3) + when (t.strArg(1)) { + "Available" -> cb.eventuallyExpect(AVAILABLE, timeout) { net == it.network } + "Suspended" -> cb.eventuallyExpect(SUSPENDED, timeout) { net == it.network } + "Resumed" -> cb.eventuallyExpect(RESUMED, timeout) { net == it.network } + "Losing" -> cb.eventuallyExpect(LOSING, timeout) { net == it.network } + "Lost" -> cb.eventuallyExpect(LOST, timeout) { net == it.network } + "Unavailable" -> cb.eventuallyExpect(UNAVAILABLE, timeout) { net == it.network } + "BlockedStatus" -> cb.eventuallyExpect(BLOCKED_STATUS, timeout) { net == it.network } + "CapabilitiesChanged" -> + cb.eventuallyExpect(NETWORK_CAPS_UPDATED, timeout) { net == it.network } + "LinkPropertiesChanged" -> + cb.eventuallyExpect(LINK_PROPERTIES_CHANGED, timeout) { net == it.network } + else -> fail("Unknown callback type") + } + } +) diff --git a/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java b/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java new file mode 100644 index 00000000..4570d0ae --- /dev/null +++ b/common/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.testutils; + +import static com.android.testutils.RecorderCallback.CallbackEntry.AVAILABLE; +import static com.android.testutils.TestableNetworkCallbackKt.anyNetwork; + +import static org.junit.Assume.assumeTrue; + +import org.junit.Test; + +public class TestableNetworkCallbackTestJava { + @Test + void testAllExpectOverloads() { + // This test should never run, it only checks that all overloads exist and build + assumeTrue(false); + final TestableNetworkCallback callback = new TestableNetworkCallback(); + TestableNetworkCallback.HasNetwork hn = TestableNetworkCallbackKt::anyNetwork; + + // Method with all arguments (version that takes a Network) + callback.expect(AVAILABLE, anyNetwork(), 10, "error", cb -> true); + + // Overloads omitting one argument. One line for omitting each argument, in positional + // order. Versions that take a Network. + callback.expect(AVAILABLE, 10, "error", cb -> true); + callback.expect(AVAILABLE, anyNetwork(), "error", cb -> true); + callback.expect(AVAILABLE, anyNetwork(), 10, cb -> true); + callback.expect(AVAILABLE, anyNetwork(), 10, "error"); + + // Overloads for omitting two arguments. One line for omitting each pair of arguments. + // Versions that take a Network. + callback.expect(AVAILABLE, "error", cb -> true); + callback.expect(AVAILABLE, 10, cb -> true); + callback.expect(AVAILABLE, 10, "error"); + callback.expect(AVAILABLE, anyNetwork(), cb -> true); + callback.expect(AVAILABLE, anyNetwork(), "error"); + callback.expect(AVAILABLE, anyNetwork(), 10); + + // Overloads for omitting three arguments. One line for each remaining argument. + // Versions that take a Network. + callback.expect(AVAILABLE, cb -> true); + callback.expect(AVAILABLE, "error"); + callback.expect(AVAILABLE, 10); + callback.expect(AVAILABLE, anyNetwork()); + + // Java overload for omitting all four arguments. + callback.expect(AVAILABLE); + + // Same orders as above, but versions that take a HasNetwork. Except overloads that + // were already tested because they omitted the Network argument + callback.expect(AVAILABLE, hn, 10, "error", cb -> true); + callback.expect(AVAILABLE, hn, "error", cb -> true); + callback.expect(AVAILABLE, hn, 10, cb -> true); + callback.expect(AVAILABLE, hn, 10, "error"); + + callback.expect(AVAILABLE, hn, cb -> true); + callback.expect(AVAILABLE, hn, "error"); + callback.expect(AVAILABLE, hn, 10); + + callback.expect(AVAILABLE, hn); + } +} diff --git a/common/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt b/common/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt index cbdc0173..9e72f4b4 100644 --- a/common/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt +++ b/common/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt @@ -13,7 +13,7 @@ import kotlin.test.assertTrue typealias InterpretMatcher<T> = Pair<Regex, (ConcurrentInterpreter<T>, T, MatchResult) -> Any?> // The default unit of time for interpreted tests -val INTERPRET_TIME_UNIT = 40L // ms +const val INTERPRET_TIME_UNIT = 60L // ms /** * A small interpreter for testing parallel code. @@ -40,10 +40,7 @@ val INTERPRET_TIME_UNIT = 40L // ms * Some expressions already exist by default and can be used by all interpreters. Refer to * getDefaultInstructions() below for a list and documentation. */ -open class ConcurrentInterpreter<T>( - localInterpretTable: List<InterpretMatcher<T>>, - val interpretTimeUnit: Long = INTERPRET_TIME_UNIT -) { +open class ConcurrentInterpreter<T>(localInterpretTable: List<InterpretMatcher<T>>) { private val interpretTable: List<InterpretMatcher<T>> = localInterpretTable + getDefaultInstructions() // The last time the thread became blocked, with base System.currentTimeMillis(). This should @@ -211,7 +208,7 @@ private fun <T> getDefaultInstructions() = listOf<InterpretMatcher<T>>( }, // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units. Regex("""sleep(\((\d+)\))?""") to { i, t, r -> - SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2)) + SystemClock.sleep(if (r.strArg(2).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(2)) }, Regex("""(.*)\s*fails""") to { i, t, r -> assertFails { i.interpret(r.strArg(1), t) } diff --git a/common/testutils/devicetests/com/android/testutils/ConnectUtil.kt b/common/testutils/devicetests/com/android/testutils/ConnectUtil.kt index fc951d86..7b5ad010 100644 --- a/common/testutils/devicetests/com/android/testutils/ConnectUtil.kt +++ b/common/testutils/devicetests/com/android/testutils/ConnectUtil.kt @@ -63,6 +63,7 @@ class ConnectUtil(private val context: Context) { try { val connInfo = wifiManager.connectionInfo + Log.d(TAG, "connInfo=" + connInfo) if (connInfo == null || connInfo.networkId == -1) { clearWifiBlocklist() val pfd = getInstrumentation().uiAutomation.executeShellCommand("svc wifi enable") @@ -75,7 +76,7 @@ class ConnectUtil(private val context: Context) { timeoutMs = WIFI_CONNECT_TIMEOUT_MS) assertNotNull(cb, "Could not connect to a wifi access point within " + - "$WIFI_CONNECT_INTERVAL_MS ms. Check that the test device has a wifi network " + + "$WIFI_CONNECT_TIMEOUT_MS ms. Check that the test device has a wifi network " + "configured, and that the test access point is functioning properly.") return cb.network } finally { diff --git a/common/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java b/common/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java index ea89edaa..ce55fdc1 100644 --- a/common/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java +++ b/common/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java @@ -20,6 +20,7 @@ import android.os.VintfRuntimeInfo; import android.text.TextUtils; import android.util.Pair; +import java.util.Objects; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -128,7 +129,7 @@ public class DeviceInfoUtils { final Pair<Integer, Integer> v1 = getMajorMinorVersion(s1); final Pair<Integer, Integer> v2 = getMajorMinorVersion(s2); - if (v1.first == v2.first) { + if (Objects.equals(v1.first, v2.first)) { return Integer.compare(v1.second, v2.second); } else { return Integer.compare(v1.first, v2.first); diff --git a/common/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt b/common/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt index 5ae2439d..b743b6c9 100644 --- a/common/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt +++ b/common/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt @@ -111,12 +111,12 @@ class TestNetworkTracker internal constructor( val tnm: TestNetworkManager, val lp: LinkProperties?, setupTimeoutMs: Long -) { +) : TestableNetworkCallback.HasNetwork { private val cm = context.getSystemService(ConnectivityManager::class.java) private val binder = Binder() private val networkCallback: NetworkCallback - val network: Network + override val network: Network val testIface: TestNetworkInterface init { diff --git a/common/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/common/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt index dffdbe8b..406a179c 100644 --- a/common/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt +++ b/common/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt @@ -36,15 +36,12 @@ import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable import kotlin.reflect.KClass import kotlin.test.assertEquals import kotlin.test.assertNotNull -import kotlin.test.assertTrue import kotlin.test.fail object NULL_NETWORK : Network(-1) object ANY_NETWORK : Network(-2) fun anyNetwork() = ANY_NETWORK -private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this) - open class RecorderCallback private constructor( private val backingRecord: ArrayTrackRecord<CallbackEntry> ) : NetworkCallback() { @@ -168,16 +165,22 @@ open class RecorderCallback private constructor( } } -private const val DEFAULT_TIMEOUT = 200L // ms +private const val DEFAULT_TIMEOUT = 30_000L // ms +private const val DEFAULT_NO_CALLBACK_TIMEOUT = 200L // ms open class TestableNetworkCallback private constructor( src: TestableNetworkCallback?, - val defaultTimeoutMs: Long = DEFAULT_TIMEOUT + val defaultTimeoutMs: Long = DEFAULT_TIMEOUT, + val defaultNoCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT ) : RecorderCallback(src) { @JvmOverloads - constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs) + constructor( + timeoutMs: Long = DEFAULT_TIMEOUT, + noCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT + ): this(null, timeoutMs, noCallbackTimeoutMs) - fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs) + fun createLinkedCopy() = TestableNetworkCallback( + this, defaultTimeoutMs, defaultNoCallbackTimeoutMs) // The last available network, or null if any network was lost since the last call to // onAvailable. TODO : fix this by fixing the tests that rely on this behavior @@ -187,14 +190,189 @@ open class TestableNetworkCallback private constructor( else -> null } - fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry { - return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms") - } + /** + * Get the next callback or null if timeout. + * + * With no argument, this method waits out the default timeout. To wait forever, pass + * Long.MAX_VALUE. + */ + @JvmOverloads + fun poll(timeoutMs: Long = defaultTimeoutMs): CallbackEntry? = history.poll(timeoutMs) + + /** + * Get the next callback or throw if timeout. + * + * With no argument, this method waits out the default timeout. To wait forever, pass + * Long.MAX_VALUE. + */ + @JvmOverloads + fun pollOrThrow( + timeoutMs: Long = defaultTimeoutMs, + errorMsg: String = "Did not receive callback after $timeoutMs" + ): CallbackEntry = poll(timeoutMs) ?: fail(errorMsg) + + /***** + * expect family of methods. + * These methods fetch the next callback and assert it matches the conditions : type, + * passed predicate. If no callback is received within the timeout, these methods fail. + */ + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + network: Network = ANY_NETWORK, + timeoutMs: Long = defaultTimeoutMs, + errorMsg: String? = null, + test: (T) -> Boolean = { true } + ) = expect<CallbackEntry>(network, timeoutMs, errorMsg) { + test(it as? T ?: fail("Expected callback ${type.simpleName}, got $it")) + } as T + + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + network: HasNetwork, + timeoutMs: Long = defaultTimeoutMs, + errorMsg: String? = null, + test: (T) -> Boolean = { true } + ) = expect(type, network.network, timeoutMs, errorMsg, test) + + // Java needs an explicit overload to let it omit arguments in the middle, so define these + // here. Note that @JvmOverloads give us the versions without the last arguments too, so + // there is no need to explicitly define versions without the test predicate. + // Without |network| + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + timeoutMs: Long, + errorMsg: String?, + test: (T) -> Boolean = { true } + ) = expect(type, ANY_NETWORK, timeoutMs, errorMsg, test) + + // Without |timeout|, in Network and HasNetwork versions + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + network: Network, + errorMsg: String?, + test: (T) -> Boolean = { true } + ) = expect(type, network, defaultTimeoutMs, errorMsg, test) + + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + network: HasNetwork, + errorMsg: String?, + test: (T) -> Boolean = { true } + ) = expect(type, network.network, defaultTimeoutMs, errorMsg, test) + + // Without |errorMsg|, in Network and HasNetwork versions + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + network: Network, + timeoutMs: Long, + test: (T) -> Boolean + ) = expect(type, network, timeoutMs, null, test) + + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + network: HasNetwork, + timeoutMs: Long, + test: (T) -> Boolean + ) = expect(type, network.network, timeoutMs, null, test) + + // Without |network| or |timeout| + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + errorMsg: String?, + test: (T) -> Boolean = { true } + ) = expect(type, ANY_NETWORK, defaultTimeoutMs, errorMsg, test) + + // Without |network| or |errorMsg| + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + timeoutMs: Long, + test: (T) -> Boolean = { true } + ) = expect(type, ANY_NETWORK, timeoutMs, null, test) + + // Without |timeout| or |errorMsg|, in Network and HasNetwork versions + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + network: Network, + test: (T) -> Boolean + ) = expect(type, network, defaultTimeoutMs, null, test) + + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + network: HasNetwork, + test: (T) -> Boolean + ) = expect(type, network.network, defaultTimeoutMs, null, test) + + // Without |network| or |timeout| or |errorMsg| + @JvmOverloads + fun <T : CallbackEntry> expect( + type: KClass<T>, + test: (T) -> Boolean + ) = expect(type, ANY_NETWORK, defaultTimeoutMs, null, test) + + // Kotlin reified versions. Don't call methods above, or the predicate would need to be noinline + inline fun <reified T : CallbackEntry> expect( + network: Network = ANY_NETWORK, + timeoutMs: Long = defaultTimeoutMs, + errorMsg: String? = null, + test: (T) -> Boolean = { true } + ) = pollOrThrow(timeoutMs).also { + if (it !is T) fail("Expected callback ${T::class.simpleName}, got $it") + if (ANY_NETWORK !== network && it.network != network) { + fail("Expected network $network for callback : $it") + } + if (!test(it)) { + fail("${errorMsg ?: "Callback doesn't match predicate"} : $it") + } + } as T + + inline fun <reified T : CallbackEntry> expect( + network: HasNetwork, + timeoutMs: Long = defaultTimeoutMs, + errorMsg: String? = null, + test: (T) -> Boolean = { true } + ) = expect(network.network, timeoutMs, errorMsg, test) + + // TODO : remove all expectCallback and expectCallbackThat methods after all callers have been + // migrated to expect(). + inline fun <reified T : CallbackEntry> expectCallback( + network: Network = ANY_NETWORK, + timeoutMs: Long = defaultTimeoutMs + ): T = expect(network, timeoutMs) + + @JvmOverloads + open fun <T : CallbackEntry> expectCallback( + type: KClass<T>, + n: Network?, + timeoutMs: Long = defaultTimeoutMs + ) = expect(type, n ?: ANY_NETWORK, timeoutMs) + + @JvmOverloads + open fun <T : CallbackEntry> expectCallback( + type: KClass<T>, + n: HasNetwork?, + timeoutMs: Long = defaultTimeoutMs + ) = expect(type, n?.network ?: ANY_NETWORK, timeoutMs) + + fun expectCallbackThat( + timeoutMs: Long = defaultTimeoutMs, + valid: (CallbackEntry) -> Boolean + ) = expect(timeoutMs = timeoutMs, test = valid) // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers. // TODO : remove the necessity to overload this, remove the open qualifier, and give a // default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary. - open fun assertNoCallback() = assertNoCallback(defaultTimeoutMs) + open fun assertNoCallback() = assertNoCallback(defaultNoCallbackTimeoutMs) fun assertNoCallback(timeoutMs: Long) { val cb = history.poll(timeoutMs) @@ -202,7 +380,7 @@ open class TestableNetworkCallback private constructor( } fun assertNoCallbackThat( - timeoutMs: Long = defaultTimeoutMs, + timeoutMs: Long = defaultNoCallbackTimeoutMs, valid: (CallbackEntry) -> Boolean ) { val cb = history.poll(timeoutMs) { valid(it) }.let { @@ -210,19 +388,6 @@ open class TestableNetworkCallback private constructor( } } - // Expects a callback of the specified type on the specified network within the timeout. - // If no callback arrives, or a different callback arrives, fail. Returns the callback. - inline fun <reified T : CallbackEntry> expectCallback( - network: Network = ANY_NETWORK, - timeoutMs: Long = defaultTimeoutMs - ): T = pollForNextCallback(timeoutMs).let { - if (it !is T || (ANY_NETWORK !== network && it.network != network)) { - fail("Unexpected callback : $it, expected ${T::class} with Network[$network]") - } else { - it - } - } - // Expects a callback of the specified type matching the predicate within the timeout. // Any callback that doesn't match the predicate will be skipped. Fails only if // no matching callback is received within the timeout. @@ -249,30 +414,19 @@ open class TestableNetworkCallback private constructor( crossinline predicate: (T) -> Boolean = { true } ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T? - fun expectCallbackThat( - timeoutMs: Long = defaultTimeoutMs, - valid: (CallbackEntry) -> Boolean - ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") } - - fun expectCapabilitiesThat( + inline fun expectCapabilitiesThat( net: Network, tmt: Long = defaultTimeoutMs, valid: (NetworkCapabilities) -> Boolean - ): CapabilitiesChanged { - return expectCallback<CapabilitiesChanged>(net, tmt).also { - assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}") - } - } + ): CapabilitiesChanged = + expect(net, tmt, "Capabilities don't match expectations") { valid(it.caps) } - fun expectLinkPropertiesThat( + inline fun expectLinkPropertiesThat( net: Network, tmt: Long = defaultTimeoutMs, valid: (LinkProperties) -> Boolean - ): LinkPropertiesChanged { - return expectCallback<LinkPropertiesChanged>(net, tmt).also { - assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}") - } - } + ): LinkPropertiesChanged = + expect(net, tmt, "LinkProperties don't match expectations") { valid(it.lp) } // Expects onAvailable and the callbacks that follow it. These are: // - onSuspended, iff the network was suspended when the callbacks fire. @@ -313,16 +467,16 @@ open class TestableNetworkCallback private constructor( validated: Boolean?, tmt: Long ) { - expectCallback<Available>(net, tmt) + expect<Available>(net, tmt) if (suspended) { - expectCallback<Suspended>(net, tmt) + expect<Suspended>(net, tmt) } expectCapabilitiesThat(net, tmt) { validated == null || validated == it.hasCapability( NET_CAPABILITY_VALIDATED ) } - expectCallback<LinkPropertiesChanged>(net, tmt) + expect<LinkPropertiesChanged>(net, tmt) } // Backward compatibility for existing Java code. Use named arguments instead and remove all @@ -333,17 +487,15 @@ open class TestableNetworkCallback private constructor( tmt: Long = defaultTimeoutMs ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt) - fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) { - expectCallback<BlockedStatus>(net, tmt).also { - assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}") - } - } + fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) = + expect<BlockedStatus>(net, tmt, "Unexpected blocked status") { + it.blocked == blocked + } - fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) { - expectCallback<BlockedStatusInt>(net, tmt).also { - assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}") - } - } + fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) = + expect<BlockedStatusInt>(net, tmt, "Unexpected blocked status") { + it.blocked == blocked + } // Expects the available callbacks (where the onCapabilitiesChanged must contain the // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the @@ -353,7 +505,7 @@ open class TestableNetworkCallback private constructor( val mark = history.mark expectAvailableCallbacks(net, tmt = tmt) val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged } - assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt)) + assertEquals(firstCaps, expect<CapabilitiesChanged>(net, tmt)) } // Expects the available callbacks where the onCapabilitiesChanged must not have validated, @@ -381,26 +533,6 @@ open class TestableNetworkCallback private constructor( val network: Network } - @JvmOverloads - open fun <T : CallbackEntry> expectCallback( - type: KClass<T>, - n: Network?, - timeoutMs: Long = defaultTimeoutMs - ) = pollForNextCallback(timeoutMs).also { - val network = n ?: NULL_NETWORK - // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of - // this writing this would be the only use of this library in the tests. - assertTrue(type.java.isInstance(it) && (ANY_NETWORK === n || it.network == network), - "Unexpected callback : $it, expected ${type.java} with Network[$network]") - } as T - - @JvmOverloads - open fun <T : CallbackEntry> expectCallback( - type: KClass<T>, - n: HasNetwork?, - timeoutMs: Long = defaultTimeoutMs - ) = expectCallback(type, n?.network, timeoutMs) - fun expectAvailableCallbacks( n: HasNetwork, suspended: Boolean, |