Release 260111
This commit is contained in:
0
system/athena/__init__.py
Normal file
0
system/athena/__init__.py
Normal file
842
system/athena/athenad.py
Executable file
842
system/athena/athenad.py
Executable file
@@ -0,0 +1,842 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, replace
|
||||
from datetime import datetime
|
||||
from functools import partial, total_ordering
|
||||
from queue import Queue
|
||||
from typing import cast
|
||||
from collections.abc import Callable
|
||||
|
||||
import requests
|
||||
from jsonrpc import JSONRPCResponseManager, dispatcher
|
||||
from websocket import (ABNF, WebSocket, WebSocketException, WebSocketTimeoutException,
|
||||
create_connection)
|
||||
|
||||
import cereal.messaging as messaging
|
||||
from cereal import log
|
||||
from cereal.services import SERVICE_LIST
|
||||
from openpilot.common.api import Api
|
||||
from openpilot.common.file_helpers import CallbackReader, get_upload_stream
|
||||
from openpilot.common.params import Params
|
||||
from openpilot.common.realtime import set_core_affinity
|
||||
from openpilot.system.hardware import HARDWARE, PC
|
||||
from openpilot.system.loggerd.xattr_cache import getxattr, setxattr
|
||||
from openpilot.common.swaglog import cloudlog
|
||||
from openpilot.system.version import get_build_metadata
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
|
||||
|
||||
ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai')
|
||||
HANDLER_THREADS = int(os.getenv('HANDLER_THREADS', "4"))
|
||||
LOCAL_PORT_WHITELIST = {22, } # SSH
|
||||
|
||||
LOG_ATTR_NAME = 'user.upload'
|
||||
LOG_ATTR_VALUE_MAX_UNIX_TIME = int.to_bytes(2147483647, 4, sys.byteorder)
|
||||
RECONNECT_TIMEOUT_S = 70
|
||||
|
||||
RETRY_DELAY = 10 # seconds
|
||||
MAX_RETRY_COUNT = 30 # Try for at most 5 minutes if upload fails immediately
|
||||
MAX_AGE = 31 * 24 * 3600 # seconds
|
||||
WS_FRAME_SIZE = 4096
|
||||
DEVICE_STATE_UPDATE_INTERVAL = 1.0 # in seconds
|
||||
DEFAULT_UPLOAD_PRIORITY = 99 # higher number = lower priority
|
||||
|
||||
NetworkType = log.DeviceState.NetworkType
|
||||
|
||||
UploadFileDict = dict[str, str | int | float | bool]
|
||||
UploadItemDict = dict[str, str | bool | int | float | dict[str, str]]
|
||||
|
||||
UploadFilesToUrlResponse = dict[str, int | list[UploadItemDict] | list[str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UploadFile:
|
||||
fn: str
|
||||
url: str
|
||||
headers: dict[str, str]
|
||||
allow_cellular: bool
|
||||
priority: int = DEFAULT_UPLOAD_PRIORITY
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> UploadFile:
|
||||
return cls(d.get("fn", ""), d.get("url", ""), d.get("headers", {}), d.get("allow_cellular", False), d.get("priority", DEFAULT_UPLOAD_PRIORITY))
|
||||
|
||||
|
||||
@dataclass
|
||||
@total_ordering
|
||||
class UploadItem:
|
||||
path: str
|
||||
url: str
|
||||
headers: dict[str, str]
|
||||
created_at: int
|
||||
id: str | None
|
||||
retry_count: int = 0
|
||||
current: bool = False
|
||||
progress: float = 0
|
||||
allow_cellular: bool = False
|
||||
priority: int = DEFAULT_UPLOAD_PRIORITY
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> UploadItem:
|
||||
return cls(d["path"], d["url"], d["headers"], d["created_at"], d["id"], d["retry_count"], d["current"],
|
||||
d["progress"], d["allow_cellular"], d["priority"])
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, UploadItem):
|
||||
return NotImplemented
|
||||
return self.priority < other.priority
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, UploadItem):
|
||||
return NotImplemented
|
||||
return self.priority == other.priority
|
||||
|
||||
|
||||
dispatcher["echo"] = lambda s: s
|
||||
recv_queue: Queue[str] = queue.Queue()
|
||||
send_queue: Queue[str] = queue.Queue()
|
||||
upload_queue: Queue[UploadItem] = queue.PriorityQueue()
|
||||
low_priority_send_queue: Queue[str] = queue.Queue()
|
||||
log_recv_queue: Queue[str] = queue.Queue()
|
||||
cancelled_uploads: set[str] = set()
|
||||
|
||||
cur_upload_items: dict[int, UploadItem | None] = {}
|
||||
|
||||
|
||||
def strip_zst_extension(fn: str) -> str:
|
||||
if fn.endswith('.zst'):
|
||||
return fn[:-4]
|
||||
return fn
|
||||
|
||||
|
||||
class AbortTransferException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UploadQueueCache:
|
||||
|
||||
@staticmethod
|
||||
def initialize(upload_queue: Queue[UploadItem]) -> None:
|
||||
try:
|
||||
upload_queue_json = Params().get("AthenadUploadQueue")
|
||||
if upload_queue_json is not None:
|
||||
for item in json.loads(upload_queue_json):
|
||||
upload_queue.put(UploadItem.from_dict(item))
|
||||
except Exception:
|
||||
cloudlog.exception("athena.UploadQueueCache.initialize.exception")
|
||||
|
||||
@staticmethod
|
||||
def cache(upload_queue: Queue[UploadItem]) -> None:
|
||||
try:
|
||||
queue: list[UploadItem | None] = list(upload_queue.queue)
|
||||
items = [asdict(i) for i in queue if i is not None and (i.id not in cancelled_uploads)]
|
||||
Params().put("AthenadUploadQueue", json.dumps(items))
|
||||
except Exception:
|
||||
cloudlog.exception("athena.UploadQueueCache.cache.exception")
|
||||
|
||||
|
||||
def handle_long_poll(ws: WebSocket, exit_event: threading.Event | None) -> None:
|
||||
end_event = threading.Event()
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=ws_manage, args=(ws, end_event), name='ws_manage'),
|
||||
threading.Thread(target=ws_recv, args=(ws, end_event), name='ws_recv'),
|
||||
threading.Thread(target=ws_send, args=(ws, end_event), name='ws_send'),
|
||||
threading.Thread(target=upload_handler, args=(end_event,), name='upload_handler'),
|
||||
threading.Thread(target=upload_handler, args=(end_event,), name='upload_handler2'),
|
||||
threading.Thread(target=upload_handler, args=(end_event,), name='upload_handler3'),
|
||||
threading.Thread(target=upload_handler, args=(end_event,), name='upload_handler4'),
|
||||
threading.Thread(target=log_handler, args=(end_event,), name='log_handler'),
|
||||
threading.Thread(target=stat_handler, args=(end_event,), name='stat_handler'),
|
||||
] + [
|
||||
threading.Thread(target=jsonrpc_handler, args=(end_event,), name=f'worker_{x}')
|
||||
for x in range(HANDLER_THREADS)
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
try:
|
||||
while not end_event.wait(0.1):
|
||||
if exit_event is not None and exit_event.is_set():
|
||||
end_event.set()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
end_event.set()
|
||||
raise
|
||||
finally:
|
||||
for thread in threads:
|
||||
cloudlog.debug(f"athena.joining {thread.name}")
|
||||
thread.join()
|
||||
|
||||
|
||||
def jsonrpc_handler(end_event: threading.Event) -> None:
|
||||
dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event)
|
||||
while not end_event.is_set():
|
||||
try:
|
||||
data = recv_queue.get(timeout=1)
|
||||
if "method" in data:
|
||||
cloudlog.event("athena.jsonrpc_handler.call_method", data=data)
|
||||
response = JSONRPCResponseManager.handle(data, dispatcher)
|
||||
send_queue.put_nowait(response.json)
|
||||
elif "id" in data and ("result" in data or "error" in data):
|
||||
log_recv_queue.put_nowait(data)
|
||||
else:
|
||||
raise Exception("not a valid request or response")
|
||||
except queue.Empty:
|
||||
pass
|
||||
except Exception as e:
|
||||
cloudlog.exception("athena jsonrpc handler failed")
|
||||
send_queue.put_nowait(json.dumps({"error": str(e)}))
|
||||
|
||||
|
||||
def retry_upload(tid: int, end_event: threading.Event, increase_count: bool = True) -> None:
|
||||
item = cur_upload_items[tid]
|
||||
if item is not None and item.retry_count < MAX_RETRY_COUNT:
|
||||
new_retry_count = item.retry_count + 1 if increase_count else item.retry_count
|
||||
|
||||
item = replace(
|
||||
item,
|
||||
retry_count=new_retry_count,
|
||||
progress=0,
|
||||
current=False
|
||||
)
|
||||
upload_queue.put_nowait(item)
|
||||
UploadQueueCache.cache(upload_queue)
|
||||
|
||||
cur_upload_items[tid] = None
|
||||
|
||||
for _ in range(RETRY_DELAY):
|
||||
time.sleep(1)
|
||||
if end_event.is_set():
|
||||
break
|
||||
|
||||
|
||||
def cb(sm, item, tid, end_event: threading.Event, sz: int, cur: int) -> None:
|
||||
# Abort transfer if connection changed to metered after starting upload
|
||||
# or if athenad is shutting down to re-connect the websocket
|
||||
if not item.allow_cellular:
|
||||
if (time.monotonic() - sm.recv_time['deviceState']) > DEVICE_STATE_UPDATE_INTERVAL:
|
||||
sm.update(0)
|
||||
if sm['deviceState'].networkMetered:
|
||||
raise AbortTransferException
|
||||
|
||||
if end_event.is_set():
|
||||
raise AbortTransferException
|
||||
|
||||
cur_upload_items[tid] = replace(item, progress=cur / sz if sz else 1)
|
||||
|
||||
|
||||
def upload_handler(end_event: threading.Event) -> None:
|
||||
sm = messaging.SubMaster(['deviceState'])
|
||||
tid = threading.get_ident()
|
||||
|
||||
while not end_event.is_set():
|
||||
cur_upload_items[tid] = None
|
||||
|
||||
try:
|
||||
cur_upload_items[tid] = item = replace(upload_queue.get(timeout=1), current=True)
|
||||
|
||||
if item.id in cancelled_uploads:
|
||||
cancelled_uploads.remove(item.id)
|
||||
continue
|
||||
|
||||
# Remove item if too old
|
||||
age = datetime.now() - datetime.fromtimestamp(item.created_at / 1000)
|
||||
if age.total_seconds() > MAX_AGE:
|
||||
cloudlog.event("athena.upload_handler.expired", item=item, error=True)
|
||||
continue
|
||||
|
||||
# Check if uploading over metered connection is allowed
|
||||
sm.update(0)
|
||||
metered = sm['deviceState'].networkMetered
|
||||
network_type = sm['deviceState'].networkType.raw
|
||||
if metered and (not item.allow_cellular):
|
||||
retry_upload(tid, end_event, False)
|
||||
continue
|
||||
|
||||
try:
|
||||
fn = item.path
|
||||
try:
|
||||
sz = os.path.getsize(fn)
|
||||
except OSError:
|
||||
sz = -1
|
||||
|
||||
cloudlog.event("athena.upload_handler.upload_start", fn=fn, sz=sz, network_type=network_type, metered=metered, retry_count=item.retry_count)
|
||||
|
||||
with _do_upload(item, partial(cb, sm, item, tid, end_event)) as response:
|
||||
if response.status_code not in (200, 201, 401, 403, 412):
|
||||
cloudlog.event("athena.upload_handler.retry", status_code=response.status_code, fn=fn, sz=sz, network_type=network_type, metered=metered)
|
||||
retry_upload(tid, end_event)
|
||||
else:
|
||||
cloudlog.event("athena.upload_handler.success", fn=fn, sz=sz, network_type=network_type, metered=metered)
|
||||
|
||||
UploadQueueCache.cache(upload_queue)
|
||||
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.SSLError):
|
||||
cloudlog.event("athena.upload_handler.timeout", fn=fn, sz=sz, network_type=network_type, metered=metered)
|
||||
retry_upload(tid, end_event)
|
||||
except AbortTransferException:
|
||||
cloudlog.event("athena.upload_handler.abort", fn=fn, sz=sz, network_type=network_type, metered=metered)
|
||||
retry_upload(tid, end_event, False)
|
||||
|
||||
except queue.Empty:
|
||||
pass
|
||||
except Exception:
|
||||
cloudlog.exception("athena.upload_handler.exception")
|
||||
|
||||
|
||||
def _do_upload(upload_item: UploadItem, callback: Callable = None) -> requests.Response:
|
||||
path = upload_item.path
|
||||
compress = False
|
||||
|
||||
# If file does not exist, but does exist without the .zst extension we will compress on the fly
|
||||
if not os.path.exists(path) and os.path.exists(strip_zst_extension(path)):
|
||||
path = strip_zst_extension(path)
|
||||
compress = True
|
||||
|
||||
stream = None
|
||||
try:
|
||||
stream, content_length = get_upload_stream(path, compress)
|
||||
response = requests.put(upload_item.url,
|
||||
data=CallbackReader(stream, callback, content_length) if callback else stream,
|
||||
headers={**upload_item.headers, 'Content-Length': str(content_length)},
|
||||
timeout=30)
|
||||
return response
|
||||
finally:
|
||||
if stream:
|
||||
stream.close()
|
||||
|
||||
|
||||
# security: user should be able to request any message from their car
|
||||
@dispatcher.add_method
|
||||
def getMessage(service: str, timeout: int = 1000) -> dict:
|
||||
if service is None or service not in SERVICE_LIST:
|
||||
raise Exception("invalid service")
|
||||
|
||||
socket = messaging.sub_sock(service, timeout=timeout)
|
||||
try:
|
||||
ret = messaging.recv_one(socket)
|
||||
|
||||
if ret is None:
|
||||
raise TimeoutError
|
||||
|
||||
# this is because capnp._DynamicStructReader doesn't have typing information
|
||||
return cast(dict, ret.to_dict())
|
||||
finally:
|
||||
del socket
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def getVersion() -> dict[str, str]:
|
||||
build_metadata = get_build_metadata()
|
||||
return {
|
||||
"version": build_metadata.openpilot.version,
|
||||
"remote": build_metadata.openpilot.git_normalized_origin,
|
||||
"branch": build_metadata.channel,
|
||||
"commit": build_metadata.openpilot.git_commit,
|
||||
}
|
||||
|
||||
|
||||
def scan_dir(path: str, prefix: str) -> list[str]:
|
||||
files = []
|
||||
# only walk directories that match the prefix
|
||||
# (glob and friends traverse entire dir tree)
|
||||
with os.scandir(path) as i:
|
||||
for e in i:
|
||||
rel_path = os.path.relpath(e.path, Paths.log_root())
|
||||
if e.is_dir(follow_symlinks=False):
|
||||
# add trailing slash
|
||||
rel_path = os.path.join(rel_path, '')
|
||||
# if prefix is a partial dir name, current dir will start with prefix
|
||||
# if prefix is a partial file name, prefix with start with dir name
|
||||
if rel_path.startswith(prefix) or prefix.startswith(rel_path):
|
||||
files.extend(scan_dir(e.path, prefix))
|
||||
else:
|
||||
if rel_path.startswith(prefix):
|
||||
files.append(rel_path)
|
||||
return files
|
||||
|
||||
@dispatcher.add_method
|
||||
def listDataDirectory(prefix='') -> list[str]:
|
||||
return scan_dir(Paths.log_root(), prefix)
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def uploadFileToUrl(fn: str, url: str, headers: dict[str, str]) -> UploadFilesToUrlResponse:
|
||||
# this is because mypy doesn't understand that the decorator doesn't change the return type
|
||||
response: UploadFilesToUrlResponse = uploadFilesToUrls([{
|
||||
"fn": fn,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
}])
|
||||
return response
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def uploadFilesToUrls(files_data: list[UploadFileDict]) -> UploadFilesToUrlResponse:
|
||||
files = map(UploadFile.from_dict, files_data)
|
||||
|
||||
items: list[UploadItemDict] = []
|
||||
failed: list[str] = []
|
||||
for file in files:
|
||||
if len(file.fn) == 0 or file.fn[0] == '/' or '..' in file.fn or len(file.url) == 0:
|
||||
failed.append(file.fn)
|
||||
continue
|
||||
|
||||
path = os.path.join(Paths.log_root(), file.fn)
|
||||
if not os.path.exists(path) and not os.path.exists(strip_zst_extension(path)):
|
||||
failed.append(file.fn)
|
||||
continue
|
||||
|
||||
# Skip item if already in queue
|
||||
url = file.url.split('?')[0]
|
||||
if any(url == item['url'].split('?')[0] for item in listUploadQueue()):
|
||||
continue
|
||||
|
||||
item = UploadItem(
|
||||
path=path,
|
||||
url=file.url,
|
||||
headers=file.headers,
|
||||
created_at=int(time.time() * 1000),
|
||||
id=None,
|
||||
allow_cellular=file.allow_cellular,
|
||||
priority=file.priority,
|
||||
)
|
||||
upload_id = hashlib.sha1(str(item).encode()).hexdigest()
|
||||
item = replace(item, id=upload_id)
|
||||
upload_queue.put_nowait(item)
|
||||
items.append(asdict(item))
|
||||
|
||||
UploadQueueCache.cache(upload_queue)
|
||||
|
||||
resp: UploadFilesToUrlResponse = {"enqueued": len(items), "items": items}
|
||||
if failed:
|
||||
cloudlog.event("athena.uploadFilesToUrls.failed", failed=failed, error=True)
|
||||
resp["failed"] = failed
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def listUploadQueue() -> list[UploadItemDict]:
|
||||
items = list(upload_queue.queue) + list(cur_upload_items.values())
|
||||
return [asdict(i) for i in items if (i is not None) and (i.id not in cancelled_uploads)]
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def cancelUpload(upload_id: str | list[str]) -> dict[str, int | str]:
|
||||
if not isinstance(upload_id, list):
|
||||
upload_id = [upload_id]
|
||||
|
||||
uploading_ids = {item.id for item in list(upload_queue.queue)}
|
||||
cancelled_ids = uploading_ids.intersection(upload_id)
|
||||
if len(cancelled_ids) == 0:
|
||||
return {"success": 0, "error": "not found"}
|
||||
|
||||
cancelled_uploads.update(cancelled_ids)
|
||||
return {"success": 1}
|
||||
|
||||
@dispatcher.add_method
|
||||
def setRouteViewed(route: str) -> dict[str, int | str]:
|
||||
# maintain a list of the last 10 routes viewed in connect
|
||||
params = Params()
|
||||
|
||||
r = params.get("AthenadRecentlyViewedRoutes", encoding="utf8")
|
||||
routes = [] if r is None else r.split(",")
|
||||
routes.append(route)
|
||||
|
||||
# remove duplicates
|
||||
routes = list(dict.fromkeys(routes))
|
||||
|
||||
params.put("AthenadRecentlyViewedRoutes", ",".join(routes[-10:]))
|
||||
return {"success": 1}
|
||||
|
||||
|
||||
def startLocalProxy(global_end_event: threading.Event, remote_ws_uri: str, local_port: int) -> dict[str, int]:
|
||||
try:
|
||||
# migration, can be removed once 0.9.8 is out for a while
|
||||
if local_port == 8022:
|
||||
local_port = 22
|
||||
|
||||
if local_port not in LOCAL_PORT_WHITELIST:
|
||||
raise Exception("Requested local port not whitelisted")
|
||||
|
||||
cloudlog.debug("athena.startLocalProxy.starting")
|
||||
|
||||
dongle_id = Params().get("DongleId").decode('utf8')
|
||||
identity_token = Api(dongle_id).get_token()
|
||||
ws = create_connection(remote_ws_uri,
|
||||
cookie="jwt=" + identity_token,
|
||||
enable_multithread=True)
|
||||
|
||||
# Set TOS to keep connection responsive while under load.
|
||||
# DSCP of 36/HDD_LINUX_AC_VI with the minimum delay flag
|
||||
ws.sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 0x90)
|
||||
|
||||
ssock, csock = socket.socketpair()
|
||||
local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
local_sock.connect(('127.0.0.1', local_port))
|
||||
local_sock.setblocking(False)
|
||||
|
||||
proxy_end_event = threading.Event()
|
||||
threads = [
|
||||
threading.Thread(target=ws_proxy_recv, args=(ws, local_sock, ssock, proxy_end_event, global_end_event)),
|
||||
threading.Thread(target=ws_proxy_send, args=(ws, local_sock, csock, proxy_end_event))
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
cloudlog.debug("athena.startLocalProxy.started")
|
||||
return {"success": 1}
|
||||
except Exception as e:
|
||||
cloudlog.exception("athenad.startLocalProxy.exception")
|
||||
raise e
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def getPublicKey() -> str | None:
|
||||
if not os.path.isfile(Paths.persist_root() + '/comma/id_rsa.pub'):
|
||||
return None
|
||||
|
||||
with open(Paths.persist_root() + '/comma/id_rsa.pub') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def getSshAuthorizedKeys() -> str:
|
||||
return Params().get("GithubSshKeys", encoding='utf8') or ''
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def getGithubUsername() -> str:
|
||||
return Params().get("GithubUsername", encoding='utf8') or ''
|
||||
|
||||
@dispatcher.add_method
|
||||
def getSimInfo():
|
||||
return HARDWARE.get_sim_info()
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def getNetworkType():
|
||||
return HARDWARE.get_network_type()
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def getNetworkMetered() -> bool:
|
||||
network_type = HARDWARE.get_network_type()
|
||||
return HARDWARE.get_network_metered(network_type)
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def getNetworks():
|
||||
return HARDWARE.get_networks()
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
def takeSnapshot() -> str | dict[str, str] | None:
|
||||
from openpilot.system.camerad.snapshot.snapshot import jpeg_write, snapshot
|
||||
ret = snapshot()
|
||||
if ret is not None:
|
||||
def b64jpeg(x):
|
||||
if x is not None:
|
||||
f = io.BytesIO()
|
||||
jpeg_write(f, x)
|
||||
return base64.b64encode(f.getvalue()).decode("utf-8")
|
||||
else:
|
||||
return None
|
||||
return {'jpegBack': b64jpeg(ret[0]),
|
||||
'jpegFront': b64jpeg(ret[1])}
|
||||
else:
|
||||
raise Exception("not available while camerad is started")
|
||||
|
||||
|
||||
def get_logs_to_send_sorted() -> list[str]:
|
||||
# TODO: scan once then use inotify to detect file creation/deletion
|
||||
curr_time = int(time.time())
|
||||
logs = []
|
||||
for log_entry in os.listdir(Paths.swaglog_root()):
|
||||
log_path = os.path.join(Paths.swaglog_root(), log_entry)
|
||||
time_sent = 0
|
||||
try:
|
||||
value = getxattr(log_path, LOG_ATTR_NAME)
|
||||
if value is not None:
|
||||
time_sent = int.from_bytes(value, sys.byteorder)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
# assume send failed and we lost the response if sent more than one hour ago
|
||||
if not time_sent or curr_time - time_sent > 3600:
|
||||
logs.append(log_entry)
|
||||
# excluding most recent (active) log file
|
||||
return sorted(logs)[:-1]
|
||||
|
||||
|
||||
def log_handler(end_event: threading.Event) -> None:
|
||||
if PC:
|
||||
return
|
||||
|
||||
log_files = []
|
||||
last_scan = 0.
|
||||
while not end_event.is_set():
|
||||
try:
|
||||
curr_scan = time.monotonic()
|
||||
if curr_scan - last_scan > 10:
|
||||
log_files = get_logs_to_send_sorted()
|
||||
last_scan = curr_scan
|
||||
|
||||
# send one log
|
||||
curr_log = None
|
||||
if len(log_files) > 0:
|
||||
log_entry = log_files.pop() # newest log file
|
||||
cloudlog.debug(f"athena.log_handler.forward_request {log_entry}")
|
||||
try:
|
||||
curr_time = int(time.time())
|
||||
log_path = os.path.join(Paths.swaglog_root(), log_entry)
|
||||
setxattr(log_path, LOG_ATTR_NAME, int.to_bytes(curr_time, 4, sys.byteorder))
|
||||
with open(log_path) as f:
|
||||
jsonrpc = {
|
||||
"method": "forwardLogs",
|
||||
"params": {
|
||||
"logs": f.read()
|
||||
},
|
||||
"jsonrpc": "2.0",
|
||||
"id": log_entry
|
||||
}
|
||||
low_priority_send_queue.put_nowait(json.dumps(jsonrpc))
|
||||
curr_log = log_entry
|
||||
except OSError:
|
||||
pass # file could be deleted by log rotation
|
||||
|
||||
# wait for response up to ~100 seconds
|
||||
# always read queue at least once to process any old responses that arrive
|
||||
for _ in range(100):
|
||||
if end_event.is_set():
|
||||
break
|
||||
try:
|
||||
log_resp = json.loads(log_recv_queue.get(timeout=1))
|
||||
log_entry = log_resp.get("id")
|
||||
log_success = "result" in log_resp and log_resp["result"].get("success")
|
||||
cloudlog.debug(f"athena.log_handler.forward_response {log_entry} {log_success}")
|
||||
if log_entry and log_success:
|
||||
log_path = os.path.join(Paths.swaglog_root(), log_entry)
|
||||
try:
|
||||
setxattr(log_path, LOG_ATTR_NAME, LOG_ATTR_VALUE_MAX_UNIX_TIME)
|
||||
except OSError:
|
||||
pass # file could be deleted by log rotation
|
||||
if curr_log == log_entry:
|
||||
break
|
||||
except queue.Empty:
|
||||
if curr_log is None:
|
||||
break
|
||||
|
||||
except Exception:
|
||||
cloudlog.exception("athena.log_handler.exception")
|
||||
|
||||
|
||||
def stat_handler(end_event: threading.Event) -> None:
|
||||
STATS_DIR = Paths.stats_root()
|
||||
last_scan = 0.0
|
||||
|
||||
while not end_event.is_set():
|
||||
curr_scan = time.monotonic()
|
||||
try:
|
||||
if curr_scan - last_scan > 10:
|
||||
stat_filenames = list(filter(lambda name: not name.startswith(tempfile.gettempprefix()), os.listdir(STATS_DIR)))
|
||||
if len(stat_filenames) > 0:
|
||||
stat_path = os.path.join(STATS_DIR, stat_filenames[0])
|
||||
with open(stat_path) as f:
|
||||
jsonrpc = {
|
||||
"method": "storeStats",
|
||||
"params": {
|
||||
"stats": f.read()
|
||||
},
|
||||
"jsonrpc": "2.0",
|
||||
"id": stat_filenames[0]
|
||||
}
|
||||
low_priority_send_queue.put_nowait(json.dumps(jsonrpc))
|
||||
os.remove(stat_path)
|
||||
last_scan = curr_scan
|
||||
except Exception:
|
||||
cloudlog.exception("athena.stat_handler.exception")
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def ws_proxy_recv(ws: WebSocket, local_sock: socket.socket, ssock: socket.socket, end_event: threading.Event, global_end_event: threading.Event) -> None:
|
||||
while not (end_event.is_set() or global_end_event.is_set()):
|
||||
try:
|
||||
r = select.select((ws.sock,), (), (), 30)
|
||||
if r[0]:
|
||||
data = ws.recv()
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
local_sock.sendall(data)
|
||||
except WebSocketTimeoutException:
|
||||
pass
|
||||
except Exception:
|
||||
cloudlog.exception("athenad.ws_proxy_recv.exception")
|
||||
break
|
||||
|
||||
cloudlog.debug("athena.ws_proxy_recv closing sockets")
|
||||
ssock.close()
|
||||
local_sock.close()
|
||||
ws.close()
|
||||
cloudlog.debug("athena.ws_proxy_recv done closing sockets")
|
||||
|
||||
end_event.set()
|
||||
|
||||
|
||||
def ws_proxy_send(ws: WebSocket, local_sock: socket.socket, signal_sock: socket.socket, end_event: threading.Event) -> None:
|
||||
while not end_event.is_set():
|
||||
try:
|
||||
r, _, _ = select.select((local_sock, signal_sock), (), ())
|
||||
if r:
|
||||
if r[0].fileno() == signal_sock.fileno():
|
||||
# got end signal from ws_proxy_recv
|
||||
end_event.set()
|
||||
break
|
||||
data = local_sock.recv(4096)
|
||||
if not data:
|
||||
# local_sock is dead
|
||||
end_event.set()
|
||||
break
|
||||
|
||||
ws.send(data, ABNF.OPCODE_BINARY)
|
||||
except Exception:
|
||||
cloudlog.exception("athenad.ws_proxy_send.exception")
|
||||
end_event.set()
|
||||
|
||||
cloudlog.debug("athena.ws_proxy_send closing sockets")
|
||||
signal_sock.close()
|
||||
cloudlog.debug("athena.ws_proxy_send done closing sockets")
|
||||
|
||||
|
||||
def ws_recv(ws: WebSocket, end_event: threading.Event) -> None:
|
||||
last_ping = int(time.monotonic() * 1e9)
|
||||
while not end_event.is_set():
|
||||
try:
|
||||
opcode, data = ws.recv_data(control_frame=True)
|
||||
if opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
|
||||
if opcode == ABNF.OPCODE_TEXT:
|
||||
data = data.decode("utf-8")
|
||||
recv_queue.put_nowait(data)
|
||||
elif opcode == ABNF.OPCODE_PING:
|
||||
last_ping = int(time.monotonic() * 1e9)
|
||||
Params().put("LastAthenaPingTime", str(last_ping))
|
||||
except WebSocketTimeoutException:
|
||||
ns_since_last_ping = int(time.monotonic() * 1e9) - last_ping
|
||||
if ns_since_last_ping > RECONNECT_TIMEOUT_S * 1e9:
|
||||
cloudlog.exception("athenad.ws_recv.timeout")
|
||||
end_event.set()
|
||||
except Exception:
|
||||
cloudlog.exception("athenad.ws_recv.exception")
|
||||
end_event.set()
|
||||
|
||||
|
||||
def ws_send(ws: WebSocket, end_event: threading.Event) -> None:
|
||||
while not end_event.is_set():
|
||||
try:
|
||||
try:
|
||||
data = send_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
data = low_priority_send_queue.get(timeout=1)
|
||||
for i in range(0, len(data), WS_FRAME_SIZE):
|
||||
frame = data[i:i+WS_FRAME_SIZE]
|
||||
last = i + WS_FRAME_SIZE >= len(data)
|
||||
opcode = ABNF.OPCODE_TEXT if i == 0 else ABNF.OPCODE_CONT
|
||||
ws.send_frame(ABNF.create_frame(frame, opcode, last))
|
||||
except queue.Empty:
|
||||
pass
|
||||
except Exception:
|
||||
cloudlog.exception("athenad.ws_send.exception")
|
||||
end_event.set()
|
||||
|
||||
|
||||
def ws_manage(ws: WebSocket, end_event: threading.Event) -> None:
|
||||
params = Params()
|
||||
onroad_prev = None
|
||||
sock = ws.sock
|
||||
|
||||
while True:
|
||||
onroad = params.get_bool("IsOnroad")
|
||||
if onroad != onroad_prev:
|
||||
onroad_prev = onroad
|
||||
|
||||
if sock is not None:
|
||||
# While not sending data, onroad, we can expect to time out in 7 + (7 * 2) = 21s
|
||||
# offroad, we can expect to time out in 30 + (10 * 3) = 60s
|
||||
# FIXME: TCP_USER_TIMEOUT is effectively 2x for some reason (32s), so it's mostly unused
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, 16000 if onroad else 0)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7 if onroad else 30)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 7 if onroad else 10)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2 if onroad else 3)
|
||||
|
||||
if end_event.wait(5):
|
||||
break
|
||||
|
||||
|
||||
def backoff(retries: int) -> int:
|
||||
return random.randrange(0, min(128, int(2 ** retries)))
|
||||
|
||||
|
||||
def main(exit_event: threading.Event = None):
|
||||
try:
|
||||
set_core_affinity([0, 1, 2, 3])
|
||||
except Exception:
|
||||
cloudlog.exception("failed to set core affinity")
|
||||
|
||||
params = Params()
|
||||
dongle_id = params.get("DongleId", encoding='utf-8')
|
||||
UploadQueueCache.initialize(upload_queue)
|
||||
|
||||
ws_uri = ATHENA_HOST + "/ws/v2/" + dongle_id
|
||||
api = Api(dongle_id)
|
||||
|
||||
conn_start = None
|
||||
conn_retries = 0
|
||||
while exit_event is None or not exit_event.is_set():
|
||||
try:
|
||||
if conn_start is None:
|
||||
conn_start = time.monotonic()
|
||||
|
||||
cloudlog.event("athenad.main.connecting_ws", ws_uri=ws_uri, retries=conn_retries)
|
||||
ws = create_connection(ws_uri,
|
||||
cookie="jwt=" + api.get_token(),
|
||||
enable_multithread=True,
|
||||
timeout=30.0)
|
||||
cloudlog.event("athenad.main.connected_ws", ws_uri=ws_uri, retries=conn_retries,
|
||||
duration=time.monotonic() - conn_start)
|
||||
conn_start = None
|
||||
|
||||
conn_retries = 0
|
||||
cur_upload_items.clear()
|
||||
|
||||
handle_long_poll(ws, exit_event)
|
||||
|
||||
ws.close()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
break
|
||||
except (ConnectionError, TimeoutError, WebSocketException):
|
||||
conn_retries += 1
|
||||
params.remove("LastAthenaPingTime")
|
||||
except Exception:
|
||||
cloudlog.exception("athenad.main.exception")
|
||||
|
||||
conn_retries += 1
|
||||
params.remove("LastAthenaPingTime")
|
||||
|
||||
time.sleep(backoff(conn_retries))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
system/athena/manage_athenad.py
Executable file
43
system/athena/manage_athenad.py
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import time
|
||||
from multiprocessing import Process
|
||||
|
||||
from openpilot.common.params import Params
|
||||
from openpilot.system.manager.process import launcher
|
||||
from openpilot.common.swaglog import cloudlog
|
||||
from openpilot.system.hardware import HARDWARE
|
||||
from openpilot.system.version import get_build_metadata
|
||||
|
||||
ATHENA_MGR_PID_PARAM = "AthenadPid"
|
||||
|
||||
|
||||
def main():
|
||||
params = Params()
|
||||
dongle_id = params.get("DongleId").decode('utf-8')
|
||||
build_metadata = get_build_metadata()
|
||||
|
||||
cloudlog.bind_global(dongle_id=dongle_id,
|
||||
version=build_metadata.openpilot.version,
|
||||
origin=build_metadata.openpilot.git_normalized_origin,
|
||||
branch=build_metadata.channel,
|
||||
commit=build_metadata.openpilot.git_commit,
|
||||
dirty=build_metadata.openpilot.is_dirty,
|
||||
device=HARDWARE.get_device_type())
|
||||
|
||||
try:
|
||||
while 1:
|
||||
cloudlog.info("starting athena daemon")
|
||||
proc = Process(name='athenad', target=launcher, args=('system.athena.athenad', 'athenad'))
|
||||
proc.start()
|
||||
proc.join()
|
||||
cloudlog.event("athenad exited", exitcode=proc.exitcode)
|
||||
time.sleep(5)
|
||||
except Exception:
|
||||
cloudlog.exception("manage_athenad.exception")
|
||||
finally:
|
||||
params.remove(ATHENA_MGR_PID_PARAM)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
120
system/athena/registration.py
Executable file
120
system/athena/registration.py
Executable file
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
import json
|
||||
import jwt
|
||||
from pathlib import Path
|
||||
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from openpilot.common.api import api_get
|
||||
from openpilot.common.params import Params
|
||||
from openpilot.common.spinner import Spinner
|
||||
from openpilot.selfdrive.selfdrived.alertmanager import set_offroad_alert
|
||||
from openpilot.system.hardware import HARDWARE, PC
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
from openpilot.common.swaglog import cloudlog
|
||||
|
||||
|
||||
UNREGISTERED_DONGLE_ID = "UnregisteredDevice"
|
||||
|
||||
DUMMY_IMEI1 = '865420071781912'
|
||||
DUMMY_IMEI2 = '865420071781904'
|
||||
|
||||
|
||||
def is_registered_device() -> bool:
|
||||
dongle = Params().get("DongleId", encoding='utf-8')
|
||||
return dongle not in (None, UNREGISTERED_DONGLE_ID)
|
||||
|
||||
|
||||
def register(show_spinner=False) -> str | None:
|
||||
"""
|
||||
All devices built since March 2024 come with all
|
||||
info stored in /persist/. This is kept around
|
||||
only for devices built before then.
|
||||
|
||||
With a backend update to take serial number instead
|
||||
of dongle ID to some endpoints, this can be removed
|
||||
entirely.
|
||||
"""
|
||||
params = Params()
|
||||
|
||||
|
||||
#return UNREGISTERED_DONGLE_ID # for c3lite, clone
|
||||
dongle_id: str | None = params.get("DongleId", encoding='utf8')
|
||||
if dongle_id is None and Path(Paths.persist_root()+"/comma/dongle_id").is_file():
|
||||
# not all devices will have this; added early in comma 3X production (2/28/24)
|
||||
with open(Paths.persist_root()+"/comma/dongle_id") as f:
|
||||
dongle_id = f.read().strip()
|
||||
|
||||
pubkey = Path(Paths.persist_root()+"/comma/id_rsa.pub")
|
||||
if not pubkey.is_file():
|
||||
dongle_id = UNREGISTERED_DONGLE_ID
|
||||
cloudlog.warning(f"missing public key: {pubkey}")
|
||||
elif dongle_id is None:
|
||||
if show_spinner:
|
||||
spinner = Spinner()
|
||||
spinner.update("registering device")
|
||||
|
||||
# Create registration token, in the future, this key will make JWTs directly
|
||||
with open(Paths.persist_root()+"/comma/id_rsa.pub") as f1, open(Paths.persist_root()+"/comma/id_rsa") as f2:
|
||||
public_key = f1.read()
|
||||
private_key = f2.read()
|
||||
|
||||
# Block until we get the imei
|
||||
serial = HARDWARE.get_serial()
|
||||
start_time = time.monotonic()
|
||||
imei1: str | None = None
|
||||
imei2: str | None = None
|
||||
|
||||
while imei1 is None and imei2 is None:
|
||||
try:
|
||||
imei1, imei2 = HARDWARE.get_imei(0), HARDWARE.get_imei(1)
|
||||
except Exception:
|
||||
cloudlog.exception("Error getting imei, trying again...")
|
||||
time.sleep(1)
|
||||
|
||||
if time.monotonic() - start_time > 30 and show_spinner:
|
||||
spinner.update(f"registering device - serial: {serial}, IMEI: ({imei1}, {imei2})")
|
||||
imei1 = DUMMY_IMEI1
|
||||
imei2 = DUMMY_IMEI2
|
||||
break
|
||||
|
||||
backoff = 0
|
||||
start_time = time.monotonic()
|
||||
while True:
|
||||
try:
|
||||
register_token = jwt.encode({'register': True, 'exp': datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1)}, private_key, algorithm='RS256')
|
||||
cloudlog.info("getting pilotauth")
|
||||
resp = api_get("v2/pilotauth/", method='POST', timeout=15,
|
||||
imei=imei1, imei2=imei2, serial=serial, public_key=public_key, register_token=register_token)
|
||||
|
||||
if resp.status_code in (402, 403):
|
||||
cloudlog.info(f"Unable to register device, got {resp.status_code}")
|
||||
dongle_id = UNREGISTERED_DONGLE_ID
|
||||
else:
|
||||
dongleauth = json.loads(resp.text)
|
||||
dongle_id = dongleauth["dongle_id"]
|
||||
break
|
||||
except Exception:
|
||||
cloudlog.exception("failed to authenticate")
|
||||
backoff = min(backoff + 1, 15)
|
||||
time.sleep(backoff)
|
||||
|
||||
if time.monotonic() - start_time > 14:
|
||||
cloudlog.error("pilotauth timed out; continuing as UNREGISTERED")
|
||||
dongle_id = UNREGISTERED_DONGLE_ID
|
||||
break
|
||||
|
||||
if time.monotonic() - start_time > 60 and show_spinner:
|
||||
spinner.update(f"registering device - serial: {serial}, IMEI: ({imei1}, {imei2})")
|
||||
|
||||
if show_spinner:
|
||||
spinner.close()
|
||||
|
||||
if dongle_id:
|
||||
params.put("DongleId", dongle_id)
|
||||
#set_offroad_alert("Offroad_UnofficialHardware", (dongle_id == UNREGISTERED_DONGLE_ID) and not PC)
|
||||
return dongle_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(register())
|
||||
0
system/athena/tests/__init__.py
Normal file
0
system/athena/tests/__init__.py
Normal file
70
system/athena/tests/helpers.py
Normal file
70
system/athena/tests/helpers.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import http.server
|
||||
import socket
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json, status_code):
|
||||
self.json = json
|
||||
self.text = json
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class EchoSocket:
|
||||
def __init__(self, port):
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.bind(('127.0.0.1', port))
|
||||
self.socket.listen(1)
|
||||
|
||||
def run(self):
|
||||
conn, _ = self.socket.accept()
|
||||
conn.settimeout(5.0)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = conn.recv(4096)
|
||||
if data:
|
||||
print(f'EchoSocket got {data}')
|
||||
conn.sendall(data)
|
||||
else:
|
||||
break
|
||||
finally:
|
||||
conn.shutdown(0)
|
||||
conn.close()
|
||||
self.socket.shutdown(0)
|
||||
self.socket.close()
|
||||
|
||||
|
||||
class MockApi:
|
||||
def __init__(self, dongle_id):
|
||||
pass
|
||||
|
||||
def get_token(self):
|
||||
return "fake-token"
|
||||
|
||||
|
||||
class MockWebsocket:
|
||||
sock = socket.socket()
|
||||
|
||||
def __init__(self, recv_queue, send_queue):
|
||||
self.recv_queue = recv_queue
|
||||
self.send_queue = send_queue
|
||||
|
||||
def recv(self):
|
||||
data = self.recv_queue.get()
|
||||
if isinstance(data, Exception):
|
||||
raise data
|
||||
return data
|
||||
|
||||
def send(self, data, opcode):
|
||||
self.send_queue.put_nowait((data, opcode))
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class HTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
def do_PUT(self):
|
||||
length = int(self.headers['Content-Length'])
|
||||
self.rfile.read(length)
|
||||
self.send_response(201, "Created")
|
||||
self.end_headers()
|
||||
428
system/athena/tests/test_athenad.py
Normal file
428
system/athena/tests/test_athenad.py
Normal file
@@ -0,0 +1,428 @@
|
||||
import pytest
|
||||
from functools import wraps
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import requests
|
||||
import shutil
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
from dataclasses import asdict, replace
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from websocket import ABNF
|
||||
from websocket._exceptions import WebSocketConnectionClosedException
|
||||
|
||||
from cereal import messaging
|
||||
|
||||
from openpilot.common.params import Params
|
||||
from openpilot.common.timeout import Timeout
|
||||
from openpilot.system.athena import athenad
|
||||
from openpilot.system.athena.athenad import MAX_RETRY_COUNT, dispatcher
|
||||
from openpilot.system.athena.tests.helpers import HTTPRequestHandler, MockWebsocket, MockApi, EchoSocket
|
||||
from openpilot.selfdrive.test.helpers import http_server_context
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
|
||||
|
||||
def seed_athena_server(host, port):
|
||||
with Timeout(2, 'HTTP Server seeding failed'):
|
||||
while True:
|
||||
try:
|
||||
requests.put(f'http://{host}:{port}/qlog.zst', data='', timeout=10)
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
time.sleep(0.1)
|
||||
|
||||
def with_upload_handler(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
end_event = threading.Event()
|
||||
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,))
|
||||
thread.start()
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
end_event.set()
|
||||
thread.join()
|
||||
return wrapper
|
||||
|
||||
@pytest.fixture
|
||||
def mock_create_connection(mocker):
|
||||
return mocker.patch('openpilot.system.athena.athenad.create_connection')
|
||||
|
||||
@pytest.fixture
|
||||
def host():
|
||||
with http_server_context(handler=HTTPRequestHandler, setup=seed_athena_server) as (host, port):
|
||||
yield f"http://{host}:{port}"
|
||||
|
||||
class TestAthenadMethods:
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.SOCKET_PORT = 45454
|
||||
athenad.Api = MockApi
|
||||
athenad.LOCAL_PORT_WHITELIST = {cls.SOCKET_PORT}
|
||||
|
||||
def setup_method(self):
|
||||
self.default_params = {
|
||||
"DongleId": "0000000000000000",
|
||||
"GithubSshKeys": b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC307aE+nuHzTAgaJhzSf5v7ZZQW9gaperjhCmyPyl4PzY7T1mDGenTlVTN7yoVFZ9UfO9oMQqo0n1OwDIiqbIFxqnhrHU0cYfj88rI85m5BEKlNu5RdaVTj1tcbaPpQc5kZEolaI1nDDjzV0lwS7jo5VYDHseiJHlik3HH1SgtdtsuamGR2T80q1SyW+5rHoMOJG73IH2553NnWuikKiuikGHUYBd00K1ilVAK2xSiMWJp55tQfZ0ecr9QjEsJ+J/efL4HqGNXhffxvypCXvbUYAFSddOwXUPo5BTKevpxMtH+2YrkpSjocWA04VnTYFiPG6U4ItKmbLOTFZtPzoez private", # noqa: E501
|
||||
"GithubUsername": b"commaci",
|
||||
"AthenadUploadQueue": '[]',
|
||||
}
|
||||
|
||||
self.params = Params()
|
||||
for k, v in self.default_params.items():
|
||||
self.params.put(k, v)
|
||||
self.params.put_bool("GsmMetered", True)
|
||||
|
||||
athenad.upload_queue = queue.Queue()
|
||||
athenad.cur_upload_items.clear()
|
||||
athenad.cancelled_uploads.clear()
|
||||
|
||||
for i in os.listdir(Paths.log_root()):
|
||||
p = os.path.join(Paths.log_root(), i)
|
||||
if os.path.isdir(p):
|
||||
shutil.rmtree(p)
|
||||
else:
|
||||
os.unlink(p)
|
||||
|
||||
# *** test helpers ***
|
||||
|
||||
@staticmethod
|
||||
def _wait_for_upload():
|
||||
now = time.time()
|
||||
while time.time() - now < 5:
|
||||
if athenad.upload_queue.qsize() == 0:
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _create_file(file: str, parent: str = None, data: bytes = b'') -> str:
|
||||
fn = os.path.join(Paths.log_root() if parent is None else parent, file)
|
||||
os.makedirs(os.path.dirname(fn), exist_ok=True)
|
||||
with open(fn, 'wb') as f:
|
||||
f.write(data)
|
||||
return fn
|
||||
|
||||
|
||||
# *** test cases ***
|
||||
|
||||
def test_echo(self):
|
||||
assert dispatcher["echo"]("bob") == "bob"
|
||||
|
||||
def test_get_message(self):
|
||||
with pytest.raises(TimeoutError) as _:
|
||||
dispatcher["getMessage"]("controlsState")
|
||||
|
||||
end_event = multiprocessing.Event()
|
||||
|
||||
pub_sock = messaging.pub_sock("deviceState")
|
||||
|
||||
def send_deviceState():
|
||||
while not end_event.is_set():
|
||||
msg = messaging.new_message('deviceState')
|
||||
pub_sock.send(msg.to_bytes())
|
||||
time.sleep(0.01)
|
||||
|
||||
p = multiprocessing.Process(target=send_deviceState)
|
||||
p.start()
|
||||
time.sleep(0.1)
|
||||
try:
|
||||
deviceState = dispatcher["getMessage"]("deviceState")
|
||||
assert deviceState['deviceState']
|
||||
finally:
|
||||
end_event.set()
|
||||
p.join()
|
||||
|
||||
def test_list_data_directory(self):
|
||||
route = '2021-03-29--13-32-47'
|
||||
segments = [0, 1, 2, 3, 11]
|
||||
|
||||
filenames = ['qlog.zst', 'qcamera.ts', 'rlog.zst', 'fcamera.hevc', 'ecamera.hevc', 'dcamera.hevc']
|
||||
files = [f'{route}--{s}/{f}' for s in segments for f in filenames]
|
||||
for file in files:
|
||||
self._create_file(file)
|
||||
|
||||
resp = dispatcher["listDataDirectory"]()
|
||||
assert resp, 'list empty!'
|
||||
assert len(resp) == len(files)
|
||||
|
||||
resp = dispatcher["listDataDirectory"](f'{route}--123')
|
||||
assert len(resp) == 0
|
||||
|
||||
prefix = f'{route}'
|
||||
expected = list(filter(lambda f: f.startswith(prefix), files))
|
||||
resp = dispatcher["listDataDirectory"](prefix)
|
||||
assert resp, 'list empty!'
|
||||
assert len(resp) == len(expected)
|
||||
|
||||
prefix = f'{route}--1'
|
||||
expected = list(filter(lambda f: f.startswith(prefix), files))
|
||||
resp = dispatcher["listDataDirectory"](prefix)
|
||||
assert resp, 'list empty!'
|
||||
assert len(resp) == len(expected)
|
||||
|
||||
prefix = f'{route}--1/'
|
||||
expected = list(filter(lambda f: f.startswith(prefix), files))
|
||||
resp = dispatcher["listDataDirectory"](prefix)
|
||||
assert resp, 'list empty!'
|
||||
assert len(resp) == len(expected)
|
||||
|
||||
prefix = f'{route}--1/q'
|
||||
expected = list(filter(lambda f: f.startswith(prefix), files))
|
||||
resp = dispatcher["listDataDirectory"](prefix)
|
||||
assert resp, 'list empty!'
|
||||
assert len(resp) == len(expected)
|
||||
|
||||
def test_strip_extension(self):
|
||||
# any requested log file with an invalid extension won't return as existing
|
||||
fn = self._create_file('qlog.bz2')
|
||||
if fn.endswith('.bz2'):
|
||||
assert athenad.strip_zst_extension(fn) == fn
|
||||
|
||||
fn = self._create_file('qlog.zst')
|
||||
if fn.endswith('.zst'):
|
||||
assert athenad.strip_zst_extension(fn) == fn[:-4]
|
||||
|
||||
@pytest.mark.parametrize("compress", [True, False])
|
||||
def test_do_upload(self, host, compress):
|
||||
# random bytes to ensure rather large object post-compression
|
||||
fn = self._create_file('qlog', data=os.urandom(10000 * 1024))
|
||||
|
||||
upload_fn = fn + ('.zst' if compress else '')
|
||||
item = athenad.UploadItem(path=upload_fn, url="http://localhost:1238", headers={}, created_at=int(time.time()*1000), id='')
|
||||
with pytest.raises(requests.exceptions.ConnectionError):
|
||||
athenad._do_upload(item)
|
||||
|
||||
item = athenad.UploadItem(path=upload_fn, url=f"{host}/qlog.zst", headers={}, created_at=int(time.time()*1000), id='')
|
||||
resp = athenad._do_upload(item)
|
||||
assert resp.status_code == 201
|
||||
|
||||
def test_upload_file_to_url(self, host):
|
||||
fn = self._create_file('qlog.zst')
|
||||
|
||||
resp = dispatcher["uploadFileToUrl"]("qlog.zst", f"{host}/qlog.zst", {})
|
||||
assert resp['enqueued'] == 1
|
||||
assert 'failed' not in resp
|
||||
assert {"path": fn, "url": f"{host}/qlog.zst", "headers": {}}.items() <= resp['items'][0].items()
|
||||
assert resp['items'][0].get('id') is not None
|
||||
assert athenad.upload_queue.qsize() == 1
|
||||
|
||||
def test_upload_file_to_url_duplicate(self, host):
|
||||
self._create_file('qlog.zst')
|
||||
|
||||
url1 = f"{host}/qlog.zst?sig=sig1"
|
||||
dispatcher["uploadFileToUrl"]("qlog.zst", url1, {})
|
||||
|
||||
# Upload same file again, but with different signature
|
||||
url2 = f"{host}/qlog.zst?sig=sig2"
|
||||
resp = dispatcher["uploadFileToUrl"]("qlog.zst", url2, {})
|
||||
assert resp == {'enqueued': 0, 'items': []}
|
||||
|
||||
def test_upload_file_to_url_does_not_exist(self, host):
|
||||
not_exists_resp = dispatcher["uploadFileToUrl"]("does_not_exist.zst", "http://localhost:1238", {})
|
||||
assert not_exists_resp == {'enqueued': 0, 'items': [], 'failed': ['does_not_exist.zst']}
|
||||
|
||||
@with_upload_handler
|
||||
def test_upload_handler(self, host):
|
||||
fn = self._create_file('qlog.zst')
|
||||
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.zst", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
|
||||
|
||||
athenad.upload_queue.put_nowait(item)
|
||||
self._wait_for_upload()
|
||||
time.sleep(0.1)
|
||||
|
||||
# TODO: verify that upload actually succeeded
|
||||
# TODO: also check that end_event and metered network raises AbortTransferException
|
||||
assert athenad.upload_queue.qsize() == 0
|
||||
|
||||
@pytest.mark.parametrize("status,retry", [(500,True), (412,False)])
|
||||
@with_upload_handler
|
||||
def test_upload_handler_retry(self, mocker, host, status, retry):
|
||||
mock_put = mocker.patch('requests.put')
|
||||
mock_put.return_value.__enter__.return_value.status_code = status
|
||||
fn = self._create_file('qlog.zst')
|
||||
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.zst", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
|
||||
|
||||
athenad.upload_queue.put_nowait(item)
|
||||
self._wait_for_upload()
|
||||
time.sleep(0.1)
|
||||
|
||||
assert athenad.upload_queue.qsize() == (1 if retry else 0)
|
||||
|
||||
if retry:
|
||||
assert athenad.upload_queue.get().retry_count == 1
|
||||
|
||||
@with_upload_handler
|
||||
def test_upload_handler_timeout(self):
|
||||
"""When an upload times out or fails to connect it should be placed back in the queue"""
|
||||
fn = self._create_file('qlog.zst')
|
||||
item = athenad.UploadItem(path=fn, url="http://localhost:44444/qlog.zst", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
|
||||
item_no_retry = replace(item, retry_count=MAX_RETRY_COUNT)
|
||||
|
||||
athenad.upload_queue.put_nowait(item_no_retry)
|
||||
self._wait_for_upload()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Check that upload with retry count exceeded is not put back
|
||||
assert athenad.upload_queue.qsize() == 0
|
||||
|
||||
athenad.upload_queue.put_nowait(item)
|
||||
self._wait_for_upload()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Check that upload item was put back in the queue with incremented retry count
|
||||
assert athenad.upload_queue.qsize() == 1
|
||||
assert athenad.upload_queue.get().retry_count == 1
|
||||
|
||||
@with_upload_handler
|
||||
def test_cancel_upload(self):
|
||||
item = athenad.UploadItem(path="qlog.zst", url="http://localhost:44444/qlog.zst", headers={},
|
||||
created_at=int(time.time()*1000), id='id', allow_cellular=True)
|
||||
athenad.upload_queue.put_nowait(item)
|
||||
dispatcher["cancelUpload"](item.id)
|
||||
|
||||
assert item.id in athenad.cancelled_uploads
|
||||
|
||||
self._wait_for_upload()
|
||||
time.sleep(0.1)
|
||||
|
||||
assert athenad.upload_queue.qsize() == 0
|
||||
assert len(athenad.cancelled_uploads) == 0
|
||||
|
||||
@with_upload_handler
|
||||
def test_cancel_expiry(self):
|
||||
t_future = datetime.now() - timedelta(days=40)
|
||||
ts = int(t_future.strftime("%s")) * 1000
|
||||
|
||||
# Item that would time out if actually uploaded
|
||||
fn = self._create_file('qlog.zst')
|
||||
item = athenad.UploadItem(path=fn, url="http://localhost:44444/qlog.zst", headers={}, created_at=ts, id='', allow_cellular=True)
|
||||
|
||||
athenad.upload_queue.put_nowait(item)
|
||||
self._wait_for_upload()
|
||||
time.sleep(0.1)
|
||||
|
||||
assert athenad.upload_queue.qsize() == 0
|
||||
|
||||
def test_list_upload_queue_empty(self):
|
||||
items = dispatcher["listUploadQueue"]()
|
||||
assert len(items) == 0
|
||||
|
||||
@with_upload_handler
|
||||
def test_list_upload_queue_current(self, host: str):
|
||||
fn = self._create_file('qlog.zst')
|
||||
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.zst", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
|
||||
|
||||
athenad.upload_queue.put_nowait(item)
|
||||
self._wait_for_upload()
|
||||
|
||||
items = dispatcher["listUploadQueue"]()
|
||||
assert len(items) == 1
|
||||
assert items[0]['current']
|
||||
|
||||
def test_list_upload_queue(self):
|
||||
item = athenad.UploadItem(path="qlog.zst", url="http://localhost:44444/qlog.zst", headers={},
|
||||
created_at=int(time.time()*1000), id='id', allow_cellular=True)
|
||||
athenad.upload_queue.put_nowait(item)
|
||||
|
||||
items = dispatcher["listUploadQueue"]()
|
||||
assert len(items) == 1
|
||||
assert items[0] == asdict(item)
|
||||
assert not items[0]['current']
|
||||
|
||||
athenad.cancelled_uploads.add(item.id)
|
||||
items = dispatcher["listUploadQueue"]()
|
||||
assert len(items) == 0
|
||||
|
||||
def test_upload_queue_persistence(self):
|
||||
item1 = athenad.UploadItem(path="_", url="_", headers={}, created_at=int(time.time()), id='id1')
|
||||
item2 = athenad.UploadItem(path="_", url="_", headers={}, created_at=int(time.time()), id='id2')
|
||||
|
||||
athenad.upload_queue.put_nowait(item1)
|
||||
athenad.upload_queue.put_nowait(item2)
|
||||
|
||||
# Ensure canceled items are not persisted
|
||||
athenad.cancelled_uploads.add(item2.id)
|
||||
|
||||
# serialize item
|
||||
athenad.UploadQueueCache.cache(athenad.upload_queue)
|
||||
|
||||
# deserialize item
|
||||
athenad.upload_queue.queue.clear()
|
||||
athenad.UploadQueueCache.initialize(athenad.upload_queue)
|
||||
|
||||
assert athenad.upload_queue.qsize() == 1
|
||||
assert asdict(athenad.upload_queue.queue[-1]) == asdict(item1)
|
||||
|
||||
def test_start_local_proxy(self, mock_create_connection):
|
||||
end_event = threading.Event()
|
||||
|
||||
ws_recv = queue.Queue()
|
||||
ws_send = queue.Queue()
|
||||
mock_ws = MockWebsocket(ws_recv, ws_send)
|
||||
mock_create_connection.return_value = mock_ws
|
||||
|
||||
echo_socket = EchoSocket(self.SOCKET_PORT)
|
||||
socket_thread = threading.Thread(target=echo_socket.run)
|
||||
socket_thread.start()
|
||||
|
||||
athenad.startLocalProxy(end_event, 'ws://localhost:1234', self.SOCKET_PORT)
|
||||
|
||||
ws_recv.put_nowait(b'ping')
|
||||
try:
|
||||
recv = ws_send.get(timeout=5)
|
||||
assert recv == (b'ping', ABNF.OPCODE_BINARY), recv
|
||||
finally:
|
||||
# signal websocket close to athenad.ws_proxy_recv
|
||||
ws_recv.put_nowait(WebSocketConnectionClosedException())
|
||||
socket_thread.join()
|
||||
|
||||
def test_get_ssh_authorized_keys(self):
|
||||
keys = dispatcher["getSshAuthorizedKeys"]()
|
||||
assert keys == self.default_params["GithubSshKeys"].decode('utf-8')
|
||||
|
||||
def test_get_github_username(self):
|
||||
keys = dispatcher["getGithubUsername"]()
|
||||
assert keys == self.default_params["GithubUsername"].decode('utf-8')
|
||||
|
||||
def test_get_version(self):
|
||||
resp = dispatcher["getVersion"]()
|
||||
keys = ["version", "remote", "branch", "commit"]
|
||||
assert list(resp.keys()) == keys
|
||||
for k in keys:
|
||||
assert isinstance(resp[k], str), f"{k} is not a string"
|
||||
assert len(resp[k]) > 0, f"{k} has no value"
|
||||
|
||||
def test_jsonrpc_handler(self):
|
||||
end_event = threading.Event()
|
||||
thread = threading.Thread(target=athenad.jsonrpc_handler, args=(end_event,))
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
try:
|
||||
# with params
|
||||
athenad.recv_queue.put_nowait(json.dumps({"method": "echo", "params": ["hello"], "jsonrpc": "2.0", "id": 0}))
|
||||
resp = athenad.send_queue.get(timeout=3)
|
||||
assert json.loads(resp) == {'result': 'hello', 'id': 0, 'jsonrpc': '2.0'}
|
||||
# without params
|
||||
athenad.recv_queue.put_nowait(json.dumps({"method": "getNetworkType", "jsonrpc": "2.0", "id": 0}))
|
||||
resp = athenad.send_queue.get(timeout=3)
|
||||
assert json.loads(resp) == {'result': 1, 'id': 0, 'jsonrpc': '2.0'}
|
||||
# log forwarding
|
||||
athenad.recv_queue.put_nowait(json.dumps({'result': {'success': 1}, 'id': 0, 'jsonrpc': '2.0'}))
|
||||
resp = athenad.log_recv_queue.get(timeout=3)
|
||||
assert json.loads(resp) == {'result': {'success': 1}, 'id': 0, 'jsonrpc': '2.0'}
|
||||
finally:
|
||||
end_event.set()
|
||||
thread.join()
|
||||
|
||||
def test_get_logs_to_send_sorted(self):
|
||||
fl = list()
|
||||
for i in range(10):
|
||||
file = f'swaglog.{i:010}'
|
||||
self._create_file(file, Paths.swaglog_root())
|
||||
fl.append(file)
|
||||
|
||||
# ensure the list is all logs except most recent
|
||||
sl = athenad.get_logs_to_send_sorted()
|
||||
assert sl == fl[:-1]
|
||||
102
system/athena/tests/test_athenad_ping.py
Normal file
102
system/athena/tests/test_athenad_ping.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import pytest
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from openpilot.common.params import Params
|
||||
from openpilot.common.timeout import Timeout
|
||||
from openpilot.system.athena import athenad
|
||||
from openpilot.system.manager.helpers import write_onroad_params
|
||||
from openpilot.system.hardware import TICI
|
||||
|
||||
TIMEOUT_TOLERANCE = 20 # seconds
|
||||
|
||||
|
||||
def wifi_radio(on: bool) -> None:
|
||||
if not TICI:
|
||||
return
|
||||
print(f"wifi {'on' if on else 'off'}")
|
||||
subprocess.run(["nmcli", "radio", "wifi", "on" if on else "off"], check=True)
|
||||
|
||||
|
||||
class TestAthenadPing:
|
||||
params: Params
|
||||
dongle_id: str
|
||||
|
||||
athenad: threading.Thread
|
||||
exit_event: threading.Event
|
||||
|
||||
def _get_ping_time(self) -> str | None:
|
||||
return cast(str | None, self.params.get("LastAthenaPingTime", encoding="utf-8"))
|
||||
|
||||
def _clear_ping_time(self) -> None:
|
||||
self.params.remove("LastAthenaPingTime")
|
||||
|
||||
def _received_ping(self) -> bool:
|
||||
return self._get_ping_time() is not None
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls) -> None:
|
||||
wifi_radio(True)
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.params = Params()
|
||||
self.dongle_id = self.params.get("DongleId", encoding="utf-8")
|
||||
|
||||
wifi_radio(True)
|
||||
self._clear_ping_time()
|
||||
|
||||
self.exit_event = threading.Event()
|
||||
self.athenad = threading.Thread(target=athenad.main, args=(self.exit_event,))
|
||||
|
||||
def teardown_method(self) -> None:
|
||||
if self.athenad.is_alive():
|
||||
self.exit_event.set()
|
||||
self.athenad.join()
|
||||
|
||||
def assertTimeout(self, reconnect_time: float, subtests, mocker) -> None:
|
||||
self.athenad.start()
|
||||
|
||||
mock_create_connection = mocker.patch('openpilot.system.athena.athenad.create_connection',
|
||||
new_callable=lambda: mocker.MagicMock(wraps=athenad.create_connection))
|
||||
|
||||
time.sleep(1)
|
||||
mock_create_connection.assert_called_once()
|
||||
mock_create_connection.reset_mock()
|
||||
|
||||
# check normal behavior, server pings on connection
|
||||
with subtests.test("Wi-Fi: receives ping"), Timeout(70, "no ping received"):
|
||||
while not self._received_ping():
|
||||
time.sleep(0.1)
|
||||
print("ping received")
|
||||
|
||||
mock_create_connection.assert_not_called()
|
||||
|
||||
# websocket should attempt reconnect after short time
|
||||
with subtests.test("LTE: attempt reconnect"):
|
||||
wifi_radio(False)
|
||||
print("waiting for reconnect attempt")
|
||||
start_time = time.monotonic()
|
||||
with Timeout(reconnect_time, "no reconnect attempt"):
|
||||
while not mock_create_connection.called:
|
||||
time.sleep(0.1)
|
||||
print(f"reconnect attempt after {time.monotonic() - start_time:.2f}s")
|
||||
|
||||
self._clear_ping_time()
|
||||
|
||||
# check ping received after reconnect
|
||||
with subtests.test("LTE: receives ping"), Timeout(70, "no ping received"):
|
||||
while not self._received_ping():
|
||||
time.sleep(0.1)
|
||||
print("ping received")
|
||||
|
||||
@pytest.mark.skipif(not TICI, reason="only run on desk")
|
||||
def test_offroad(self, subtests, mocker) -> None:
|
||||
write_onroad_params(False, self.params)
|
||||
self.assertTimeout(60 + TIMEOUT_TOLERANCE, subtests, mocker) # based using TCP keepalive settings
|
||||
|
||||
@pytest.mark.skipif(not TICI, reason="only run on desk")
|
||||
def test_onroad(self, subtests, mocker) -> None:
|
||||
write_onroad_params(True, self.params)
|
||||
self.assertTimeout(21 + TIMEOUT_TOLERANCE, subtests, mocker)
|
||||
76
system/athena/tests/test_registration.py
Normal file
76
system/athena/tests/test_registration.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import json
|
||||
from Crypto.PublicKey import RSA
|
||||
from pathlib import Path
|
||||
|
||||
from openpilot.common.params import Params
|
||||
from openpilot.system.athena.registration import register, UNREGISTERED_DONGLE_ID
|
||||
from openpilot.system.athena.tests.helpers import MockResponse
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
|
||||
|
||||
class TestRegistration:
|
||||
|
||||
def setup_method(self):
|
||||
# clear params and setup key paths
|
||||
self.params = Params()
|
||||
|
||||
persist_dir = Path(Paths.persist_root()) / "comma"
|
||||
persist_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.priv_key = persist_dir / "id_rsa"
|
||||
self.pub_key = persist_dir / "id_rsa.pub"
|
||||
self.dongle_id = persist_dir / "dongle_id"
|
||||
|
||||
def _generate_keys(self):
|
||||
self.pub_key.touch()
|
||||
k = RSA.generate(2048)
|
||||
with open(self.priv_key, "wb") as f:
|
||||
f.write(k.export_key())
|
||||
with open(self.pub_key, "wb") as f:
|
||||
f.write(k.publickey().export_key())
|
||||
|
||||
def test_valid_cache(self, mocker):
|
||||
# if all params are written, return the cached dongle id.
|
||||
# should work with a dongle ID on either /persist/ or normal params
|
||||
self._generate_keys()
|
||||
|
||||
dongle = "DONGLE_ID_123"
|
||||
m = mocker.patch("openpilot.system.athena.registration.api_get", autospec=True)
|
||||
for persist, params in [(True, True), (True, False), (False, True)]:
|
||||
self.params.put("DongleId", dongle if params else "")
|
||||
with open(self.dongle_id, "w") as f:
|
||||
f.write(dongle if persist else "")
|
||||
assert register() == dongle
|
||||
assert not m.called
|
||||
|
||||
def test_no_keys(self, mocker):
|
||||
# missing pubkey
|
||||
m = mocker.patch("openpilot.system.athena.registration.api_get", autospec=True)
|
||||
dongle = register()
|
||||
assert m.call_count == 0
|
||||
assert dongle == UNREGISTERED_DONGLE_ID
|
||||
assert self.params.get("DongleId", encoding='utf-8') == dongle
|
||||
|
||||
def test_missing_cache(self, mocker):
|
||||
# keys exist but no dongle id
|
||||
self._generate_keys()
|
||||
m = mocker.patch("openpilot.system.athena.registration.api_get", autospec=True)
|
||||
dongle = "DONGLE_ID_123"
|
||||
m.return_value = MockResponse(json.dumps({'dongle_id': dongle}), 200)
|
||||
assert register() == dongle
|
||||
assert m.call_count == 1
|
||||
|
||||
# call again, shouldn't hit the API this time
|
||||
assert register() == dongle
|
||||
assert m.call_count == 1
|
||||
assert self.params.get("DongleId", encoding='utf-8') == dongle
|
||||
|
||||
def test_unregistered(self, mocker):
|
||||
# keys exist, but unregistered
|
||||
self._generate_keys()
|
||||
m = mocker.patch("openpilot.system.athena.registration.api_get", autospec=True)
|
||||
m.return_value = MockResponse(None, 402)
|
||||
dongle = register()
|
||||
assert m.call_count == 1
|
||||
assert dongle == UNREGISTERED_DONGLE_ID
|
||||
assert self.params.get("DongleId", encoding='utf-8') == dongle
|
||||
Reference in New Issue
Block a user