Release 260111

This commit is contained in:
Comma Device
2026-01-11 18:23:29 +08:00
commit 3721ecbf8a
2601 changed files with 855070 additions and 0 deletions

View File

341
tinygrad/engine/jit.py Normal file
View File

@@ -0,0 +1,341 @@
from typing import TypeVar, Generic, Callable, cast, Any
import functools, collections
from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters
from dataclasses import dataclass
from weakref import WeakKeyDictionary
class GraphException(Exception): pass
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
graphed_jit_cache: list[ExecItem] = []
current_batch: list[ExecItem] = []
current_batch_devs: list[Compiled] = []
def flush_batch():
nonlocal current_batch, current_batch_devs, max_batch_size
try:
if len(current_batch_devs) == 0: raise GraphException("no device for graph")
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals)
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Buffer|None], input_rawbuffers)))
max_batch_size *= 2
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
except GraphException as e:
graphed_jit_cache.extend(current_batch)
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_batch_devs[0]}: {e}")
current_batch = []
current_batch_devs = []
for ji in jit_cache:
match ji.prg:
case CompiledRunner(): ji_graph_dev = ji.prg.dev
case BufferXfer(): ji_graph_dev = Device[unwrap(ji.bufs[0]).device]
case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device != "CPU"), None)
case ViewOp(): continue # ViewOps are just ignored
case _: ji_graph_dev = None # Everything else is not graphed and flushes existing graph if it's being constructed
# Check if this jit item can be graphed at all, so check if a new graph supports the current item.
can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji)
# Check if the current batch can be extended with this item.
can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \
graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji)
can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
# Flush the current batch if any, since it can't be extended or is full.
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
(current_batch if can_be_graphed else graphed_jit_cache).append(ji)
current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else []
if len(current_batch) > 0: flush_batch()
return graphed_jit_cache
def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]:
input_replace: dict[tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.bufs):
if a in input_rawbuffers:
input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace
class GraphRunner(Runner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
self.vars = sorted(var_vals.keys())
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
estimates = Estimates()
for j,ji in enumerate(jit_cache):
estimates += ji.prg.estimates
if isinstance(ji.prg, CompiledRunner):
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars]
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
if global_dim_idx is not None or local_dim_idx is not None:
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
# used in MultiGraphRunner. the ints are id() of _bufs
self.w_dependency_map: dict[int, Any] = {}
self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
def updated_vars(self, var_vals: dict[str, int]):
vals = [var_vals[v] for v in self.vars]
for j, vidxs in self.var_vals_replace.items():
for i, v in vidxs: yield j, i, vals[v]
def updated_launch_dims(self, var_vals: dict[str, int]):
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
for j, (gl, lc) in self.launch_dims_replace.items():
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any):
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
wait_nodes = []
for i,rawbuf in enumerate(rawbufs):
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
if i in write:
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
for i,rawbuf in enumerate(rawbufs):
if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
return list({id(x):x for x in wait_nodes}.values())
@staticmethod
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool: return isinstance(ei.prg, CompiledRunner) and len(dedup(devs)) == 1
# a marker for your graph supporting multiple devices of the same type
class MultiGraphRunner(GraphRunner):
@staticmethod
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
# Devices must be the same type
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) and len(dedup([type(Device[b.device]) for b in ei.bufs if b]+[type(d) for d in devs]))==1
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
if isinstance(ei.prg, (BufferCopy, BufferXfer)): return [cast(Buffer, ei.bufs[0])]
return []
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
for ei in jit_cache:
if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei))
ReturnType = TypeVar('ReturnType')
@dataclass
class CapturedJit(Generic[ReturnType]):
ret: Any # includes the Tensors or any other returned object
jit_cache: list[ExecItem]
input_replace: dict[tuple[int, int], int]
extra_view_inputs: list[tuple[int, int, str, int, DType]]
expected_names: list[int|str]
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
def __reduce__(self):
# TODO: free_intermediates here? replan_buffers_memory_layout here?
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
self.expected_names, self.expected_st_vars_dtype_device)
def __post_init__(self):
self._jit_cache: list[ExecItem] = self.jit_cache
self._input_replace: dict[tuple[int, int], int] = self.input_replace
self._first_run = True
self._clear_inputs()
def _clear_inputs(self):
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
def free_intermediates(self):
depends: set[Buffer|None] = set([None])
update_depends(depends, self.jit_cache)
for b in depends:
if b is not None:
if b.is_allocated(): b.deallocate()
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
self.__post_init__() # reset the graph state
def replan_buffers_memory_layout(self):
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
for old, new in asgn.items():
if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer())
self.__post_init__()
# jit exec
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
# assign inputs
for idx, offset, device, size, dtype in self.extra_view_inputs:
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
# Condense the items into a graph executor.
if self._first_run:
# allocate intermediates if freed
for ji in self.jit_cache:
for b in ji.bufs:
if b is not None: b.ensure_allocated()
# create graph if needed
if JIT < 2:
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=JIT_BATCH_SIZE.value)
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
self._first_run = False
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
for ei in self._jit_cache: ei.run(var_vals, jit=True)
self._clear_inputs()
return self.ret
def _prepare_jit_inputs(args, kwargs):
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
# TODO: this multi unpack stuff is not well tested.
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
for lb in lbs if lb.base.realized is not None])
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
var_vals = {k.expr:v for k,v in _var_vals.items()}
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
return input_buffers, var_vals, names, st_vars_dtype_device
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False):
assert fxn or captured, "need either a function or a CapturedJit"
self.fxn = fxn
self.captured: CapturedJit|None = captured
self.cnt: int = 2 if self.fxn is None else 0
self.prune = prune
self.optimize = optimize
def add_buffer(self, b:Buffer) -> Buffer:
if found:=self._buffer_replace.get(b, None): return found
if b.is_allocated() or b.uop_refcount > 0: return b
if b._base is not None:
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
else:
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
return ret
def add(self, ei:ExecItem):
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars))
def reset(self):
assert self.fxn is not None, "can't reset without function"
self.cnt = 0
self.captured = None
def __reduce__(self):
assert self.captured is not None, "can't pickle an uncaptured JIT"
return self.__class__, (None, self.captured)
# keep legacy code working
@property
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
@property
def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
def __call__(self, *args, **kwargs) -> ReturnType:
input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
if not JIT or self.cnt == 0:
# jit ignore
assert self.fxn is not None
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
elif self.cnt == 1:
# jit capture
assert self.fxn is not None
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
self._jit_cache: list[ExecItem] = []
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
# TODO: should we always disable the memory planner here? it must be off for prune
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
except Exception as e: raise e
finally: capturing.clear()
jit_cache = self._jit_cache
del self._buffer_replace, self._jit_cache
assert len(jit_cache), "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
# track inputs that are views of buffers
# TODO: eventually expected_buffers should live in ExecItem
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
for item in jit_cache:
for b in item.bufs:
if b is not None and b._base is not None and b._base in input_buffers:
input_buffers.append(b)
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
# prune independent kernels (optional)
if self.prune:
depends = set(input_buffers)
update_depends(depends, jit_cache)
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
# run the onetime kernels here
for ei in onetime:
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
ei.run(var_vals, jit=True)
jit_cache = pruned
# memory planning (optional)
# Exclude buffers involved in transfer ops to preserve parallelism.
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy)) for b in ji.bufs}
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
item.metadata, item.fixedvars) for item in jit_cache]
input_replace = get_input_replace(jit_cache, input_buffers)
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
# set this for next run
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
if self.optimize: self.captured.replan_buffers_memory_layout()
elif self.cnt >= 2:
# jit exec
assert self.captured is not None
assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}"
assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \
f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
ret = self.captured(input_buffers, var_vals)
self.cnt += 1
return ret

70
tinygrad/engine/memory.py Normal file
View File

@@ -0,0 +1,70 @@
from typing import cast
from collections import defaultdict
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.device import Device, Buffer
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
from tinygrad.uop.ops import Ops
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.runtime.support.memory import TLSFAllocator
# **************** memory planning ****************
def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
if NO_MEMORY_PLANNER: return {}
first_appearance, last_appearance, buf_to_opt = {}, {}, set()
for i,u in enumerate(buffers):
for buf in u:
should_skip = buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers)
if not ignore_checks and should_skip: continue
if buf.base not in first_appearance: first_appearance[buf.base] = i
last_appearance[buf.base] = i
buf_to_opt.add(buf)
# Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed.
buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \
[((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
total_memory = sum(round_up(buf.nbytes, min_block_size:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%)
# Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator.
# Also track buffer replacements for buffers that do not support suballocation.
buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {}
reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list)
global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=min_block_size, lv2_cnt=32)))
for (_, is_open_ev), buf in buffer_requests:
# Check if suballocation is possible for the given buffer and device.
if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType):
if is_open_ev: buffer_replace[buf] = (None, global_planner[buf.device][1].alloc(round_up(buf.nbytes, 0x1000)))
else: global_planner[buf.device][1].free(cast(int, buffer_replace[buf][1]))
global_planner[buf.device] = (max(global_planner[buf.device][0], buffer_replace[buf][1] + buf.nbytes), global_planner[buf.device][1])
else:
key = (buf.device, buf.dtype, buf.options, buf.nbytes)
if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None)
else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0]))
# Allocate global buffers based on the memory planner.
global_buffers = {dev: Buffer(dev, round_up(sz, 0x1000), dtypes.int8) for dev, (sz, _) in global_planner.items()}
buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[buf.device], off) for buf,(base,off) in buffer_replace.items()}
# Assign buffers. First, assign full buffers (not sub-buffers).
assigned:dict[Buffer, Buffer] = {}
for buf, (base, off) in buffer_resolve.items():
if buf != base:
assigned[buf] = base if off is None else Buffer(buf.device, buf.size, buf.dtype, base=base, offset=off)
# Now assign sub-buffers.
for buf in buf_to_opt:
if buf._base is not None:
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=(pbuf:=assigned.get(buf.base, buf.base)).base, offset=pbuf.offset+buf.offset)
if DEBUG >= 1:
ak, av = dedup(x for x in assigned.keys() if x._base is None),dedup(x for x in assigned.values() if x._base is None)+list(global_buffers.values())
omem, nmem = sum([x.nbytes for x in ak])/1e6, sum([x.nbytes for x in av])/1e6
if omem != nmem: print(f"{debug_prefix}memory reduced from {omem:.2f} MB -> {nmem:.2f} MB,", f"{len(ak)} -> {len(av)} bufs")
return assigned
def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([list(si.bufs) for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.fixedvars) for si in schedule]

233
tinygrad/engine/realize.py Normal file
View File

@@ -0,0 +1,233 @@
from typing import cast, Generator, Callable
import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.codegen import full_rewrite
from tinygrad.codegen.opt import Opt
# **************** Program Creation ****************
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
"""
Transform an AST into a ProgramSpec. May trigger BEAM search.
Args:
ast: The Ops.SINK rooted AST
renderer: The renderer used to generate the code
Returns:
The ProgramSpec of the program.
"""
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print('\n'.join(pyrender(ast)))
# linearize
if renderer is None: renderer = Device.default.renderer
if opts is not None:
assert ast.arg is None, "can't apply opts if sink has an arg"
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
try:
uops = full_rewrite(ast, renderer)
except RuntimeError as e:
print("***** LINEARIZE FAILURE *****")
print(e)
print('\n'.join(pyrender(ast)))
raise
assert uops[-1].op is Ops.SINK, "last uop must be sink"
# print and render
if DEBUG >= 6: print_uops(uops)
src = renderer.render(uops)
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
local_size=[1,1,1] if renderer.has_local else None)
# **************** Runners ****************
class Runner:
def __init__(self, display_name:str, device:str, estimates=Estimates()):
self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
@property
def dev(self): return Device[self.device]
def exec(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None) -> float|None:
return self(rawbufs, {} if var_vals is None else var_vals)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
raise NotImplementedError("override this")
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try:
return _prg(*[x._buf for x in test_rawbuffers],global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)],
local_size=local_size, wait=True)
except Exception: return float('inf')
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
return ret[1]
class CompiledRunner(Runner):
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None):
if DEBUG >= 4: print(p.src)
self.p:ProgramSpec = p
if precompiled is not None: self.lib = precompiled
else:
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
self.lib = Device[p.device].compiler.compile_cached(p.src)
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
super().__init__(p.name, p.device, p.estimates)
def __reduce__(self): return self.__class__, (self.p, self.lib)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
has_local = Device[self.p.device].renderer.has_local
global_size, local_size = self.p.launch_dims(var_vals)
if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
local_size = optimize_local_size(self._prg, global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
self.p = replace(self.p, global_size=global_size, local_size=local_size)
lra = {}
if global_size:
lra['global_size'] = tuple(global_size)
assert len(global_size) == 3, "global size must have len 3"
if local_size:
lra['local_size'] = tuple(local_size)
assert len(local_size) == 3, "local size must have len 3"
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait)
class ViewOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
class BufferCopy(Runner):
def __init__(self, total_sz, dest_device, src_device):
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
def copy(self, dest, src):
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
else:
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
dest, src = rawbufs[0:2]
assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
st = time.perf_counter()
self.copy(dest, src)
if wait:
Device[dest.device].synchronize()
return time.perf_counter() - st
class BufferXfer(BufferCopy):
def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev)
# **************** method cache ****************
method_cache: dict[tuple[str, type, bytes, tuple[int, ...], bool], CompiledRunner] = {}
def get_runner(device:str, ast:UOp) -> CompiledRunner:
# TODO: this should be all context relevant to rendering
context = (BEAM.value, NOOPT.value, DEVECTORIZE.value)
ckey = (device, type(Device[device].compiler), ast.key, context, False)
if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
else:
prg: ProgramSpec = get_program(ast, Device[device].renderer)
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
return ret
# **************** lowering functions ****************
@dataclass(frozen=True)
class ExecItem:
prg: Runner
bufs: list[Buffer|None]
metadata: tuple[Metadata, ...]|None = None
fixedvars: dict[str, int] = field(default_factory=dict)
def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
if PROFILE: cpu_events.append(ProfilePointEvent(self.prg.device, "exec", self.prg.display_name, {"metadata":self.metadata, "var_vals":var_vals}))
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
if do_update_stats:
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals))
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals))
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 2:
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
header_color = 'magenta' if jit else ('green' if self.prg.first_run else None)
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20)
flops_str = f"{flops*1e-9:9.2f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:9.2f} TFLOPS", 'green')
mem_str = f"{membw*1e-9:6.1f}|{ldsbw*1e-9:<7.1f} GB/s" if membw < 1e13 else colored(f"{membw*1e-12:6.1f}|{ldsbw*1e-12:<7.1f} TB/s", 'green')
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+
f" {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB"+
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}")
self.prg.first_run = False
return et
# NOTE: ctx is the buffers
si_lowerer = PatternMatcher([
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
])
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars)
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
while len(schedule):
si = schedule.pop(0)
try: yield (si, lower_schedule_item(si))
except Exception as e:
if DEBUG >= 2:
print(f"error lowering {si.ast.op}")
print("tensor operations:")
pprint.pprint(si.metadata, indent=2)
raise e
# **************** main run function ****************
capturing: list = [] # put classes with an add method in here
def run_schedule(schedule:list[ScheduleItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
for si, ei in lower_schedule(schedule):
if len(capturing) and CAPTURING: capturing[0].add(ei)
if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
# copy in allocated buffers from the GPU
nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs)
for cpu_b, gpu_b in zip(nb, si.bufs):
if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
# run on GPU
ei.run(var_vals, do_update_stats=do_update_stats)
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
with Context(BEAM=0): lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats)
import numpy as np
np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
else:
ei.run(var_vals, do_update_stats=do_update_stats)

View File

@@ -0,0 +1,83 @@
from typing import cast
from dataclasses import dataclass, field
from collections import deque, defaultdict
from tinygrad.uop.ops import UOp, Ops, buffers
from tinygrad.device import Device, Buffer, MultiBuffer
from tinygrad.helpers import Metadata, all_same
# **** ScheduleItem return type
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...] = ()
fixedvars: dict[str, int] = field(default_factory=dict)
# **** schedule linearizer
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
# construct the KERNEL children graph based on assigns
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree: dict[UOp, int] = {}
var_vals: dict[str, int] = {}
for u in sched_sink.toposort():
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
k = u.src[1]
in_degree.setdefault(k, 0)
for s in k.src:
if s.op is Ops.ASSIGN:
children[s.src[1]].append(k)
in_degree[k] += 1
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.ASSIGN, f"ss.op is not ASSIGN, it's {ss.op}"
children[ss.src[1]].append(k)
in_degree[k] += 1
elif s.op is Ops.BUFFER:
pass # a BUFFER is already realized, nothing to do here
elif s.op is Ops.BIND:
var, val = s.unbind()
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
else:
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
# linearize KERNEL UOps into ScheduleItems in BFS order
def _heuristic(k: UOp):
if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000
return 0
last_heuristic: int = 0
queues: defaultdict[int, deque[UOp]] = defaultdict(deque)
last_queue: deque[UOp] = deque()
for k,v in in_degree.items():
if v == 0: queues[_heuristic(k)].append(k)
schedule: list[ScheduleItem] = []
while last_queue or any(queues.values()):
if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
k = last_queue.popleft()
ast = k.arg.ast
# create subbuffers if needed
if ast.op is Ops.BUFFER_VIEW:
base = k.src[1].buf_uop.buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
if any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}))
else:
# ONE -> ONE
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
for x in children[k]:
in_degree[x] -= 1
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
return schedule, var_vals