-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnetstack.py
More file actions
571 lines (520 loc) · 18.8 KB
/
netstack.py
File metadata and controls
571 lines (520 loc) · 18.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
# Minimal network stack for USB-side networking
import struct
import socket
import select
import time
import gc
import config
ETH_HDR_LEN = 14
ETH_TYPE_ARP = 0x0806
ETH_TYPE_IP = 0x0800
BROADCAST_MAC = b"\xff\xff\xff\xff\xff\xff"
IP_HDR_MIN = 20
IP_PROTO_ICMP = 1
IP_PROTO_TCP = 6
IP_PROTO_UDP = 17
ICMP_ECHO_REQUEST = 8
ICMP_ECHO_REPLY = 0
ARP_REQUEST = 1
ARP_REPLY = 2
ARP_PKT_LEN = 28
DNS_PORT = 53
DHCP_SERVER_PORT = 67
DHCP_CLIENT_PORT = 68
DHCP_MAGIC = b"\x63\x82\x53\x63"
DHCP_DISCOVER = 1
DHCP_OFFER = 2
DHCP_REQUEST = 3
DHCP_ACK = 5
TCP_FIN = 0x01
TCP_SYN = 0x02
TCP_RST = 0x04
TCP_PSH = 0x08
TCP_ACK = 0x10
def _ip2bytes(ip_str):
return bytes(int(x) for x in ip_str.split("."))
def _bytes2ip(b):
return "{}.{}.{}.{}".format(b[0], b[1], b[2], b[3])
def inet_checksum(data):
if len(data) & 1:
data = data + b"\x00"
s = 0
for i in range(0, len(data), 2):
s += (data[i] << 8) | data[i + 1]
while s >> 16:
s = (s & 0xFFFF) + (s >> 16)
return (~s) & 0xFFFF
def _build_eth(dst, src, ethertype, payload):
return dst + src + struct.pack("!H", ethertype) + payload
def _parse_eth(frame):
if len(frame) < ETH_HDR_LEN:
return None
dst = frame[0:6]
src = frame[6:12]
etype = struct.unpack("!H", frame[12:14])[0]
return dst, src, etype, frame[ETH_HDR_LEN:]
def _build_ip(src, dst, proto, payload, ttl=64):
length = IP_HDR_MIN + len(payload)
hdr = struct.pack("!BBHHHBBH", 0x45, 0, length, 0, 0, ttl, proto, 0)
hdr += src + dst
csum = inet_checksum(hdr)
hdr = hdr[:10] + struct.pack("!H", csum) + hdr[12:]
return hdr + payload
def _parse_ip(data):
if len(data) < IP_HDR_MIN:
return None
ihl = (data[0] & 0x0F) * 4
length = struct.unpack("!H", data[2:4])[0]
proto = data[9]
src = data[12:16]
dst = data[16:20]
payload = data[ihl:length]
return src, dst, proto, payload
def _build_udp(src_port, dst_port, payload):
length = 8 + len(payload)
hdr = struct.pack("!HHHH", src_port, dst_port, length, 0)
return hdr + payload
def _parse_udp(data):
if len(data) < 8:
return None
src_port, dst_port, length = struct.unpack("!HHH", data[0:6])
payload = data[8:length]
return src_port, dst_port, payload
def _tcp_checksum(src_ip, dst_ip, tcp_data):
pseudo = src_ip + dst_ip + b"\x00" + struct.pack("!BH", IP_PROTO_TCP, len(tcp_data))
return inet_checksum(pseudo + tcp_data)
def _build_tcp(src_port, dst_port, seq, ack, flags, window, payload, src_ip, dst_ip):
data_offset = 5
offset_flags = (data_offset << 12) | flags
hdr = struct.pack(
"!HHIIHHHH", src_port, dst_port, seq, ack, offset_flags, window, 0, 0
)
tcp_data = hdr + payload
csum = _tcp_checksum(src_ip, dst_ip, tcp_data)
tcp_data = tcp_data[:16] + struct.pack("!H", csum) + tcp_data[18:]
return tcp_data
def _parse_tcp(data):
if len(data) < 20:
return None
src_port, dst_port = struct.unpack("!HH", data[0:4])
seq, ack = struct.unpack("!II", data[4:12])
offset_flags = struct.unpack("!H", data[12:14])[0]
data_offset = (offset_flags >> 12) * 4
flags = offset_flags & 0x3F
window = struct.unpack("!H", data[14:16])[0]
payload = data[data_offset:]
return src_port, dst_port, seq, ack, flags, window, payload
class Bridge:
def __init__(self, ecm, dns_server=None):
self.ecm = ecm
self.our_ip = _ip2bytes(config.USB_IP)
self.our_mac = config.USB_MAC
self.host_ip = _ip2bytes(config.USB_HOST_IP)
self.host_mac = None
self.mask = _ip2bytes(config.USB_MASK)
self.dns_server = dns_server or config.DNS_FALLBACK
self.tcp_conns = {}
self.udp_conns = {}
self.poller = select.poll()
self._isn = int(time.ticks_ms()) & 0xFFFFFFFF
self._dns_sock = None
self._dns_pending = {}
try:
self._dns_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._dns_sock.setblocking(False)
self._dns_sock.connect((_bytes2ip(_ip2bytes(self.dns_server)), DNS_PORT))
self.poller.register(self._dns_sock, select.POLLIN)
print("[bridge] DNS socket ready ->", self.dns_server)
except OSError as e:
print("[bridge] DNS socket failed:", e)
self._dns_sock = None
def process(self):
while True:
frame = self.ecm.recv_frame()
if not frame:
break
self._handle_frame(frame)
self._poll_wifi_sockets()
self._cleanup_stale()
self._drain_tx_queue()
def _handle_frame(self, frame):
parsed = _parse_eth(frame)
if not parsed:
return
dst, src, etype, payload = parsed
if self.host_mac is None and src != self.our_mac:
self.host_mac = src
if etype == ETH_TYPE_ARP:
self._handle_arp(src, payload)
elif etype == ETH_TYPE_IP:
self._handle_ip(src, payload)
def _handle_arp(self, eth_src, data):
if len(data) < ARP_PKT_LEN:
return
op = struct.unpack("!H", data[6:8])[0]
sender_mac = data[8:14]
sender_ip = data[14:18]
target_ip = data[24:28]
if op == ARP_REQUEST and target_ip == self.our_ip:
self.host_mac = sender_mac
reply = struct.pack("!HHBBH", 0x0001, 0x0800, 6, 4, ARP_REPLY)
reply += self.our_mac + self.our_ip
reply += sender_mac + sender_ip
eth = _build_eth(sender_mac, self.our_mac, ETH_TYPE_ARP, reply)
self.ecm.send_frame(eth)
def _handle_ip(self, eth_src, data):
parsed = _parse_ip(data)
if not parsed:
return
src_ip, dst_ip, proto, payload = parsed
if proto == IP_PROTO_ICMP:
self._handle_icmp(src_ip, dst_ip, payload)
elif proto == IP_PROTO_UDP:
self._handle_udp(src_ip, dst_ip, payload)
elif proto == IP_PROTO_TCP:
self._handle_tcp(src_ip, dst_ip, payload)
def _handle_icmp(self, src_ip, dst_ip, data):
if len(data) < 8:
return
icmp_type = data[0]
if icmp_type == ICMP_ECHO_REQUEST and dst_ip == self.our_ip:
reply = bytes([ICMP_ECHO_REPLY, 0]) + b"\x00\x00" + data[4:]
csum = inet_checksum(reply)
reply = reply[:2] + struct.pack("!H", csum) + reply[4:]
ip = _build_ip(self.our_ip, src_ip, IP_PROTO_ICMP, reply)
eth = _build_eth(
self.host_mac or BROADCAST_MAC, self.our_mac, ETH_TYPE_IP, ip
)
self.ecm.send_frame(eth)
def _handle_udp(self, src_ip, dst_ip, data):
parsed = _parse_udp(data)
if not parsed:
return
src_port, dst_port, payload = parsed
if dst_port == DHCP_SERVER_PORT and (
dst_ip == self.our_ip or dst_ip == b"\xff\xff\xff\xff"
):
self._handle_dhcp(payload)
return
if dst_port == DNS_PORT and dst_ip == self.our_ip:
if self._dns_sock and len(payload) >= 2:
txid = struct.unpack("!H", payload[:2])[0]
self._dns_pending[txid] = (src_port, time.ticks_ms())
try:
self._dns_sock.send(payload)
except OSError:
pass
return
if dst_ip[0] >= 224 or dst_ip == b"\xff\xff\xff\xff":
return
key = (src_port, bytes(dst_ip), dst_port)
conn = self.udp_conns.get(key)
if not conn:
if len(self.udp_conns) >= config.MAX_UDP_CONNS:
oldest = min(
self.udp_conns, key=lambda k: self.udp_conns[k].last_active
)
self._remove_udp(oldest)
conn = UDPConn(key, src_port, dst_ip, dst_port, self)
if not conn.sock:
return
self.udp_conns[key] = conn
self.poller.register(conn.sock, select.POLLIN)
conn.send(payload)
def _handle_tcp(self, src_ip, dst_ip, data):
parsed = _parse_tcp(data)
if not parsed:
return
src_port, dst_port, seq, ack, flags, window, payload = parsed
key = (src_port, bytes(dst_ip), dst_port)
if flags & TCP_RST:
conn = self.tcp_conns.get(key)
if conn:
self._remove_tcp(key)
return
if flags & TCP_SYN and not (flags & TCP_ACK):
if key in self.tcp_conns:
self._remove_tcp(key)
if len(self.tcp_conns) >= config.MAX_TCP_CONNS:
oldest = min(
self.tcp_conns, key=lambda k: self.tcp_conns[k].last_active
)
self._remove_tcp(oldest)
self._isn = (self._isn + 1) & 0xFFFFFFFF
conn = TCPConn(key, src_port, dst_ip, dst_port, seq, self._isn, self)
if conn.sock is None:
return
self.tcp_conns[key] = conn
self.poller.register(conn.sock, select.POLLIN)
self._send_tcp_to_host(conn, TCP_SYN | TCP_ACK, b"")
return
conn = self.tcp_conns.get(key)
if not conn:
self._send_tcp_rst(
src_ip, dst_ip, src_port, dst_port, ack, seq + len(payload)
)
return
conn.last_active = time.ticks_ms()
if flags & TCP_FIN:
conn.host_ack = seq + 1
self._send_tcp_to_host(conn, TCP_ACK | TCP_FIN, b"")
self._remove_tcp(key)
return
if payload:
if conn.send_buf or not conn.connected:
conn.send_buf += payload
else:
try:
conn.sock.send(payload)
except OSError:
self._send_tcp_to_host(conn, TCP_RST, b"")
self._remove_tcp(key)
return
conn.host_ack = seq + len(payload)
self._send_tcp_to_host(conn, TCP_ACK, b"")
def _remove_tcp(self, key):
conn = self.tcp_conns.pop(key, None)
if conn and conn.sock:
try:
self.poller.unregister(conn.sock)
except Exception:
pass
conn.close()
def _remove_udp(self, key):
conn = self.udp_conns.pop(key, None)
if conn and conn.sock:
try:
self.poller.unregister(conn.sock)
except Exception:
pass
conn.close()
def _send_tcp_to_host(self, conn, flags, payload):
tcp = _build_tcp(
conn.dst_port,
conn.src_port,
conn.our_seq,
conn.host_ack,
flags,
16384,
payload,
conn.dst_ip,
self.host_ip,
)
ip = _build_ip(conn.dst_ip, self.host_ip, IP_PROTO_TCP, tcp)
eth = _build_eth(self.host_mac or BROADCAST_MAC, self.our_mac, ETH_TYPE_IP, ip)
self.ecm.send_frame(eth)
if payload:
conn.our_seq = (conn.our_seq + len(payload)) & 0xFFFFFFFF
if flags & (TCP_SYN | TCP_FIN):
conn.our_seq = (conn.our_seq + 1) & 0xFFFFFFFF
def _send_tcp_rst(self, src_ip, dst_ip, src_port, dst_port, seq, ack):
tcp = _build_tcp(
dst_port,
src_port,
seq,
ack,
TCP_RST | TCP_ACK,
0,
b"",
dst_ip,
self.host_ip,
)
ip = _build_ip(dst_ip, self.host_ip, IP_PROTO_TCP, tcp)
eth = _build_eth(self.host_mac or BROADCAST_MAC, self.our_mac, ETH_TYPE_IP, ip)
self.ecm.send_frame(eth)
def _drain_tx_queue(self):
pass
def _poll_wifi_sockets(self):
for key, conn in list(self.tcp_conns.items()):
if conn.send_buf:
try:
conn.sock.send(conn.send_buf)
conn.send_buf = b""
conn.connected = True
except OSError:
if time.ticks_diff(time.ticks_ms(), conn.connect_deadline) > 0:
self._send_tcp_to_host(conn, TCP_RST, b"")
self._remove_tcp(key)
events = self.poller.poll(0)
for sock, event in events:
if event & select.POLLIN:
self._handle_sock_data(sock)
if event & (select.POLLHUP | select.POLLERR):
self._handle_sock_close(sock)
def _handle_sock_data(self, sock):
if self._dns_sock and sock is self._dns_sock:
try:
data = sock.recv(1500)
except OSError:
return
if data and len(data) >= 2:
txid = struct.unpack("!H", data[:2])[0]
entry = self._dns_pending.pop(txid, None)
if entry:
client_port = entry[0]
udp = _build_udp(DNS_PORT, client_port, data)
ip = _build_ip(self.our_ip, self.host_ip, IP_PROTO_UDP, udp)
eth = _build_eth(
self.host_mac or BROADCAST_MAC, self.our_mac, ETH_TYPE_IP, ip
)
self.ecm.send_frame(eth)
return
for key, conn in list(self.tcp_conns.items()):
if conn.sock == sock:
try:
data = sock.recv(1500)
except OSError:
return
if data:
conn.last_active = time.ticks_ms()
self._send_tcp_to_host(conn, TCP_ACK | TCP_PSH, data)
else:
self._send_tcp_to_host(conn, TCP_FIN | TCP_ACK, b"")
self._remove_tcp(key)
return
for key, conn in list(self.udp_conns.items()):
if conn.sock == sock:
try:
data = sock.recv(1500)
except OSError:
return
if data:
conn.last_active = time.ticks_ms()
udp = _build_udp(conn.dst_port, conn.src_port, data)
ip = _build_ip(conn.dst_ip, self.host_ip, IP_PROTO_UDP, udp)
eth = _build_eth(
self.host_mac or BROADCAST_MAC, self.our_mac, ETH_TYPE_IP, ip
)
self.ecm.send_frame(eth)
return
def _handle_sock_close(self, sock):
for key, conn in list(self.tcp_conns.items()):
if conn.sock == sock:
self._send_tcp_to_host(conn, TCP_FIN | TCP_ACK, b"")
self._remove_tcp(key)
return
for key, conn in list(self.udp_conns.items()):
if conn.sock == sock:
self._remove_udp(key)
return
def _cleanup_stale(self):
now = time.ticks_ms()
for key in list(self.tcp_conns):
conn = self.tcp_conns[key]
if time.ticks_diff(now, conn.last_active) > config.TCP_TIMEOUT * 1000:
self._remove_tcp(key)
for key in list(self.udp_conns):
conn = self.udp_conns[key]
if time.ticks_diff(now, conn.last_active) > config.UDP_TIMEOUT * 1000:
self._remove_udp(key)
for txid in list(self._dns_pending):
if time.ticks_diff(now, self._dns_pending[txid][1]) > 5000:
del self._dns_pending[txid]
def _handle_dhcp(self, data):
if len(data) < 240:
return
op = data[0]
if op != 1:
return
xid = data[4:8]
client_mac = data[28:34]
self.host_mac = client_mac
options = data[240:]
msg_type = None
i = 0
while i < len(options):
if options[i] == 255:
break
if options[i] == 0:
i += 1
continue
opt = options[i]
olen = options[i + 1]
if opt == 53 and olen == 1:
msg_type = options[i + 2]
i += 2 + olen
if msg_type not in (DHCP_DISCOVER, DHCP_REQUEST):
return
reply_type = DHCP_OFFER if msg_type == DHCP_DISCOVER else DHCP_ACK
reply = bytearray(300)
reply[0] = 2
reply[1] = 1
reply[2] = 6
reply[4:8] = xid
reply[16:20] = _ip2bytes(config.USB_HOST_IP)
reply[20:24] = self.our_ip
reply[28:34] = client_mac
reply[236:240] = DHCP_MAGIC
opts = bytearray()
opts += bytes([53, 1, reply_type])
opts += bytes([54, 4]) + self.our_ip
opts += bytes([51, 4, 0, 0, 0x1C, 0x20])
opts += bytes([1, 4]) + self.mask
opts += bytes([3, 4]) + self.our_ip
dns = _ip2bytes(self.dns_server)
opts += bytes([6, 4]) + dns
opts += bytes([255])
reply[240 : 240 + len(opts)] = opts
reply = bytes(reply[: 240 + len(opts)])
udp = _build_udp(DHCP_SERVER_PORT, DHCP_CLIENT_PORT, reply)
ip = _build_ip(self.our_ip, b"\xff\xff\xff\xff", IP_PROTO_UDP, udp)
eth = _build_eth(BROADCAST_MAC, self.our_mac, ETH_TYPE_IP, ip)
self.ecm.send_frame(eth)
class TCPConn:
def __init__(self, key, src_port, dst_ip, dst_port, host_seq, our_isn, bridge):
self.key = key
self.src_port = src_port
self.dst_ip = dst_ip
self.dst_port = dst_port
self.host_ack = host_seq + 1
self.our_seq = our_isn
self.last_active = time.ticks_ms()
self.connected = False
self.send_buf = b""
self.connect_deadline = time.ticks_add(time.ticks_ms(), 5000)
self.sock = None
try:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setblocking(False)
try:
self.sock.connect((_bytes2ip(dst_ip), dst_port))
except OSError:
pass
except OSError:
self.sock = None
def close(self):
if self.sock:
try:
self.sock.close()
except OSError:
pass
self.sock = None
class UDPConn:
def __init__(self, key, src_port, dst_ip, dst_port, bridge, override_src_ip=None):
self.key = key
self.src_port = src_port
self.dst_ip = dst_ip
self.dst_port = dst_port
self.last_active = time.ticks_ms()
self.override_src_ip = override_src_ip
self.sock = None
try:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.setblocking(False)
self.sock.connect((_bytes2ip(dst_ip), dst_port))
except OSError:
self.sock = None
def send(self, data):
if self.sock:
try:
self.sock.send(data)
self.last_active = time.ticks_ms()
except OSError:
pass
def close(self):
if self.sock:
try:
self.sock.close()
except OSError:
pass
self.sock = None