Release 260111

This commit is contained in:
Comma Device
2026-01-11 18:23:29 +08:00
commit 3721ecbf8a
2601 changed files with 855070 additions and 0 deletions

View File

842
system/athena/athenad.py Executable file
View File

@@ -0,0 +1,842 @@
#!/usr/bin/env python3
from __future__ import annotations
import base64
import hashlib
import io
import json
import os
import queue
import random
import select
import socket
import sys
import tempfile
import threading
import time
from dataclasses import asdict, dataclass, replace
from datetime import datetime
from functools import partial, total_ordering
from queue import Queue
from typing import cast
from collections.abc import Callable
import requests
from jsonrpc import JSONRPCResponseManager, dispatcher
from websocket import (ABNF, WebSocket, WebSocketException, WebSocketTimeoutException,
create_connection)
import cereal.messaging as messaging
from cereal import log
from cereal.services import SERVICE_LIST
from openpilot.common.api import Api
from openpilot.common.file_helpers import CallbackReader, get_upload_stream
from openpilot.common.params import Params
from openpilot.common.realtime import set_core_affinity
from openpilot.system.hardware import HARDWARE, PC
from openpilot.system.loggerd.xattr_cache import getxattr, setxattr
from openpilot.common.swaglog import cloudlog
from openpilot.system.version import get_build_metadata
from openpilot.system.hardware.hw import Paths
ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai')
HANDLER_THREADS = int(os.getenv('HANDLER_THREADS', "4"))
LOCAL_PORT_WHITELIST = {22, } # SSH
LOG_ATTR_NAME = 'user.upload'
LOG_ATTR_VALUE_MAX_UNIX_TIME = int.to_bytes(2147483647, 4, sys.byteorder)
RECONNECT_TIMEOUT_S = 70
RETRY_DELAY = 10 # seconds
MAX_RETRY_COUNT = 30 # Try for at most 5 minutes if upload fails immediately
MAX_AGE = 31 * 24 * 3600 # seconds
WS_FRAME_SIZE = 4096
DEVICE_STATE_UPDATE_INTERVAL = 1.0 # in seconds
DEFAULT_UPLOAD_PRIORITY = 99 # higher number = lower priority
NetworkType = log.DeviceState.NetworkType
UploadFileDict = dict[str, str | int | float | bool]
UploadItemDict = dict[str, str | bool | int | float | dict[str, str]]
UploadFilesToUrlResponse = dict[str, int | list[UploadItemDict] | list[str]]
@dataclass
class UploadFile:
fn: str
url: str
headers: dict[str, str]
allow_cellular: bool
priority: int = DEFAULT_UPLOAD_PRIORITY
@classmethod
def from_dict(cls, d: dict) -> UploadFile:
return cls(d.get("fn", ""), d.get("url", ""), d.get("headers", {}), d.get("allow_cellular", False), d.get("priority", DEFAULT_UPLOAD_PRIORITY))
@dataclass
@total_ordering
class UploadItem:
path: str
url: str
headers: dict[str, str]
created_at: int
id: str | None
retry_count: int = 0
current: bool = False
progress: float = 0
allow_cellular: bool = False
priority: int = DEFAULT_UPLOAD_PRIORITY
@classmethod
def from_dict(cls, d: dict) -> UploadItem:
return cls(d["path"], d["url"], d["headers"], d["created_at"], d["id"], d["retry_count"], d["current"],
d["progress"], d["allow_cellular"], d["priority"])
def __lt__(self, other):
if not isinstance(other, UploadItem):
return NotImplemented
return self.priority < other.priority
def __eq__(self, other):
if not isinstance(other, UploadItem):
return NotImplemented
return self.priority == other.priority
dispatcher["echo"] = lambda s: s
recv_queue: Queue[str] = queue.Queue()
send_queue: Queue[str] = queue.Queue()
upload_queue: Queue[UploadItem] = queue.PriorityQueue()
low_priority_send_queue: Queue[str] = queue.Queue()
log_recv_queue: Queue[str] = queue.Queue()
cancelled_uploads: set[str] = set()
cur_upload_items: dict[int, UploadItem | None] = {}
def strip_zst_extension(fn: str) -> str:
if fn.endswith('.zst'):
return fn[:-4]
return fn
class AbortTransferException(Exception):
pass
class UploadQueueCache:
@staticmethod
def initialize(upload_queue: Queue[UploadItem]) -> None:
try:
upload_queue_json = Params().get("AthenadUploadQueue")
if upload_queue_json is not None:
for item in json.loads(upload_queue_json):
upload_queue.put(UploadItem.from_dict(item))
except Exception:
cloudlog.exception("athena.UploadQueueCache.initialize.exception")
@staticmethod
def cache(upload_queue: Queue[UploadItem]) -> None:
try:
queue: list[UploadItem | None] = list(upload_queue.queue)
items = [asdict(i) for i in queue if i is not None and (i.id not in cancelled_uploads)]
Params().put("AthenadUploadQueue", json.dumps(items))
except Exception:
cloudlog.exception("athena.UploadQueueCache.cache.exception")
def handle_long_poll(ws: WebSocket, exit_event: threading.Event | None) -> None:
end_event = threading.Event()
threads = [
threading.Thread(target=ws_manage, args=(ws, end_event), name='ws_manage'),
threading.Thread(target=ws_recv, args=(ws, end_event), name='ws_recv'),
threading.Thread(target=ws_send, args=(ws, end_event), name='ws_send'),
threading.Thread(target=upload_handler, args=(end_event,), name='upload_handler'),
threading.Thread(target=upload_handler, args=(end_event,), name='upload_handler2'),
threading.Thread(target=upload_handler, args=(end_event,), name='upload_handler3'),
threading.Thread(target=upload_handler, args=(end_event,), name='upload_handler4'),
threading.Thread(target=log_handler, args=(end_event,), name='log_handler'),
threading.Thread(target=stat_handler, args=(end_event,), name='stat_handler'),
] + [
threading.Thread(target=jsonrpc_handler, args=(end_event,), name=f'worker_{x}')
for x in range(HANDLER_THREADS)
]
for thread in threads:
thread.start()
try:
while not end_event.wait(0.1):
if exit_event is not None and exit_event.is_set():
end_event.set()
except (KeyboardInterrupt, SystemExit):
end_event.set()
raise
finally:
for thread in threads:
cloudlog.debug(f"athena.joining {thread.name}")
thread.join()
def jsonrpc_handler(end_event: threading.Event) -> None:
dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event)
while not end_event.is_set():
try:
data = recv_queue.get(timeout=1)
if "method" in data:
cloudlog.event("athena.jsonrpc_handler.call_method", data=data)
response = JSONRPCResponseManager.handle(data, dispatcher)
send_queue.put_nowait(response.json)
elif "id" in data and ("result" in data or "error" in data):
log_recv_queue.put_nowait(data)
else:
raise Exception("not a valid request or response")
except queue.Empty:
pass
except Exception as e:
cloudlog.exception("athena jsonrpc handler failed")
send_queue.put_nowait(json.dumps({"error": str(e)}))
def retry_upload(tid: int, end_event: threading.Event, increase_count: bool = True) -> None:
item = cur_upload_items[tid]
if item is not None and item.retry_count < MAX_RETRY_COUNT:
new_retry_count = item.retry_count + 1 if increase_count else item.retry_count
item = replace(
item,
retry_count=new_retry_count,
progress=0,
current=False
)
upload_queue.put_nowait(item)
UploadQueueCache.cache(upload_queue)
cur_upload_items[tid] = None
for _ in range(RETRY_DELAY):
time.sleep(1)
if end_event.is_set():
break
def cb(sm, item, tid, end_event: threading.Event, sz: int, cur: int) -> None:
# Abort transfer if connection changed to metered after starting upload
# or if athenad is shutting down to re-connect the websocket
if not item.allow_cellular:
if (time.monotonic() - sm.recv_time['deviceState']) > DEVICE_STATE_UPDATE_INTERVAL:
sm.update(0)
if sm['deviceState'].networkMetered:
raise AbortTransferException
if end_event.is_set():
raise AbortTransferException
cur_upload_items[tid] = replace(item, progress=cur / sz if sz else 1)
def upload_handler(end_event: threading.Event) -> None:
sm = messaging.SubMaster(['deviceState'])
tid = threading.get_ident()
while not end_event.is_set():
cur_upload_items[tid] = None
try:
cur_upload_items[tid] = item = replace(upload_queue.get(timeout=1), current=True)
if item.id in cancelled_uploads:
cancelled_uploads.remove(item.id)
continue
# Remove item if too old
age = datetime.now() - datetime.fromtimestamp(item.created_at / 1000)
if age.total_seconds() > MAX_AGE:
cloudlog.event("athena.upload_handler.expired", item=item, error=True)
continue
# Check if uploading over metered connection is allowed
sm.update(0)
metered = sm['deviceState'].networkMetered
network_type = sm['deviceState'].networkType.raw
if metered and (not item.allow_cellular):
retry_upload(tid, end_event, False)
continue
try:
fn = item.path
try:
sz = os.path.getsize(fn)
except OSError:
sz = -1
cloudlog.event("athena.upload_handler.upload_start", fn=fn, sz=sz, network_type=network_type, metered=metered, retry_count=item.retry_count)
with _do_upload(item, partial(cb, sm, item, tid, end_event)) as response:
if response.status_code not in (200, 201, 401, 403, 412):
cloudlog.event("athena.upload_handler.retry", status_code=response.status_code, fn=fn, sz=sz, network_type=network_type, metered=metered)
retry_upload(tid, end_event)
else:
cloudlog.event("athena.upload_handler.success", fn=fn, sz=sz, network_type=network_type, metered=metered)
UploadQueueCache.cache(upload_queue)
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.SSLError):
cloudlog.event("athena.upload_handler.timeout", fn=fn, sz=sz, network_type=network_type, metered=metered)
retry_upload(tid, end_event)
except AbortTransferException:
cloudlog.event("athena.upload_handler.abort", fn=fn, sz=sz, network_type=network_type, metered=metered)
retry_upload(tid, end_event, False)
except queue.Empty:
pass
except Exception:
cloudlog.exception("athena.upload_handler.exception")
def _do_upload(upload_item: UploadItem, callback: Callable = None) -> requests.Response:
path = upload_item.path
compress = False
# If file does not exist, but does exist without the .zst extension we will compress on the fly
if not os.path.exists(path) and os.path.exists(strip_zst_extension(path)):
path = strip_zst_extension(path)
compress = True
stream = None
try:
stream, content_length = get_upload_stream(path, compress)
response = requests.put(upload_item.url,
data=CallbackReader(stream, callback, content_length) if callback else stream,
headers={**upload_item.headers, 'Content-Length': str(content_length)},
timeout=30)
return response
finally:
if stream:
stream.close()
# security: user should be able to request any message from their car
@dispatcher.add_method
def getMessage(service: str, timeout: int = 1000) -> dict:
if service is None or service not in SERVICE_LIST:
raise Exception("invalid service")
socket = messaging.sub_sock(service, timeout=timeout)
try:
ret = messaging.recv_one(socket)
if ret is None:
raise TimeoutError
# this is because capnp._DynamicStructReader doesn't have typing information
return cast(dict, ret.to_dict())
finally:
del socket
@dispatcher.add_method
def getVersion() -> dict[str, str]:
build_metadata = get_build_metadata()
return {
"version": build_metadata.openpilot.version,
"remote": build_metadata.openpilot.git_normalized_origin,
"branch": build_metadata.channel,
"commit": build_metadata.openpilot.git_commit,
}
def scan_dir(path: str, prefix: str) -> list[str]:
files = []
# only walk directories that match the prefix
# (glob and friends traverse entire dir tree)
with os.scandir(path) as i:
for e in i:
rel_path = os.path.relpath(e.path, Paths.log_root())
if e.is_dir(follow_symlinks=False):
# add trailing slash
rel_path = os.path.join(rel_path, '')
# if prefix is a partial dir name, current dir will start with prefix
# if prefix is a partial file name, prefix with start with dir name
if rel_path.startswith(prefix) or prefix.startswith(rel_path):
files.extend(scan_dir(e.path, prefix))
else:
if rel_path.startswith(prefix):
files.append(rel_path)
return files
@dispatcher.add_method
def listDataDirectory(prefix='') -> list[str]:
return scan_dir(Paths.log_root(), prefix)
@dispatcher.add_method
def uploadFileToUrl(fn: str, url: str, headers: dict[str, str]) -> UploadFilesToUrlResponse:
# this is because mypy doesn't understand that the decorator doesn't change the return type
response: UploadFilesToUrlResponse = uploadFilesToUrls([{
"fn": fn,
"url": url,
"headers": headers,
}])
return response
@dispatcher.add_method
def uploadFilesToUrls(files_data: list[UploadFileDict]) -> UploadFilesToUrlResponse:
files = map(UploadFile.from_dict, files_data)
items: list[UploadItemDict] = []
failed: list[str] = []
for file in files:
if len(file.fn) == 0 or file.fn[0] == '/' or '..' in file.fn or len(file.url) == 0:
failed.append(file.fn)
continue
path = os.path.join(Paths.log_root(), file.fn)
if not os.path.exists(path) and not os.path.exists(strip_zst_extension(path)):
failed.append(file.fn)
continue
# Skip item if already in queue
url = file.url.split('?')[0]
if any(url == item['url'].split('?')[0] for item in listUploadQueue()):
continue
item = UploadItem(
path=path,
url=file.url,
headers=file.headers,
created_at=int(time.time() * 1000),
id=None,
allow_cellular=file.allow_cellular,
priority=file.priority,
)
upload_id = hashlib.sha1(str(item).encode()).hexdigest()
item = replace(item, id=upload_id)
upload_queue.put_nowait(item)
items.append(asdict(item))
UploadQueueCache.cache(upload_queue)
resp: UploadFilesToUrlResponse = {"enqueued": len(items), "items": items}
if failed:
cloudlog.event("athena.uploadFilesToUrls.failed", failed=failed, error=True)
resp["failed"] = failed
return resp
@dispatcher.add_method
def listUploadQueue() -> list[UploadItemDict]:
items = list(upload_queue.queue) + list(cur_upload_items.values())
return [asdict(i) for i in items if (i is not None) and (i.id not in cancelled_uploads)]
@dispatcher.add_method
def cancelUpload(upload_id: str | list[str]) -> dict[str, int | str]:
if not isinstance(upload_id, list):
upload_id = [upload_id]
uploading_ids = {item.id for item in list(upload_queue.queue)}
cancelled_ids = uploading_ids.intersection(upload_id)
if len(cancelled_ids) == 0:
return {"success": 0, "error": "not found"}
cancelled_uploads.update(cancelled_ids)
return {"success": 1}
@dispatcher.add_method
def setRouteViewed(route: str) -> dict[str, int | str]:
# maintain a list of the last 10 routes viewed in connect
params = Params()
r = params.get("AthenadRecentlyViewedRoutes", encoding="utf8")
routes = [] if r is None else r.split(",")
routes.append(route)
# remove duplicates
routes = list(dict.fromkeys(routes))
params.put("AthenadRecentlyViewedRoutes", ",".join(routes[-10:]))
return {"success": 1}
def startLocalProxy(global_end_event: threading.Event, remote_ws_uri: str, local_port: int) -> dict[str, int]:
try:
# migration, can be removed once 0.9.8 is out for a while
if local_port == 8022:
local_port = 22
if local_port not in LOCAL_PORT_WHITELIST:
raise Exception("Requested local port not whitelisted")
cloudlog.debug("athena.startLocalProxy.starting")
dongle_id = Params().get("DongleId").decode('utf8')
identity_token = Api(dongle_id).get_token()
ws = create_connection(remote_ws_uri,
cookie="jwt=" + identity_token,
enable_multithread=True)
# Set TOS to keep connection responsive while under load.
# DSCP of 36/HDD_LINUX_AC_VI with the minimum delay flag
ws.sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 0x90)
ssock, csock = socket.socketpair()
local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
local_sock.connect(('127.0.0.1', local_port))
local_sock.setblocking(False)
proxy_end_event = threading.Event()
threads = [
threading.Thread(target=ws_proxy_recv, args=(ws, local_sock, ssock, proxy_end_event, global_end_event)),
threading.Thread(target=ws_proxy_send, args=(ws, local_sock, csock, proxy_end_event))
]
for thread in threads:
thread.start()
cloudlog.debug("athena.startLocalProxy.started")
return {"success": 1}
except Exception as e:
cloudlog.exception("athenad.startLocalProxy.exception")
raise e
@dispatcher.add_method
def getPublicKey() -> str | None:
if not os.path.isfile(Paths.persist_root() + '/comma/id_rsa.pub'):
return None
with open(Paths.persist_root() + '/comma/id_rsa.pub') as f:
return f.read()
@dispatcher.add_method
def getSshAuthorizedKeys() -> str:
return Params().get("GithubSshKeys", encoding='utf8') or ''
@dispatcher.add_method
def getGithubUsername() -> str:
return Params().get("GithubUsername", encoding='utf8') or ''
@dispatcher.add_method
def getSimInfo():
return HARDWARE.get_sim_info()
@dispatcher.add_method
def getNetworkType():
return HARDWARE.get_network_type()
@dispatcher.add_method
def getNetworkMetered() -> bool:
network_type = HARDWARE.get_network_type()
return HARDWARE.get_network_metered(network_type)
@dispatcher.add_method
def getNetworks():
return HARDWARE.get_networks()
@dispatcher.add_method
def takeSnapshot() -> str | dict[str, str] | None:
from openpilot.system.camerad.snapshot.snapshot import jpeg_write, snapshot
ret = snapshot()
if ret is not None:
def b64jpeg(x):
if x is not None:
f = io.BytesIO()
jpeg_write(f, x)
return base64.b64encode(f.getvalue()).decode("utf-8")
else:
return None
return {'jpegBack': b64jpeg(ret[0]),
'jpegFront': b64jpeg(ret[1])}
else:
raise Exception("not available while camerad is started")
def get_logs_to_send_sorted() -> list[str]:
# TODO: scan once then use inotify to detect file creation/deletion
curr_time = int(time.time())
logs = []
for log_entry in os.listdir(Paths.swaglog_root()):
log_path = os.path.join(Paths.swaglog_root(), log_entry)
time_sent = 0
try:
value = getxattr(log_path, LOG_ATTR_NAME)
if value is not None:
time_sent = int.from_bytes(value, sys.byteorder)
except (ValueError, TypeError):
pass
# assume send failed and we lost the response if sent more than one hour ago
if not time_sent or curr_time - time_sent > 3600:
logs.append(log_entry)
# excluding most recent (active) log file
return sorted(logs)[:-1]
def log_handler(end_event: threading.Event) -> None:
if PC:
return
log_files = []
last_scan = 0.
while not end_event.is_set():
try:
curr_scan = time.monotonic()
if curr_scan - last_scan > 10:
log_files = get_logs_to_send_sorted()
last_scan = curr_scan
# send one log
curr_log = None
if len(log_files) > 0:
log_entry = log_files.pop() # newest log file
cloudlog.debug(f"athena.log_handler.forward_request {log_entry}")
try:
curr_time = int(time.time())
log_path = os.path.join(Paths.swaglog_root(), log_entry)
setxattr(log_path, LOG_ATTR_NAME, int.to_bytes(curr_time, 4, sys.byteorder))
with open(log_path) as f:
jsonrpc = {
"method": "forwardLogs",
"params": {
"logs": f.read()
},
"jsonrpc": "2.0",
"id": log_entry
}
low_priority_send_queue.put_nowait(json.dumps(jsonrpc))
curr_log = log_entry
except OSError:
pass # file could be deleted by log rotation
# wait for response up to ~100 seconds
# always read queue at least once to process any old responses that arrive
for _ in range(100):
if end_event.is_set():
break
try:
log_resp = json.loads(log_recv_queue.get(timeout=1))
log_entry = log_resp.get("id")
log_success = "result" in log_resp and log_resp["result"].get("success")
cloudlog.debug(f"athena.log_handler.forward_response {log_entry} {log_success}")
if log_entry and log_success:
log_path = os.path.join(Paths.swaglog_root(), log_entry)
try:
setxattr(log_path, LOG_ATTR_NAME, LOG_ATTR_VALUE_MAX_UNIX_TIME)
except OSError:
pass # file could be deleted by log rotation
if curr_log == log_entry:
break
except queue.Empty:
if curr_log is None:
break
except Exception:
cloudlog.exception("athena.log_handler.exception")
def stat_handler(end_event: threading.Event) -> None:
STATS_DIR = Paths.stats_root()
last_scan = 0.0
while not end_event.is_set():
curr_scan = time.monotonic()
try:
if curr_scan - last_scan > 10:
stat_filenames = list(filter(lambda name: not name.startswith(tempfile.gettempprefix()), os.listdir(STATS_DIR)))
if len(stat_filenames) > 0:
stat_path = os.path.join(STATS_DIR, stat_filenames[0])
with open(stat_path) as f:
jsonrpc = {
"method": "storeStats",
"params": {
"stats": f.read()
},
"jsonrpc": "2.0",
"id": stat_filenames[0]
}
low_priority_send_queue.put_nowait(json.dumps(jsonrpc))
os.remove(stat_path)
last_scan = curr_scan
except Exception:
cloudlog.exception("athena.stat_handler.exception")
time.sleep(0.1)
def ws_proxy_recv(ws: WebSocket, local_sock: socket.socket, ssock: socket.socket, end_event: threading.Event, global_end_event: threading.Event) -> None:
while not (end_event.is_set() or global_end_event.is_set()):
try:
r = select.select((ws.sock,), (), (), 30)
if r[0]:
data = ws.recv()
if isinstance(data, str):
data = data.encode("utf-8")
local_sock.sendall(data)
except WebSocketTimeoutException:
pass
except Exception:
cloudlog.exception("athenad.ws_proxy_recv.exception")
break
cloudlog.debug("athena.ws_proxy_recv closing sockets")
ssock.close()
local_sock.close()
ws.close()
cloudlog.debug("athena.ws_proxy_recv done closing sockets")
end_event.set()
def ws_proxy_send(ws: WebSocket, local_sock: socket.socket, signal_sock: socket.socket, end_event: threading.Event) -> None:
while not end_event.is_set():
try:
r, _, _ = select.select((local_sock, signal_sock), (), ())
if r:
if r[0].fileno() == signal_sock.fileno():
# got end signal from ws_proxy_recv
end_event.set()
break
data = local_sock.recv(4096)
if not data:
# local_sock is dead
end_event.set()
break
ws.send(data, ABNF.OPCODE_BINARY)
except Exception:
cloudlog.exception("athenad.ws_proxy_send.exception")
end_event.set()
cloudlog.debug("athena.ws_proxy_send closing sockets")
signal_sock.close()
cloudlog.debug("athena.ws_proxy_send done closing sockets")
def ws_recv(ws: WebSocket, end_event: threading.Event) -> None:
last_ping = int(time.monotonic() * 1e9)
while not end_event.is_set():
try:
opcode, data = ws.recv_data(control_frame=True)
if opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
if opcode == ABNF.OPCODE_TEXT:
data = data.decode("utf-8")
recv_queue.put_nowait(data)
elif opcode == ABNF.OPCODE_PING:
last_ping = int(time.monotonic() * 1e9)
Params().put("LastAthenaPingTime", str(last_ping))
except WebSocketTimeoutException:
ns_since_last_ping = int(time.monotonic() * 1e9) - last_ping
if ns_since_last_ping > RECONNECT_TIMEOUT_S * 1e9:
cloudlog.exception("athenad.ws_recv.timeout")
end_event.set()
except Exception:
cloudlog.exception("athenad.ws_recv.exception")
end_event.set()
def ws_send(ws: WebSocket, end_event: threading.Event) -> None:
while not end_event.is_set():
try:
try:
data = send_queue.get_nowait()
except queue.Empty:
data = low_priority_send_queue.get(timeout=1)
for i in range(0, len(data), WS_FRAME_SIZE):
frame = data[i:i+WS_FRAME_SIZE]
last = i + WS_FRAME_SIZE >= len(data)
opcode = ABNF.OPCODE_TEXT if i == 0 else ABNF.OPCODE_CONT
ws.send_frame(ABNF.create_frame(frame, opcode, last))
except queue.Empty:
pass
except Exception:
cloudlog.exception("athenad.ws_send.exception")
end_event.set()
def ws_manage(ws: WebSocket, end_event: threading.Event) -> None:
params = Params()
onroad_prev = None
sock = ws.sock
while True:
onroad = params.get_bool("IsOnroad")
if onroad != onroad_prev:
onroad_prev = onroad
if sock is not None:
# While not sending data, onroad, we can expect to time out in 7 + (7 * 2) = 21s
# offroad, we can expect to time out in 30 + (10 * 3) = 60s
# FIXME: TCP_USER_TIMEOUT is effectively 2x for some reason (32s), so it's mostly unused
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, 16000 if onroad else 0)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7 if onroad else 30)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 7 if onroad else 10)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2 if onroad else 3)
if end_event.wait(5):
break
def backoff(retries: int) -> int:
return random.randrange(0, min(128, int(2 ** retries)))
def main(exit_event: threading.Event = None):
try:
set_core_affinity([0, 1, 2, 3])
except Exception:
cloudlog.exception("failed to set core affinity")
params = Params()
dongle_id = params.get("DongleId", encoding='utf-8')
UploadQueueCache.initialize(upload_queue)
ws_uri = ATHENA_HOST + "/ws/v2/" + dongle_id
api = Api(dongle_id)
conn_start = None
conn_retries = 0
while exit_event is None or not exit_event.is_set():
try:
if conn_start is None:
conn_start = time.monotonic()
cloudlog.event("athenad.main.connecting_ws", ws_uri=ws_uri, retries=conn_retries)
ws = create_connection(ws_uri,
cookie="jwt=" + api.get_token(),
enable_multithread=True,
timeout=30.0)
cloudlog.event("athenad.main.connected_ws", ws_uri=ws_uri, retries=conn_retries,
duration=time.monotonic() - conn_start)
conn_start = None
conn_retries = 0
cur_upload_items.clear()
handle_long_poll(ws, exit_event)
ws.close()
except (KeyboardInterrupt, SystemExit):
break
except (ConnectionError, TimeoutError, WebSocketException):
conn_retries += 1
params.remove("LastAthenaPingTime")
except Exception:
cloudlog.exception("athenad.main.exception")
conn_retries += 1
params.remove("LastAthenaPingTime")
time.sleep(backoff(conn_retries))
if __name__ == "__main__":
main()

43
system/athena/manage_athenad.py Executable file
View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python3
import time
from multiprocessing import Process
from openpilot.common.params import Params
from openpilot.system.manager.process import launcher
from openpilot.common.swaglog import cloudlog
from openpilot.system.hardware import HARDWARE
from openpilot.system.version import get_build_metadata
ATHENA_MGR_PID_PARAM = "AthenadPid"
def main():
params = Params()
dongle_id = params.get("DongleId").decode('utf-8')
build_metadata = get_build_metadata()
cloudlog.bind_global(dongle_id=dongle_id,
version=build_metadata.openpilot.version,
origin=build_metadata.openpilot.git_normalized_origin,
branch=build_metadata.channel,
commit=build_metadata.openpilot.git_commit,
dirty=build_metadata.openpilot.is_dirty,
device=HARDWARE.get_device_type())
try:
while 1:
cloudlog.info("starting athena daemon")
proc = Process(name='athenad', target=launcher, args=('system.athena.athenad', 'athenad'))
proc.start()
proc.join()
cloudlog.event("athenad exited", exitcode=proc.exitcode)
time.sleep(5)
except Exception:
cloudlog.exception("manage_athenad.exception")
finally:
params.remove(ATHENA_MGR_PID_PARAM)
if __name__ == '__main__':
main()

120
system/athena/registration.py Executable file
View File

@@ -0,0 +1,120 @@
#!/usr/bin/env python3
import time
import json
import jwt
from pathlib import Path
from datetime import datetime, timedelta, UTC
from openpilot.common.api import api_get
from openpilot.common.params import Params
from openpilot.common.spinner import Spinner
from openpilot.selfdrive.selfdrived.alertmanager import set_offroad_alert
from openpilot.system.hardware import HARDWARE, PC
from openpilot.system.hardware.hw import Paths
from openpilot.common.swaglog import cloudlog
UNREGISTERED_DONGLE_ID = "UnregisteredDevice"
DUMMY_IMEI1 = '865420071781912'
DUMMY_IMEI2 = '865420071781904'
def is_registered_device() -> bool:
dongle = Params().get("DongleId", encoding='utf-8')
return dongle not in (None, UNREGISTERED_DONGLE_ID)
def register(show_spinner=False) -> str | None:
"""
All devices built since March 2024 come with all
info stored in /persist/. This is kept around
only for devices built before then.
With a backend update to take serial number instead
of dongle ID to some endpoints, this can be removed
entirely.
"""
params = Params()
#return UNREGISTERED_DONGLE_ID # for c3lite, clone
dongle_id: str | None = params.get("DongleId", encoding='utf8')
if dongle_id is None and Path(Paths.persist_root()+"/comma/dongle_id").is_file():
# not all devices will have this; added early in comma 3X production (2/28/24)
with open(Paths.persist_root()+"/comma/dongle_id") as f:
dongle_id = f.read().strip()
pubkey = Path(Paths.persist_root()+"/comma/id_rsa.pub")
if not pubkey.is_file():
dongle_id = UNREGISTERED_DONGLE_ID
cloudlog.warning(f"missing public key: {pubkey}")
elif dongle_id is None:
if show_spinner:
spinner = Spinner()
spinner.update("registering device")
# Create registration token, in the future, this key will make JWTs directly
with open(Paths.persist_root()+"/comma/id_rsa.pub") as f1, open(Paths.persist_root()+"/comma/id_rsa") as f2:
public_key = f1.read()
private_key = f2.read()
# Block until we get the imei
serial = HARDWARE.get_serial()
start_time = time.monotonic()
imei1: str | None = None
imei2: str | None = None
while imei1 is None and imei2 is None:
try:
imei1, imei2 = HARDWARE.get_imei(0), HARDWARE.get_imei(1)
except Exception:
cloudlog.exception("Error getting imei, trying again...")
time.sleep(1)
if time.monotonic() - start_time > 30 and show_spinner:
spinner.update(f"registering device - serial: {serial}, IMEI: ({imei1}, {imei2})")
imei1 = DUMMY_IMEI1
imei2 = DUMMY_IMEI2
break
backoff = 0
start_time = time.monotonic()
while True:
try:
register_token = jwt.encode({'register': True, 'exp': datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1)}, private_key, algorithm='RS256')
cloudlog.info("getting pilotauth")
resp = api_get("v2/pilotauth/", method='POST', timeout=15,
imei=imei1, imei2=imei2, serial=serial, public_key=public_key, register_token=register_token)
if resp.status_code in (402, 403):
cloudlog.info(f"Unable to register device, got {resp.status_code}")
dongle_id = UNREGISTERED_DONGLE_ID
else:
dongleauth = json.loads(resp.text)
dongle_id = dongleauth["dongle_id"]
break
except Exception:
cloudlog.exception("failed to authenticate")
backoff = min(backoff + 1, 15)
time.sleep(backoff)
if time.monotonic() - start_time > 14:
cloudlog.error("pilotauth timed out; continuing as UNREGISTERED")
dongle_id = UNREGISTERED_DONGLE_ID
break
if time.monotonic() - start_time > 60 and show_spinner:
spinner.update(f"registering device - serial: {serial}, IMEI: ({imei1}, {imei2})")
if show_spinner:
spinner.close()
if dongle_id:
params.put("DongleId", dongle_id)
#set_offroad_alert("Offroad_UnofficialHardware", (dongle_id == UNREGISTERED_DONGLE_ID) and not PC)
return dongle_id
if __name__ == "__main__":
print(register())

View File

View File

@@ -0,0 +1,70 @@
import http.server
import socket
class MockResponse:
def __init__(self, json, status_code):
self.json = json
self.text = json
self.status_code = status_code
class EchoSocket:
def __init__(self, port):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind(('127.0.0.1', port))
self.socket.listen(1)
def run(self):
conn, _ = self.socket.accept()
conn.settimeout(5.0)
try:
while True:
data = conn.recv(4096)
if data:
print(f'EchoSocket got {data}')
conn.sendall(data)
else:
break
finally:
conn.shutdown(0)
conn.close()
self.socket.shutdown(0)
self.socket.close()
class MockApi:
def __init__(self, dongle_id):
pass
def get_token(self):
return "fake-token"
class MockWebsocket:
sock = socket.socket()
def __init__(self, recv_queue, send_queue):
self.recv_queue = recv_queue
self.send_queue = send_queue
def recv(self):
data = self.recv_queue.get()
if isinstance(data, Exception):
raise data
return data
def send(self, data, opcode):
self.send_queue.put_nowait((data, opcode))
def close(self):
pass
class HTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def do_PUT(self):
length = int(self.headers['Content-Length'])
self.rfile.read(length)
self.send_response(201, "Created")
self.end_headers()

View File

@@ -0,0 +1,428 @@
import pytest
from functools import wraps
import json
import multiprocessing
import os
import requests
import shutil
import time
import threading
import queue
from dataclasses import asdict, replace
from datetime import datetime, timedelta
from websocket import ABNF
from websocket._exceptions import WebSocketConnectionClosedException
from cereal import messaging
from openpilot.common.params import Params
from openpilot.common.timeout import Timeout
from openpilot.system.athena import athenad
from openpilot.system.athena.athenad import MAX_RETRY_COUNT, dispatcher
from openpilot.system.athena.tests.helpers import HTTPRequestHandler, MockWebsocket, MockApi, EchoSocket
from openpilot.selfdrive.test.helpers import http_server_context
from openpilot.system.hardware.hw import Paths
def seed_athena_server(host, port):
with Timeout(2, 'HTTP Server seeding failed'):
while True:
try:
requests.put(f'http://{host}:{port}/qlog.zst', data='', timeout=10)
break
except requests.exceptions.ConnectionError:
time.sleep(0.1)
def with_upload_handler(func):
@wraps(func)
def wrapper(*args, **kwargs):
end_event = threading.Event()
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,))
thread.start()
try:
return func(*args, **kwargs)
finally:
end_event.set()
thread.join()
return wrapper
@pytest.fixture
def mock_create_connection(mocker):
return mocker.patch('openpilot.system.athena.athenad.create_connection')
@pytest.fixture
def host():
with http_server_context(handler=HTTPRequestHandler, setup=seed_athena_server) as (host, port):
yield f"http://{host}:{port}"
class TestAthenadMethods:
@classmethod
def setup_class(cls):
cls.SOCKET_PORT = 45454
athenad.Api = MockApi
athenad.LOCAL_PORT_WHITELIST = {cls.SOCKET_PORT}
def setup_method(self):
self.default_params = {
"DongleId": "0000000000000000",
"GithubSshKeys": b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC307aE+nuHzTAgaJhzSf5v7ZZQW9gaperjhCmyPyl4PzY7T1mDGenTlVTN7yoVFZ9UfO9oMQqo0n1OwDIiqbIFxqnhrHU0cYfj88rI85m5BEKlNu5RdaVTj1tcbaPpQc5kZEolaI1nDDjzV0lwS7jo5VYDHseiJHlik3HH1SgtdtsuamGR2T80q1SyW+5rHoMOJG73IH2553NnWuikKiuikGHUYBd00K1ilVAK2xSiMWJp55tQfZ0ecr9QjEsJ+J/efL4HqGNXhffxvypCXvbUYAFSddOwXUPo5BTKevpxMtH+2YrkpSjocWA04VnTYFiPG6U4ItKmbLOTFZtPzoez private", # noqa: E501
"GithubUsername": b"commaci",
"AthenadUploadQueue": '[]',
}
self.params = Params()
for k, v in self.default_params.items():
self.params.put(k, v)
self.params.put_bool("GsmMetered", True)
athenad.upload_queue = queue.Queue()
athenad.cur_upload_items.clear()
athenad.cancelled_uploads.clear()
for i in os.listdir(Paths.log_root()):
p = os.path.join(Paths.log_root(), i)
if os.path.isdir(p):
shutil.rmtree(p)
else:
os.unlink(p)
# *** test helpers ***
@staticmethod
def _wait_for_upload():
now = time.time()
while time.time() - now < 5:
if athenad.upload_queue.qsize() == 0:
break
@staticmethod
def _create_file(file: str, parent: str = None, data: bytes = b'') -> str:
fn = os.path.join(Paths.log_root() if parent is None else parent, file)
os.makedirs(os.path.dirname(fn), exist_ok=True)
with open(fn, 'wb') as f:
f.write(data)
return fn
# *** test cases ***
def test_echo(self):
assert dispatcher["echo"]("bob") == "bob"
def test_get_message(self):
with pytest.raises(TimeoutError) as _:
dispatcher["getMessage"]("controlsState")
end_event = multiprocessing.Event()
pub_sock = messaging.pub_sock("deviceState")
def send_deviceState():
while not end_event.is_set():
msg = messaging.new_message('deviceState')
pub_sock.send(msg.to_bytes())
time.sleep(0.01)
p = multiprocessing.Process(target=send_deviceState)
p.start()
time.sleep(0.1)
try:
deviceState = dispatcher["getMessage"]("deviceState")
assert deviceState['deviceState']
finally:
end_event.set()
p.join()
def test_list_data_directory(self):
route = '2021-03-29--13-32-47'
segments = [0, 1, 2, 3, 11]
filenames = ['qlog.zst', 'qcamera.ts', 'rlog.zst', 'fcamera.hevc', 'ecamera.hevc', 'dcamera.hevc']
files = [f'{route}--{s}/{f}' for s in segments for f in filenames]
for file in files:
self._create_file(file)
resp = dispatcher["listDataDirectory"]()
assert resp, 'list empty!'
assert len(resp) == len(files)
resp = dispatcher["listDataDirectory"](f'{route}--123')
assert len(resp) == 0
prefix = f'{route}'
expected = list(filter(lambda f: f.startswith(prefix), files))
resp = dispatcher["listDataDirectory"](prefix)
assert resp, 'list empty!'
assert len(resp) == len(expected)
prefix = f'{route}--1'
expected = list(filter(lambda f: f.startswith(prefix), files))
resp = dispatcher["listDataDirectory"](prefix)
assert resp, 'list empty!'
assert len(resp) == len(expected)
prefix = f'{route}--1/'
expected = list(filter(lambda f: f.startswith(prefix), files))
resp = dispatcher["listDataDirectory"](prefix)
assert resp, 'list empty!'
assert len(resp) == len(expected)
prefix = f'{route}--1/q'
expected = list(filter(lambda f: f.startswith(prefix), files))
resp = dispatcher["listDataDirectory"](prefix)
assert resp, 'list empty!'
assert len(resp) == len(expected)
def test_strip_extension(self):
# any requested log file with an invalid extension won't return as existing
fn = self._create_file('qlog.bz2')
if fn.endswith('.bz2'):
assert athenad.strip_zst_extension(fn) == fn
fn = self._create_file('qlog.zst')
if fn.endswith('.zst'):
assert athenad.strip_zst_extension(fn) == fn[:-4]
@pytest.mark.parametrize("compress", [True, False])
def test_do_upload(self, host, compress):
# random bytes to ensure rather large object post-compression
fn = self._create_file('qlog', data=os.urandom(10000 * 1024))
upload_fn = fn + ('.zst' if compress else '')
item = athenad.UploadItem(path=upload_fn, url="http://localhost:1238", headers={}, created_at=int(time.time()*1000), id='')
with pytest.raises(requests.exceptions.ConnectionError):
athenad._do_upload(item)
item = athenad.UploadItem(path=upload_fn, url=f"{host}/qlog.zst", headers={}, created_at=int(time.time()*1000), id='')
resp = athenad._do_upload(item)
assert resp.status_code == 201
def test_upload_file_to_url(self, host):
fn = self._create_file('qlog.zst')
resp = dispatcher["uploadFileToUrl"]("qlog.zst", f"{host}/qlog.zst", {})
assert resp['enqueued'] == 1
assert 'failed' not in resp
assert {"path": fn, "url": f"{host}/qlog.zst", "headers": {}}.items() <= resp['items'][0].items()
assert resp['items'][0].get('id') is not None
assert athenad.upload_queue.qsize() == 1
def test_upload_file_to_url_duplicate(self, host):
self._create_file('qlog.zst')
url1 = f"{host}/qlog.zst?sig=sig1"
dispatcher["uploadFileToUrl"]("qlog.zst", url1, {})
# Upload same file again, but with different signature
url2 = f"{host}/qlog.zst?sig=sig2"
resp = dispatcher["uploadFileToUrl"]("qlog.zst", url2, {})
assert resp == {'enqueued': 0, 'items': []}
def test_upload_file_to_url_does_not_exist(self, host):
not_exists_resp = dispatcher["uploadFileToUrl"]("does_not_exist.zst", "http://localhost:1238", {})
assert not_exists_resp == {'enqueued': 0, 'items': [], 'failed': ['does_not_exist.zst']}
@with_upload_handler
def test_upload_handler(self, host):
fn = self._create_file('qlog.zst')
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.zst", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
athenad.upload_queue.put_nowait(item)
self._wait_for_upload()
time.sleep(0.1)
# TODO: verify that upload actually succeeded
# TODO: also check that end_event and metered network raises AbortTransferException
assert athenad.upload_queue.qsize() == 0
@pytest.mark.parametrize("status,retry", [(500,True), (412,False)])
@with_upload_handler
def test_upload_handler_retry(self, mocker, host, status, retry):
mock_put = mocker.patch('requests.put')
mock_put.return_value.__enter__.return_value.status_code = status
fn = self._create_file('qlog.zst')
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.zst", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
athenad.upload_queue.put_nowait(item)
self._wait_for_upload()
time.sleep(0.1)
assert athenad.upload_queue.qsize() == (1 if retry else 0)
if retry:
assert athenad.upload_queue.get().retry_count == 1
@with_upload_handler
def test_upload_handler_timeout(self):
"""When an upload times out or fails to connect it should be placed back in the queue"""
fn = self._create_file('qlog.zst')
item = athenad.UploadItem(path=fn, url="http://localhost:44444/qlog.zst", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
item_no_retry = replace(item, retry_count=MAX_RETRY_COUNT)
athenad.upload_queue.put_nowait(item_no_retry)
self._wait_for_upload()
time.sleep(0.1)
# Check that upload with retry count exceeded is not put back
assert athenad.upload_queue.qsize() == 0
athenad.upload_queue.put_nowait(item)
self._wait_for_upload()
time.sleep(0.1)
# Check that upload item was put back in the queue with incremented retry count
assert athenad.upload_queue.qsize() == 1
assert athenad.upload_queue.get().retry_count == 1
@with_upload_handler
def test_cancel_upload(self):
item = athenad.UploadItem(path="qlog.zst", url="http://localhost:44444/qlog.zst", headers={},
created_at=int(time.time()*1000), id='id', allow_cellular=True)
athenad.upload_queue.put_nowait(item)
dispatcher["cancelUpload"](item.id)
assert item.id in athenad.cancelled_uploads
self._wait_for_upload()
time.sleep(0.1)
assert athenad.upload_queue.qsize() == 0
assert len(athenad.cancelled_uploads) == 0
@with_upload_handler
def test_cancel_expiry(self):
t_future = datetime.now() - timedelta(days=40)
ts = int(t_future.strftime("%s")) * 1000
# Item that would time out if actually uploaded
fn = self._create_file('qlog.zst')
item = athenad.UploadItem(path=fn, url="http://localhost:44444/qlog.zst", headers={}, created_at=ts, id='', allow_cellular=True)
athenad.upload_queue.put_nowait(item)
self._wait_for_upload()
time.sleep(0.1)
assert athenad.upload_queue.qsize() == 0
def test_list_upload_queue_empty(self):
items = dispatcher["listUploadQueue"]()
assert len(items) == 0
@with_upload_handler
def test_list_upload_queue_current(self, host: str):
fn = self._create_file('qlog.zst')
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.zst", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
athenad.upload_queue.put_nowait(item)
self._wait_for_upload()
items = dispatcher["listUploadQueue"]()
assert len(items) == 1
assert items[0]['current']
def test_list_upload_queue(self):
item = athenad.UploadItem(path="qlog.zst", url="http://localhost:44444/qlog.zst", headers={},
created_at=int(time.time()*1000), id='id', allow_cellular=True)
athenad.upload_queue.put_nowait(item)
items = dispatcher["listUploadQueue"]()
assert len(items) == 1
assert items[0] == asdict(item)
assert not items[0]['current']
athenad.cancelled_uploads.add(item.id)
items = dispatcher["listUploadQueue"]()
assert len(items) == 0
def test_upload_queue_persistence(self):
item1 = athenad.UploadItem(path="_", url="_", headers={}, created_at=int(time.time()), id='id1')
item2 = athenad.UploadItem(path="_", url="_", headers={}, created_at=int(time.time()), id='id2')
athenad.upload_queue.put_nowait(item1)
athenad.upload_queue.put_nowait(item2)
# Ensure canceled items are not persisted
athenad.cancelled_uploads.add(item2.id)
# serialize item
athenad.UploadQueueCache.cache(athenad.upload_queue)
# deserialize item
athenad.upload_queue.queue.clear()
athenad.UploadQueueCache.initialize(athenad.upload_queue)
assert athenad.upload_queue.qsize() == 1
assert asdict(athenad.upload_queue.queue[-1]) == asdict(item1)
def test_start_local_proxy(self, mock_create_connection):
end_event = threading.Event()
ws_recv = queue.Queue()
ws_send = queue.Queue()
mock_ws = MockWebsocket(ws_recv, ws_send)
mock_create_connection.return_value = mock_ws
echo_socket = EchoSocket(self.SOCKET_PORT)
socket_thread = threading.Thread(target=echo_socket.run)
socket_thread.start()
athenad.startLocalProxy(end_event, 'ws://localhost:1234', self.SOCKET_PORT)
ws_recv.put_nowait(b'ping')
try:
recv = ws_send.get(timeout=5)
assert recv == (b'ping', ABNF.OPCODE_BINARY), recv
finally:
# signal websocket close to athenad.ws_proxy_recv
ws_recv.put_nowait(WebSocketConnectionClosedException())
socket_thread.join()
def test_get_ssh_authorized_keys(self):
keys = dispatcher["getSshAuthorizedKeys"]()
assert keys == self.default_params["GithubSshKeys"].decode('utf-8')
def test_get_github_username(self):
keys = dispatcher["getGithubUsername"]()
assert keys == self.default_params["GithubUsername"].decode('utf-8')
def test_get_version(self):
resp = dispatcher["getVersion"]()
keys = ["version", "remote", "branch", "commit"]
assert list(resp.keys()) == keys
for k in keys:
assert isinstance(resp[k], str), f"{k} is not a string"
assert len(resp[k]) > 0, f"{k} has no value"
def test_jsonrpc_handler(self):
end_event = threading.Event()
thread = threading.Thread(target=athenad.jsonrpc_handler, args=(end_event,))
thread.daemon = True
thread.start()
try:
# with params
athenad.recv_queue.put_nowait(json.dumps({"method": "echo", "params": ["hello"], "jsonrpc": "2.0", "id": 0}))
resp = athenad.send_queue.get(timeout=3)
assert json.loads(resp) == {'result': 'hello', 'id': 0, 'jsonrpc': '2.0'}
# without params
athenad.recv_queue.put_nowait(json.dumps({"method": "getNetworkType", "jsonrpc": "2.0", "id": 0}))
resp = athenad.send_queue.get(timeout=3)
assert json.loads(resp) == {'result': 1, 'id': 0, 'jsonrpc': '2.0'}
# log forwarding
athenad.recv_queue.put_nowait(json.dumps({'result': {'success': 1}, 'id': 0, 'jsonrpc': '2.0'}))
resp = athenad.log_recv_queue.get(timeout=3)
assert json.loads(resp) == {'result': {'success': 1}, 'id': 0, 'jsonrpc': '2.0'}
finally:
end_event.set()
thread.join()
def test_get_logs_to_send_sorted(self):
fl = list()
for i in range(10):
file = f'swaglog.{i:010}'
self._create_file(file, Paths.swaglog_root())
fl.append(file)
# ensure the list is all logs except most recent
sl = athenad.get_logs_to_send_sorted()
assert sl == fl[:-1]

View File

@@ -0,0 +1,102 @@
import pytest
import subprocess
import threading
import time
from typing import cast
from openpilot.common.params import Params
from openpilot.common.timeout import Timeout
from openpilot.system.athena import athenad
from openpilot.system.manager.helpers import write_onroad_params
from openpilot.system.hardware import TICI
TIMEOUT_TOLERANCE = 20 # seconds
def wifi_radio(on: bool) -> None:
if not TICI:
return
print(f"wifi {'on' if on else 'off'}")
subprocess.run(["nmcli", "radio", "wifi", "on" if on else "off"], check=True)
class TestAthenadPing:
params: Params
dongle_id: str
athenad: threading.Thread
exit_event: threading.Event
def _get_ping_time(self) -> str | None:
return cast(str | None, self.params.get("LastAthenaPingTime", encoding="utf-8"))
def _clear_ping_time(self) -> None:
self.params.remove("LastAthenaPingTime")
def _received_ping(self) -> bool:
return self._get_ping_time() is not None
@classmethod
def teardown_class(cls) -> None:
wifi_radio(True)
def setup_method(self) -> None:
self.params = Params()
self.dongle_id = self.params.get("DongleId", encoding="utf-8")
wifi_radio(True)
self._clear_ping_time()
self.exit_event = threading.Event()
self.athenad = threading.Thread(target=athenad.main, args=(self.exit_event,))
def teardown_method(self) -> None:
if self.athenad.is_alive():
self.exit_event.set()
self.athenad.join()
def assertTimeout(self, reconnect_time: float, subtests, mocker) -> None:
self.athenad.start()
mock_create_connection = mocker.patch('openpilot.system.athena.athenad.create_connection',
new_callable=lambda: mocker.MagicMock(wraps=athenad.create_connection))
time.sleep(1)
mock_create_connection.assert_called_once()
mock_create_connection.reset_mock()
# check normal behavior, server pings on connection
with subtests.test("Wi-Fi: receives ping"), Timeout(70, "no ping received"):
while not self._received_ping():
time.sleep(0.1)
print("ping received")
mock_create_connection.assert_not_called()
# websocket should attempt reconnect after short time
with subtests.test("LTE: attempt reconnect"):
wifi_radio(False)
print("waiting for reconnect attempt")
start_time = time.monotonic()
with Timeout(reconnect_time, "no reconnect attempt"):
while not mock_create_connection.called:
time.sleep(0.1)
print(f"reconnect attempt after {time.monotonic() - start_time:.2f}s")
self._clear_ping_time()
# check ping received after reconnect
with subtests.test("LTE: receives ping"), Timeout(70, "no ping received"):
while not self._received_ping():
time.sleep(0.1)
print("ping received")
@pytest.mark.skipif(not TICI, reason="only run on desk")
def test_offroad(self, subtests, mocker) -> None:
write_onroad_params(False, self.params)
self.assertTimeout(60 + TIMEOUT_TOLERANCE, subtests, mocker) # based using TCP keepalive settings
@pytest.mark.skipif(not TICI, reason="only run on desk")
def test_onroad(self, subtests, mocker) -> None:
write_onroad_params(True, self.params)
self.assertTimeout(21 + TIMEOUT_TOLERANCE, subtests, mocker)

View File

@@ -0,0 +1,76 @@
import json
from Crypto.PublicKey import RSA
from pathlib import Path
from openpilot.common.params import Params
from openpilot.system.athena.registration import register, UNREGISTERED_DONGLE_ID
from openpilot.system.athena.tests.helpers import MockResponse
from openpilot.system.hardware.hw import Paths
class TestRegistration:
def setup_method(self):
# clear params and setup key paths
self.params = Params()
persist_dir = Path(Paths.persist_root()) / "comma"
persist_dir.mkdir(parents=True, exist_ok=True)
self.priv_key = persist_dir / "id_rsa"
self.pub_key = persist_dir / "id_rsa.pub"
self.dongle_id = persist_dir / "dongle_id"
def _generate_keys(self):
self.pub_key.touch()
k = RSA.generate(2048)
with open(self.priv_key, "wb") as f:
f.write(k.export_key())
with open(self.pub_key, "wb") as f:
f.write(k.publickey().export_key())
def test_valid_cache(self, mocker):
# if all params are written, return the cached dongle id.
# should work with a dongle ID on either /persist/ or normal params
self._generate_keys()
dongle = "DONGLE_ID_123"
m = mocker.patch("openpilot.system.athena.registration.api_get", autospec=True)
for persist, params in [(True, True), (True, False), (False, True)]:
self.params.put("DongleId", dongle if params else "")
with open(self.dongle_id, "w") as f:
f.write(dongle if persist else "")
assert register() == dongle
assert not m.called
def test_no_keys(self, mocker):
# missing pubkey
m = mocker.patch("openpilot.system.athena.registration.api_get", autospec=True)
dongle = register()
assert m.call_count == 0
assert dongle == UNREGISTERED_DONGLE_ID
assert self.params.get("DongleId", encoding='utf-8') == dongle
def test_missing_cache(self, mocker):
# keys exist but no dongle id
self._generate_keys()
m = mocker.patch("openpilot.system.athena.registration.api_get", autospec=True)
dongle = "DONGLE_ID_123"
m.return_value = MockResponse(json.dumps({'dongle_id': dongle}), 200)
assert register() == dongle
assert m.call_count == 1
# call again, shouldn't hit the API this time
assert register() == dongle
assert m.call_count == 1
assert self.params.get("DongleId", encoding='utf-8') == dongle
def test_unregistered(self, mocker):
# keys exist, but unregistered
self._generate_keys()
m = mocker.patch("openpilot.system.athena.registration.api_get", autospec=True)
m.return_value = MockResponse(None, 402)
dongle = register()
assert m.call_count == 1
assert dongle == UNREGISTERED_DONGLE_ID
assert self.params.get("DongleId", encoding='utf-8') == dongle