summaryrefslogtreecommitdiff
path: root/net/test/netlink.py
blob: b5efe11c82f6afece819f3eaff0359f7a21435d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
#!/usr/bin/python3
#
# Copyright 2014 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Partial Python implementation of iproute functionality."""

# pylint: disable=g-bad-todo

import os
import socket
import struct
import sys

import cstruct
import util

### Base netlink constants. See include/uapi/linux/netlink.h.
NETLINK_ROUTE = 0
NETLINK_SOCK_DIAG = 4
NETLINK_XFRM = 6
NETLINK_GENERIC = 16

# Request constants.
NLM_F_REQUEST = 1
NLM_F_ACK = 4
NLM_F_REPLACE = 0x100
NLM_F_EXCL = 0x200
NLM_F_CREATE = 0x400
NLM_F_DUMP = 0x300

# Message types.
NLMSG_ERROR = 2
NLMSG_DONE = 3

# Data structure formats.
# These aren't constants, they're classes. So, pylint: disable=invalid-name
NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")

# Alignment / padding.
NLA_ALIGNTO = 4

# List of attributes that can appear more than once in a given netlink message.
# These can appear more than once but don't seem to contain any data.
DUP_ATTRS_OK = ["INET_DIAG_NONE", "IFLA_PAD"]


def MakeConstantPrefixes(prefixes):
  return sorted(prefixes, key=len, reverse=True)


class NetlinkSocket(object):
  """A basic netlink socket object."""

  BUFSIZE = 65536
  DEBUG = False
  # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"]
  NL_DEBUG = []

  def _Debug(self, s):
    if self.DEBUG:
      print(s)

  def _NlAttr(self, nla_type, data):
    assert isinstance(data, bytes)
    datalen = len(data)
    # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long.
    padding = b"\x00" * util.GetPadLength(NLA_ALIGNTO, datalen)
    nla_len = datalen + len(NLAttr)
    return NLAttr((nla_len, nla_type)).Pack() + data + padding

  def _NlAttrIPAddress(self, nla_type, family, address):
    return self._NlAttr(nla_type, socket.inet_pton(family, address))

  def _NlAttrStr(self, nla_type, value):
    value = value + "\x00"
    return self._NlAttr(nla_type, value.encode("UTF-8"))

  def _NlAttrU32(self, nla_type, value):
    return self._NlAttr(nla_type, struct.pack("=I", value))

  @staticmethod
  def _GetConstantName(module, value, prefix):

    def FirstMatching(name, prefixlist):
      for prefix in prefixlist:
        if name.startswith(prefix):
         return prefix
      return None

    thismodule = sys.modules[module]
    constant_prefixes = getattr(thismodule, "CONSTANT_PREFIXES", [])
    for name in dir(thismodule):
      if value != getattr(thismodule, name) or not name.isupper():
        continue
      # If the module explicitly specifies prefixes, only return this name if
      # the passed-in prefix is the longest prefix that matches the name.
      # This ensures, for example, that passing in a prefix of "IFA_" and a
      # value of 1 returns "IFA_ADDRESS" instead of "IFA_F_SECONDARY".
      # The longest matching prefix is always the first matching prefix because
      # CONSTANT_PREFIXES must be sorted longest first.
      if constant_prefixes and prefix != FirstMatching(name, constant_prefixes):
        continue
      if name.startswith(prefix):
        return name
    return value

  def _Decode(self, command, msg, nla_type, nla_data, nested):
    """No-op, nonspecific version of decode."""
    return nla_type, nla_data

  def _ReadNlAttr(self, data):
    # Read the nlattr header.
    nla, data = cstruct.Read(data, NLAttr)

    # Read the data.
    datalen = nla.nla_len - len(nla)
    padded_len = util.GetPadLength(NLA_ALIGNTO, datalen) + datalen
    nla_data, data = data[:datalen], data[padded_len:]

    return nla, nla_data, data

  def _ParseAttributes(self, command, msg, data, nested):
    """Parses and decodes netlink attributes.

    Takes a block of NLAttr data structures, decodes them using Decode, and
    returns the result in a dict keyed by attribute number.

    Args:
      command: An integer, the rtnetlink command being carried out.
      msg: A Struct, the type of the data after the netlink header.
      data: A byte string containing a sequence of NLAttr data structures.
      nested: A list, outermost first, of each of the attributes the NLAttrs are
              nested inside. Empty for non-nested attributes.

    Returns:
      A dictionary mapping attribute types (integers) to decoded values.

    Raises:
      ValueError: There was a duplicate attribute type.
    """
    attributes = {}
    while data:
      nla, nla_data, data = self._ReadNlAttr(data)

      # If it's an attribute we know about, try to decode it.
      nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data, nested)

      if nla_name in attributes and nla_name not in DUP_ATTRS_OK:
        raise ValueError("Duplicate attribute %s" % nla_name)

      attributes[nla_name] = nla_data
      if not nested:
        self._Debug("      %s" % (str((nla_name, nla_data))))

    return attributes

  def _OpenNetlinkSocket(self, family, groups):
    sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, family)
    if groups:
      sock.bind((0,  groups))
    sock.connect((0, 0))  # The kernel.
    return sock

  def __init__(self, family, groups=None):
    # Global sequence number.
    self.seq = 0
    self.sock = self._OpenNetlinkSocket(family, groups)
    self.pid = self.sock.getsockname()[1]

  def close(self):
    self.sock.close()
    self.sock = None

  def __del__(self):
    if self.sock:
      self.close()

  def MaybeDebugCommand(self, command, flags, data):
    # Default no-op implementation to be overridden by subclasses.
    pass

  def _Send(self, msg):
    # self._Debug(msg.encode("hex"))
    self.seq += 1
    self.sock.send(msg)

  def _Recv(self):
    data = self.sock.recv(self.BUFSIZE)
    # self._Debug(data.encode("hex"))
    return data

  def _ExpectDone(self):
    response = self._Recv()
    hdr = NLMsgHdr(response)
    if hdr.type != NLMSG_DONE:
      raise ValueError("Expected DONE, got type %d" % hdr.type)

  def _ParseAck(self, response):
    # Find the error code.
    hdr, data = cstruct.Read(response, NLMsgHdr)
    if hdr.type == NLMSG_ERROR:
      error = -NLMsgErr(data).error
      if error:
        raise IOError(error, os.strerror(error))
    else:
      raise ValueError("Expected ACK, got type %d" % hdr.type)

  def _ExpectAck(self):
    response = self._Recv()
    self._ParseAck(response)

  def _SendNlRequest(self, command, data, flags):
    """Sends a netlink request and expects an ack."""
    length = len(NLMsgHdr) + len(data)
    nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack()

    self.MaybeDebugCommand(command, flags, nlmsg + data)

    # Send the message.
    self._Send(nlmsg + data)

    if flags & NLM_F_ACK:
      self._ExpectAck()

  def _ParseNLMsg(self, data, msgtype):
    """Parses a Netlink message into a header and a dictionary of attributes."""
    nlmsghdr, data = cstruct.Read(data, NLMsgHdr)
    self._Debug("  %s" % nlmsghdr)

    if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE:
      print("done")
      return (None, None), data

    nlmsg, data = cstruct.Read(data, msgtype)
    self._Debug("    %s" % nlmsg)

    # Parse the attributes in the nlmsg.
    attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg)
    attributes = self._ParseAttributes(nlmsghdr.type, nlmsg, data[:attrlen], [])
    data = data[attrlen:]
    return (nlmsg, attributes), data

  def _GetMsg(self, msgtype):
    data = self._Recv()
    if NLMsgHdr(data).type == NLMSG_ERROR:
      self._ParseAck(data)
    return self._ParseNLMsg(data, msgtype)[0]

  def _GetMsgList(self, msgtype, data, expect_done):
    out = []
    while data:
      msg, data = self._ParseNLMsg(data, msgtype)
      if msg is None:
        break
      out.append(msg)
    if expect_done:
      self._ExpectDone()
    return out

  def _Dump(self, command, msg, msgtype, attrs=b""):
    """Sends a dump request and returns a list of decoded messages.

    Args:
      command: An integer, the command to run (e.g., RTM_NEWADDR).
      msg: A struct, the request (e.g., a RTMsg). May be None.
      msgtype: A cstruct.Struct, the data type to parse the dump results as.
      attrs: A string, the raw bytes of any request attributes to include.

    Returns:
      A list of (msg, attrs) tuples where msg is of type msgtype and attrs is
      a dict of attributes.
    """
    # Create a netlink dump request containing the msg.
    flags = NLM_F_DUMP | NLM_F_REQUEST
    msg = b"" if msg is None else msg.Pack()
    length = len(NLMsgHdr) + len(msg) + len(attrs)
    nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid))

    # Send the request.
    request = nlmsghdr.Pack() + msg + attrs
    self.MaybeDebugCommand(command, flags, request)
    self._Send(request)

    # Keep reading netlink messages until we get a NLMSG_DONE.
    out = []
    while True:
      data = self._Recv()
      response_type = NLMsgHdr(data).type
      if response_type == NLMSG_DONE:
        break
      elif response_type == NLMSG_ERROR:
        # Likely means that the kernel didn't like our dump request.
        # Parse the error and throw an exception.
        self._ParseAck(data)
      out.extend(self._GetMsgList(msgtype, data, False))

    return out