Release 260111

This commit is contained in:
Comma Device
2026-01-11 18:23:29 +08:00
commit 3721ecbf8a
2601 changed files with 855070 additions and 0 deletions

59
tools/lib/README.md Normal file
View File

@@ -0,0 +1,59 @@
## LogReader
Route is a class for conveniently accessing all the [logs](/system/loggerd/) from your routes. The LogReader class reads the non-video logs, i.e. rlog.bz2 and qlog.bz2. There's also a matching FrameReader class for reading the videos.
```python
from openpilot.tools.lib.route import Route
from openpilot.tools.lib.logreader import LogReader
r = Route("a2a0ccea32023010|2023-07-27--13-01-19")
# get a list of paths for the route's rlog files
print(r.log_paths())
# and road camera (fcamera.hevc) files
print(r.camera_paths())
# setup a LogReader to read the route's first rlog
lr = LogReader(r.log_paths()[0])
# print out all the messages in the log
import codecs
codecs.register_error("strict", codecs.backslashreplace_errors)
for msg in lr:
print(msg)
# setup a LogReader for the route's second qlog
lr = LogReader(r.log_paths()[1])
# print all the steering angles values from the log
for msg in lr:
if msg.which() == "carState":
print(msg.carState.steeringAngleDeg)
```
### Segment Ranges
We also support a new format called a "segment range":
```
344c5c15b34f2d8a / 2024-01-03--09-37-12 / 2:6 / q
[ dongle id ] [ timestamp ] [ selector ] [ query type]
```
you can specify which segments from a route to load
```python
lr = LogReader("a2a0ccea32023010|2023-07-27--13-01-19/4") # 4th segment
lr = LogReader("a2a0ccea32023010|2023-07-27--13-01-19/4:6") # 4th and 5th segment
lr = LogReader("a2a0ccea32023010|2023-07-27--13-01-19/-1") # last segment
lr = LogReader("a2a0ccea32023010|2023-07-27--13-01-19/:5") # first 5 segments
lr = LogReader("a2a0ccea32023010|2023-07-27--13-01-19/1:") # all except first segment
```
and can select which type of logs to grab
```python
lr = LogReader("a2a0ccea32023010|2023-07-27--13-01-19/4/q") # get qlogs
lr = LogReader("a2a0ccea32023010|2023-07-27--13-01-19/4/r") # get rlogs (default)
```

0
tools/lib/__init__.py Normal file
View File

34
tools/lib/api.py Normal file
View File

@@ -0,0 +1,34 @@
import os
import requests
API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com')
class CommaApi:
def __init__(self, token=None):
self.session = requests.Session()
self.session.headers['User-agent'] = 'OpenpilotTools'
if token:
self.session.headers['Authorization'] = 'JWT ' + token
def request(self, method, endpoint, **kwargs):
with self.session.request(method, API_HOST + '/' + endpoint, **kwargs) as resp:
resp_json = resp.json()
if isinstance(resp_json, dict) and resp_json.get('error'):
if resp.status_code in [401, 403]:
raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py')
e = APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error'])))
e.status_code = resp.status_code
raise e
return resp_json
def get(self, endpoint, **kwargs):
return self.request('GET', endpoint, **kwargs)
def post(self, endpoint, **kwargs):
return self.request('POST', endpoint, **kwargs)
class APIError(Exception):
pass
class UnauthorizedError(Exception):
pass

145
tools/lib/auth.py Executable file
View File

@@ -0,0 +1,145 @@
#!/usr/bin/env python3
"""
Usage::
usage: auth.py [-h] [{google,apple,github,jwt}] [jwt]
Login to your comma account
positional arguments:
{google,apple,github,jwt}
jwt
optional arguments:
-h, --help show this help message and exit
Examples::
./auth.py # Log in with google account
./auth.py github # Log in with GitHub Account
./auth.py jwt ey......hw # Log in with a JWT from https://jwt.comma.ai, for use in CI
"""
import argparse
import sys
import pprint
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from urllib.parse import parse_qs, urlencode
from openpilot.tools.lib.api import APIError, CommaApi, UnauthorizedError
from openpilot.tools.lib.auth_config import set_token, get_token
PORT = 3000
class ClientRedirectServer(HTTPServer):
query_params: dict[str, Any] = {}
class ClientRedirectHandler(BaseHTTPRequestHandler):
def do_GET(self):
if not self.path.startswith('/auth'):
self.send_response(204)
return
query = self.path.split('?', 1)[-1]
query_parsed = parse_qs(query, keep_blank_values=True)
self.server.query_params = query_parsed
self.send_response(200)
self.send_header('Content-type', 'text/plain')
self.end_headers()
self.wfile.write(b'Return to the CLI to continue')
def log_message(self, *args):
pass # this prevent http server from dumping messages to stdout
def auth_redirect_link(method):
provider_id = {
'google': 'g',
'apple': 'a',
'github': 'h',
}[method]
params = {
'redirect_uri': f"https://api.comma.ai/v2/auth/{provider_id}/redirect/",
'state': f'service,localhost:{PORT}',
}
if method == 'google':
params.update({
'type': 'web_server',
'client_id': '45471411055-ornt4svd2miog6dnopve7qtmh5mnu6id.apps.googleusercontent.com',
'response_type': 'code',
'scope': 'https://www.googleapis.com/auth/userinfo.email',
'prompt': 'select_account',
})
return 'https://accounts.google.com/o/oauth2/auth?' + urlencode(params)
elif method == 'github':
params.update({
'client_id': '28c4ecb54bb7272cb5a4',
'scope': 'read:user',
})
return 'https://github.com/login/oauth/authorize?' + urlencode(params)
elif method == 'apple':
params.update({
'client_id': 'ai.comma.login',
'response_type': 'code',
'response_mode': 'form_post',
'scope': 'name email',
})
return 'https://appleid.apple.com/auth/authorize?' + urlencode(params)
else:
raise NotImplementedError(f"no redirect implemented for method {method}")
def login(method):
oauth_uri = auth_redirect_link(method)
web_server = ClientRedirectServer(('localhost', PORT), ClientRedirectHandler)
print(f'To sign in, use your browser and navigate to {oauth_uri}')
webbrowser.open(oauth_uri, new=2)
while True:
web_server.handle_request()
if 'code' in web_server.query_params:
break
elif 'error' in web_server.query_params:
print('Authentication Error: "{}". Description: "{}" '.format(
web_server.query_params['error'],
web_server.query_params.get('error_description')), file=sys.stderr)
break
try:
auth_resp = CommaApi().post('v2/auth/', data={'code': web_server.query_params['code'], 'provider': web_server.query_params['provider']})
set_token(auth_resp['access_token'])
except APIError as e:
print(f'Authentication Error: {e}', file=sys.stderr)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Login to your comma account')
parser.add_argument('method', default='google', const='google', nargs='?', choices=['google', 'apple', 'github', 'jwt'])
parser.add_argument('jwt', nargs='?')
args = parser.parse_args()
if args.method == 'jwt':
if args.jwt is None:
print("method JWT selected, but no JWT was provided")
exit(1)
set_token(args.jwt)
else:
login(args.method)
try:
me = CommaApi(token=get_token()).get('/v1/me')
print("Authenticated!")
pprint.pprint(me)
except UnauthorizedError:
print("Got invalid JWT")
exit(1)

29
tools/lib/auth_config.py Normal file
View File

@@ -0,0 +1,29 @@
import json
import os
from openpilot.system.hardware.hw import Paths
class MissingAuthConfigError(Exception):
pass
def get_token():
try:
with open(os.path.join(Paths.config_root(), 'auth.json')) as f:
auth = json.load(f)
return auth['access_token']
except Exception:
return None
def set_token(token):
os.makedirs(Paths.config_root(), exist_ok=True)
with open(os.path.join(Paths.config_root(), 'auth.json'), 'w') as f:
json.dump({'access_token': token}, f)
def clear_token():
try:
os.unlink(os.path.join(Paths.config_root(), 'auth.json'))
except FileNotFoundError:
pass

View File

@@ -0,0 +1,73 @@
import os
from datetime import datetime, timedelta, UTC
from functools import lru_cache
from pathlib import Path
from typing import IO
TOKEN_PATH = Path("/data/azure_token")
@lru_cache
def get_azure_credential():
if "AZURE_TOKEN" in os.environ:
return os.environ["AZURE_TOKEN"]
elif TOKEN_PATH.is_file():
return TOKEN_PATH.read_text().strip()
else:
from azure.identity import AzureCliCredential
return AzureCliCredential()
@lru_cache
def get_container_sas(account_name: str, container_name: str):
from azure.storage.blob import BlobServiceClient, ContainerSasPermissions, generate_container_sas
start_time = datetime.now(UTC).replace(tzinfo=None)
expiry_time = start_time + timedelta(hours=1)
blob_service = BlobServiceClient(
account_url=f"https://{account_name}.blob.core.windows.net",
credential=get_azure_credential(),
)
return generate_container_sas(
account_name,
container_name,
user_delegation_key=blob_service.get_user_delegation_key(start_time, expiry_time),
permission=ContainerSasPermissions(read=True, write=True, list=True),
expiry=expiry_time,
)
class AzureContainer:
def __init__(self, account, container):
self.ACCOUNT = account
self.CONTAINER = container
@property
def ACCOUNT_URL(self) -> str:
return f"https://{self.ACCOUNT}.blob.core.windows.net"
@property
def BASE_URL(self) -> str:
return f"{self.ACCOUNT_URL}/{self.CONTAINER}/"
def get_client_and_key(self):
from azure.storage.blob import ContainerClient
client = ContainerClient(self.ACCOUNT_URL, self.CONTAINER, credential=get_azure_credential())
key = get_container_sas(self.ACCOUNT, self.CONTAINER)
return client, key
def get_url(self, route_name: str, segment_num: str, filename: str) -> str:
return self.BASE_URL + f"{route_name.replace('|', '/')}/{segment_num}/{filename}"
def upload_bytes(self, data: bytes | IO, blob_name: str, overwrite=False) -> str:
from azure.storage.blob import BlobClient
blob = BlobClient(
account_url=self.ACCOUNT_URL,
container_name=self.CONTAINER,
blob_name=blob_name,
credential=get_azure_credential(),
overwrite=overwrite,
)
blob.upload_blob(data, overwrite=overwrite)
return self.BASE_URL + blob_name
def upload_file(self, path: str | os.PathLike, blob_name: str, overwrite=False) -> str:
with open(path, "rb") as f:
return self.upload_bytes(f, blob_name, overwrite)

57
tools/lib/bootlog.py Normal file
View File

@@ -0,0 +1,57 @@
import functools
import re
from openpilot.tools.lib.auth_config import get_token
from openpilot.tools.lib.api import CommaApi
from openpilot.tools.lib.helpers import RE
@functools.total_ordering
class Bootlog:
def __init__(self, url: str):
self._url = url
r = re.search(RE.BOOTLOG_NAME, url)
if not r:
raise Exception(f"Unable to parse: {url}")
self._id = r.group('log_id')
self._dongle_id = r.group('dongle_id')
@property
def url(self) -> str:
return self._url
@property
def dongle_id(self) -> str:
return self._dongle_id
@property
def id(self) -> str:
return self._id
def __str__(self):
return f"{self._dongle_id}/{self._id}"
def __eq__(self, b) -> bool:
if not isinstance(b, Bootlog):
return False
return self.id == b.id
def __lt__(self, b) -> bool:
if not isinstance(b, Bootlog):
return False
return self.id < b.id
def get_bootlog_from_id(bootlog_id: str) -> Bootlog | None:
# TODO: implement an API endpoint for this
bl = Bootlog(bootlog_id)
for b in get_bootlogs(bl.dongle_id):
if b == bl:
return b
return None
def get_bootlogs(dongle_id: str) -> list[Bootlog]:
api = CommaApi(get_token())
r = api.get(f'v1/devices/{dongle_id}/bootlogs')
return [Bootlog(b) for b in r]

14
tools/lib/cache.py Normal file
View File

@@ -0,0 +1,14 @@
import os
import urllib.parse
DEFAULT_CACHE_DIR = os.getenv("CACHE_ROOT", os.path.expanduser("~/.commacache"))
def cache_path_for_file_path(fn, cache_dir=DEFAULT_CACHE_DIR):
dir_ = os.path.join(cache_dir, "local")
os.makedirs(dir_, exist_ok=True)
fn_parsed = urllib.parse.urlparse(fn)
if fn_parsed.scheme == '':
cache_fn = os.path.abspath(fn).replace("/", "_")
else:
cache_fn = f'{fn_parsed.hostname}_{fn_parsed.path.replace("/", "_")}'
return os.path.join(dir_, cache_fn)

View File

@@ -0,0 +1,90 @@
import os
import requests
# Forks with additional car support can fork the commaCarSegments repo on huggingface or host the LFS files themselves
COMMA_CAR_SEGMENTS_REPO = os.environ.get("COMMA_CAR_SEGMENTS_REPO", "https://huggingface.co/datasets/commaai/commaCarSegments")
COMMA_CAR_SEGMENTS_BRANCH = os.environ.get("COMMA_CAR_SEGMENTS_BRANCH", "main")
COMMA_CAR_SEGMENTS_LFS_INSTANCE = os.environ.get("COMMA_CAR_SEGMENTS_LFS_INSTANCE", COMMA_CAR_SEGMENTS_REPO)
def get_comma_car_segments_database():
from opendbc.car.fingerprints import MIGRATION
database = requests.get(get_repo_raw_url("database.json")).json()
ret = {}
for platform in database:
ret[MIGRATION.get(platform, platform)] = database[platform]
return ret
# Helpers related to interfacing with the commaCarSegments repository, which contains a collection of public segments for users to perform validation on.
def parse_lfs_pointer(text):
header, lfs_version = text.splitlines()[0].split(" ")
assert header == "version"
assert lfs_version == "https://git-lfs.github.com/spec/v1"
header, oid_raw = text.splitlines()[1].split(" ")
assert header == "oid"
header, oid = oid_raw.split(":")
assert header == "sha256"
header, size = text.splitlines()[2].split(" ")
assert header == "size"
return oid, size
def get_lfs_file_url(oid, size):
data = {
"operation": "download",
"transfers": [ "basic" ],
"objects": [
{
"oid": oid,
"size": int(size)
}
],
"hash_algo": "sha256"
}
headers = {
"Accept": "application/vnd.git-lfs+json",
"Content-Type": "application/vnd.git-lfs+json"
}
response = requests.post(f"{COMMA_CAR_SEGMENTS_LFS_INSTANCE}.git/info/lfs/objects/batch", json=data, headers=headers)
assert response.ok
obj = response.json()["objects"][0]
assert "error" not in obj, obj
return obj["actions"]["download"]["href"]
def get_repo_raw_url(path):
if "huggingface" in COMMA_CAR_SEGMENTS_REPO:
return f"{COMMA_CAR_SEGMENTS_REPO}/raw/{COMMA_CAR_SEGMENTS_BRANCH}/{path}"
def get_repo_url(path):
# Automatically switch to LFS if we are requesting a file that is stored in LFS
response = requests.head(get_repo_raw_url(path))
if "text/plain" in response.headers.get("content-type"):
# This is an LFS pointer, so download the raw data from lfs
response = requests.get(get_repo_raw_url(path))
assert response.status_code == 200
oid, size = parse_lfs_pointer(response.text)
return get_lfs_file_url(oid, size)
else:
# File has not been uploaded to LFS yet
# (either we are on a fork where the data hasn't been pushed to LFS yet, or the CI job to push hasn't finished)
return get_repo_raw_url(path)
def get_url(route, segment, file="rlog.bz2"):
return get_repo_url(f"segments/{route.replace('|', '/')}/{segment}/{file}")

2
tools/lib/exceptions.py Normal file
View File

@@ -0,0 +1,2 @@
class DataUnreadableError(Exception):
pass

44
tools/lib/filereader.py Normal file
View File

@@ -0,0 +1,44 @@
import os
import posixpath
import socket
from urllib.parse import urlparse
from openpilot.tools.lib.url_file import URLFile
DATA_ENDPOINT = os.getenv("DATA_ENDPOINT", "http://data-raw.comma.internal/")
def internal_source_available(url=DATA_ENDPOINT):
if os.path.isdir(url):
return True
try:
hostname = urlparse(url).hostname
port = urlparse(url).port or 80
with socket.socket(socket.AF_INET,socket.SOCK_STREAM) as s:
s.settimeout(0.5)
s.connect((hostname, port))
return True
except (socket.gaierror, ConnectionRefusedError):
pass
return False
def resolve_name(fn):
if fn.startswith("cd:/"):
return posixpath.join(DATA_ENDPOINT, fn[4:])
return fn
def file_exists(fn):
fn = resolve_name(fn)
if fn.startswith(("http://", "https://")):
return URLFile(fn).get_length_online() != -1
return os.path.exists(fn)
def FileReader(fn, debug=False):
fn = resolve_name(fn)
if fn.startswith(("http://", "https://")):
return URLFile(fn, debug=debug)
return open(fn, "rb")

559
tools/lib/framereader.py Normal file
View File

@@ -0,0 +1,559 @@
import json
import os
import pickle
import struct
import subprocess
import threading
from enum import IntEnum
from functools import wraps
import numpy as np
from lru import LRU
import _io
from openpilot.tools.lib.cache import cache_path_for_file_path, DEFAULT_CACHE_DIR
from openpilot.tools.lib.exceptions import DataUnreadableError
from openpilot.tools.lib.vidindex import hevc_index
from openpilot.common.file_helpers import atomic_write_in_dir
from openpilot.tools.lib.filereader import FileReader, resolve_name
HEVC_SLICE_B = 0
HEVC_SLICE_P = 1
HEVC_SLICE_I = 2
class GOPReader:
def get_gop(self, num):
# returns (start_frame_num, num_frames, frames_to_skip, gop_data)
raise NotImplementedError
class DoNothingContextManager:
def __enter__(self):
return self
def __exit__(self, *x):
pass
class FrameType(IntEnum):
raw = 1
h265_stream = 2
def fingerprint_video(fn):
with FileReader(fn) as f:
header = f.read(4)
if len(header) == 0:
raise DataUnreadableError(f"{fn} is empty")
elif header == b"\x00\xc0\x12\x00":
return FrameType.raw
elif header == b"\x00\x00\x00\x01":
if 'hevc' in fn:
return FrameType.h265_stream
else:
raise NotImplementedError(fn)
else:
raise NotImplementedError(fn)
def ffprobe(fn, fmt=None):
fn = resolve_name(fn)
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", "-show_streams"]
if fmt:
cmd += ["-f", fmt]
cmd += ["-i", "-"]
try:
with FileReader(fn) as f:
ffprobe_output = subprocess.check_output(cmd, input=f.read(4096))
except subprocess.CalledProcessError as e:
raise DataUnreadableError(fn) from e
return json.loads(ffprobe_output)
def cache_fn(func):
@wraps(func)
def cache_inner(fn, *args, **kwargs):
if kwargs.pop('no_cache', None):
cache_path = None
else:
cache_dir = kwargs.pop('cache_dir', DEFAULT_CACHE_DIR)
cache_path = cache_path_for_file_path(fn, cache_dir)
if cache_path and os.path.exists(cache_path):
with open(cache_path, "rb") as cache_file:
cache_value = pickle.load(cache_file)
else:
cache_value = func(fn, *args, **kwargs)
if cache_path:
with atomic_write_in_dir(cache_path, mode="wb", overwrite=True) as cache_file:
pickle.dump(cache_value, cache_file, -1)
return cache_value
return cache_inner
@cache_fn
def index_stream(fn, ft):
if ft != FrameType.h265_stream:
raise NotImplementedError("Only h265 supported")
frame_types, dat_len, prefix = hevc_index(fn)
index = np.array(frame_types + [(0xFFFFFFFF, dat_len)], dtype=np.uint32)
probe = ffprobe(fn, "hevc")
return {
'index': index,
'global_prefix': prefix,
'probe': probe
}
def get_video_index(fn, frame_type, cache_dir=DEFAULT_CACHE_DIR):
return index_stream(fn, frame_type, cache_dir=cache_dir)
def read_file_check_size(f, sz, cookie):
buff = bytearray(sz)
bytes_read = f.readinto(buff)
assert bytes_read == sz, (bytes_read, sz)
return buff
def rgb24toyuv(rgb):
yuv_from_rgb = np.array([[ 0.299 , 0.587 , 0.114 ],
[-0.14714119, -0.28886916, 0.43601035 ],
[ 0.61497538, -0.51496512, -0.10001026 ]])
img = np.dot(rgb.reshape(-1, 3), yuv_from_rgb.T).reshape(rgb.shape)
ys = img[:, :, 0]
us = (img[::2, ::2, 1] + img[1::2, ::2, 1] + img[::2, 1::2, 1] + img[1::2, 1::2, 1]) / 4 + 128
vs = (img[::2, ::2, 2] + img[1::2, ::2, 2] + img[::2, 1::2, 2] + img[1::2, 1::2, 2]) / 4 + 128
return ys, us, vs
def rgb24toyuv420(rgb):
ys, us, vs = rgb24toyuv(rgb)
y_len = rgb.shape[0] * rgb.shape[1]
uv_len = y_len // 4
yuv420 = np.empty(y_len + 2 * uv_len, dtype=rgb.dtype)
yuv420[:y_len] = ys.reshape(-1)
yuv420[y_len:y_len + uv_len] = us.reshape(-1)
yuv420[y_len + uv_len:y_len + 2 * uv_len] = vs.reshape(-1)
return yuv420.clip(0, 255).astype('uint8')
def rgb24tonv12(rgb):
ys, us, vs = rgb24toyuv(rgb)
y_len = rgb.shape[0] * rgb.shape[1]
uv_len = y_len // 4
nv12 = np.empty(y_len + 2 * uv_len, dtype=rgb.dtype)
nv12[:y_len] = ys.reshape(-1)
nv12[y_len::2] = us.reshape(-1)
nv12[y_len+1::2] = vs.reshape(-1)
return nv12.clip(0, 255).astype('uint8')
def decompress_video_data(rawdat, vid_fmt, w, h, pix_fmt):
threads = os.getenv("FFMPEG_THREADS", "0")
cuda = os.getenv("FFMPEG_CUDA", "0") == "1"
args = ["ffmpeg", "-v", "quiet",
"-threads", threads,
"-hwaccel", "none" if not cuda else "cuda",
"-c:v", "hevc",
"-vsync", "0",
"-f", vid_fmt,
"-flags2", "showall",
"-i", "-",
"-threads", threads,
"-f", "rawvideo",
"-pix_fmt", pix_fmt,
"-"]
dat = subprocess.check_output(args, input=rawdat)
if pix_fmt == "rgb24":
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, h, w, 3)
elif pix_fmt == "nv12":
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, (h*w*3//2))
elif pix_fmt == "yuv420p":
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, (h*w*3//2))
elif pix_fmt == "yuv444p":
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, 3, h, w)
else:
raise NotImplementedError
return ret
class BaseFrameReader:
# properties: frame_type, frame_count, w, h
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
def close(self):
pass
def get(self, num, count=1, pix_fmt="yuv420p"):
raise NotImplementedError
def FrameReader(fn, cache_dir=DEFAULT_CACHE_DIR, readahead=False, readbehind=False, index_data=None):
frame_type = fingerprint_video(fn)
if frame_type == FrameType.raw:
return RawFrameReader(fn)
elif frame_type in (FrameType.h265_stream,):
if not index_data:
index_data = get_video_index(fn, frame_type, cache_dir)
return StreamFrameReader(fn, frame_type, index_data, readahead=readahead, readbehind=readbehind)
else:
raise NotImplementedError(frame_type)
class RawData:
def __init__(self, f):
self.f = _io.FileIO(f, 'rb')
self.lenn = struct.unpack("I", self.f.read(4))[0]
self.count = os.path.getsize(f) / (self.lenn+4)
def read(self, i):
self.f.seek((self.lenn+4)*i + 4)
return self.f.read(self.lenn)
class RawFrameReader(BaseFrameReader):
def __init__(self, fn):
# raw camera
self.fn = fn
self.frame_type = FrameType.raw
self.rawfile = RawData(self.fn)
self.frame_count = self.rawfile.count
self.w, self.h = 640, 480
def load_and_debayer(self, img):
img = np.frombuffer(img, dtype='uint8').reshape(960, 1280)
cimg = np.dstack([img[0::2, 1::2], ((img[0::2, 0::2].astype("uint16") + img[1::2, 1::2].astype("uint16")) >> 1).astype("uint8"), img[1::2, 0::2]])
return cimg
def get(self, num, count=1, pix_fmt="yuv420p"):
assert self.frame_count is not None
assert num+count <= self.frame_count
if pix_fmt not in ("nv12", "yuv420p", "rgb24"):
raise ValueError(f"Unsupported pixel format {pix_fmt!r}")
app = []
for i in range(num, num+count):
dat = self.rawfile.read(i)
rgb_dat = self.load_and_debayer(dat)
if pix_fmt == "rgb24":
app.append(rgb_dat)
elif pix_fmt == "nv12":
app.append(rgb24tonv12(rgb_dat))
elif pix_fmt == "yuv420p":
app.append(rgb24toyuv420(rgb_dat))
else:
raise NotImplementedError
return app
class VideoStreamDecompressor:
def __init__(self, fn, vid_fmt, w, h, pix_fmt):
self.fn = fn
self.vid_fmt = vid_fmt
self.w = w
self.h = h
self.pix_fmt = pix_fmt
if pix_fmt in ("nv12", "yuv420p"):
self.out_size = w*h*3//2 # yuv420p
elif pix_fmt in ("rgb24", "yuv444p"):
self.out_size = w*h*3
else:
raise NotImplementedError
self.proc = None
self.t = threading.Thread(target=self.write_thread)
self.t.daemon = True
def write_thread(self):
try:
with FileReader(self.fn) as f:
while True:
r = f.read(1024*1024)
if len(r) == 0:
break
self.proc.stdin.write(r)
except BrokenPipeError:
pass
finally:
self.proc.stdin.close()
def read(self):
threads = os.getenv("FFMPEG_THREADS", "0")
cuda = os.getenv("FFMPEG_CUDA", "0") == "1"
cmd = [
"ffmpeg",
"-threads", threads,
"-hwaccel", "none" if not cuda else "cuda",
"-c:v", "hevc",
# "-avioflags", "direct",
"-analyzeduration", "0",
"-probesize", "32",
"-flush_packets", "0",
# "-fflags", "nobuffer",
"-vsync", "0",
"-f", self.vid_fmt,
"-i", "pipe:0",
"-threads", threads,
"-f", "rawvideo",
"-pix_fmt", self.pix_fmt,
"pipe:1"
]
self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
try:
self.t.start()
while True:
dat = self.proc.stdout.read(self.out_size)
if len(dat) == 0:
break
assert len(dat) == self.out_size
if self.pix_fmt == "rgb24":
ret = np.frombuffer(dat, dtype=np.uint8).reshape((self.h, self.w, 3))
elif self.pix_fmt == "yuv420p":
ret = np.frombuffer(dat, dtype=np.uint8)
elif self.pix_fmt == "nv12":
ret = np.frombuffer(dat, dtype=np.uint8)
elif self.pix_fmt == "yuv444p":
ret = np.frombuffer(dat, dtype=np.uint8).reshape((3, self.h, self.w))
else:
raise RuntimeError(f"unknown pix_fmt: {self.pix_fmt}")
yield ret
result_code = self.proc.wait()
assert result_code == 0, result_code
finally:
self.proc.kill()
self.t.join()
class StreamGOPReader(GOPReader):
def __init__(self, fn, frame_type, index_data):
assert frame_type == FrameType.h265_stream
self.fn = fn
self.frame_type = frame_type
self.frame_count = None
self.w, self.h = None, None
self.prefix = None
self.index = None
self.index = index_data['index']
self.prefix = index_data['global_prefix']
probe = index_data['probe']
self.prefix_frame_data = None
self.num_prefix_frames = 0
self.vid_fmt = "hevc"
i = 0
while i < self.index.shape[0] and self.index[i, 0] != HEVC_SLICE_I:
i += 1
self.first_iframe = i
assert self.first_iframe == 0
self.frame_count = len(self.index) - 1
self.w = probe['streams'][0]['width']
self.h = probe['streams'][0]['height']
def _lookup_gop(self, num):
frame_b = num
while frame_b > 0 and self.index[frame_b, 0] != HEVC_SLICE_I:
frame_b -= 1
frame_e = num + 1
while frame_e < (len(self.index) - 1) and self.index[frame_e, 0] != HEVC_SLICE_I:
frame_e += 1
offset_b = self.index[frame_b, 1]
offset_e = self.index[frame_e, 1]
return (frame_b, frame_e, offset_b, offset_e)
def get_gop(self, num):
frame_b, frame_e, offset_b, offset_e = self._lookup_gop(num)
assert frame_b <= num < frame_e
num_frames = frame_e - frame_b
with FileReader(self.fn) as f:
f.seek(offset_b)
rawdat = f.read(offset_e - offset_b)
if num < self.first_iframe:
assert self.prefix_frame_data
rawdat = self.prefix_frame_data + rawdat
rawdat = self.prefix + rawdat
skip_frames = 0
if num < self.first_iframe:
skip_frames = self.num_prefix_frames
return frame_b, num_frames, skip_frames, rawdat
class GOPFrameReader(BaseFrameReader):
#FrameReader with caching and readahead for formats that are group-of-picture based
def __init__(self, readahead=False, readbehind=False):
self.open_ = True
self.readahead = readahead
self.readbehind = readbehind
self.frame_cache = LRU(64)
if self.readahead:
self.cache_lock = threading.RLock()
self.readahead_last = None
self.readahead_len = 30
self.readahead_c = threading.Condition()
self.readahead_thread = threading.Thread(target=self._readahead_thread)
self.readahead_thread.daemon = True
self.readahead_thread.start()
else:
self.cache_lock = DoNothingContextManager()
def close(self):
if not self.open_:
return
self.open_ = False
if self.readahead:
self.readahead_c.acquire()
self.readahead_c.notify()
self.readahead_c.release()
self.readahead_thread.join()
def _readahead_thread(self):
while True:
self.readahead_c.acquire()
try:
if not self.open_:
break
self.readahead_c.wait()
finally:
self.readahead_c.release()
if not self.open_:
break
assert self.readahead_last
num, pix_fmt = self.readahead_last
if self.readbehind:
for k in range(num - 1, max(0, num - self.readahead_len), -1):
self._get_one(k, pix_fmt)
else:
for k in range(num, min(self.frame_count, num + self.readahead_len)):
self._get_one(k, pix_fmt)
def _get_one(self, num, pix_fmt):
assert num < self.frame_count
if (num, pix_fmt) in self.frame_cache:
return self.frame_cache[(num, pix_fmt)]
with self.cache_lock:
if (num, pix_fmt) in self.frame_cache:
return self.frame_cache[(num, pix_fmt)]
frame_b, num_frames, skip_frames, rawdat = self.get_gop(num)
ret = decompress_video_data(rawdat, self.vid_fmt, self.w, self.h, pix_fmt)
ret = ret[skip_frames:]
assert ret.shape[0] == num_frames
for i in range(ret.shape[0]):
self.frame_cache[(frame_b+i, pix_fmt)] = ret[i]
return self.frame_cache[(num, pix_fmt)]
def get(self, num, count=1, pix_fmt="yuv420p"):
assert self.frame_count is not None
if num + count > self.frame_count:
raise ValueError(f"{num + count} > {self.frame_count}")
if pix_fmt not in ("nv12", "yuv420p", "rgb24", "yuv444p"):
raise ValueError(f"Unsupported pixel format {pix_fmt!r}")
ret = [self._get_one(num + i, pix_fmt) for i in range(count)]
if self.readahead:
self.readahead_last = (num+count, pix_fmt)
self.readahead_c.acquire()
self.readahead_c.notify()
self.readahead_c.release()
return ret
class StreamFrameReader(StreamGOPReader, GOPFrameReader):
def __init__(self, fn, frame_type, index_data, readahead=False, readbehind=False):
StreamGOPReader.__init__(self, fn, frame_type, index_data)
GOPFrameReader.__init__(self, readahead, readbehind)
def GOPFrameIterator(gop_reader, pix_fmt):
dec = VideoStreamDecompressor(gop_reader.fn, gop_reader.vid_fmt, gop_reader.w, gop_reader.h, pix_fmt)
yield from dec.read()
def FrameIterator(fn, pix_fmt, **kwargs):
fr = FrameReader(fn, **kwargs)
if isinstance(fr, GOPReader):
yield from GOPFrameIterator(fr, pix_fmt)
else:
for i in range(fr.frame_count):
yield fr.get(i, pix_fmt=pix_fmt)[0]
class NumpyFrameReader:
def __init__(self, name, w, h, cache_size):
self.name = name
self.pos = -1
self.frames = None
self.w = w
self.h = h
self.cache_size = cache_size
def close(self):
pass
def get(self, num, count=1, pix_fmt="nv12"):
num -= 1
q = num // self.cache_size
if q != self.pos:
del self.frames
self.pos = q
self.frames = np.load(f'{self.name}_{self.pos}.npy')
return [self.frames[num % self.cache_size]]

114
tools/lib/github_utils.py Normal file
View File

@@ -0,0 +1,114 @@
import base64
import requests
from http import HTTPMethod
class GithubUtils:
def __init__(self, api_token, data_token, owner='commaai', api_repo='openpilot', data_repo='ci-artifacts'):
self.OWNER = owner
self.API_REPO = api_repo
self.DATA_REPO = data_repo
self.API_TOKEN = api_token
self.DATA_TOKEN = data_token
@property
def API_ROUTE(self):
return f"https://api.github.com/repos/{self.OWNER}/{self.API_REPO}"
@property
def DATA_ROUTE(self):
return f"https://api.github.com/repos/{self.OWNER}/{self.DATA_REPO}"
def api_call(self, path, data="", method=HTTPMethod.GET, accept="", data_call=False, raise_on_failure=True):
token = self.DATA_TOKEN if data_call else self.API_TOKEN
if token:
headers = {"Authorization": f"Bearer {self.DATA_TOKEN if data_call else self.API_TOKEN}", \
"Accept": f"application/vnd.github{accept}+json"}
else:
headers = {}
path = f'{self.DATA_ROUTE if data_call else self.API_ROUTE}/{path}'
r = requests.request(method, path, headers=headers, data=data)
if not r.ok and raise_on_failure:
raise Exception(f"Call to {path} failed with {r.status_code}")
else:
return r
def upload_file(self, bucket, path, file_name):
with open(path, "rb") as f:
encoded = base64.b64encode(f.read()).decode()
# check if file already exists
sha = self.get_file_sha(bucket, file_name)
sha = f'"sha":"{sha}",' if sha else ''
data = f'{{"message":"uploading {file_name}", \
"branch":"{bucket}", \
"committer":{{"name":"Vehicle Researcher", "email": "user@comma.ai"}}, \
{sha} \
"content":"{encoded}"}}'
github_path = f"contents/{file_name}"
self.api_call(github_path, data=data, method=HTTPMethod.PUT, data_call=True)
def upload_files(self, bucket, files):
self.create_bucket(bucket)
for file_name,path in files:
self.upload_file(bucket, path, file_name)
def create_bucket(self, bucket):
if self.get_bucket_sha(bucket):
return
master_sha = self.get_bucket_sha('master')
github_path = "git/refs"
data = f'{{"ref":"refs/heads/{bucket}", "sha":"{master_sha}"}}'
self.api_call(github_path, data=data, method=HTTPMethod.POST, data_call=True)
def get_bucket_sha(self, bucket):
github_path = f"git/refs/heads/{bucket}"
r = self.api_call(github_path, data_call=True, raise_on_failure=False)
return r.json()['object']['sha'] if r.ok else None
def get_file_url(self, bucket, file_name):
github_path = f"contents/{file_name}?ref={bucket}"
r = self.api_call(github_path, data_call=True)
return r.json()['download_url']
def get_file_sha(self, bucket, file_name):
github_path = f"contents/{file_name}?ref={bucket}"
r = self.api_call(github_path, data_call=True, raise_on_failure=False)
return r.json()['sha'] if r.ok else None
def get_pr_number(self, pr_branch):
github_path = f"commits/{pr_branch}/pulls"
r = self.api_call(github_path)
return r.json()[0]['number']
def get_bucket_link(self, bucket):
return f'https://raw.githubusercontent.com/{self.OWNER}/{self.DATA_REPO}/refs/heads/{bucket}'
def comment_on_pr(self, comment, pr_branch, commenter="", overwrite=False):
pr_number = self.get_pr_number(pr_branch)
data = f'{{"body": "{comment}"}}'
if overwrite:
github_path = f'issues/{pr_number}/comments'
r = self.api_call(github_path)
comments = [x['id'] for x in r.json() if x['user']['login'] == commenter]
if comments:
github_path = f'issues/comments/{comments[0]}'
self.api_call(github_path, data=data, method=HTTPMethod.PATCH)
return
github_path=f'issues/{pr_number}/comments'
self.api_call(github_path, data=data, method=HTTPMethod.POST)
# upload files to github and comment them on the pr
def comment_images_on_pr(self, title, commenter, pr_branch, bucket, images):
self.upload_files(bucket, images)
table = [f'<details><summary>{title}</summary><table>']
for i,f in enumerate(images):
if not (i % 2):
table.append('<tr>')
table.append(f'<td><img src=\\"https://raw.githubusercontent.com/{self.OWNER}/{self.DATA_REPO}/{bucket}/{f[0]}\\"></td>')
if (i % 2):
table.append('</tr>')
table.append('</table></details>')
table = ''.join(table)
self.comment_on_pr(table, commenter, pr_branch)

17
tools/lib/helpers.py Normal file
View File

@@ -0,0 +1,17 @@
# regex patterns
class RE:
DONGLE_ID = r'(?P<dongle_id>[a-f0-9]{16})'
TIMESTAMP = r'(?P<timestamp>[0-9]{4}-[0-9]{2}-[0-9]{2}--[0-9]{2}-[0-9]{2}-[0-9]{2})'
LOG_ID_V2 = r'(?P<count>[a-f0-9]{8})--(?P<uid>[a-z0-9]{10})'
LOG_ID = fr'(?P<log_id>(?:{TIMESTAMP}|{LOG_ID_V2}))'
ROUTE_NAME = fr'(?P<route_name>{DONGLE_ID}[|_/]{LOG_ID})'
SEGMENT_NAME = fr'{ROUTE_NAME}(?:--|/)(?P<segment_num>[0-9]+)'
INDEX = r'-?[0-9]+'
SLICE = fr'(?P<start>{INDEX})?:?(?P<end>{INDEX})?:?(?P<step>{INDEX})?'
SEGMENT_RANGE = fr'{ROUTE_NAME}(?:(--|/)(?P<slice>({SLICE})))?(?:/(?P<selector>([qras])))?'
BOOTLOG_NAME = ROUTE_NAME
EXPLORER_FILE = fr'^(?P<segment_name>{SEGMENT_NAME})--(?P<file_name>[a-z]+\.[a-z0-9]+)$'
OP_SEGMENT_DIR = fr'^(?P<segment_name>{SEGMENT_NAME})$'

81
tools/lib/kbhit.py Executable file
View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python3
import sys
import termios
import atexit
from select import select
class KBHit:
def __init__(self) -> None:
''' Creates a KBHit object that you can call to do various keyboard things.
'''
self.stdin_fd = sys.stdin.fileno()
self.set_kbhit_terminal()
def set_kbhit_terminal(self) -> None:
''' Save old terminal settings for closure, remove ICANON & ECHO flags.
'''
# Save the terminal settings
self.old_term = termios.tcgetattr(self.stdin_fd)
self.new_term = self.old_term.copy()
# New terminal setting unbuffered
self.new_term[3] &= ~(termios.ICANON | termios.ECHO)
termios.tcsetattr(self.stdin_fd, termios.TCSAFLUSH, self.new_term)
# Support normal-terminal reset at exit
atexit.register(self.set_normal_term)
def set_normal_term(self) -> None:
''' Resets to normal terminal. On Windows this is a no-op.
'''
termios.tcsetattr(self.stdin_fd, termios.TCSAFLUSH, self.old_term)
@staticmethod
def getch() -> str:
''' Returns a keyboard character after kbhit() has been called.
Should not be called in the same program as getarrow().
'''
return sys.stdin.read(1)
@staticmethod
def getarrow() -> int:
''' Returns an arrow-key code after kbhit() has been called. Codes are
0 : up
1 : right
2 : down
3 : left
Should not be called in the same program as getch().
'''
c = sys.stdin.read(3)[2]
vals = [65, 67, 66, 68]
return vals.index(ord(c))
@staticmethod
def kbhit():
''' Returns True if keyboard character was hit, False otherwise.
'''
return select([sys.stdin], [], [], 0)[0] != []
# Test
if __name__ == "__main__":
kb = KBHit()
print('Hit any key, or ESC to exit')
while True:
if kb.kbhit():
c = kb.getch()
if c == '\x1b': # ESC
break
print(c)
kb.set_normal_term()

View File

@@ -0,0 +1,30 @@
import os
from cereal import log as capnp_log, messaging
from cereal.services import SERVICE_LIST
from openpilot.tools.lib.logreader import LogIterable, RawLogIterable
ALL_SERVICES = list(SERVICE_LIST.keys())
def raw_live_logreader(services: list[str] = ALL_SERVICES, addr: str = '127.0.0.1') -> RawLogIterable:
if addr != "127.0.0.1":
os.environ["ZMQ"] = "1"
messaging.reset_context()
poller = messaging.Poller()
for m in services:
messaging.sub_sock(m, poller, addr=addr)
while True:
polld = poller.poll(100)
for sock in polld:
msg = sock.receive()
yield msg
def live_logreader(services: list[str] = ALL_SERVICES, addr: str = '127.0.0.1') -> LogIterable:
for m in raw_live_logreader(services, addr):
with capnp_log.Event.from_bytes(m) as evt:
yield evt

View File

@@ -0,0 +1,84 @@
import numpy as np
def flatten_type_dict(d, sep="/", prefix=None):
res = {}
if isinstance(d, dict):
for key, val in d.items():
if prefix is None:
res.update(flatten_type_dict(val, prefix=key))
else:
res.update(flatten_type_dict(val, prefix=prefix + sep + key))
return res
elif isinstance(d, list):
return {prefix: np.array(d)}
else:
return {prefix: d}
def get_message_dict(message, typ):
valid = message.valid
message = message._get(typ)
if not hasattr(message, 'to_dict') or typ in ('qcomGnss', 'ubloxGnss'):
# TODO: support these
#print("skipping", typ)
return
msg_dict = message.to_dict(verbose=True)
msg_dict = flatten_type_dict(msg_dict)
msg_dict['_valid'] = valid
return msg_dict
def append_dict(path, t, d, values):
if path not in values:
group = {}
group["t"] = []
for k in d:
group[k] = []
values[path] = group
else:
group = values[path]
group["t"].append(t)
for k, v in d.items():
group[k].append(v)
def potentially_ragged_array(arr, dtype=None, **kwargs):
# TODO: is there a better way to detect inhomogeneous shapes?
try:
return np.array(arr, dtype=dtype, **kwargs)
except ValueError:
return np.array(arr, dtype=object, **kwargs)
def msgs_to_time_series(msgs):
"""
Convert an iterable of canonical capnp messages into a dictionary of time series.
Each time series has a value with key "t" which consists of monotonically increasing timestamps
in seconds.
"""
values = {}
for msg in msgs:
typ = msg.which()
tm = msg.logMonoTime / 1.0e9
msg_dict = get_message_dict(msg, typ)
if msg_dict is not None:
append_dict(typ, tm, msg_dict, values)
# Sort values by time.
for group in values.values():
order = np.argsort(group["t"])
for name, group_values in group.items():
group[name] = potentially_ragged_array(group_values)[order]
return values
if __name__ == "__main__":
import sys
from openpilot.tools.lib.logreader import LogReader
m = msgs_to_time_series(LogReader(sys.argv[1]))
print(m['driverCameraState']['t'])
print(np.diff(m['driverCameraState']['timestampSof']))

339
tools/lib/logreader.py Executable file
View File

@@ -0,0 +1,339 @@
#!/usr/bin/env python3
import bz2
from functools import cache, partial
import multiprocessing
import capnp
import enum
import os
import pathlib
import sys
import tqdm
import urllib.parse
import warnings
import zstandard as zstd
from collections.abc import Callable, Iterable, Iterator
from urllib.parse import parse_qs, urlparse
from cereal import log as capnp_log
from openpilot.common.swaglog import cloudlog
from openpilot.tools.lib.comma_car_segments import get_url as get_comma_segments_url
from openpilot.tools.lib.openpilotci import get_url
from openpilot.tools.lib.filereader import FileReader, file_exists, internal_source_available
from openpilot.tools.lib.route import Route, SegmentRange
from openpilot.tools.lib.log_time_series import msgs_to_time_series
LogMessage = type[capnp._DynamicStructReader]
LogIterable = Iterable[LogMessage]
RawLogIterable = Iterable[bytes]
def save_log(dest, log_msgs, compress=True):
dat = b"".join(msg.as_builder().to_bytes() for msg in log_msgs)
if compress and dest.endswith(".bz2"):
dat = bz2.compress(dat)
elif compress and dest.endswith(".zst"):
dat = zstd.compress(dat, 10)
with open(dest, "wb") as f:
f.write(dat)
def decompress_stream(data: bytes):
dctx = zstd.ZstdDecompressor()
decompressed_data = b""
with dctx.stream_reader(data) as reader:
decompressed_data = reader.read()
return decompressed_data
class _LogFileReader:
def __init__(self, fn, canonicalize=True, only_union_types=False, sort_by_time=False, dat=None):
self.data_version = None
self._only_union_types = only_union_types
ext = None
if not dat:
_, ext = os.path.splitext(urllib.parse.urlparse(fn).path)
if ext not in ('', '.bz2', '.zst'):
# old rlogs weren't compressed
raise ValueError(f"unknown extension {ext}")
with FileReader(fn) as f:
dat = f.read()
if ext == ".bz2" or dat.startswith(b'BZh9'):
dat = bz2.decompress(dat)
elif ext == ".zst" or dat.startswith(b'\x28\xB5\x2F\xFD'):
# https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#zstandard-frames
dat = decompress_stream(dat)
ents = capnp_log.Event.read_multiple_bytes(dat)
self._ents = []
try:
for e in ents:
self._ents.append(e)
except capnp.KjException:
warnings.warn("Corrupted events detected", RuntimeWarning, stacklevel=1)
if sort_by_time:
self._ents.sort(key=lambda x: x.logMonoTime)
def __iter__(self) -> Iterator[capnp._DynamicStructReader]:
for ent in self._ents:
if self._only_union_types:
try:
ent.which()
yield ent
except capnp.lib.capnp.KjException:
pass
else:
yield ent
class ReadMode(enum.StrEnum):
RLOG = "r" # only read rlogs
QLOG = "q" # only read qlogs
SANITIZED = "s" # read from the commaCarSegments database
AUTO = "a" # default to rlogs, fallback to qlogs
AUTO_INTERACTIVE = "i" # default to rlogs, fallback to qlogs with a prompt from the user
LogPath = str | None
ValidFileCallable = Callable[[LogPath], bool]
Source = Callable[[SegmentRange, ReadMode], list[LogPath]]
InternalUnavailableException = Exception("Internal source not available")
class LogsUnavailable(Exception):
pass
@cache
def default_valid_file(fn: LogPath) -> bool:
return fn is not None and file_exists(fn)
def auto_strategy(rlog_paths: list[LogPath], qlog_paths: list[LogPath], interactive: bool, valid_file: ValidFileCallable) -> list[LogPath]:
# auto select logs based on availability
missing_rlogs = [rlog is None or not valid_file(rlog) for rlog in rlog_paths].count(True)
if missing_rlogs != 0:
if interactive:
if input(f"{missing_rlogs}/{len(rlog_paths)} rlogs were not found, would you like to fallback to qlogs for those segments? (y/n) ").lower() != "y":
return rlog_paths
else:
cloudlog.warning(f"{missing_rlogs}/{len(rlog_paths)} rlogs were not found, falling back to qlogs for those segments...")
return [rlog if valid_file(rlog) else (qlog if valid_file(qlog) else None)
for (rlog, qlog) in zip(rlog_paths, qlog_paths, strict=True)]
return rlog_paths
def apply_strategy(mode: ReadMode, rlog_paths: list[LogPath], qlog_paths: list[LogPath], valid_file: ValidFileCallable = default_valid_file) -> list[LogPath]:
if mode == ReadMode.RLOG:
return rlog_paths
elif mode == ReadMode.QLOG:
return qlog_paths
elif mode == ReadMode.AUTO:
return auto_strategy(rlog_paths, qlog_paths, False, valid_file)
elif mode == ReadMode.AUTO_INTERACTIVE:
return auto_strategy(rlog_paths, qlog_paths, True, valid_file)
raise ValueError(f"invalid mode: {mode}")
def comma_api_source(sr: SegmentRange, mode: ReadMode) -> list[LogPath]:
route = Route(sr.route_name)
rlog_paths = [route.log_paths()[seg] for seg in sr.seg_idxs]
qlog_paths = [route.qlog_paths()[seg] for seg in sr.seg_idxs]
# comma api will have already checked if the file exists
def valid_file(fn):
return fn is not None
return apply_strategy(mode, rlog_paths, qlog_paths, valid_file=valid_file)
def internal_source(sr: SegmentRange, mode: ReadMode, file_ext: str = "bz2") -> list[LogPath]:
if not internal_source_available():
raise InternalUnavailableException
def get_internal_url(sr: SegmentRange, seg, file):
return f"cd:/{sr.dongle_id}/{sr.log_id}/{seg}/{file}.{file_ext}"
# TODO: list instead of using static URLs to support routes with multiple file extensions
rlog_paths = [get_internal_url(sr, seg, "rlog") for seg in sr.seg_idxs]
qlog_paths = [get_internal_url(sr, seg, "qlog") for seg in sr.seg_idxs]
return apply_strategy(mode, rlog_paths, qlog_paths)
def internal_source_zst(sr: SegmentRange, mode: ReadMode, file_ext: str = "zst") -> list[LogPath]:
return internal_source(sr, mode, file_ext)
def openpilotci_source(sr: SegmentRange, mode: ReadMode, file_ext: str = "bz2") -> list[LogPath]:
rlog_paths = [get_url(sr.route_name, seg, f"rlog.{file_ext}") for seg in sr.seg_idxs]
qlog_paths = [get_url(sr.route_name, seg, f"qlog.{file_ext}") for seg in sr.seg_idxs]
return apply_strategy(mode, rlog_paths, qlog_paths)
def openpilotci_source_zst(sr: SegmentRange, mode: ReadMode) -> list[LogPath]:
return openpilotci_source(sr, mode, "zst")
def comma_car_segments_source(sr: SegmentRange, mode=ReadMode.RLOG) -> list[LogPath]:
return [get_comma_segments_url(sr.route_name, seg) for seg in sr.seg_idxs]
def testing_closet_source(sr: SegmentRange, mode=ReadMode.RLOG) -> list[LogPath]:
if not internal_source_available('http://testing.comma.life'):
raise InternalUnavailableException
return [f"http://testing.comma.life/download/{sr.route_name.replace('|', '/')}/{seg}/rlog" for seg in sr.seg_idxs]
def direct_source(file_or_url: str) -> list[LogPath]:
return [file_or_url]
def get_invalid_files(files):
for f in files:
if f is None or not file_exists(f):
yield f
def check_source(source: Source, *args) -> list[LogPath]:
files = source(*args)
assert len(files) > 0, "No files on source"
assert next(get_invalid_files(files), False) is False, "Some files are invalid"
return files
def auto_source(sr: SegmentRange, mode=ReadMode.RLOG, sources: list[Source] = None) -> list[LogPath]:
if mode == ReadMode.SANITIZED:
return comma_car_segments_source(sr, mode)
if sources is None:
sources = [internal_source, internal_source_zst, openpilotci_source, openpilotci_source_zst,
comma_api_source, comma_car_segments_source, testing_closet_source]
exceptions = {}
# for automatic fallback modes, auto_source needs to first check if rlogs exist for any source
if mode in [ReadMode.AUTO, ReadMode.AUTO_INTERACTIVE]:
for source in sources:
try:
return check_source(source, sr, ReadMode.RLOG)
except Exception:
pass
# Automatically determine viable source
for source in sources:
try:
return check_source(source, sr, mode)
except Exception as e:
exceptions[source.__name__] = e
raise LogsUnavailable("auto_source could not find any valid source, exceptions for sources:\n - " +
"\n - ".join([f"{k}: {repr(v)}" for k, v in exceptions.items()]))
def parse_indirect(identifier: str) -> str:
if "useradmin.comma.ai" in identifier:
query = parse_qs(urlparse(identifier).query)
return query["onebox"][0]
return identifier
def parse_direct(identifier: str):
if identifier.startswith(("http://", "https://", "cd:/")) or pathlib.Path(identifier).exists():
return identifier
return None
class LogReader:
def _parse_identifier(self, identifier: str) -> list[LogPath]:
# useradmin, etc.
identifier = parse_indirect(identifier)
# direct url or file
direct_parsed = parse_direct(identifier)
if direct_parsed is not None:
return direct_source(identifier)
sr = SegmentRange(identifier)
mode = self.default_mode if sr.selector is None else ReadMode(sr.selector)
identifiers = self.source(sr, mode)
invalid_count = len(list(get_invalid_files(identifiers)))
assert invalid_count == 0, (f"{invalid_count}/{len(identifiers)} invalid log(s) found, please ensure all logs " +
"are uploaded or auto fallback to qlogs with '/a' selector at the end of the route name.")
return identifiers
def __init__(self, identifier: str | list[str], default_mode: ReadMode = ReadMode.RLOG,
source: Source = auto_source, sort_by_time=False, only_union_types=False):
self.default_mode = default_mode
self.source = source
self.identifier = identifier
if isinstance(identifier, str):
self.identifier = [identifier]
self.sort_by_time = sort_by_time
self.only_union_types = only_union_types
self.__lrs: dict[int, _LogFileReader] = {}
self.reset()
def _get_lr(self, i):
if i not in self.__lrs:
self.__lrs[i] = _LogFileReader(self.logreader_identifiers[i], sort_by_time=self.sort_by_time, only_union_types=self.only_union_types)
return self.__lrs[i]
def __iter__(self):
for i in range(len(self.logreader_identifiers)):
yield from self._get_lr(i)
def _run_on_segment(self, func, i):
return func(self._get_lr(i))
def run_across_segments(self, num_processes, func, disable_tqdm=False, desc=None):
with multiprocessing.Pool(num_processes) as pool:
ret = []
num_segs = len(self.logreader_identifiers)
for p in tqdm.tqdm(pool.imap(partial(self._run_on_segment, func), range(num_segs)), total=num_segs, disable=disable_tqdm, desc=desc):
ret.extend(p)
return ret
def reset(self):
self.logreader_identifiers = []
for identifier in self.identifier:
self.logreader_identifiers.extend(self._parse_identifier(identifier))
@staticmethod
def from_bytes(dat):
return _LogFileReader("", dat=dat)
def filter(self, msg_type: str):
return (getattr(m, m.which()) for m in filter(lambda m: m.which() == msg_type, self))
def first(self, msg_type: str):
return next(self.filter(msg_type), None)
@property
def time_series(self):
return msgs_to_time_series(self)
if __name__ == "__main__":
import codecs
# capnproto <= 0.8.0 throws errors converting byte data to string
# below line catches those errors and replaces the bytes with \x__
codecs.register_error("strict", codecs.backslashreplace_errors)
log_path = sys.argv[1]
lr = LogReader(log_path, sort_by_time=True)
for msg in lr:
print(msg)

12
tools/lib/openpilotci.py Normal file
View File

@@ -0,0 +1,12 @@
from openpilot.tools.lib.openpilotcontainers import OpenpilotCIContainer
def get_url(*args, **kwargs):
return OpenpilotCIContainer.get_url(*args, **kwargs)
def upload_file(*args, **kwargs):
return OpenpilotCIContainer.upload_file(*args, **kwargs)
def upload_bytes(*args, **kwargs):
return OpenpilotCIContainer.upload_bytes(*args, **kwargs)
BASE_URL = OpenpilotCIContainer.BASE_URL

View File

@@ -0,0 +1,6 @@
#!/usr/bin/env python3
from openpilot.tools.lib.azure_container import AzureContainer
OpenpilotCIContainer = AzureContainer("commadataci", "openpilotci")
DataCIContainer = AzureContainer("commadataci", "commadataci")
DataProdContainer = AzureContainer("commadata2", "commadata2")

298
tools/lib/route.py Normal file
View File

@@ -0,0 +1,298 @@
import os
import re
from functools import cache
from urllib.parse import urlparse
from collections import defaultdict
from itertools import chain
from openpilot.tools.lib.auth_config import get_token
from openpilot.tools.lib.api import CommaApi
from openpilot.tools.lib.helpers import RE
QLOG_FILENAMES = ['qlog', 'qlog.bz2', 'qlog.zst']
QCAMERA_FILENAMES = ['qcamera.ts']
LOG_FILENAMES = ['rlog', 'rlog.bz2', 'raw_log.bz2', 'rlog.zst']
CAMERA_FILENAMES = ['fcamera.hevc', 'video.hevc']
DCAMERA_FILENAMES = ['dcamera.hevc']
ECAMERA_FILENAMES = ['ecamera.hevc']
class Route:
def __init__(self, name, data_dir=None):
self._name = RouteName(name)
self.files = None
if data_dir is not None:
self._segments = self._get_segments_local(data_dir)
else:
self._segments = self._get_segments_remote()
self.max_seg_number = self._segments[-1].name.segment_num
@property
def name(self):
return self._name
@property
def segments(self):
return self._segments
def log_paths(self):
log_path_by_seg_num = {s.name.segment_num: s.log_path for s in self._segments}
return [log_path_by_seg_num.get(i, None) for i in range(self.max_seg_number + 1)]
def qlog_paths(self):
qlog_path_by_seg_num = {s.name.segment_num: s.qlog_path for s in self._segments}
return [qlog_path_by_seg_num.get(i, None) for i in range(self.max_seg_number + 1)]
def camera_paths(self):
camera_path_by_seg_num = {s.name.segment_num: s.camera_path for s in self._segments}
return [camera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number + 1)]
def dcamera_paths(self):
dcamera_path_by_seg_num = {s.name.segment_num: s.dcamera_path for s in self._segments}
return [dcamera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number + 1)]
def ecamera_paths(self):
ecamera_path_by_seg_num = {s.name.segment_num: s.ecamera_path for s in self._segments}
return [ecamera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number + 1)]
def qcamera_paths(self):
qcamera_path_by_seg_num = {s.name.segment_num: s.qcamera_path for s in self._segments}
return [qcamera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number + 1)]
# TODO: refactor this, it's super repetitive
def _get_segments_remote(self):
api = CommaApi(get_token())
route_files = api.get('v1/route/' + self.name.canonical_name + '/files')
self.files = list(chain.from_iterable(route_files.values()))
segments = {}
for url in self.files:
_, dongle_id, time_str, segment_num, fn = urlparse(url).path.rsplit('/', maxsplit=4)
segment_name = f'{dongle_id}|{time_str}--{segment_num}'
if segments.get(segment_name):
segments[segment_name] = Segment(
segment_name,
url if fn in LOG_FILENAMES else segments[segment_name].log_path,
url if fn in QLOG_FILENAMES else segments[segment_name].qlog_path,
url if fn in CAMERA_FILENAMES else segments[segment_name].camera_path,
url if fn in DCAMERA_FILENAMES else segments[segment_name].dcamera_path,
url if fn in ECAMERA_FILENAMES else segments[segment_name].ecamera_path,
url if fn in QCAMERA_FILENAMES else segments[segment_name].qcamera_path,
)
else:
segments[segment_name] = Segment(
segment_name,
url if fn in LOG_FILENAMES else None,
url if fn in QLOG_FILENAMES else None,
url if fn in CAMERA_FILENAMES else None,
url if fn in DCAMERA_FILENAMES else None,
url if fn in ECAMERA_FILENAMES else None,
url if fn in QCAMERA_FILENAMES else None,
)
return sorted(segments.values(), key=lambda seg: seg.name.segment_num)
def _get_segments_local(self, data_dir):
files = os.listdir(data_dir)
segment_files = defaultdict(list)
for f in files:
fullpath = os.path.join(data_dir, f)
explorer_match = re.match(RE.EXPLORER_FILE, f)
op_match = re.match(RE.OP_SEGMENT_DIR, f)
if explorer_match:
segment_name = explorer_match.group('segment_name')
fn = explorer_match.group('file_name')
if segment_name.replace('_', '|').startswith(self.name.canonical_name):
segment_files[segment_name].append((fullpath, fn))
elif op_match and os.path.isdir(fullpath):
segment_name = op_match.group('segment_name')
if segment_name.startswith(self.name.canonical_name):
for seg_f in os.listdir(fullpath):
segment_files[segment_name].append((os.path.join(fullpath, seg_f), seg_f))
elif f == self.name.canonical_name:
for seg_num in os.listdir(fullpath):
if not seg_num.isdigit():
continue
segment_name = f'{self.name.canonical_name}--{seg_num}'
for seg_f in os.listdir(os.path.join(fullpath, seg_num)):
segment_files[segment_name].append((os.path.join(fullpath, seg_num, seg_f), seg_f))
segments = []
for segment, files in segment_files.items():
try:
log_path = next(path for path, filename in files if filename in LOG_FILENAMES)
except StopIteration:
log_path = None
try:
qlog_path = next(path for path, filename in files if filename in QLOG_FILENAMES)
except StopIteration:
qlog_path = None
try:
camera_path = next(path for path, filename in files if filename in CAMERA_FILENAMES)
except StopIteration:
camera_path = None
try:
dcamera_path = next(path for path, filename in files if filename in DCAMERA_FILENAMES)
except StopIteration:
dcamera_path = None
try:
ecamera_path = next(path for path, filename in files if filename in ECAMERA_FILENAMES)
except StopIteration:
ecamera_path = None
try:
qcamera_path = next(path for path, filename in files if filename in QCAMERA_FILENAMES)
except StopIteration:
qcamera_path = None
segments.append(Segment(segment, log_path, qlog_path, camera_path, dcamera_path, ecamera_path, qcamera_path))
if len(segments) == 0:
raise ValueError(f'Could not find segments for route {self.name.canonical_name} in data directory {data_dir}')
return sorted(segments, key=lambda seg: seg.name.segment_num)
class Segment:
def __init__(self, name, log_path, qlog_path, camera_path, dcamera_path, ecamera_path, qcamera_path):
self._name = SegmentName(name)
self.log_path = log_path
self.qlog_path = qlog_path
self.camera_path = camera_path
self.dcamera_path = dcamera_path
self.ecamera_path = ecamera_path
self.qcamera_path = qcamera_path
@property
def name(self):
return self._name
class RouteName:
def __init__(self, name_str: str):
self._name_str = name_str
delim = next(c for c in self._name_str if c in ("|", "/"))
self._dongle_id, self._time_str = self._name_str.split(delim)
assert len(self._dongle_id) == 16, self._name_str
assert len(self._time_str) == 20, self._name_str
self._canonical_name = f"{self._dongle_id}|{self._time_str}"
@property
def canonical_name(self) -> str: return self._canonical_name
@property
def dongle_id(self) -> str: return self._dongle_id
@property
def time_str(self) -> str: return self._time_str
def __str__(self) -> str: return self._canonical_name
class SegmentName:
# TODO: add constructor that takes dongle_id, time_str, segment_num and then create instances
# of this class instead of manually constructing a segment name (use canonical_name prop instead)
def __init__(self, name_str: str, allow_route_name=False):
data_dir_path_separator_index = name_str.rsplit("|", 1)[0].rfind("/")
use_data_dir = (data_dir_path_separator_index != -1) and ("|" in name_str)
self._name_str = name_str[data_dir_path_separator_index + 1:] if use_data_dir else name_str
self._data_dir = name_str[:data_dir_path_separator_index] if use_data_dir else None
seg_num_delim = "--" if self._name_str.count("--") == 2 else "/"
name_parts = self._name_str.rsplit(seg_num_delim, 1)
if allow_route_name and len(name_parts) == 1:
name_parts.append("-1") # no segment number
self._route_name = RouteName(name_parts[0])
self._num = int(name_parts[1])
self._canonical_name = f"{self._route_name._dongle_id}|{self._route_name._time_str}--{self._num}"
@property
def canonical_name(self) -> str: return self._canonical_name
@property
def dongle_id(self) -> str: return self._route_name.dongle_id
@property
def time_str(self) -> str: return self._route_name.time_str
@property
def segment_num(self) -> int: return self._num
@property
def route_name(self) -> RouteName: return self._route_name
@property
def data_dir(self) -> str | None: return self._data_dir
def __str__(self) -> str: return self._canonical_name
@cache
def get_max_seg_number_cached(sr: 'SegmentRange') -> int:
try:
api = CommaApi(get_token())
max_seg_number = api.get("/v1/route/" + sr.route_name.replace("/", "|"))["maxqlog"]
assert isinstance(max_seg_number, int)
return max_seg_number
except Exception as e:
raise Exception("unable to get max_segment_number. ensure you have access to this route or the route is public.") from e
class SegmentRange:
def __init__(self, segment_range: str):
m = re.fullmatch(RE.SEGMENT_RANGE, segment_range)
assert m is not None, f"Segment range is not valid {segment_range}"
self.m = m
@property
def route_name(self) -> str:
return self.m.group("route_name")
@property
def dongle_id(self) -> str:
return self.m.group("dongle_id")
@property
def log_id(self) -> str:
return self.m.group("log_id")
@property
def slice(self) -> str:
return self.m.group("slice") or ""
@property
def selector(self) -> str | None:
return self.m.group("selector")
@property
def seg_idxs(self) -> list[int]:
m = re.fullmatch(RE.SLICE, self.slice)
assert m is not None, f"Invalid slice: {self.slice}"
start, end, step = (None if s is None else int(s) for s in m.groups())
# one segment specified
if start is not None and end is None and ':' not in self.slice:
if start < 0:
start += get_max_seg_number_cached(self) + 1
return [start]
s = slice(start, end, step)
# no specified end or using relative indexing, need number of segments
if end is None or end < 0 or (start is not None and start < 0):
return list(range(get_max_seg_number_cached(self) + 1))[s]
else:
return list(range(end + 1))[s]
def __str__(self) -> str:
return f"{self.dongle_id}/{self.log_id}" + (f"/{self.slice}" if self.slice else "") + (f"/{self.selector}" if self.selector else "")
def __repr__(self) -> str:
return self.__str__()

26
tools/lib/sanitizer.py Normal file
View File

@@ -0,0 +1,26 @@
# Utilities for sanitizing routes of only essential data for testing car ports and doing validation.
from openpilot.tools.lib.logreader import LogIterable, LogMessage
def sanitize_vin(vin: str):
# (last 6 digits of vin are serial number https://en.wikipedia.org/wiki/Vehicle_identification_number)
VIN_SENSITIVE = 6
return vin[:-VIN_SENSITIVE] + "X" * VIN_SENSITIVE
def sanitize_msg(msg: LogMessage) -> LogMessage:
if msg.which() == "carParams":
msg = msg.as_builder()
msg.carParams.carVin = sanitize_vin(msg.carParams.carVin)
msg = msg.as_reader()
return msg
PRESERVE_SERVICES = ["can", "carParams", "pandaStates", "pandaStateDEPRECATED"]
def sanitize(lr: LogIterable) -> LogIterable:
filtered = filter(lambda msg: msg.which() in PRESERVE_SERVICES, lr)
sanitized = map(sanitize_msg, filtered)
return sanitized

View File

View File

@@ -0,0 +1,130 @@
import http.server
import os
import shutil
import socket
import pytest
from openpilot.selfdrive.test.helpers import http_server_context
from openpilot.system.hardware.hw import Paths
from openpilot.tools.lib.url_file import URLFile
class CachingTestRequestHandler(http.server.BaseHTTPRequestHandler):
FILE_EXISTS = True
def do_GET(self):
if self.FILE_EXISTS:
self.send_response(206 if "Range" in self.headers else 200, b'1234')
else:
self.send_response(404)
self.end_headers()
def do_HEAD(self):
if self.FILE_EXISTS:
self.send_response(200)
self.send_header("Content-Length", "4")
else:
self.send_response(404)
self.end_headers()
@pytest.fixture
def host():
with http_server_context(handler=CachingTestRequestHandler) as (host, port):
yield f"http://{host}:{port}"
class TestFileDownload:
def test_pipeline_defaults(self, host):
# TODO: parameterize the defaults so we don't rely on hard-coded values in xx
assert URLFile.pool_manager().pools._maxsize == 10# PoolManager num_pools param
pool_manager_defaults = {
"maxsize": 100,
"socket_options": [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),],
}
for k, v in pool_manager_defaults.items():
assert URLFile.pool_manager().connection_pool_kw.get(k) == v
retry_defaults = {
"total": 5,
"backoff_factor": 0.5,
"status_forcelist": [409, 429, 503, 504],
}
for k, v in retry_defaults.items():
assert getattr(URLFile.pool_manager().connection_pool_kw["retries"], k) == v
# ensure caching off by default and cache dir doesn't get created
os.environ.pop("FILEREADER_CACHE", None)
if os.path.exists(Paths.download_cache_root()):
shutil.rmtree(Paths.download_cache_root())
URLFile(f"{host}/test.txt").get_length()
URLFile(f"{host}/test.txt").read()
assert not os.path.exists(Paths.download_cache_root())
def compare_loads(self, url, start=0, length=None):
"""Compares range between cached and non cached version"""
file_cached = URLFile(url, cache=True)
file_downloaded = URLFile(url, cache=False)
file_cached.seek(start)
file_downloaded.seek(start)
assert file_cached.get_length() == file_downloaded.get_length()
assert length + start if length is not None else 0 <= file_downloaded.get_length()
response_cached = file_cached.read(ll=length)
response_downloaded = file_downloaded.read(ll=length)
assert response_cached == response_downloaded
# Now test with cache in place
file_cached = URLFile(url, cache=True)
file_cached.seek(start)
response_cached = file_cached.read(ll=length)
assert file_cached.get_length() == file_downloaded.get_length()
assert response_cached == response_downloaded
def test_small_file(self):
# Make sure we don't force cache
os.environ["FILEREADER_CACHE"] = "0"
small_file_url = "https://raw.githubusercontent.com/commaai/openpilot/master/docs/SAFETY.md"
# If you want large file to be larger than a chunk
# large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/fcamera.hevc"
# Load full small file
self.compare_loads(small_file_url)
file_small = URLFile(small_file_url)
length = file_small.get_length()
self.compare_loads(small_file_url, length - 100, 100)
self.compare_loads(small_file_url, 50, 100)
# Load small file 100 bytes at a time
for i in range(length // 100):
self.compare_loads(small_file_url, 100 * i, 100)
def test_large_file(self):
large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"
# Load the end 100 bytes of both files
file_large = URLFile(large_file_url)
length = file_large.get_length()
self.compare_loads(large_file_url, length - 100, 100)
self.compare_loads(large_file_url)
@pytest.mark.parametrize("cache_enabled", [True, False])
def test_recover_from_missing_file(self, host, cache_enabled):
os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0"
file_url = f"{host}/test.png"
CachingTestRequestHandler.FILE_EXISTS = False
length = URLFile(file_url).get_length()
assert length == -1
CachingTestRequestHandler.FILE_EXISTS = True
length = URLFile(file_url).get_length()
assert length == 4

View File

@@ -0,0 +1,34 @@
import pytest
import requests
from opendbc.car.fingerprints import MIGRATION
from openpilot.tools.lib.comma_car_segments import get_comma_car_segments_database, get_url
from openpilot.tools.lib.logreader import LogReader
from openpilot.tools.lib.route import SegmentRange
@pytest.mark.skip(reason="huggingface is flaky, run this test manually to check for issues")
class TestCommaCarSegments:
def test_database(self):
database = get_comma_car_segments_database()
platforms = database.keys()
assert len(platforms) > 100
def test_download_segment(self):
database = get_comma_car_segments_database()
fp = "SUBARU_FORESTER"
segment = database[fp][0]
sr = SegmentRange(segment)
url = get_url(sr.route_name, sr.slice)
resp = requests.get(url)
assert resp.status_code == 200
lr = LogReader(url)
CP = lr.first("carParams")
assert MIGRATION.get(CP.carFingerprint, CP.carFingerprint) == fp

View File

@@ -0,0 +1,256 @@
import capnp
import contextlib
import io
import shutil
import tempfile
import os
import pytest
import requests
from parameterized import parameterized
from cereal import log as capnp_log
from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_source, parse_indirect, ReadMode, InternalUnavailableException
from openpilot.tools.lib.route import SegmentRange
from openpilot.tools.lib.url_file import URLFileException
NUM_SEGS = 17 # number of segments in the test route
ALL_SEGS = list(range(NUM_SEGS))
TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12"
QLOG_FILE = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"
def noop(segment: LogIterable):
return segment
@contextlib.contextmanager
def setup_source_scenario(mocker, is_internal=False):
internal_source_mock = mocker.patch("openpilot.tools.lib.logreader.internal_source")
internal_source_mock.__name__ = internal_source_mock._mock_name
openpilotci_source_mock = mocker.patch("openpilot.tools.lib.logreader.openpilotci_source")
openpilotci_source_mock.__name__ = openpilotci_source_mock._mock_name
comma_api_source_mock = mocker.patch("openpilot.tools.lib.logreader.comma_api_source")
comma_api_source_mock.__name__ = comma_api_source_mock._mock_name
if is_internal:
internal_source_mock.return_value = [QLOG_FILE]
else:
internal_source_mock.side_effect = InternalUnavailableException
openpilotci_source_mock.return_value = [None]
comma_api_source_mock.return_value = [QLOG_FILE]
yield
class TestLogReader:
@parameterized.expand([
(f"{TEST_ROUTE}", ALL_SEGS),
(f"{TEST_ROUTE.replace('/', '|')}", ALL_SEGS),
(f"{TEST_ROUTE}--0", [0]),
(f"{TEST_ROUTE}--5", [5]),
(f"{TEST_ROUTE}/0", [0]),
(f"{TEST_ROUTE}/5", [5]),
(f"{TEST_ROUTE}/0:10", ALL_SEGS[0:10]),
(f"{TEST_ROUTE}/0:0", []),
(f"{TEST_ROUTE}/4:6", ALL_SEGS[4:6]),
(f"{TEST_ROUTE}/0:-1", ALL_SEGS[0:-1]),
(f"{TEST_ROUTE}/:5", ALL_SEGS[:5]),
(f"{TEST_ROUTE}/2:", ALL_SEGS[2:]),
(f"{TEST_ROUTE}/2:-1", ALL_SEGS[2:-1]),
(f"{TEST_ROUTE}/-1", [ALL_SEGS[-1]]),
(f"{TEST_ROUTE}/-2", [ALL_SEGS[-2]]),
(f"{TEST_ROUTE}/-2:-1", ALL_SEGS[-2:-1]),
(f"{TEST_ROUTE}/-4:-2", ALL_SEGS[-4:-2]),
(f"{TEST_ROUTE}/:10:2", ALL_SEGS[:10:2]),
(f"{TEST_ROUTE}/5::2", ALL_SEGS[5::2]),
(f"https://useradmin.comma.ai/?onebox={TEST_ROUTE}", ALL_SEGS),
(f"https://useradmin.comma.ai/?onebox={TEST_ROUTE.replace('/', '|')}", ALL_SEGS),
(f"https://useradmin.comma.ai/?onebox={TEST_ROUTE.replace('/', '%7C')}", ALL_SEGS),
])
def test_indirect_parsing(self, identifier, expected):
parsed = parse_indirect(identifier)
sr = SegmentRange(parsed)
assert list(sr.seg_idxs) == expected, identifier
@parameterized.expand([
(f"{TEST_ROUTE}", f"{TEST_ROUTE}"),
(f"{TEST_ROUTE.replace('/', '|')}", f"{TEST_ROUTE}"),
(f"{TEST_ROUTE}--5", f"{TEST_ROUTE}/5"),
(f"{TEST_ROUTE}/0/q", f"{TEST_ROUTE}/0/q"),
(f"{TEST_ROUTE}/5:6/r", f"{TEST_ROUTE}/5:6/r"),
(f"{TEST_ROUTE}/5", f"{TEST_ROUTE}/5"),
])
def test_canonical_name(self, identifier, expected):
sr = SegmentRange(identifier)
assert str(sr) == expected
@pytest.mark.parametrize("cache_enabled", [True, False])
def test_direct_parsing(self, mocker, cache_enabled):
file_exists_mock = mocker.patch("openpilot.tools.lib.logreader.file_exists")
os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0"
qlog = tempfile.NamedTemporaryFile(mode='wb', delete=False)
with requests.get(QLOG_FILE, stream=True) as r:
with qlog as f:
shutil.copyfileobj(r.raw, f)
for f in [QLOG_FILE, qlog.name]:
l = len(list(LogReader(f)))
assert l > 100
with pytest.raises(URLFileException) if not cache_enabled else pytest.raises(AssertionError):
l = len(list(LogReader(QLOG_FILE.replace("/3/", "/200/"))))
# file_exists should not be called for direct files
assert file_exists_mock.call_count == 0
@parameterized.expand([
(f"{TEST_ROUTE}///",),
(f"{TEST_ROUTE}---",),
(f"{TEST_ROUTE}/-4:--2",),
(f"{TEST_ROUTE}/-a",),
(f"{TEST_ROUTE}/j",),
(f"{TEST_ROUTE}/0:1:2:3",),
(f"{TEST_ROUTE}/:::3",),
(f"{TEST_ROUTE}3",),
(f"{TEST_ROUTE}-3",),
(f"{TEST_ROUTE}--3a",),
])
def test_bad_ranges(self, segment_range):
with pytest.raises(AssertionError):
_ = SegmentRange(segment_range).seg_idxs
@pytest.mark.parametrize("segment_range, api_call", [
(f"{TEST_ROUTE}/0", False),
(f"{TEST_ROUTE}/:2", False),
(f"{TEST_ROUTE}/0:", True),
(f"{TEST_ROUTE}/-1", True),
(f"{TEST_ROUTE}", True),
])
def test_slicing_api_call(self, mocker, segment_range, api_call):
max_seg_mock = mocker.patch("openpilot.tools.lib.route.get_max_seg_number_cached")
max_seg_mock.return_value = NUM_SEGS
_ = SegmentRange(segment_range).seg_idxs
assert api_call == max_seg_mock.called
@pytest.mark.slow
def test_modes(self):
qlog_len = len(list(LogReader(f"{TEST_ROUTE}/0", ReadMode.QLOG)))
rlog_len = len(list(LogReader(f"{TEST_ROUTE}/0", ReadMode.RLOG)))
assert qlog_len * 6 < rlog_len
@pytest.mark.slow
def test_modes_from_name(self):
qlog_len = len(list(LogReader(f"{TEST_ROUTE}/0/q")))
rlog_len = len(list(LogReader(f"{TEST_ROUTE}/0/r")))
assert qlog_len * 6 < rlog_len
@pytest.mark.slow
def test_list(self):
qlog_len = len(list(LogReader(f"{TEST_ROUTE}/0/q")))
qlog_len_2 = len(list(LogReader([f"{TEST_ROUTE}/0/q", f"{TEST_ROUTE}/0/q"])))
assert qlog_len * 2 == qlog_len_2
@pytest.mark.slow
def test_multiple_iterations(self, mocker):
init_mock = mocker.patch("openpilot.tools.lib.logreader._LogFileReader")
lr = LogReader(f"{TEST_ROUTE}/0/q")
qlog_len1 = len(list(lr))
qlog_len2 = len(list(lr))
# ensure we don't create multiple instances of _LogFileReader, which means downloading the files twice
assert init_mock.call_count == 1
assert qlog_len1 == qlog_len2
@pytest.mark.slow
def test_helpers(self):
lr = LogReader(f"{TEST_ROUTE}/0/q")
assert lr.first("carParams").carFingerprint == "SUBARU OUTBACK 6TH GEN"
assert 0 < len(list(lr.filter("carParams"))) < len(list(lr))
@parameterized.expand([(True,), (False,)])
@pytest.mark.slow
def test_run_across_segments(self, cache_enabled):
os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0"
lr = LogReader(f"{TEST_ROUTE}/0:4")
assert len(lr.run_across_segments(4, noop)) == len(list(lr))
@pytest.mark.slow
def test_auto_mode(self, subtests, mocker):
lr = LogReader(f"{TEST_ROUTE}/0/q")
qlog_len = len(list(lr))
log_paths_mock = mocker.patch("openpilot.tools.lib.route.Route.log_paths")
log_paths_mock.return_value = [None] * NUM_SEGS
# Should fall back to qlogs since rlogs are not available
with subtests.test("interactive_yes"):
mocker.patch("sys.stdin", new=io.StringIO("y\n"))
lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, source=comma_api_source)
log_len = len(list(lr))
assert qlog_len == log_len
with subtests.test("interactive_no"):
mocker.patch("sys.stdin", new=io.StringIO("n\n"))
with pytest.raises(AssertionError):
lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, source=comma_api_source)
with subtests.test("non_interactive"):
lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO, source=comma_api_source)
log_len = len(list(lr))
assert qlog_len == log_len
@pytest.mark.parametrize("is_internal", [True, False])
@pytest.mark.slow
def test_auto_source_scenarios(self, mocker, is_internal):
lr = LogReader(QLOG_FILE)
qlog_len = len(list(lr))
with setup_source_scenario(mocker, is_internal=is_internal):
lr = LogReader(f"{TEST_ROUTE}/0/q")
log_len = len(list(lr))
assert qlog_len == log_len
@pytest.mark.slow
def test_sort_by_time(self):
msgs = list(LogReader(f"{TEST_ROUTE}/0/q"))
assert msgs != sorted(msgs, key=lambda m: m.logMonoTime)
msgs = list(LogReader(f"{TEST_ROUTE}/0/q", sort_by_time=True))
assert msgs == sorted(msgs, key=lambda m: m.logMonoTime)
def test_only_union_types(self):
with tempfile.NamedTemporaryFile() as qlog:
# write valid Event messages
num_msgs = 100
with open(qlog.name, "wb") as f:
f.write(b"".join(capnp_log.Event.new_message().to_bytes() for _ in range(num_msgs)))
msgs = list(LogReader(qlog.name))
assert len(msgs) == num_msgs
[m.which() for m in msgs]
# append non-union Event message
event_msg = capnp_log.Event.new_message()
non_union_bytes = bytearray(event_msg.to_bytes())
non_union_bytes[event_msg.total_size.word_count * 8] = 0xff # set discriminant value out of range using Event word offset
with open(qlog.name, "ab") as f:
f.write(non_union_bytes)
# ensure new message is added, but is not a union type
msgs = list(LogReader(qlog.name))
assert len(msgs) == num_msgs + 1
with pytest.raises(capnp.KjException):
[m.which() for m in msgs]
# should not be added when only_union_types=True
msgs = list(LogReader(qlog.name, only_union_types=True))
assert len(msgs) == num_msgs
[m.which() for m in msgs]

View File

@@ -0,0 +1,63 @@
import pytest
import requests
import tempfile
from collections import defaultdict
import numpy as np
from openpilot.tools.lib.framereader import FrameReader
from openpilot.tools.lib.logreader import LogReader
class TestReaders:
@pytest.mark.skip("skip for bandwidth reasons")
def test_logreader(self):
def _check_data(lr):
hist = defaultdict(int)
for l in lr:
hist[l.which()] += 1
assert hist['carControl'] == 6000
assert hist['logMessage'] == 6857
with tempfile.NamedTemporaryFile(suffix=".bz2") as fp:
r = requests.get("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/raw_log.bz2?raw=true", timeout=10)
fp.write(r.content)
fp.flush()
lr_file = LogReader(fp.name)
_check_data(lr_file)
lr_url = LogReader("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/raw_log.bz2?raw=true")
_check_data(lr_url)
@pytest.mark.skip("skip for bandwidth reasons")
def test_framereader(self):
def _check_data(f):
assert f.frame_count == 1200
assert f.w == 1164
assert f.h == 874
frame_first_30 = f.get(0, 30)
assert len(frame_first_30) == 30
print(frame_first_30[15])
print("frame_0")
frame_0 = f.get(0, 1)
frame_15 = f.get(15, 1)
print(frame_15[0])
assert np.all(frame_first_30[0] == frame_0[0])
assert np.all(frame_first_30[15] == frame_15[0])
with tempfile.NamedTemporaryFile(suffix=".hevc") as fp:
r = requests.get("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/video.hevc?raw=true", timeout=10)
fp.write(r.content)
fp.flush()
fr_file = FrameReader(fp.name)
_check_data(fr_file)
fr_url = FrameReader("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/video.hevc?raw=true")
_check_data(fr_url)

View File

@@ -0,0 +1,27 @@
from collections import namedtuple
from openpilot.tools.lib.route import SegmentName
class TestRouteLibrary:
def test_segment_name_formats(self):
Case = namedtuple('Case', ['input', 'expected_route', 'expected_segment_num', 'expected_data_dir'])
cases = [ Case("a2a0ccea32023010|2023-07-27--13-01-19", "a2a0ccea32023010|2023-07-27--13-01-19", -1, None),
Case("a2a0ccea32023010/2023-07-27--13-01-19--1", "a2a0ccea32023010|2023-07-27--13-01-19", 1, None),
Case("a2a0ccea32023010|2023-07-27--13-01-19/2", "a2a0ccea32023010|2023-07-27--13-01-19", 2, None),
Case("a2a0ccea32023010/2023-07-27--13-01-19/3", "a2a0ccea32023010|2023-07-27--13-01-19", 3, None),
Case("/data/media/0/realdata/a2a0ccea32023010|2023-07-27--13-01-19", "a2a0ccea32023010|2023-07-27--13-01-19", -1, "/data/media/0/realdata"),
Case("/data/media/0/realdata/a2a0ccea32023010|2023-07-27--13-01-19--1", "a2a0ccea32023010|2023-07-27--13-01-19", 1, "/data/media/0/realdata"),
Case("/data/media/0/realdata/a2a0ccea32023010|2023-07-27--13-01-19/2", "a2a0ccea32023010|2023-07-27--13-01-19", 2, "/data/media/0/realdata") ]
def _validate(case):
route_or_segment_name = case.input
s = SegmentName(route_or_segment_name, allow_route_name=True)
assert str(s.route_name) == case.expected_route
assert s.segment_num == case.expected_segment_num
assert s.data_dir == case.expected_data_dir
for case in cases:
_validate(case)

163
tools/lib/url_file.py Normal file
View File

@@ -0,0 +1,163 @@
import logging
import os
import socket
import time
from hashlib import sha256
from urllib3 import PoolManager, Retry
from urllib3.response import BaseHTTPResponse
from urllib3.util import Timeout
from openpilot.common.file_helpers import atomic_write_in_dir
from openpilot.system.hardware.hw import Paths
# Cache chunk size
K = 1000
CHUNK_SIZE = 1000 * K
logging.getLogger("urllib3").setLevel(logging.WARNING)
def hash_256(link: str) -> str:
return sha256((link.split("?")[0]).encode('utf-8')).hexdigest()
class URLFileException(Exception):
pass
class URLFile:
_pool_manager: PoolManager|None = None
@staticmethod
def reset() -> None:
URLFile._pool_manager = None
@staticmethod
def pool_manager() -> PoolManager:
if URLFile._pool_manager is None:
socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),]
retries = Retry(total=5, backoff_factor=0.5, status_forcelist=[409, 429, 503, 504])
URLFile._pool_manager = PoolManager(num_pools=10, maxsize=100, socket_options=socket_options, retries=retries)
return URLFile._pool_manager
def __init__(self, url: str, timeout: int=10, debug: bool=False, cache: bool|None=None):
self._url = url
self._timeout = Timeout(connect=timeout, read=timeout)
self._pos = 0
self._length: int|None = None
self._debug = debug
# True by default, false if FILEREADER_CACHE is defined, but can be overwritten by the cache input
self._force_download = not int(os.environ.get("FILEREADER_CACHE", "0"))
if cache is not None:
self._force_download = not cache
if not self._force_download:
os.makedirs(Paths.download_cache_root(), exist_ok=True)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
def _request(self, method: str, url: str, headers: dict[str, str]|None=None) -> BaseHTTPResponse:
return URLFile.pool_manager().request(method, url, timeout=self._timeout, headers=headers)
def get_length_online(self) -> int:
response = self._request('HEAD', self._url)
if not (200 <= response.status <= 299):
return -1
length = response.headers.get('content-length', 0)
return int(length)
def get_length(self) -> int:
if self._length is not None:
return self._length
file_length_path = os.path.join(Paths.download_cache_root(), hash_256(self._url) + "_length")
if not self._force_download and os.path.exists(file_length_path):
with open(file_length_path) as file_length:
content = file_length.read()
self._length = int(content)
return self._length
self._length = self.get_length_online()
if not self._force_download and self._length != -1:
with atomic_write_in_dir(file_length_path, mode="w", overwrite=True) as file_length:
file_length.write(str(self._length))
return self._length
def read(self, ll: int|None=None) -> bytes:
if self._force_download:
return self.read_aux(ll=ll)
file_begin = self._pos
file_end = self._pos + ll if ll is not None else self.get_length()
assert file_end != -1, f"Remote file is empty or doesn't exist: {self._url}"
# We have to align with chunks we store. Position is the begginiing of the latest chunk that starts before or at our file
position = (file_begin // CHUNK_SIZE) * CHUNK_SIZE
response = b""
while True:
self._pos = position
chunk_number = self._pos / CHUNK_SIZE
file_name = hash_256(self._url) + "_" + str(chunk_number)
full_path = os.path.join(Paths.download_cache_root(), str(file_name))
data = None
# If we don't have a file, download it
if not os.path.exists(full_path):
data = self.read_aux(ll=CHUNK_SIZE)
with atomic_write_in_dir(full_path, mode="wb", overwrite=True) as new_cached_file:
new_cached_file.write(data)
else:
with open(full_path, "rb") as cached_file:
data = cached_file.read()
response += data[max(0, file_begin - position): min(CHUNK_SIZE, file_end - position)]
position += CHUNK_SIZE
if position >= file_end:
self._pos = file_end
return response
def read_aux(self, ll: int|None=None) -> bytes:
download_range = False
headers = {}
if self._pos != 0 or ll is not None:
if ll is None:
end = self.get_length() - 1
else:
end = min(self._pos + ll, self.get_length()) - 1
if self._pos >= end:
return b""
headers['Range'] = f"bytes={self._pos}-{end}"
download_range = True
if self._debug:
t1 = time.time()
response = self._request('GET', self._url, headers=headers)
ret = response.data
if self._debug:
t2 = time.time()
if t2 - t1 > 0.1:
print(f"get {self._url} {headers!r} {t2 - t1:.3f} slow")
response_code = response.status
if response_code == 416: # Requested Range Not Satisfiable
raise URLFileException(f"Error, range out of bounds {response_code} {headers} ({self._url}): {repr(ret)[:500]}")
if download_range and response_code != 206: # Partial Content
raise URLFileException(f"Error, requested range but got unexpected response {response_code} {headers} ({self._url}): {repr(ret)[:500]}")
if (not download_range) and response_code != 200: # OK
raise URLFileException(f"Error {response_code} {headers} ({self._url}): {repr(ret)[:500]}")
self._pos += len(ret)
return ret
def seek(self, pos:int) -> None:
self._pos = pos
@property
def name(self) -> str:
return self._url
os.register_at_fork(after_in_child=URLFile.reset)

311
tools/lib/vidindex.py Executable file
View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python3
import argparse
import os
import struct
from enum import IntEnum
from openpilot.tools.lib.filereader import FileReader
DEBUG = int(os.getenv("DEBUG", "0"))
# compare to ffmpeg parsing
# ffmpeg -i <input.hevc> -c copy -bsf:v trace_headers -f null - 2>&1 | grep -B4 -A32 '] 0 '
# H.265 specification
# https://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-H.265-201802-S!!PDF-E&type=items
NAL_UNIT_START_CODE = b"\x00\x00\x01"
NAL_UNIT_START_CODE_SIZE = len(NAL_UNIT_START_CODE)
NAL_UNIT_HEADER_SIZE = 2
class HevcNalUnitType(IntEnum):
TRAIL_N = 0 # RBSP structure: slice_segment_layer_rbsp( )
TRAIL_R = 1 # RBSP structure: slice_segment_layer_rbsp( )
TSA_N = 2 # RBSP structure: slice_segment_layer_rbsp( )
TSA_R = 3 # RBSP structure: slice_segment_layer_rbsp( )
STSA_N = 4 # RBSP structure: slice_segment_layer_rbsp( )
STSA_R = 5 # RBSP structure: slice_segment_layer_rbsp( )
RADL_N = 6 # RBSP structure: slice_segment_layer_rbsp( )
RADL_R = 7 # RBSP structure: slice_segment_layer_rbsp( )
RASL_N = 8 # RBSP structure: slice_segment_layer_rbsp( )
RASL_R = 9 # RBSP structure: slice_segment_layer_rbsp( )
RSV_VCL_N10 = 10
RSV_VCL_R11 = 11
RSV_VCL_N12 = 12
RSV_VCL_R13 = 13
RSV_VCL_N14 = 14
RSV_VCL_R15 = 15
BLA_W_LP = 16 # RBSP structure: slice_segment_layer_rbsp( )
BLA_W_RADL = 17 # RBSP structure: slice_segment_layer_rbsp( )
BLA_N_LP = 18 # RBSP structure: slice_segment_layer_rbsp( )
IDR_W_RADL = 19 # RBSP structure: slice_segment_layer_rbsp( )
IDR_N_LP = 20 # RBSP structure: slice_segment_layer_rbsp( )
CRA_NUT = 21 # RBSP structure: slice_segment_layer_rbsp( )
RSV_IRAP_VCL22 = 22
RSV_IRAP_VCL23 = 23
RSV_VCL24 = 24
RSV_VCL25 = 25
RSV_VCL26 = 26
RSV_VCL27 = 27
RSV_VCL28 = 28
RSV_VCL29 = 29
RSV_VCL30 = 30
RSV_VCL31 = 31
VPS_NUT = 32 # RBSP structure: video_parameter_set_rbsp( )
SPS_NUT = 33 # RBSP structure: seq_parameter_set_rbsp( )
PPS_NUT = 34 # RBSP structure: pic_parameter_set_rbsp( )
AUD_NUT = 35
EOS_NUT = 36
EOB_NUT = 37
FD_NUT = 38
PREFIX_SEI_NUT = 39
SUFFIX_SEI_NUT = 40
RSV_NVCL41 = 41
RSV_NVCL42 = 42
RSV_NVCL43 = 43
RSV_NVCL44 = 44
RSV_NVCL45 = 45
RSV_NVCL46 = 46
RSV_NVCL47 = 47
UNSPEC48 = 48
UNSPEC49 = 49
UNSPEC50 = 50
UNSPEC51 = 51
UNSPEC52 = 52
UNSPEC53 = 53
UNSPEC54 = 54
UNSPEC55 = 55
UNSPEC56 = 56
UNSPEC57 = 57
UNSPEC58 = 58
UNSPEC59 = 59
UNSPEC60 = 60
UNSPEC61 = 61
UNSPEC62 = 62
UNSPEC63 = 63
# B.2.2 Byte stream NAL unit semantics
# - The nal_unit_type within the nal_unit( ) syntax structure is equal to VPS_NUT, SPS_NUT or PPS_NUT.
# - The byte stream NAL unit syntax structure contains the first NAL unit of an access unit in decoding
# order, as specified in clause 7.4.2.4.4.
HEVC_PARAMETER_SET_NAL_UNITS = (
HevcNalUnitType.VPS_NUT,
HevcNalUnitType.SPS_NUT,
HevcNalUnitType.PPS_NUT,
)
# 3.29 coded slice segment NAL unit: A NAL unit that has nal_unit_type in the range of TRAIL_N to RASL_R,
# inclusive, or in the range of BLA_W_LP to RSV_IRAP_VCL23, inclusive, which indicates that the NAL unit
# contains a coded slice segment
HEVC_CODED_SLICE_SEGMENT_NAL_UNITS = (
HevcNalUnitType.TRAIL_N,
HevcNalUnitType.TRAIL_R,
HevcNalUnitType.TSA_N,
HevcNalUnitType.TSA_R,
HevcNalUnitType.STSA_N,
HevcNalUnitType.STSA_R,
HevcNalUnitType.RADL_N,
HevcNalUnitType.RADL_R,
HevcNalUnitType.RASL_N,
HevcNalUnitType.RASL_R,
HevcNalUnitType.BLA_W_LP,
HevcNalUnitType.BLA_W_RADL,
HevcNalUnitType.BLA_N_LP,
HevcNalUnitType.IDR_W_RADL,
HevcNalUnitType.IDR_N_LP,
HevcNalUnitType.CRA_NUT,
)
class VideoFileInvalid(Exception):
pass
def get_ue(dat: bytes, start_idx: int, skip_bits: int) -> tuple[int, int]:
prefix_val = 0
prefix_len = 0
suffix_val = 0
suffix_len = 0
i = start_idx
while i < len(dat):
j = 7
while j >= 0:
if skip_bits > 0:
skip_bits -= 1
elif prefix_val == 0:
prefix_val = (dat[i] >> j) & 1
prefix_len += 1
else:
suffix_val = (suffix_val << 1) | ((dat[i] >> j) & 1)
suffix_len += 1
j -= 1
if prefix_val == 1 and prefix_len - 1 == suffix_len:
val = 2**(prefix_len-1) - 1 + suffix_val
size = prefix_len + suffix_len
return val, size
i += 1
raise VideoFileInvalid("invalid exponential-golomb code")
def require_nal_unit_start(dat: bytes, nal_unit_start: int) -> None:
if nal_unit_start < 1:
raise ValueError("start index must be greater than zero")
if dat[nal_unit_start:nal_unit_start + NAL_UNIT_START_CODE_SIZE] != NAL_UNIT_START_CODE:
raise VideoFileInvalid("data must begin with start code")
def get_hevc_nal_unit_length(dat: bytes, nal_unit_start: int) -> int:
try:
pos = dat.index(NAL_UNIT_START_CODE, nal_unit_start + NAL_UNIT_START_CODE_SIZE)
except ValueError:
pos = -1
# length of NAL unit is byte count up to next NAL unit start index
nal_unit_len = (pos if pos != -1 else len(dat)) - nal_unit_start
if DEBUG:
print(" nal_unit_len:", nal_unit_len)
return nal_unit_len
def get_hevc_nal_unit_type(dat: bytes, nal_unit_start: int) -> HevcNalUnitType:
# 7.3.1.2 NAL unit header syntax
# nal_unit_header( ) { // descriptor
# forbidden_zero_bit f(1)
# nal_unit_type u(6)
# nuh_layer_id u(6)
# nuh_temporal_id_plus1 u(3)
# }
header_start = nal_unit_start + NAL_UNIT_START_CODE_SIZE
nal_unit_header = dat[header_start:header_start + NAL_UNIT_HEADER_SIZE]
if len(nal_unit_header) != 2:
raise VideoFileInvalid("data to short to contain nal unit header")
nal_unit_type = HevcNalUnitType((nal_unit_header[0] >> 1) & 0x3F)
if DEBUG:
print(" nal_unit_type:", nal_unit_type.name, f"({nal_unit_type.value})")
return nal_unit_type
def get_hevc_slice_type(dat: bytes, nal_unit_start: int, nal_unit_type: HevcNalUnitType) -> tuple[int, bool]:
# 7.3.2.9 Slice segment layer RBSP syntax
# slice_segment_layer_rbsp( ) {
# slice_segment_header( )
# slice_segment_data( )
# rbsp_slice_segment_trailing_bits( )
# }
# ...
# 7.3.6.1 General slice segment header syntax
# slice_segment_header( ) { // descriptor
# first_slice_segment_in_pic_flag u(1)
# if( nal_unit_type >= BLA_W_LP && nal_unit_type <= RSV_IRAP_VCL23 )
# no_output_of_prior_pics_flag u(1)
# slice_pic_parameter_set_id ue(v)
# if( !first_slice_segment_in_pic_flag ) {
# if( dependent_slice_segments_enabled_flag )
# dependent_slice_segment_flag u(1)
# slice_segment_address u(v)
# }
# if( !dependent_slice_segment_flag ) {
# for( i = 0; i < num_extra_slice_header_bits; i++ )
# slice_reserved_flag[ i ] u(1)
# slice_type ue(v)
# ...
rbsp_start = nal_unit_start + NAL_UNIT_START_CODE_SIZE + NAL_UNIT_HEADER_SIZE
skip_bits = 0
# 7.4.7.1 General slice segment header semantics
# first_slice_segment_in_pic_flag equal to 1 specifies that the slice segment is the first slice segment of the picture in
# decoding order. first_slice_segment_in_pic_flag equal to 0 specifies that the slice segment is not the first slice segment
# of the picture in decoding order.
is_first_slice = dat[rbsp_start] >> 7 & 1 == 1
if not is_first_slice:
# TODO: parse dependent_slice_segment_flag and slice_segment_address and get real slice_type
# for now since we don't use it return -1 for slice_type
return (-1, is_first_slice)
skip_bits += 1 # skip past first_slice_segment_in_pic_flag
if nal_unit_type >= HevcNalUnitType.BLA_W_LP and nal_unit_type <= HevcNalUnitType.RSV_IRAP_VCL23:
# 7.4.7.1 General slice segment header semantics
# no_output_of_prior_pics_flag affects the output of previously-decoded pictures in the decoded picture buffer after the
# decoding of an IDR or a BLA picture that is not the first picture in the bitstream as specified in Annex C.
skip_bits += 1 # skip past no_output_of_prior_pics_flag
# 7.4.7.1 General slice segment header semantics
# slice_pic_parameter_set_id specifies the value of pps_pic_parameter_set_id for the PPS in use.
# The value of slice_pic_parameter_set_id shall be in the range of 0 to 63, inclusive.
_, size = get_ue(dat, rbsp_start, skip_bits)
skip_bits += size # skip past slice_pic_parameter_set_id
# 7.4.3.3.1 General picture parameter set RBSP semanal_unit_lenntics
# num_extra_slice_header_bits specifies the number of extra slice header bits that are present in the slice header RBSP
# for coded pictures referring to the PPS. The value of num_extra_slice_header_bits shall be in the range of 0 to 2, inclusive,
# in bitstreams conforming to this version of this Specification. Other values for num_extra_slice_header_bits are reserved
# for future use by ITU-T | ISO/IEC. However, decoders shall allow num_extra_slice_header_bits to have any value.
# TODO: get from PPS_NUT pic_parameter_set_rbsp( ) for corresponding slice_pic_parameter_set_id
num_extra_slice_header_bits = 0
skip_bits += num_extra_slice_header_bits
# 7.4.7.1 General slice segment header semantics
# slice_type specifies the coding type of the slice according to Table 7-7.
# Table 7-7 - Name association to slice_type
# slice_type | Name of slice_type
# 0 | B (B slice)
# 1 | P (P slice)
# 2 | I (I slice)
# unsigned integer 0-th order Exp-Golomb-coded syntax element with the left bit first
slice_type, _ = get_ue(dat, rbsp_start, skip_bits)
if DEBUG:
print(" slice_type:", slice_type, f"(first slice: {is_first_slice})")
if slice_type > 2:
raise VideoFileInvalid("slice_type must be 0, 1, or 2")
return slice_type, is_first_slice
def hevc_index(hevc_file_name: str, allow_corrupt: bool=False) -> tuple[list, int, bytes]:
with FileReader(hevc_file_name) as f:
dat = f.read()
if len(dat) < NAL_UNIT_START_CODE_SIZE + 1:
raise VideoFileInvalid("data is too short")
if dat[0] != 0x00:
raise VideoFileInvalid("first byte must be 0x00")
prefix_dat = b""
frame_types = list()
i = 1 # skip past first byte 0x00
try:
while i < len(dat):
require_nal_unit_start(dat, i)
nal_unit_len = get_hevc_nal_unit_length(dat, i)
nal_unit_type = get_hevc_nal_unit_type(dat, i)
if nal_unit_type in HEVC_PARAMETER_SET_NAL_UNITS:
prefix_dat += dat[i:i+nal_unit_len]
elif nal_unit_type in HEVC_CODED_SLICE_SEGMENT_NAL_UNITS:
slice_type, is_first_slice = get_hevc_slice_type(dat, i, nal_unit_type)
if is_first_slice:
frame_types.append((slice_type, i))
i += nal_unit_len
except Exception as e:
if not allow_corrupt:
raise
print(f"ERROR: NAL unit skipped @ {i}\n", str(e))
return frame_types, len(dat), prefix_dat
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("input_file", type=str)
parser.add_argument("output_prefix_file", type=str)
parser.add_argument("output_index_file", type=str)
args = parser.parse_args()
frame_types, dat_len, prefix_dat = hevc_index(args.input_file)
with open(args.output_prefix_file, "wb") as f:
f.write(prefix_dat)
with open(args.output_index_file, "wb") as f:
for ft, fp in frame_types:
f.write(struct.pack("<II", ft, fp))
f.write(struct.pack("<II", 0xFFFFFFFF, dat_len))
if __name__ == "__main__":
main()