Release 260111
This commit is contained in:
0
tinygrad/engine/__init__.py
Normal file
0
tinygrad/engine/__init__.py
Normal file
341
tinygrad/engine/jit.py
Normal file
341
tinygrad/engine/jit.py
Normal file
@@ -0,0 +1,341 @@
|
||||
from typing import TypeVar, Generic, Callable, cast, Any
|
||||
import functools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap
|
||||
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from dataclasses import dataclass
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
class GraphException(Exception): pass
|
||||
|
||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
graphed_jit_cache: list[ExecItem] = []
|
||||
current_batch: list[ExecItem] = []
|
||||
current_batch_devs: list[Compiled] = []
|
||||
|
||||
def flush_batch():
|
||||
nonlocal current_batch, current_batch_devs, max_batch_size
|
||||
try:
|
||||
if len(current_batch_devs) == 0: raise GraphException("no device for graph")
|
||||
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
|
||||
graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals)
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
||||
graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Buffer|None], input_rawbuffers)))
|
||||
max_batch_size *= 2
|
||||
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
|
||||
except GraphException as e:
|
||||
graphed_jit_cache.extend(current_batch)
|
||||
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_batch_devs[0]}: {e}")
|
||||
current_batch = []
|
||||
current_batch_devs = []
|
||||
|
||||
for ji in jit_cache:
|
||||
match ji.prg:
|
||||
case CompiledRunner(): ji_graph_dev = ji.prg.dev
|
||||
case BufferXfer(): ji_graph_dev = Device[unwrap(ji.bufs[0]).device]
|
||||
case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device != "CPU"), None)
|
||||
case ViewOp(): continue # ViewOps are just ignored
|
||||
case _: ji_graph_dev = None # Everything else is not graphed and flushes existing graph if it's being constructed
|
||||
|
||||
# Check if this jit item can be graphed at all, so check if a new graph supports the current item.
|
||||
can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji)
|
||||
|
||||
# Check if the current batch can be extended with this item.
|
||||
can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \
|
||||
graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji)
|
||||
can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
|
||||
|
||||
# Flush the current batch if any, since it can't be extended or is full.
|
||||
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
||||
(current_batch if can_be_graphed else graphed_jit_cache).append(ji)
|
||||
current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else []
|
||||
|
||||
if len(current_batch) > 0: flush_batch()
|
||||
return graphed_jit_cache
|
||||
|
||||
def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]:
|
||||
input_replace: dict[tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for i,a in enumerate(ji.bufs):
|
||||
if a in input_rawbuffers:
|
||||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
return input_replace
|
||||
|
||||
class GraphRunner(Runner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
||||
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
|
||||
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
||||
|
||||
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
||||
|
||||
self.vars = sorted(var_vals.keys())
|
||||
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
|
||||
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
|
||||
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
|
||||
|
||||
estimates = Estimates()
|
||||
for j,ji in enumerate(jit_cache):
|
||||
estimates += ji.prg.estimates
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars]
|
||||
|
||||
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
||||
if global_dim_idx is not None or local_dim_idx is not None:
|
||||
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
||||
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
|
||||
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
|
||||
|
||||
# used in MultiGraphRunner. the ints are id() of _bufs
|
||||
self.w_dependency_map: dict[int, Any] = {}
|
||||
self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
|
||||
|
||||
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
|
||||
|
||||
def updated_vars(self, var_vals: dict[str, int]):
|
||||
vals = [var_vals[v] for v in self.vars]
|
||||
for j, vidxs in self.var_vals_replace.items():
|
||||
for i, v in vidxs: yield j, i, vals[v]
|
||||
|
||||
def updated_launch_dims(self, var_vals: dict[str, int]):
|
||||
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
||||
for j, (gl, lc) in self.launch_dims_replace.items():
|
||||
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
|
||||
|
||||
def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any):
|
||||
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
||||
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
||||
wait_nodes = []
|
||||
|
||||
for i,rawbuf in enumerate(rawbufs):
|
||||
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
|
||||
if i in write:
|
||||
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
|
||||
|
||||
for i,rawbuf in enumerate(rawbufs):
|
||||
if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
||||
else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
|
||||
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
||||
@staticmethod
|
||||
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool: return isinstance(ei.prg, CompiledRunner) and len(dedup(devs)) == 1
|
||||
|
||||
# a marker for your graph supporting multiple devices of the same type
|
||||
class MultiGraphRunner(GraphRunner):
|
||||
@staticmethod
|
||||
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
|
||||
# Devices must be the same type
|
||||
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) and len(dedup([type(Device[b.device]) for b in ei.bufs if b]+[type(d) for d in devs]))==1
|
||||
|
||||
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
|
||||
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
|
||||
if isinstance(ei.prg, (BufferCopy, BufferXfer)): return [cast(Buffer, ei.bufs[0])]
|
||||
return []
|
||||
|
||||
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
||||
for ei in jit_cache:
|
||||
if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei))
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
@dataclass
|
||||
class CapturedJit(Generic[ReturnType]):
|
||||
ret: Any # includes the Tensors or any other returned object
|
||||
jit_cache: list[ExecItem]
|
||||
input_replace: dict[tuple[int, int], int]
|
||||
extra_view_inputs: list[tuple[int, int, str, int, DType]]
|
||||
expected_names: list[int|str]
|
||||
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
|
||||
|
||||
def __reduce__(self):
|
||||
# TODO: free_intermediates here? replan_buffers_memory_layout here?
|
||||
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
||||
self.expected_names, self.expected_st_vars_dtype_device)
|
||||
|
||||
def __post_init__(self):
|
||||
self._jit_cache: list[ExecItem] = self.jit_cache
|
||||
self._input_replace: dict[tuple[int, int], int] = self.input_replace
|
||||
self._first_run = True
|
||||
self._clear_inputs()
|
||||
|
||||
def _clear_inputs(self):
|
||||
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
||||
|
||||
def free_intermediates(self):
|
||||
depends: set[Buffer|None] = set([None])
|
||||
update_depends(depends, self.jit_cache)
|
||||
for b in depends:
|
||||
if b is not None:
|
||||
if b.is_allocated(): b.deallocate()
|
||||
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
|
||||
self.__post_init__() # reset the graph state
|
||||
|
||||
def replan_buffers_memory_layout(self):
|
||||
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
|
||||
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
|
||||
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
|
||||
for old, new in asgn.items():
|
||||
if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer())
|
||||
self.__post_init__()
|
||||
|
||||
# jit exec
|
||||
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
|
||||
# assign inputs
|
||||
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
||||
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
||||
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
||||
|
||||
# Condense the items into a graph executor.
|
||||
if self._first_run:
|
||||
# allocate intermediates if freed
|
||||
for ji in self.jit_cache:
|
||||
for b in ji.bufs:
|
||||
if b is not None: b.ensure_allocated()
|
||||
# create graph if needed
|
||||
if JIT < 2:
|
||||
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=JIT_BATCH_SIZE.value)
|
||||
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
|
||||
self._first_run = False
|
||||
|
||||
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
||||
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
||||
self._clear_inputs()
|
||||
return self.ret
|
||||
|
||||
def _prepare_jit_inputs(args, kwargs):
|
||||
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
||||
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
||||
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||
# TODO: this multi unpack stuff is not well tested.
|
||||
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
|
||||
for lb in lbs if lb.base.realized is not None])
|
||||
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
||||
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
|
||||
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||
var_vals = {k.expr:v for k,v in _var_vals.items()}
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
|
||||
return input_buffers, var_vals, names, st_vars_dtype_device
|
||||
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False):
|
||||
assert fxn or captured, "need either a function or a CapturedJit"
|
||||
self.fxn = fxn
|
||||
self.captured: CapturedJit|None = captured
|
||||
self.cnt: int = 2 if self.fxn is None else 0
|
||||
self.prune = prune
|
||||
self.optimize = optimize
|
||||
|
||||
def add_buffer(self, b:Buffer) -> Buffer:
|
||||
if found:=self._buffer_replace.get(b, None): return found
|
||||
if b.is_allocated() or b.uop_refcount > 0: return b
|
||||
if b._base is not None:
|
||||
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
|
||||
else:
|
||||
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
||||
return ret
|
||||
|
||||
def add(self, ei:ExecItem):
|
||||
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars))
|
||||
|
||||
def reset(self):
|
||||
assert self.fxn is not None, "can't reset without function"
|
||||
self.cnt = 0
|
||||
self.captured = None
|
||||
|
||||
def __reduce__(self):
|
||||
assert self.captured is not None, "can't pickle an uncaptured JIT"
|
||||
return self.__class__, (None, self.captured)
|
||||
|
||||
# keep legacy code working
|
||||
@property
|
||||
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
|
||||
@property
|
||||
def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
|
||||
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
|
||||
if not JIT or self.cnt == 0:
|
||||
# jit ignore
|
||||
assert self.fxn is not None
|
||||
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
||||
elif self.cnt == 1:
|
||||
# jit capture
|
||||
assert self.fxn is not None
|
||||
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
||||
self._jit_cache: list[ExecItem] = []
|
||||
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
||||
# TODO: should we always disable the memory planner here? it must be off for prune
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
|
||||
capturing.append(self)
|
||||
try:
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
||||
except Exception as e: raise e
|
||||
finally: capturing.clear()
|
||||
jit_cache = self._jit_cache
|
||||
del self._buffer_replace, self._jit_cache
|
||||
assert len(jit_cache), "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
|
||||
|
||||
# track inputs that are views of buffers
|
||||
# TODO: eventually expected_buffers should live in ExecItem
|
||||
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
|
||||
for item in jit_cache:
|
||||
for b in item.bufs:
|
||||
if b is not None and b._base is not None and b._base in input_buffers:
|
||||
input_buffers.append(b)
|
||||
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
||||
|
||||
# prune independent kernels (optional)
|
||||
if self.prune:
|
||||
depends = set(input_buffers)
|
||||
update_depends(depends, jit_cache)
|
||||
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
|
||||
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
||||
# run the onetime kernels here
|
||||
for ei in onetime:
|
||||
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
||||
ei.run(var_vals, jit=True)
|
||||
jit_cache = pruned
|
||||
|
||||
# memory planning (optional)
|
||||
# Exclude buffers involved in transfer ops to preserve parallelism.
|
||||
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy)) for b in ji.bufs}
|
||||
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
||||
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
|
||||
item.metadata, item.fixedvars) for item in jit_cache]
|
||||
|
||||
input_replace = get_input_replace(jit_cache, input_buffers)
|
||||
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
||||
|
||||
# set this for next run
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
||||
if self.optimize: self.captured.replan_buffers_memory_layout()
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.captured is not None
|
||||
assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}"
|
||||
assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \
|
||||
f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
|
||||
ret = self.captured(input_buffers, var_vals)
|
||||
|
||||
self.cnt += 1
|
||||
return ret
|
||||
70
tinygrad/engine/memory.py
Normal file
70
tinygrad/engine/memory.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.runtime.support.memory import TLSFAllocator
|
||||
|
||||
# **************** memory planning ****************
|
||||
|
||||
def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
|
||||
if NO_MEMORY_PLANNER: return {}
|
||||
first_appearance, last_appearance, buf_to_opt = {}, {}, set()
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u:
|
||||
should_skip = buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers)
|
||||
if not ignore_checks and should_skip: continue
|
||||
if buf.base not in first_appearance: first_appearance[buf.base] = i
|
||||
last_appearance[buf.base] = i
|
||||
buf_to_opt.add(buf)
|
||||
|
||||
# Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed.
|
||||
buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \
|
||||
[((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
|
||||
total_memory = sum(round_up(buf.nbytes, min_block_size:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%)
|
||||
|
||||
# Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator.
|
||||
# Also track buffer replacements for buffers that do not support suballocation.
|
||||
buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {}
|
||||
reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list)
|
||||
global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=min_block_size, lv2_cnt=32)))
|
||||
for (_, is_open_ev), buf in buffer_requests:
|
||||
# Check if suballocation is possible for the given buffer and device.
|
||||
if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType):
|
||||
if is_open_ev: buffer_replace[buf] = (None, global_planner[buf.device][1].alloc(round_up(buf.nbytes, 0x1000)))
|
||||
else: global_planner[buf.device][1].free(cast(int, buffer_replace[buf][1]))
|
||||
global_planner[buf.device] = (max(global_planner[buf.device][0], buffer_replace[buf][1] + buf.nbytes), global_planner[buf.device][1])
|
||||
else:
|
||||
key = (buf.device, buf.dtype, buf.options, buf.nbytes)
|
||||
if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None)
|
||||
else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0]))
|
||||
|
||||
# Allocate global buffers based on the memory planner.
|
||||
global_buffers = {dev: Buffer(dev, round_up(sz, 0x1000), dtypes.int8) for dev, (sz, _) in global_planner.items()}
|
||||
buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[buf.device], off) for buf,(base,off) in buffer_replace.items()}
|
||||
|
||||
# Assign buffers. First, assign full buffers (not sub-buffers).
|
||||
assigned:dict[Buffer, Buffer] = {}
|
||||
for buf, (base, off) in buffer_resolve.items():
|
||||
if buf != base:
|
||||
assigned[buf] = base if off is None else Buffer(buf.device, buf.size, buf.dtype, base=base, offset=off)
|
||||
|
||||
# Now assign sub-buffers.
|
||||
for buf in buf_to_opt:
|
||||
if buf._base is not None:
|
||||
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=(pbuf:=assigned.get(buf.base, buf.base)).base, offset=pbuf.offset+buf.offset)
|
||||
|
||||
if DEBUG >= 1:
|
||||
ak, av = dedup(x for x in assigned.keys() if x._base is None),dedup(x for x in assigned.values() if x._base is None)+list(global_buffers.values())
|
||||
omem, nmem = sum([x.nbytes for x in ak])/1e6, sum([x.nbytes for x in av])/1e6
|
||||
if omem != nmem: print(f"{debug_prefix}memory reduced from {omem:.2f} MB -> {nmem:.2f} MB,", f"{len(ak)} -> {len(av)} bufs")
|
||||
|
||||
return assigned
|
||||
|
||||
def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([list(si.bufs) for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.fixedvars) for si in schedule]
|
||||
233
tinygrad/engine/realize.py
Normal file
233
tinygrad/engine/realize.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from typing import cast, Generator, Callable
|
||||
import time, pprint, random, itertools, math
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
|
||||
def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
"""
|
||||
Transform an AST into a ProgramSpec. May trigger BEAM search.
|
||||
|
||||
Args:
|
||||
ast: The Ops.SINK rooted AST
|
||||
renderer: The renderer used to generate the code
|
||||
|
||||
Returns:
|
||||
The ProgramSpec of the program.
|
||||
"""
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print('\n'.join(pyrender(ast)))
|
||||
|
||||
# linearize
|
||||
if renderer is None: renderer = Device.default.renderer
|
||||
if opts is not None:
|
||||
assert ast.arg is None, "can't apply opts if sink has an arg"
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
try:
|
||||
uops = full_rewrite(ast, renderer)
|
||||
except RuntimeError as e:
|
||||
print("***** LINEARIZE FAILURE *****")
|
||||
print(e)
|
||||
print('\n'.join(pyrender(ast)))
|
||||
raise
|
||||
assert uops[-1].op is Ops.SINK, "last uop must be sink"
|
||||
|
||||
# print and render
|
||||
if DEBUG >= 6: print_uops(uops)
|
||||
src = renderer.render(uops)
|
||||
|
||||
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
|
||||
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
|
||||
local_size=[1,1,1] if renderer.has_local else None)
|
||||
|
||||
# **************** Runners ****************
|
||||
|
||||
class Runner:
|
||||
def __init__(self, display_name:str, device:str, estimates=Estimates()):
|
||||
self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
|
||||
@property
|
||||
def dev(self): return Device[self.device]
|
||||
def exec(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None) -> float|None:
|
||||
return self(rawbufs, {} if var_vals is None else var_vals)
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
|
||||
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
||||
MAX_WORKGROUP = 1024
|
||||
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
||||
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
||||
def try_exec(local_size):
|
||||
try:
|
||||
return _prg(*[x._buf for x in test_rawbuffers],global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)],
|
||||
local_size=local_size, wait=True)
|
||||
except Exception: return float('inf')
|
||||
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
||||
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
||||
return ret[1]
|
||||
|
||||
class CompiledRunner(Runner):
|
||||
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None):
|
||||
if DEBUG >= 4: print(p.src)
|
||||
self.p:ProgramSpec = p
|
||||
if precompiled is not None: self.lib = precompiled
|
||||
else:
|
||||
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
|
||||
self.lib = Device[p.device].compiler.compile_cached(p.src)
|
||||
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
|
||||
super().__init__(p.name, p.device, p.estimates)
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
||||
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
|
||||
has_local = Device[self.p.device].renderer.has_local
|
||||
global_size, local_size = self.p.launch_dims(var_vals)
|
||||
if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
|
||||
local_size = optimize_local_size(self._prg, global_size, rawbufs)
|
||||
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
self.p = replace(self.p, global_size=global_size, local_size=local_size)
|
||||
lra = {}
|
||||
if global_size:
|
||||
lra['global_size'] = tuple(global_size)
|
||||
assert len(global_size) == 3, "global size must have len 3"
|
||||
if local_size:
|
||||
lra['local_size'] = tuple(local_size)
|
||||
assert len(local_size) == 3, "local size must have len 3"
|
||||
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait)
|
||||
|
||||
class ViewOp(Runner):
|
||||
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
|
||||
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
|
||||
|
||||
class BufferCopy(Runner):
|
||||
def __init__(self, total_sz, dest_device, src_device):
|
||||
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
||||
else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
||||
super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
|
||||
def copy(self, dest, src):
|
||||
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
|
||||
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
|
||||
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
|
||||
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
||||
elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
|
||||
else:
|
||||
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
|
||||
dest, src = rawbufs[0:2]
|
||||
assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
|
||||
st = time.perf_counter()
|
||||
self.copy(dest, src)
|
||||
if wait:
|
||||
Device[dest.device].synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
class BufferXfer(BufferCopy):
|
||||
def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev)
|
||||
|
||||
# **************** method cache ****************
|
||||
|
||||
method_cache: dict[tuple[str, type, bytes, tuple[int, ...], bool], CompiledRunner] = {}
|
||||
def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
# TODO: this should be all context relevant to rendering
|
||||
context = (BEAM.value, NOOPT.value, DEVECTORIZE.value)
|
||||
ckey = (device, type(Device[device].compiler), ast.key, context, False)
|
||||
if cret:=method_cache.get(ckey): return cret
|
||||
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
||||
else:
|
||||
prg: ProgramSpec = get_program(ast, Device[device].renderer)
|
||||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
|
||||
return ret
|
||||
|
||||
# **************** lowering functions ****************
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExecItem:
|
||||
prg: Runner
|
||||
bufs: list[Buffer|None]
|
||||
metadata: tuple[Metadata, ...]|None = None
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
|
||||
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
|
||||
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
||||
if PROFILE: cpu_events.append(ProfilePointEvent(self.prg.device, "exec", self.prg.display_name, {"metadata":self.metadata, "var_vals":var_vals}))
|
||||
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals))
|
||||
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals))
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
|
||||
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
|
||||
header_color = 'magenta' if jit else ('green' if self.prg.first_run else None)
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20)
|
||||
flops_str = f"{flops*1e-9:9.2f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:9.2f} TFLOPS", 'green')
|
||||
mem_str = f"{membw*1e-9:6.1f}|{ldsbw*1e-9:<7.1f} GB/s" if membw < 1e13 else colored(f"{membw*1e-12:6.1f}|{ldsbw*1e-12:<7.1f} TB/s", 'green')
|
||||
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+
|
||||
f" {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB"+
|
||||
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
|
||||
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}")
|
||||
self.prg.first_run = False
|
||||
return et
|
||||
|
||||
# NOTE: ctx is the buffers
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
|
||||
])
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars)
|
||||
|
||||
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
|
||||
while len(schedule):
|
||||
si = schedule.pop(0)
|
||||
try: yield (si, lower_schedule_item(si))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"error lowering {si.ast.op}")
|
||||
print("tensor operations:")
|
||||
pprint.pprint(si.metadata, indent=2)
|
||||
raise e
|
||||
|
||||
# **************** main run function ****************
|
||||
|
||||
capturing: list = [] # put classes with an add method in here
|
||||
|
||||
def run_schedule(schedule:list[ScheduleItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
for si, ei in lower_schedule(schedule):
|
||||
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
||||
if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
|
||||
# copy in allocated buffers from the GPU
|
||||
nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs)
|
||||
for cpu_b, gpu_b in zip(nb, si.bufs):
|
||||
if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
|
||||
|
||||
# run on GPU
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
with Context(BEAM=0): lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats)
|
||||
import numpy as np
|
||||
np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
83
tinygrad/engine/schedule.py
Normal file
83
tinygrad/engine/schedule.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import cast
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque, defaultdict
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, all_same
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
|
||||
# construct the KERNEL children graph based on assigns
|
||||
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree: dict[UOp, int] = {}
|
||||
var_vals: dict[str, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
|
||||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
for s in k.src:
|
||||
if s.op is Ops.ASSIGN:
|
||||
children[s.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.ASSIGN, f"ss.op is not ASSIGN, it's {ss.op}"
|
||||
children[ss.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.BUFFER:
|
||||
pass # a BUFFER is already realized, nothing to do here
|
||||
elif s.op is Ops.BIND:
|
||||
var, val = s.unbind()
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
|
||||
|
||||
# linearize KERNEL UOps into ScheduleItems in BFS order
|
||||
|
||||
def _heuristic(k: UOp):
|
||||
if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000
|
||||
return 0
|
||||
|
||||
last_heuristic: int = 0
|
||||
queues: defaultdict[int, deque[UOp]] = defaultdict(deque)
|
||||
last_queue: deque[UOp] = deque()
|
||||
for k,v in in_degree.items():
|
||||
if v == 0: queues[_heuristic(k)].append(k)
|
||||
|
||||
schedule: list[ScheduleItem] = []
|
||||
while last_queue or any(queues.values()):
|
||||
if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
|
||||
k = last_queue.popleft()
|
||||
ast = k.arg.ast
|
||||
# create subbuffers if needed
|
||||
if ast.op is Ops.BUFFER_VIEW:
|
||||
base = k.src[1].buf_uop.buffer
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
|
||||
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}))
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
|
||||
for x in children[k]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
|
||||
|
||||
return schedule, var_vals
|
||||
Reference in New Issue
Block a user