Release 260111

This commit is contained in:
Comma Device
2026-01-11 18:23:29 +08:00
commit 3721ecbf8a
2601 changed files with 855070 additions and 0 deletions

View File

View File

@@ -0,0 +1,119 @@
from tinygrad.uop.ops import Ops, UOp, resolve, can_pad, GroupOp, UPat, PatternMatcher, graph_rewrite
from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, FUSE_CONV_BW
from tinygrad.shape.shapetracker import ShapeTracker
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
# **** Grouper decides which of the UOps realize
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
st = unwrap(view.st)
# always realize unsafe pad ops before masked view
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx): return realize(ctx, tr)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
# realize before expand
if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr)
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
# realize parents of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
])
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
if (tr, st) in cache: return
cache.setdefault((tr, st))
rsize = unwrap(r.st).size
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children.get(tr, {}):
# max one reduceop per kernel
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
def group_realizes(sink:UOp) -> dict[UOp, None]:
# start by adding uops that always realize
realizes: dict[UOp, None] = {}
sink = graph_rewrite(sink, do_realize, ctx=realizes, name="do_realize")
if DONT_GROUP_REDUCES: return realizes
# construct children graph (only for bases)
children: dict[UOp, dict[UOp, None]] = {}
assigns: dict[UOp, None] = {}
for u in (toposort:=sink.toposort()):
if u.op in {Ops.VIEW, Ops.SINK}: continue
if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
for s in u.src: children.setdefault(s.base, {})[u] = None
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
double_reduces: list[UOp] = []
for r in toposort:
if r.op is not Ops.REDUCE_AXIS: continue
if len(r.arg) == 3 and r.arg[2] is True: continue
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
if r in realizes: continue
group: dict[UOp, None] = {}
recursive_group(r, unwrap(r.st), r, children, realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
for u in r.toposort(gate=lambda u: u not in realizes):
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST:
can_chase = False
break
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
# can only have one output
if not forced_realize and len(group) > 1: forced_realize = True
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
parents = [r, *group]
while parents and not forced_realize:
p = parents.pop().base
if p.op is Ops.BUFFER and p in assigns and p not in assign_targets: forced_realize, can_chase = True, False
if p in realizes: continue
parents.extend(p.src)
if forced_realize or not group:
tr = r
if can_chase:
# can chase this down to contiguous children
st = unwrap(tr.st)
while len(lst:=children.get(tr, {})) == 1:
tr_next = next(iter(lst))
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
if len(st_childs) > 1: break
if st.size != st_childs[0].size: break
st = st + st_childs[0]
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
tr = tr.src[0].base
group = {tr: None}
realizes[tr] = None
reduce_for_op.update((tr, r) for tr in group)
# fuse double reduces with no other child
for reduceop in double_reduces:
top_reduce = reduceop.src[0].base
if len(children.get(top_reduce, {})) == 1: del realizes[top_reduce]
return realizes

View File

@@ -0,0 +1,382 @@
from dataclasses import dataclass
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
from tinygrad.codegen.opt import Opt
# creation can recurse a lot
import sys
sys.setrecursionlimit(10000)
# **** schedule simplifier
def simplify_stride0_reduce(reduce:UOp, x:UOp):
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
if any(v.mask is not None for v in unwrap(x.st).views): return None
# must have all stride 0 in the relevant axis (NOTE: can do partial)
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
prshape = prod(x.shape[i] for i in reduce.arg[1])
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
match reduce.arg[0]:
case Ops.ADD: return ret*prshape
case Ops.MUL: return ret.pow(prshape)
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
def split_reduceop(reduce:UOp, x:UOp):
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
# if there are few globals, make some reduces into globals by splitting into two kernels
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
# ~2**10 should be enough if GROUP is used
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
# split is moved to the end to provide maximum locality for the second phase reduce.
real_strides = unwrap(x.st).real_strides(ignore_valid=True)
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
if x.shape[i]%d==0 and real_strides[i]!=0]): return None
dim_to_split, divisor = split_candidates[0]
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
# reduce original axes, then split
return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
return base.copy_to_device(copy.device).view(view.arg)
kernelize_sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce on stride 0 is collapsed
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)),
# non device changing COPY is a NOOP
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# CAST before masking constants
(UPat.cvar("x").view().cast(name="c"), lambda x,c: x.cast(c.dtype).view(c.src[0].arg)),
# make things that can't be images not images
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# contiguous/buffer/copy/assign is already contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,),
(t.size, x.st.views[0].offset)).reshape(t.shape) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
# double ASSIGN to same target is one ASSIGN
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))))), lambda x,t: t.assign(x.contiguous())),
# ASSIGN to unrealized replaces the UOp
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))), lambda x,t: x.contiguous() if t.base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW} and
not (t.base.op is Ops.MSTACK and all(x.op is Ops.BUFFER for x in t.base.src)) else None),
# put CAST to smaller dtype before EXPAND
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None),
# put UnaryOps before EXPANDs, if it can fuse with the input
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"),
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
])
# support for using a contiguous permuted view instead of the parent view if one exists
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
replace_contiguous = PatternMatcher([
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, name="src"),), name="contig"), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
# **** create kernels
@dataclass(frozen=True)
class Kernel:
ast: UOp
metadata: tuple[Metadata, ...] = ()
def __repr__(self):
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
def create_kernel(x:UOp, b:UOp|None=None):
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
# we have to shrink the buffer back to the symbolic shape
return buffer.assign(kernel).reshape(tuple(d.vmax if isinstance(d, UOp) else d for d in x.shape)).shrink(tuple((0, d) for d in x.shape))
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND}
def append_to_kernel(x:UOp):
new_srcs: list[UOp] = []
metadata = x.arg.metadata
for s in x.src:
if s.op in DONT_PLACE_IN_KERNEL: new_srcs.append(s)
else:
new_srcs.extend(s.src)
# NOTE: because const and device are shared UOps they don't change metadata
# NOTE: if it's a reshape after ASSIGN we're not fusing that parent kernel
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (not (s.op is Ops.RESHAPE and s.base.op is Ops.ASSIGN)) and (m:=s.metadata): metadata += m
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata))))
create_kernels = PatternMatcher([
# always give assign/contiguous a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
(UPat(Ops.CONTIGUOUS, name="x"), create_kernel),
# walk back the local graph until we reach a realized source
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# push RESHAPE through MSELECT
(UPat(Ops.MSELECT, src=(UPat(Ops.RESHAPE, name="r"),), name="ms"), lambda ms,r: r.src[0].mselect(ms.arg).reshape(r.arg)),
# push RESHAPE through MSTACK
(UPat(Ops.MSTACK, src=UPat(Ops.RESHAPE), name="ms"),
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
])
def add_stores(ctx, sink: UOp):
stores = []
for i,x in enumerate(sink.src):
gbl = UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i)
# if this is an assign then we already have a buffer with a view that should be the target of the store
if x.op is Ops.ASSIGN: stores.append(UOp.store(gbl.view(unwrap(s.st)), s))
# otherwise we have to create the shapetracker and shrink it to the correct symbolic shape
else: stores.append(
UOp.store(gbl.reshape(tuple(int(d.vmax) if isinstance(d,UOp) else d for d in s.shape)).shrink(tuple((0,d) for d in s.shape)),s))
return UOp.sink(*stores, arg=sink.arg)
# **** fix kernel AST
def unbind_view(x:UOp):
if any(x.op is Ops.BIND for x in x.arg.vars()): return x.replace(arg=x.arg.unbind()[0])
return None
replace_buffers = PatternMatcher([
# sink on contig creates a KernelInfo
(UPat(Ops.CONTIGUOUS, name="c").sink(name="s"),
lambda s,c: s.replace(src=(c.replace(arg=None),), arg=KernelInfo(opts_to_apply=c.arg)) \
if s.arg is None and c.arg is not None and isinstance(c.arg[0], Opt) else None),
# replace ASSIGN with the target BUFFER
(UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
(UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]),
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).load()),
# no SINK for meta ops
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), add_stores),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# passthrough ASSIGN (but let MSTACK process first)
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.MSTACK}), UPat()), name="x"), lambda x: x.src[1]),
# remove any BINDs from VIEWS
(UPat(Ops.VIEW, src=(UPat(), UPat((Ops.BIND, Ops.DEFINE_VAR))), allow_any_len=True, name="x"), lambda x: x.replace(src=x.src[0:1])),
# remove any BINDs from DEFINE_VARs
(UPat(Ops.BIND, name="x"), lambda x: x.src[0]),
# remove BINDs from ShapeTrackers
(UPat(Ops.VIEW, name="x"), unbind_view),
])
def fix_kernel_ast(k:UOp) -> UOp|None:
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
# replace buffer with define_global + add load/store last
bufs = []
for s in k.src:
if s.op is Ops.BIND: continue
s = s.buf_uop
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
bufs.append(s)
# replace global memory ops with the BUFFER they write to
# NOTE: merge_views is needed to unbind the reshapes
ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
return k.replace(arg=Kernel(ast, k.arg.metadata))
create_ast = PatternMatcher([
(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),
(UPat(Ops.DEFINE_VAR, src=(UPat(),), allow_any_len=True, name="x"), lambda x: x.replace(src=())),
])
# ** add metadata of KERNEL outputs
def append_metadata(root:UOp, k:UOp):
if not root.metadata or (new_metadata:=tuple(dedup(k.arg.metadata+root.metadata))) == k.arg.metadata: return None
return root.replace(src=(root.src[0], k.replace(arg=Kernel(k.arg.ast, new_metadata)))+root.src[2:])
replace_metadata = PatternMatcher([(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.KERNEL, name="k")), name="root", allow_any_len=True), append_metadata),])
pm_fuse = PatternMatcher([
# FUSE on CONTIGUOUS removes FUSE
(UPat(Ops.CONTIGUOUS, name="c").fuse(), lambda c: c),
# FUSE triggers swizzle on reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").or_casted(),), name="view").fuse(),
lambda r,src,view: ret.cast(view.dtype) if (ret:=swizzle_reduceop(r, src, view, fuse=True)) is not None else None),
# FUSE on reduce (without view) adds fuse marker to grouper
(UPat(Ops.REDUCE_AXIS, name="r").fuse(),
lambda r: r.replace(src=(r.src[0].fuse(),), arg=r.arg+(True,)) if len(r.arg) == 2 else None),
# remove FUSE and insert CONTIGUOUS if it's an unsafe pad
(UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="alu"),), name="view").fuse(),
lambda alu, view: alu.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
# FUSE elementwise.
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(),
lambda alu, view: alu.replace(src=tuple(apply_swizzle(x.view(view.arg)).fuse() for x in alu.src))),
# push FUSE through to srcs
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
])
def do_fusion(x:UOp):
found_contiguous = {}
def gate_contiguous(x):
if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique()))
return not is_contiguous
x.toposort(gate=gate_contiguous)
del gate_contiguous
return graph_rewrite(x.substitute(found_contiguous), pm_fuse, name="local fusion").substitute({v:k for k,v in found_contiguous.items()})
def fuse_arange(root:UOp):
# skip if root is arange
if not FUSE_ARANGE or root.src[0].base.op is Ops.CONST: return None
# gather all local aranges (including any fused ones)
local_arange: list[UOp] = []
def gate_reduce(u):
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST: local_arange.append(u)
return u.op not in {*ALWAYS_CONTIGUOUS, Ops.REDUCE_AXIS} or u is root
toposort = root.toposort(gate=gate_reduce)
if not local_arange: return None
# fuse the nearest expand child of arange
local_children: dict[UOp, list[UOp]] = {}
for u in toposort:
for s in u.src: local_children.setdefault(s, []).append(u)
fuse_rep: dict[UOp, UOp] = {}
for r in local_arange:
# skip if already fused
if len(r.arg) > 2: continue
q = list(local_children[r])
while q:
u = q.pop()
if not (curr_children:=local_children.get(u, [])): continue
for child in curr_children:
other_paths = {s for s in child.toposort() if s.op in {Ops.REDUCE_AXIS, Ops.BUFFER} and s not in {root, r}}
fuse_rep[child] = child.replace(src=tuple(s.fuse() if s is u else s for s in child.src))
if other_paths: break
else: q.extend(curr_children)
return root.substitute(fuse_rep, name="fuse_arange") if fuse_rep else None
do_fuse = PatternMatcher([
(UPat(Ops.FUSE, name="x"), do_fusion),
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
])
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"),
lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
# TODO: get this from the device through GrouperOpts
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
def limit_bufs(root:UOp):
# check if backend has a buffer limit
device = root.device if isinstance(root.device, str) else root.device[0].split(":")[0]
if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None
# count number of unique buffers flowing into this op
bufs: set[UOp] = set()
def gate_input(u:UOp):
if (is_load:=(u.op in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u)
return not is_load
root.toposort(gate=gate_input)
# NOTE: this -1 is for the output buffer
if len(bufs)>=MAX_BUFS-1:
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).contiguous() for s in root.src))
def view_add_srcs(x:UOp):
if len(avars:=x.arg.vars()) and len(x.src) == 1:
return x.replace(src=x.src+tuple(avars))
return None
finalize_contiguous = PatternMatcher([
# if an op takes more than one input, check combined LOADs don't exceed device limits
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
# merge contiguous
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]),
# simplify views
(UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
# vars to views srcs
(UPat(Ops.VIEW, name="x"), view_add_srcs),
])
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
"""
Function to transform the Tensor UOp graph into a version with Ops.KERNEL
Args:
sink: The Ops.SINK rooting the Tensor graph.
Returns:
Map transforming each UOp in the sink to the Ops.KERNEL graph.
"""
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
# insert contiguous in places determined by the realize map
realize_map = group_realizes(tensor_map[sink])
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
# group into kernels (this is context-free)
tensor_map = graph_rewrite_map(tensor_map[sink], create_kernels, input_map=tensor_map, name="create_kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
assign_rep: dict[UOp, UOp] = {}
for u in tensor_map[sink].toposort():
if u.op is not Ops.ASSIGN: continue
kernel_assign[u.buf_uop] = u
for s in u.src[1].src:
# TODO: this is probably broken for MSELECT/MSTACK
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep:
tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
# finally, create the AST for kernels
tensor_map = graph_rewrite_map(tensor_map[sink], create_ast+replace_metadata, bottom_up=True, input_map=tensor_map, name="create_ast")
# display the final graph
sched_sink = tensor_map[sink]
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
# verify Kernels match the spec
if __debug__: type_verify(list(sched_sink.toposort()), tensor_uop_spec)
return tensor_map

231
tinygrad/schedule/multi.py Normal file
View File

@@ -0,0 +1,231 @@
from typing import cast
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve
from tinygrad.device import Device
# *** allreduce implementation ***
def handle_allreduce_multirank(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
# Group buffers
groups: dict[int|None, list[UOp]] = {}
for i,dev in enumerate(buf.device):
groups.setdefault(Device[dev].group_id, []).append(buf.mselect(i))
# Put reduce leader of each group first
reduce_leaders = set(getenv("REDUCE_LEADERS", "").split(","))
groups = {gid: sorted(bufs, key=lambda x: (x.device not in reduce_leaders, x.device)) for gid,bufs in groups.items()}
# Skip if only one group or if every group has only one buffer
if len(groups) <= 1 or not any(len(g) > 1 for g in groups.values()): return None
# Reduce inside each group
inner = [UOp(Ops.MSTACK, buf.dtype, tuple(bufs)).allreduce(red.arg, (cast(str, bufs[0].device),)).mselect(0) for bufs in groups.values()]
# Allreduce across groups
outer = UOp(Ops.MSTACK, buf.dtype, tuple(inner)).allreduce(red.arg, tuple(buf.device for buf in inner))
# Broadcast back to all devices in the group
gid2bid = {Device[device].group_id: i for i,device in enumerate(outer.device)}
return outer.mselect(gid2bid[Device[red.device].group_id]).copy_to_device(red.device) if not isinstance(red.device, tuple) else \
UOp(Ops.MSTACK, buf.dtype, tuple(outer.mselect(gid2bid[Device[device].group_id]).copy_to_device(device) for device in red.device))
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {buf.dtype}")
# contiguous before we copy it
buf = buf.contiguous()
# copy to all devices. if you shrink later, that'll be handled
if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y),
[UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(len(buf.device))])
# new ring reduce
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
# extract chunks and scatter-reduce
reduced_chunks = []
for i,(s,e) in enumerate(chunks):
chunk = buf.reshape((numel,)).shrink(((s,e),))
reduced_chunk = chunk
for step in range(n_lbs-1):
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
# copy the chunk from the src device to the dest (operating device), and select the chunk on the dest device
reduced_chunk = reduced_chunk.copy_to_device(buf.device[dest], src if isinstance(reduced_chunk.device, tuple) else None) \
.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
reduced_chunks.append(reduced_chunk)
# allgather
copied_chunks = []
for i,c in enumerate(reduced_chunks):
this_chunk = [None] * len(buf.device)
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
for step in range(n_lbs-1):
dest = (i+step)%n_lbs
this_chunk[dest] = c = c.copy_to_device(buf.device[dest])
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
# reassemble
pads = [((s,numel-e),) for s,e in chunks]
return functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads, copied_chunks)]).reshape(shape)
# ***** multi rewrite MSELECT/MSTACK *****
def _replace_dnum(st, val):
# replace dnum in ShapeTracker with literal const for this mselect
if (dnums:=[x for x in st.vars() if x.op is Ops.DEFINE_VAR and x.arg[0] == '_device_num']):
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
st = st.substitute({dnums[0]:dnums[0].const_like(val)})
return st
def mstack_reorder_view(ms:UOp):
args = [x.arg for x in ms.src]
if not all_same(args) or len([x for x in args[0].vars() if x.arg[0] == '_device_num']) != 0: return None
return UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).view(args[0])
def mstack_early_shrink(view:UOp, ms:UOp):
if resolve(prod(view.shape) >= prod(ms.shape)) or _replace_dnum(view.st, 0) == view.st: return None
ret = []
for i, x in enumerate(ms.src):
new_view = _replace_dnum(view.st, i)
if x.op is Ops.COPY:
# if src device doesn't have a renderer, we have to view after the copy
# TODO: a way to understand this
if x.src[0].device in {"DISK", "NPY"}:
ret.append(x.view(new_view))
else:
ret.append(x.src[0].view(new_view).copy_to_device(x.device))
else:
ret.append(x.view(new_view).contiguous())
return ms.replace(src=tuple(ret))
replace_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank),
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),
# COPY_TO_ONE: if copying from multidevice to one, MSELECT the first (TODO: a little from each?)
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
x.mselect(0).copy_to_device(c.device) if isinstance(c.device, str) and isinstance(x.device, tuple) else None),
# MSELECT on MSTACK is replaced with nothing
(UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
# MSELECT must select a base, if there are views apply them after selecting the base
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), lambda ms, view, base:
base.mselect(ms.arg).view(_replace_dnum(unwrap(view.st), ms.arg))),
# move view through MSTACK
(UPat(Ops.MSTACK, src=UPat(Ops.VIEW), name="ms"), mstack_reorder_view),
# move shrink before MSTACK
(UPat(Ops.VIEW, src=(UPat(Ops.MSTACK, name="ms"),), name="view"), mstack_early_shrink),
])
# ***** multi functions *****
def alu_multi(root:UOp):
msrcs = root.src
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
axis = root.axis
assert axis is not None
srcs = []
for mlb in msrcs:
if mlb.axis == axis:
# same axis, just copy through
assert mlb.op is Ops.MULTI
srcs.append(mlb.src[0])
elif mlb.axis is None:
# no axis, shard it
assert mlb.op is not Ops.MULTI
srcs.append(mlb._shard(axis))
else:
# axis mismatch, unshard it, send it to all devices, and shard it correctly
assert mlb.op is Ops.MULTI
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
def reduce_multi(root:UOp, multi:UOp):
op, axis = root.arg
if multi.axis is not None and multi.axis in axis:
# all-reduce on sharded axes
return multi.src[0].r(op, axis).allreduce(op, multi.device)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return multi.src[0].r(op, axis).multi(axis=multi.axis)
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
def reshape_multi(root:UOp, multi:UOp):
arg = root.arg
if (new_axis:=root.axis) is None: return multi.src[0].reshape(arg).multi(new_axis)
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
assert prod(multi.src[0].shape[multi.axis:])%prod(arg[new_axis+1:]) == 0, f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
new_shape_axis = prod(multi.src[0].shape[multi.axis:]) // prod(arg[new_axis+1:])
return multi.src[0].reshape(tuple(s if a!=new_axis else new_shape_axis for a,s in enumerate(arg))).multi(new_axis)
def expand_multi(root:UOp, multi:UOp):
# NOTE: this assert isn't needed, sharded axis can have dim 1
assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
return multi.src[0].expand(_shape_to_single_shard(multi.axis, root.arg, multi.src[0])).multi(multi.axis)
def pad_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == (0,0), f"padding not supported for {root.arg=}"
return multi.src[0].pad(root.arg).multi(multi.axis)
def permute_multi(root:UOp, multi:UOp):
# all permutes supported!
return multi.src[0].permute(root.arg).multi(root.axis)
def shrink_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
f"shrinking not supported for {root.arg=}"
if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
"cannot shrink sharded and non-sharded axis at the same time"
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
# we just copy it to all the devices, no real. this will be optimized out later
return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.arg[multi.axis]))
return multi.src[0].shrink(tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))).multi(multi.axis)
def flip_multi(root:UOp, multi:UOp):
assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
return multi.src[0].flip(root.arg).multi(multi.axis)
# from multiple devices -> one
def copy_multi(multi:UOp, device:UOp):
assert multi.axis is not None, "all multi ops have axis"
return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)
def assign_multi(dest:UOp, src:UOp):
if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}")
return dest.src[0].assign(src.src[0]).multi(src.axis)
def passthrough_multi(root:UOp, multi:UOp):
return root.replace(src=(multi.src[0],)).multi(multi.axis)
# NOTE: this is the same pattern as Ops.UNROLL
multi_pm = PatternMatcher([
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
])+replace_allreduce

View File

@@ -0,0 +1,615 @@
from typing import Any, cast
import functools, operator
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, graph_rewrite_map
from tinygrad.uop.symbolic import sym, symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.kernelize import Kernel
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
double_reshape = PatternMatcher([
# RESHAPE on RESHAPE is the second reshape
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))),
])
earliest_rewrites = double_reshape+PatternMatcher([
# non shape changing RESHAPE is NOOP
#(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
#(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0].f(Ops.NOOP, tag=x.tag)),
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
# preserve tags?
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# COPY and source size need to match
# TODO: expand after copy creates issues with tagging
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None),
# assign only to buffer
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"),
lambda x,target,assign: x.f(Ops.NOOP, tag=assign.tag) if target.base.op is not Ops.BUFFER else None),
# handle disk
# TODO: this doesn't need to use st.views
(UPat.var("x").f((Ops.BITCAST, Ops.CONTIGUOUS), name="t"),
lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset), tag=t.tag).reshape(t.shape) if isinstance(x.device, str) \
and x.device.startswith("DISK") else None),
# contiguous/buffer/copy/assign is already contiguous
#(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
])
# *****************
# 1. add realize where we have to
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize ASSIGN/COPY/BUFFER_VIEW/CONTIGUOUS
(UPat({Ops.ASSIGN, Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS}, name="tr"), realize),
# realize parents of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
# realize input to assign (might be optimized out)
(UPat(Ops.ASSIGN, name="a"), realize_assign),
])
class WrappedContig:
def __init__(self, x): self.x = x
def __repr__(self): return f"C({self.x})"
add_contiguous = PatternMatcher([
(UPat(GroupOp.All, name="x"),
lambda ctx,x: x.replace(tag=WrappedContig(x.tag)).realize() if x in ctx and not isinstance(x.tag, WrappedContig) else None),
])
remove_contig_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=x.tag.x) if isinstance(x.tag, WrappedContig) else None)])
# *****************
# 2. mark all children
@dataclass
class ChildrenContext: children: dict[UOp, list[UOp]]|None = None
def extract_children(ctx:ChildrenContext, x:UOp):
if ctx.children is not None: return
children_map = x.get_children_map()
ctx.children = {}
for k,v in children_map.items():
non_sink_children = [u for u in v if u.op is not Ops.SINK]
if len(non_sink_children) <= 1: continue
# NOTE: this gate shouldn't be here
if any(x.op is Ops.REDUCE_AXIS for x in k.toposort()) and any(x.op in {Ops.BUFFER, Ops.CONTIGUOUS} for x in k.toposort()):
ctx.children[k] = non_sink_children
def mark_children(ctx:ChildrenContext, x:UOp):
assert ctx.children is not None
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(UOp(Ops.CHILDREN, s.dtype, (s,), arg=len(ctx.children[s])),),
arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
return x.replace(src=tuple(new_srcs))
pm_children = PatternMatcher([
(UPat(Ops.SINK, name="x"), extract_children),
(UPat(GroupOp.All-{Ops.CHILD, Ops.CHILDREN, Ops.SINK}, name="x"), mark_children),
])
# *****************
# 3a. rangeify (movement)
@dataclass
class RangeifyContext:
# block on parent until all children have been seen
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
seen_child: dict[UOp, Any] = field(default_factory=dict)
progress: int = 0
# create ranges
range_idx: int = 0
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP):
ret = UOp.range(s, self.range_idx, axistype)
self.range_idx += 1
return ret
def map_reshape(idx:UOp, r:UOp):
acc = 1
to_sum = []
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
to_sum.append(acc*src)
acc *= s
mish = sum(to_sum, start=UOp.const(dtypes.index, 0))
ret:list[UOp] = []
for s in r.src[0].shape[::-1]:
ret.append(mish % s) # NOTE: simplify will turn this to CONST
mish //= s
tret = ret[0].sink(*ret[1:]).simplify().src[::-1] if len(ret) else ()
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
def map_pad(idx:UOp, r:UOp):
ret = list(idx.src[1:])
bigwhere = UOp.const(dtypes.bool, True)
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
if s == 0 and e == 0: continue
where = UOp.const(dtypes.bool, True)
if resolve(e > 0): where = where & (ret[i] < (sh-e))
if resolve(s > 0): where = where & (ret[i] >= s)
bigwhere = bigwhere & where
with Context(TRACK_MATCH_STATS=0):
ret[i] = graph_rewrite(where.where(ret[i]-s, UOp.invalid()), sym)
# PAD is with 0
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
def map_expand(r:UOp, idx:UOp):
new_rngs = []
ending_ranges = []
non_ending_ranges = []
for a,x,y in zip(idx.src[1:], r.src[0].shape, r.shape):
axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE]
if resolve(x==y, False):
non_ending_ranges.extend(axis_to_range)
new_rngs.append(a)
else:
ending_ranges.extend(axis_to_range)
new_rngs.append(a.const_like(0))
ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
if idx.arg is not None: ending_ranges.append(idx.arg)
return r.src[0].index(*new_rngs, arg=min(ending_ranges) if ending_ranges else None)
pm_mops = PatternMatcher([
# this is like the definitions of these
(UPat(Ops.SHRINK, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.PERMUTE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.FLIP, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)),
# expand needs to end ranges
(UPat(Ops.EXPAND, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_expand),
# reshape does a lot of symbolic stuff
(UPat(Ops.RESHAPE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_reshape),
# pad adds min and max
(UPat(Ops.PAD, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_pad),
])
# *****************
# 3b. rangeify (ops)
# bufferization can happen in three ways
# 1. there's an explicit REALIZE in the graph
# 2. the ranges from the children don't match and we have to create a buffer (only on children)
# 3. might_end_axis triggers because we should be closing a loop to save compute
@dataclass(frozen=True)
class BufferizeOpts:
# on AddrSpace.LOCAL, device is the id
device: str|tuple[str, ...]|int|None
addrspace: AddrSpace = AddrSpace.GLOBAL
def map_partial_realize(ctx:RangeifyContext, x:UOp, idx:UOp):
if x.arg is None: return None # map_contiguous can handle this
# NOTE: all partial contiguous can safely be replaced by full contiguous. we should be able to match old functionality like this
if not (RANGEIFY > 1): return idx.replace(src=(x.replace(arg=None),)+idx.src[1:])
ranges = []
new_ranges = []
passthrough_idx = []
for i,s in enumerate(x.shape):
if i not in x.arg:
ranges.append(idx.src[1+i])
continue
passthrough_idx.append(idx.src[1+i])
ranges.append(ctx.new_range(s))
new_ranges.append(ranges[-1])
# TODO: this should be able to be global or local
ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST],
arg=BufferizeOpts(device=None, addrspace=AddrSpace.LOCAL))
return ret.index(*passthrough_idx)
def map_realize(ctx:RangeifyContext, x:UOp):
if x.arg is not None: return None
ranges = [ctx.new_range(s) for s in x.shape]
return x.src[0].index(*ranges).bufferize(*x.src[1:], *ranges, arg=BufferizeOpts(device=x.device), tag=x.src[0].tag)
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
rngs = list(idx.src[1:])
new_ranges = []
for i,s in enumerate(red.src[0].shape):
if i in red.arg[1]:
rngs[i] = ctx.new_range(s, axistype=AxisType.REDUCE)
new_ranges.append(rngs[i])
return UOp(Ops.REDUCE, red.dtype, src=(red.src[0].index(*rngs),)+tuple(new_ranges), arg=red.arg[0], tag=red.tag)
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
if c not in ctx.seen_children: ctx.seen_children[c] = {}
# wait here until we have seen all the children
if len(ctx.seen_children[c]) != x.arg[1]:
ctx.progress += 1
if ctx.progress > 10000: raise RuntimeError("children not making progress")
# NOTE: we mark this here
ctx.seen_children[c][x.arg[0]] = idx
raise RewriteNotReady
ctx.progress = 0
if c not in ctx.seen_child:
all_rngs = list(zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()]))
out_rngs = []
end_ranges = []
idx_ranges = []
# NOTE: locals aren't working, so we only fully bufferize here (unless RANGEIFY > 1)
all_all_same = all(all_same(r) for r in all_rngs)
for i,valid_rngs in enumerate(all_rngs):
rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs])
# we compare the ranges without their valids
if all_same(rngs) and (all_all_same or RANGEIFY > 1):
# the new valid is the OR of all the children valids
minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False))
out_rngs.append(minimum_valid.where(rngs[0], UOp.invalid()).simplify())
else:
out_rngs.append(ctx.new_range(c.shape[i]))
end_ranges.append(out_rngs[-1])
idx_ranges.append(i)
ctx.seen_child[c] = (out_rngs, idx_ranges, end_ranges)
else:
out_rngs, idx_ranges, end_ranges = ctx.seen_child[c]
for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr
# index based on the shared ranges
ret = c.index(*out_rngs)
# if all ranges aren't the same between children, we have to bufferize
if len(idx_ranges) > 0:
if len(idx_ranges) == len(out_rngs):
# this is a global bufferize
ret = ret.bufferize(*end_ranges, arg=BufferizeOpts(device=x.device))
else:
assert RANGEIFY > 1, "this isn't supported with RANGEIFY=1"
ret = ret.bufferize(*end_ranges, arg=BufferizeOpts(device=None, addrspace=AddrSpace.LOCAL))
ret = ret.index(*[idx.src[1+i] for i in idx_ranges])
return ret
def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
if len(ctx.seen_children[c]) != c.arg: raise RuntimeError("all children should have been seen by now")
return idx.replace(src=(idx.src[0].src[0],)+idx.src[1:])
def might_end_axis(idx:UOp):
if idx.arg is None: return None
# TODO: write a proper cost function here
if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE} for x in idx.toposort()): return None
if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None
to_end_axis = []
for i,a in enumerate(idx.src[1:]):
if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
to_end_axis.append(i)
if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
return idx.replace(arg=None)
def unprocessed_index(x:UOp): raise RuntimeError(f"unprocessed index on {x.src[0].op}")
pm_rangeify = pm_mops+PatternMatcher([
# sink contigs to kick it off
(UPat(Ops.REALIZE, src=(UPat(),), name="x", allow_any_len=True), map_realize),
# if there's an INDEX it can support partial contig
(UPat(Ops.INDEX, src=(UPat(Ops.REALIZE, src=(UPat(),), name="x"),), allow_any_len=True, name="idx"), map_partial_realize),
# if there are new ended children, tag the SINK
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
(UPat(Ops.INDEX, src=(UPat(Ops.CHILDREN, name="c"),), allow_any_len=True, name="idx"), children_gate),
# if we come across this, remove it. it was a CHILD unused in an INDEX
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x),
# CONST (or DEFINE_VAR) can't have axes. remove INDEX when we get here
(UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c),
# handle arg on any op with weight. old endrange stuff
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
# handle size 0
(UPat(Ops.INDEX, name="x"), lambda x: x.replace(src=(x.const_like(0),)+x.src[1:]) if x.st is not None and x.size == 0 else None),
# handle assign
(UPat(Ops.INDEX, src=(UPat(Ops.ASSIGN, name="assign"),), allow_any_len=True, name="x"),
lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],))),
# move MAP through elementwise ALU / reduce. these are the items with cost
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union(
{Ops.STORE, Ops.COPY, Ops.BUFFER_VIEW, Ops.DEVICE, Ops.BIND, Ops.CONTIGUOUS, Ops.NOOP})),), allow_any_len=True, name="x"),
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
# assert if there's any index we didn't process
(UPat(GroupOp.All-{Ops.REALIZE, Ops.BUFFERIZE}).f(Ops.INDEX, name="x"), unprocessed_index),
])
# *****************
# 3.5 cleanups
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
def cleanup_dead_axes(b:UOp):
new_rng = []
hit = False
reshape: list[sint] = []
for s,rng in zip(b.shape, b.src[1:]):
if rng not in b.src[0].sparents and rng.op is Ops.RANGE:
reshape.append(1)
hit = True
else:
reshape.append(s)
new_rng.append(rng)
if hit:
return b.replace(src=b.src[0:1]+tuple(new_rng)).reshape(tuple(reshape)).expand(b.shape)
# if a buffer is being stored just for permutes or something, remove it
# we want to reexpress the indexes of idx2 in terms of the implied b1
def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
# see if we can't do it, should this ever hit?
assert len(buf.src) == len(idx.src), "index on wrong bufferize"
assert all(x.op is Ops.RANGE for x in buf.src[1:])
# if it's user contiguous, we never remove it
if src.op is Ops.CONTIGUOUS: return None
# here is where we compute the cost
# for now just no REDUCE, COPY, or ASSIGN
ran = src.toposort(gate=lambda x: x.op not in {Ops.INDEX})
# we don't want to bufferize threefry, also causes problems because not all platforms support long
if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.BUFFER_VIEW, Ops.ASSIGN} for x in ran) and src.op is not Ops.THREEFRY: return None
# simple, matching old behavior
#if src.op is not Ops.INDEX: return None
# this is the ranges replaced
return src.substitute(dict(zip(buf.src[1:], idx.src[1:])))
def pre_bufferize(b:UOp, x:UOp, copy:UOp):
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:])
return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1]))
pm_cleanups = double_reshape+pm_mops+PatternMatcher([
#(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
# remove noop buffers. if we look at the next index we can remove even more of these
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"),
lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] \
and idx.src[0].op is not Ops.BUFFER_VIEW else None),
# remove reindexing with cost function
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
# no buffers for const
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"),
lambda c,b: c.reshape((1,)*len(b.shape)).expand(b.shape).replace(tag=b.tag)),
# if any CONST with DEVICE make it here (symbolic/copy issue), remove it
#(UPat(Ops.DEVICE).f(Ops.CONST, name="c"), lambda c: c.replace(src=())),
# copy on CONST is CONST
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
(UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b")
.f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize),
])
# *****************
# 4. put in buffers for bufferize
# TODO: should BUFFERIZE look a lot more like STORE
# BUFFERIZE has device in arg
# BUFFERIZE doesn't have indexing, that's implied by the ranges it closes
# BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier)
# NOTE: this has been fixed up a bit
def bufferize_to_store(x:UOp):
rngs = x.src[1:]
shape = tuple([int(r.vmax+1) for r in rngs])
sym_shape = tuple([ssimplify(r.src[0]) for r in rngs])
size = prod(shape)
assert size > 0, f"no zero sized buffers {shape}"
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
if x.src[0].op is Ops.ASSIGN:
assign_target, assign_src, assign_mops = x.src[0].src
assert assign_target.op is Ops.INDEX
# in assign, this is the buffer size, not the bufferize size
# TODO: assign_mops here
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype)
mops = []
walk = assign_mops
while walk is not assign_mops.base:
mops.append((walk.op, walk.arg))
walk = walk.src[0]
for m in mops[::-1]: ret = ret._mop(*m)
return ret.forced_reshape(shape).replace(tag=x.tag)
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype)
ret = ret.forced_reshape(shape)
# TODO: is this right? what if it's offset
if shape is not sym_shape: ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
return ret.replace(tag=x.tag)
# handle locals
tag = x.arg.device
if tag is None: tag = UOp.unique().arg # TODO: hack
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag)
# store has the other dtype here
# TODO: how is this unified?
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
pm_add_buffers = pm_mops+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
# move RESHAPEs through MSELECT/MSTACK
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)),
])
# *****************
# 5. split into kernels
@dataclass
class LocalAddBufferContext:
dg:int = 0
map:dict = field(default_factory=dict)
vars:dict = field(default_factory=dict)
range:int = 0
def debuf(ctx:LocalAddBufferContext, buf:UOp):
ret = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(buf.arg), arg=ctx.dg)
if buf not in ctx.map: ctx.map[buf] = buf
ctx.dg += 1
return ret
def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
ctx.vars[b] = None
return b.src[0]
def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
buf = assign.as_buf()
# HACK to put the buffer in the MAP instead of MSTACK/MSELECT
if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0]
assert buf not in ctx.map
ctx.map[buf] = assign
return buf
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.tag is not None: return None
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=())
ctx.range += 1
return ret
to_define_global = PatternMatcher([
(UPat(Ops.BUFFER, name="buf"), debuf),
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat((Ops.ASSIGN, Ops.MSTACK, Ops.MSELECT), name="assign"), handle_assign),
# HACK in case any CONSTs were replaced
# this is only needed if you are using symbolic
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
# renumber the ranges starting with 0 so that kernel deduping works
(UPat(Ops.RANGE, name="r"), renumber_range),
])
rangeify_codegen = PatternMatcher([
# no NOOP in the kernel graph
# TODO: this can be moved into codegen?
(UPat((Ops.NOOP, Ops.CONTIGUOUS), name="x"), lambda x: x.src[0]),
# strip the arg from store
(UPat(Ops.STORE, name="x"), lambda x: x.replace(arg=None) if x.arg is not None else None),
# add loads to non ptr indexes
# TODO: this can be moved into codegen?
(UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
# TODO: this can be moved into codegen
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())),
# TODO: hack for group for reduce
(UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)),
lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))),
])
def split_store(ctx:list[UOp], x:UOp):
if len(x.ranges): return None
if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None
# local kernel rewrite
lctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True)
# gather the metadata
metadatas = [ctx[y].metadata for x in ret.sparents if x.tag is not None for y in x.tag]
# NOTE: the hack for COPY is here
ret = ret.sink() if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None]))))
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
return x.as_buf().assign(kernel)
split_kernels = PatternMatcher([
(UPat(Ops.STORE, name="x"), split_store),
])
def tag_uop(ctx:list[UOp], x:UOp):
if x.tag is not None: return None
ctx.append(x)
return x.replace(tag=(len(ctx)-1,))
add_tags = PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND}.union(GroupOp.Movement), name="x"), tag_uop),
])
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
uop_list: list[UOp] = []
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
# HACKS: handle multi with graph_rewrite_map in order to not have to add all the tag logic to multi
msink = graph_rewrite_map(tsink, multi_pm, name="multi")
tsink = msink[tsink].substitute({v:v.rtag(k.tag) for k,v in msink.items() if v.tag is None and k.tag is not None})
tsink = graph_rewrite(tsink, earliest_rewrites, name="earliest rewrites")
realize_map: dict[UOp, UOp] = {}
graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph")
# NOTE: we don't use contiguous here, contiguous is a user op
tsink = graph_rewrite(tsink, add_contiguous, ctx=realize_map, bottom_up=True, name="add realize")
tsink = graph_rewrite(tsink, remove_contig_tags, name="remove contiguous tags")
tsink = graph_rewrite(tsink, pm_children, ctx=ChildrenContext(), bottom_up=True, name="get children")
# rangeify
tsink = graph_rewrite(tsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="rangeify")
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
tsink = graph_rewrite(tsink, symbolic_simple, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
# if it's not tagged by here, it's out
tsink = UOp.sink(*[x for x in tsink.parents if (x.op is Ops.BUFFERIZE or x.base.op in {Ops.CONST}) and x.tag is not None])
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
# bufferize -> store
tsink = graph_rewrite(tsink, pm_add_buffers, bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
assign_rep: dict[UOp, UOp] = {}
for u in tsink.toposort():
if u.op is not Ops.ASSIGN: continue
kernel_assign[u.buf_uop] = u
for s in u.src[1].src:
# TODO: this is probably broken for MSELECT/MSTACK
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
becomes_map: dict[UOp, UOp] = {}
for s in tsink.src:
assert s.tag is not None
for a in s.tag:
if a is None: continue
becomes_map[uop_list[cast(int, a)]] = s.replace(tag=None)
return becomes_map