Release 260111

This commit is contained in:
Comma Device
2026-01-11 18:23:29 +08:00
commit 3721ecbf8a
2601 changed files with 855070 additions and 0 deletions

View File

@@ -0,0 +1,131 @@
from typing import Any, Callable
import functools
from dataclasses import dataclass
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify
from tinygrad.renderer import Renderer
# import all pattern matchers here
from tinygrad.codegen.lowerer import pm_lowerer, get_index
from tinygrad.codegen.quantize import pm_quant
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
from tinygrad.codegen.opt.postrange import pm_postrange_opt
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
@dataclass
class RewriteStep:
pm: PatternMatcher
ctx: Callable[[UOp], Any]|None = None
name: str|None = None
bottom_up: bool = False
def __call__(self, sink:UOp):
return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
rewrites_for_views = [
RewriteStep(view_left, name="Main View Left"),
RewriteStep(view_right, name="Main View Right"),
RewriteStep(view_left+fix_kernel_ops, bottom_up=True, name="Finalize Kernel"),
]
rewrites_for_linearizer = [
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
RewriteStep(block_merge, name="Linearizer: Merge Blocks"),
RewriteStep(pm_finalize, name="Linearizer: Finalize")]
def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]:
# cache with the values of the context vars
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
@functools.cache
def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
if optimize:
# view pushing
ret.extend(rewrites_for_views)
# lowerer first
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
ret.append(RewriteStep(sym+pm_flatten_range, name="initial symbolic"))
# optimize (schedule) the AST
ret.append(RewriteStep(pm_simplify_ranges, name="simplify ranges"))
ret.append(RewriteStep(pm_reduce_simplify, name="simplify reduces"))
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing, name="postopt symbolic"))
# expand
ret.append(RewriteStep(sym+pm_pre_expander+expander, name="expander"))
# add locals
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
# ** devectorizer (full_graph_rewrite) **
# remove reduce
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
# add gpu dims (late). this works after devectorize, but it's faster here
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
# devectorize (TODO: does this need opts?)
if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
supported_ops = tuple(opts.code_for_op.keys())
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
# lower the index dtype to a concrete int
ret.append(RewriteStep(pm_lower_index_dtype+load_store_indexing, lambda _: opts.device, name="lower all index dtypes"))
# optional pre matcher
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
# decompositions
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)
ret.append(RewriteStep(pm_decomp, lambda _: opts.device, name="decompositions"))
# final rules for the renderer (without sym)
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
# return the list (with optional linearizer)
return ret + (rewrites_for_linearizer if linearizer else [])
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True, linearizer:bool=False) -> UOp:
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize, linearizer))
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
"""
Function to transform the Kernel UOp graph into a linearized program.
Args:
sink: The Ops.SINK rooting the Kernel graph.
opts: The Renderer (can change how things are processed, fix this).
Returns:
Linear program in UOps.
"""
lst = list(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True).arg.lst)
if __debug__: type_verify(lst)
return lst

View File

@@ -0,0 +1,94 @@
import math
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop
from tinygrad.helpers import all_int, dedup
from tinygrad.dtype import dtypes
from tinygrad.shape.view import get_contraction
from tinygrad.renderer import Renderer
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
# TODO: symbolic shape
if not all_int(dims): return dims
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
for i,m in enumerate(max_sizes):
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
break
else: return None
return dims
def _split_dims(dims, max_sizes):
if all(d <= m for d,m in zip(dims, max_sizes)): return dims
_dims = list(dims) + [1]*(3-len(dims))
for i in range(len(_dims)):
while _dims[i] > max_sizes[i]:
div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
_dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
if reverse: dims = dims[::-1]
# try to group first: (a, b, c, d) -> (ab, c, d)
limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
# check if grouping failed
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
# try to split up dims: (a,) -> (b, c)
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.index, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)]
if len(limited) < len(dims):
ret = []
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
for idx, contraction_group in zip(raw_idxs, contraction):
for c in contraction_group[:-1]:
ret.append(idx % dims[c])
idx //= dims[c]
ret.append(idx)
elif len(limited) > len(dims):
a, b = len(limited), len(dims)
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
return ret[::-1] if reverse else ret
def add_gpudims(ctx:Renderer, s:UOp):
if s.arg is None: return None
s_topo = list(s.toposort())
if any(x.op is Ops.SPECIAL for x in s_topo): return None
# get ranges
all_ranges = {x.arg[0:-1]:x for x in s_topo if x.op is Ops.RANGE}
# extract global/local dims
global_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.GLOBAL, AxisType.THREAD)]))
local_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)]))
if not global_dims and not local_dims: return None
# get global and local shape
ranges = [all_ranges[r] for r in global_dims+local_dims if r in all_ranges]
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0:-1] in global_dims])
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0:-1] in local_dims])
# get the idxs
ki: KernelInfo = s.arg
if ki.dont_use_locals:
assert not local_dims, "can't use locals if there's no local dims"
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
else:
# define indexes for GPU-like execution
idxs = get_grouped_dims("gidx", global_shape, ctx.global_max, reverse=True) + get_grouped_dims("lidx", local_shape, ctx.local_max)
# apply to multiple ranges
subs = {}
for r in s_topo:
if r.op is not Ops.RANGE: continue
try:
ii = (global_dims+local_dims).index(r.arg[0:-1])
if r.arg[1] == AxisType.REDUCE: continue
subs[r] = idxs[ii]
except ValueError: continue
return s.substitute(subs)
pm_add_gpudims = PatternMatcher([
# add gpudims must be last
(UPat(Ops.SINK, name="s"), add_gpudims),
])

View File

@@ -0,0 +1,306 @@
from typing import Any, cast
import functools, operator, itertools
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate
from tinygrad.helpers import getenv, flatten, AMX, prod
from tinygrad.renderer import Renderer
# ***** image load valid simplification *****
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.index(UOp.invalid())
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
# wait for it to be image indexed before running simplification
if start_idx.dtype.count != 2: return None
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in valid.split_uop(Ops.AND):
try: X, is_upper_bound, c = parse_valid(stmt)
except ValueError: return None
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in X.split_uop(Ops.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx)
testidx = testidx.simplify()
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
drop_stmt.append(stmt)
continue
# if X <= c, check if it's out of bound when X = c+1
# if X >= c, check if it's out of bound when X = c-1
test_value = c + 1 if is_upper_bound else c - 1
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
if i.is_increasing():
rw = i.substitute({X:X.const_like(test_value)}).simplify()
if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt)
break
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx, new_valid)
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None
# remove the gate from the index
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:])
load_store_indexing = PatternMatcher([
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# lower turn the invalid into a gate, must come before index dtype lowering
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate,),), lambda buf,x,cond,i: buf.index(x, cond)),
# drop true gate
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
# remove hanging cast
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
# delete_redundant_gates (after expand)
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
])
# ***** load/store grouping *****
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
# generate the individual indexes
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
# extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
for i in range(vec.dtype.count):
idx: Any = midx.src[i].src[1]
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
# then rewrite everything we can into groups
ret = []
idxs: list[int|None] = [None]*vec.dtype.count
global_offset = 0
for offsets in offsets_rootsrc.values():
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for grp in grouped_offsets:
# get the index offset for this element. using [0] is okay, because they are the same
lidx = midx.src[offsets[grp[0]][0]]
if len(grp) > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(len(grp)).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
# set the idxs of the output
for i,g in enumerate(grp):
for oo in offsets[g]: idxs[oo] = global_offset+i
# add this lidx to the CAT
ret.append(lidx)
global_offset += len(grp)
assert None not in idxs, f"some idxs are missing {idxs}"
# this base thing is for image, we want the CAT to be a normal pointer
post_cat = UOp(Ops.PTRCAT, buf.ptrdtype.base.ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
return post_cat.gep(tuple(cast(list[int], idxs)))
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
# TODO: this is written in many places
offset = 0
ret: list[UOp] = []
for s in cat.src:
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
offset += s.dtype.count
return UOp(Ops.NOOP, src=tuple(ret))
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
# NOTE: we need to invert the gep here, but it may be an expanding gep
# fake argsort. TODO: handle duplicates
a = {}
for i,x in enumerate(gep.arg): a[x] = i
new_arg = tuple(x[1] for x in sorted(a.items()))
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
load_store_folding = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"),
UPat.var("mask"))), expand_index),
# GEP after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
# GEP on data of STORE
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
# put PTRCAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
# put PTRCAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
])
# *** correct load/store ***
def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# this splits loads and stores into multiple chunks
# if there's only one element to load/store, no splitting needed
if (sz:=ls.src[0].dtype.count) == 1: return None
buf = idx.src[0]
# determine fold lengths
lengths = []
must_divide = True
if ctx is not None and ctx.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
pass
elif buf.ptrdtype.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
lengths.append(1) # worst case, it's not folded
# filter fold lengths that don't divide
if must_divide: lengths = [x for x in lengths if idx.src[1].divides(x) is not None]
# split based on the fold lengths
global_offset = 0
ret = []
while global_offset < sz:
# with 1 at the end of the lengths list, this will always hit
for fold_length in lengths:
if global_offset+fold_length > sz: continue
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
global_offset += fold_length
break
# if it wasn't split, we return None. otherwise we CAT them
if len(ret) <= 1: return None
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret))
def image_fixup(ls:UOp):
# normal image load or store, with the CAST from expand_index
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
idx = ls.src[0].src[0]
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
return ls.replace(src=(idx,)+ls.src[1:])
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
idx = ls.src[0]
id4 = idx.src[1] % 4
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
return None
correct_load_store = PatternMatcher([
# split LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
# image indexing, including unfoldable images
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
])
# *** uop expander ***
# TODO: there's a lot shared with gep_through_wmma here
def no_vectorized_wmma(wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
if wmma.dtype.count == out_sz: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
ssz = prod(x[1] for x in sz)
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
def no_vectorized_alu(alu:UOp):
if alu.dtype.vcount == 1: return None
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(Ops.VECTORIZE, alu.dtype, alus)
def no_vectorized_buf(buf:UOp):
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
cnt = cast.dtype.count
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
])
pm_render = PatternMatcher([
# for rendering, we use explicit VECTORIZE
(UPat(Ops.CONST, name='c'),
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# give any loads that are masked an alt value
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
len(store.src) <= 2 or store.src[2].op != Ops.IF else None),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
@dataclass
class ReduceContext:
acc_num: int = 0
def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
# if this has a horizontal reduction component, do that first
if inp.dtype != out_dtype:
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
horizontal_amount = inp.dtype.count//out_dtype.count
return [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
return [inp]
def reduce_to_acc(ctx:ReduceContext, red:UOp):
inp, reduce_range = red.src[0], red.src[1:]
lst = horizontal_reduce(inp, red.dtype)
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
# if we have a range
if len(reduce_range) != 0:
topo = inp.toposort()
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
pm_reduce = PatternMatcher([
# REDUCE -> DEFINE_ACC+ASSIGN
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
# tensor core built in accumulate
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
])+sym

View File

@@ -0,0 +1,162 @@
# this converts a lowerer program into a vectorized program
import functools, itertools, operator
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, partition
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start
from tinygrad.schedule.rangeify import BufferizeOpts
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
idx, mul = 0, 1
for axis,m in args[::-1]:
idx += rpk[axis] * mul
mul *= m
return idx
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
@functools.cache
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
def do_expand(root:UOp):
expands = [x for x in root.src if x.op is Ops.UNROLL]
if len(expands) == 0: return None
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
# if there's only one expand arg, it's okay to use it (optimization)
expand_args = expands[0].arg
else:
# otherwise, we sort them and GEP
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
expand_sz = prod([x[1] for x in expand_args])
new_srcs = []
for i,src in enumerate(root.src):
if src.op is Ops.UNROLL:
if root.op is Ops.IF and i == 0:
# IF means OR on first arg to IF
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
elif expand_args == src.arg:
# just remove the expand
new_srcs.append(src.src[0])
else:
lst = _swizzle_args(expand_args, src.arg, exclude_args)
# if the base dtype is > 1, put those at the end
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
new_srcs.append(src.src[0].gep(tuple(lst)))
else:
# non-UNROLL input
if root.op is Ops.IF or src.op is Ops.IF:
# for the first arg of IF, just pass them through ignoring UNROLLS
new_srcs.append(src)
elif root.op in range_start and i >= range_start[root.op]:
# for any range args of STORE/REDUCE, pass them through
new_srcs.append(src)
elif root.op is Ops.INDEX and i >= 1 and not isinstance(root.dtype, PtrDType):
new_srcs.append(src)
elif src.dtype.count > 1:
# put any input dtype > 1 grouped together
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
else:
# repeat the arg
new_srcs.append(src.broadcast(expand_sz))
new_arg = root.arg
if root.op is Ops.GEP:
assert root.dtype.count == 1
# is this right?
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
def do_contract(con:UOp):
ex = con.src[0]
# CONTRACT without UNROLL repeats the element VECTORIZED
if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
# CONTRACT may remove several axes from UNROLL
assert con.dtype == dtypes.void or con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
idxs = []
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
expander = PatternMatcher([
# double expand
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# BARRIERs aren't actually expanded
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
# empty UNROLL is NOOP
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
])
def create_gate(root:UOp) -> UOp|None:
@functools.cache
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
if u.op is Ops.BARRIER: return u
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg)
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
idx = root.src[0]
if idx.op is Ops.CAST: idx = idx.src[0]
return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
migrate_indexing = PatternMatcher([
# create gate MUST BE BEFORE expander
(UPat(Ops.STORE, name="root"), create_gate),
])
# ****
def fix_reduce_unroll(x:UOp):
reduce_range, reduce_expand = partition(x.src[1:], lambda y: y.op is Ops.RANGE)
if len(reduce_expand) == 0: return None
reduce_expand = [x for x in reduce_expand if x.op is not Ops.CONST]
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand}"
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1)
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
return x.replace(src=(ret,)+tuple(reduce_range))
def fix_store_unroll(x:UOp):
store_expand, store_range = partition(x.src[2:], lambda y: y.op is Ops.UNROLL)
if len(store_expand) == 0: return None
return UOp(Ops.CONTRACT, dtypes.void, (x.replace(src=x.src[:2]+tuple(store_range)),), tuple(flatten(x.arg for x in store_expand)), tag=1)
def fix_group_for_reduce(x:UOp):
reduce_gfr, reduce_r = partition(x.src[1:], lambda u: u.op is Ops.RANGE and u.arg[1] == AxisType.GROUP_REDUCE)
if len(reduce_gfr) == 0: return None
# NOTE: if there's other locals here, we need them in the buffer too
upstream_locals = [u for u in x.toposort() if u.op is Ops.RANGE and u.arg[1] == AxisType.LOCAL]
# do only the non grouped reduces early
ret = x.replace(src=(x.src[0],)+tuple(reduce_r))
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr]
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop)
# gate with an if on the store + do the final reduce
buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf))
return buf.reduce(*reduce_loop, arg=x.arg)
pm_pre_expander = PatternMatcher([
# rewrite UPCAST/UNROLL range to something to be expanded
(UPat(Ops.RANGE, name="r"),
lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \
if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None),
# fix REDUCEs with UNROLLs
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
(UPat(Ops.STORE, name="x"), fix_store_unroll),
# fix group for reduce
(UPat(Ops.REDUCE, name="x"), fix_group_for_reduce),
])

View File

@@ -0,0 +1,243 @@
from __future__ import annotations
import heapq
from collections import defaultdict
from dataclasses import dataclass, replace
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp, BottomUpGate
from tinygrad.helpers import dedup, all_same, flatten, BLOCK_REORDER
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
def block_reorder(lst:list[UOp]) -> list[UOp]:
in_this_block = set(lst)
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree:dict[UOp, int] = {}
priorities:dict[UOp, int] = {}
# get local children and assign priorities
# NOTE: this requires the lst be locally toposorted
for u in reversed(lst):
in_degree[u] = 0
for s in u.src:
if s in in_this_block:
local_children[s].append(u)
in_degree[u] += 1
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
priority = [0] + [priorities[x] for x in local_children[u]]
if u.op is Ops.LOAD: priority.append(-1000)
if u.op is Ops.BARRIER: priority.append(-1500)
priorities[u] = min(priority)
# number the uops in "ideal" order
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
# then force then to be toposorted in as close to the ideal order as possible
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
newlst = []
while heap:
newlst.append(u:=heapq.heappop(heap)[1])
for v in local_children[u]:
in_degree[v] -= 1
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
return newlst
# ***** basic block *****
def disp(y:UOp) -> str:
if y.op is Ops.IF: return f'IF{id(y)}'
if y.op is Ops.RANGE: return str(y.arg)
return "<NONE>"
@dataclass(frozen=True, eq=False)
class BasicBlock:
lst: tuple[UOp, ...]
ctx: tuple[UOp, ...] = ()
end: UOp|None = None
cnt: int = 0
child_ctx: tuple[UOp, ...]|None = None
def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks")
def __repr__(self):
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx
def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize))
# ***** block context *****
@dataclass
class BlockContext:
child_count: dict[UOp, int]
block_ctxs: dict[UOp, tuple[UOp, ...]]
child_ctxs: dict[UOp, tuple[UOp, ...]]
def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u])
@staticmethod
def from_sink(sink:UOp) -> BlockContext:
# get children and all block contexts
ctx = BlockContext({}, {}, {})
for u in sink.toposort(gate=lambda u:u.op is not Ops.SPECIAL):
this_block_ctx: list[UOp] = []
ctx.child_count[u] = 0
# get children and accumulate the last_ctx
for s in u.src:
if s.op is Ops.SPECIAL: continue
# NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
ctx.child_count[s] += 1
this_block_ctx += ctx.last_ctx(s)
# save the block ctx. SINK never has anything
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
# RANGE/IF add to the next ctx
# STORE/ASSIGN subtract from the next ctx
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
return ctx
# ***** make blocks *****
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
ends_to_add = [z for z in new_ctx if z not in current_ctx]
while len(ends_to_add):
r:UOp = ends_to_add.pop(-1)
new_ctx = tuple([z for z in new_ctx if z is not r])
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
return base_block
def make_block_bottom_up(ctx:BlockContext, x:UOp):
if x.op is Ops.BLOCKSTART:
current_ctx, child_ctx = x.arg
lst = list(x.src)
child_count = 1
else:
current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None)
lst = [x]
# count of times we've seen this block, or a seed for a new block if we can't merge it
unmergable: defaultdict[UOp, int] = defaultdict(int)
blockseeds = defaultdict(list)
# add the srcs of this to the frontier
# NOTE: things may be in here multiple times, that's okay
frontier_nodes = list(flatten(y.src[::-1] for y in lst))
while len(frontier_nodes):
u = frontier_nodes.pop(0)
if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1:
# count is correct
if (newctx:=ctx.block_ctxs[u]) == current_ctx:
# block has same context, merge it, and put the srcs on the frontier
lst.append(u)
frontier_nodes.extend(u.src[::-1])
else:
# block has different context, add it to blockseeds
blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u)
del unmergable[u]
else:
# count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
unmergable[u] += 1
# add unmergables to sources
srcs = []
for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs.get(u,()), current_ctx, cnt=cnt)]*cnt
# add blockseeds, with blockends as needed
for (new_ctx, new_child_ctx), v in blockseeds.items():
base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx))
srcs.append(add_blockends(base_block, new_ctx, current_ctx))
lst = lst[::-1]
if BLOCK_REORDER: lst = block_reorder(lst)
bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
# we prevent the source of the SPECIAL from being linearized since its not part of the kernel
def raise_bottom_up_gate(): raise BottomUpGate()
block_create = PatternMatcher([
(UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
(UPat(Ops.SPECIAL), raise_bottom_up_gate)
])
# ***** blockend merging ****
def merge_blockends(sink:UOp) -> UOp|None:
# only run on the final BLOCK with the SINK in it
if sink.arg.lst[-1].op is not Ops.SINK: return None
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
blockends_to_arg: dict[UOp, list[UOp]] = {}
for be in sink.toposort():
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
new_forks = {}
for k,v in blockends_to_arg.items():
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
if len(v) > 1:
bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
# NOTE: bb.ctx != u.arg.ctx can cause problems here
for u in v: new_forks[u] = out
if len(new_forks) == 0: return None
return sink.substitute(new_forks)
pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
# ***** block merging ****
def merge_block(x:UOp):
unmergable_blocks, mergable_blocks = [], []
mergable_dict: defaultdict[UOp, int] = defaultdict(int)
for y in x.src:
if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1
elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1
else: unmergable_blocks.append(y)
for k,v in mergable_dict.items():
if v == k.arg.cnt: mergable_blocks.append(k)
else: unmergable_blocks.extend([k]*v)
if len(mergable_blocks) == 0: return None
del mergable_dict
# create the block
arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst)
return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg)
def remove_blockend(x:UOp):
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]):
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
parent_block = parent_blocks[0]
assert len(parent_blocks) == parent_block.arg.cnt
# NOTE: DEFINE_ACC doesn't have to be handled in any special way
late_ops = list(x.arg.lst)
# NOTE: we have to add a barrier at the start if barrier is used in the range
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
late_ops = [UOp(Ops.BARRIER)] + late_ops
# peephole opt, remove any BARRIERs next to each other
for i in range(len(late_ops)-1):
if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP)
arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
# else the whole context ended by the blockend is already in this block and we can safely turn it into a block
return UOp(Ops.BLOCK, src=x.src, arg=BasicBlock(x.arg.lst, tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt))
block_merge = PatternMatcher([
(UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block),
(UPat(Ops.BLOCKEND, name="x"), remove_blockend),
])
# ****** finalize ******
def finalize(sink:UOp) -> UOp:
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
# place the early things
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst)))
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])

View File

@@ -0,0 +1,86 @@
# the job of the lowerer is to do indexing
from dataclasses import dataclass
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve
# ***** indexing *****
@dataclass
class IndexContext:
axis_types: tuple[AxisType, ...]
idxs: list[UOp]
start: int = 0
def shape_to_idx(s, axis_types, start=0):
return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
def get_index(ast:UOp) -> IndexContext:
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
if len(ast.full_shape) != len(axis_types) and ast.st is not None:
axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
return IndexContext(axis_types, [], 0)
# ***** lowering (given index) *****
def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp):
lc = IndexContext(ctx.axis_types, full_new_idx, ctx.start+1000)
ctx.start = lc.start
return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True)
def lower_reduce_axis(ctx: IndexContext, x: UOp):
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
full_new_idx = list(ctx.idxs)
for a in x.axis_arg: full_new_idx[a] = new_idxs[a]
ret = subblock(ctx, full_new_idx, x.src[0])
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple([full_new_idx[i] for i in x.axis_arg]), x.arg[0])
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
# TODO: reenable after REDUCE_AXIS is fixed
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
idx = x.st_arg.to_valid_uop(new_idxs)
used_idxs = [x for x in idx.toposort() if x in new_idxs]
real_new_idxs = []
for i in range(len(x.src[0].shape)):
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
else: real_new_idxs.append(ctx.idxs[i])
stored = subblock(ctx, real_new_idxs, x.src[1])
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
return buf.index(idx).store(stored, *used_ranges)
def fixup_wmma(ctx:IndexContext, x:UOp):
if x.tag is not None: return None
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
full_new_idx = list(ctx.idxs)
for a in x.arg[-1]: full_new_idx[a] = new_idxs[a]
srcs = subblock(ctx, full_new_idx, UOp.sink(*x.src)).src
# NOTE: this assumes these are expanded. which now shouldn't change anything
new_x_arg_m2 = tuple([tuple([(full_new_idx[a].arg[0], sz) for a,sz in v]) for v in x.arg[-2]])
new_x_arg_m1 = tuple([full_new_idx[a].arg[0] for a in x.arg[-1]])
return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1)
pm_lowerer = PatternMatcher([
# TODO: remove these hacks
# hack for old style CONST(VIEW) (now it's just VIEW(CONST))
(UPat((Ops.DEFINE_VAR, Ops.CONST), src=(UPat(Ops.VIEW, name="v"),), name="c"), lambda c,v: c.replace(src=()).view(v.arg)),
# hack for old style VALID (now it's just VIEW(CONST))
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
# consts and loads
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"),
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_valid_uop(ctx.idxs).get_valid().where(c, c.const_like(0))),
(UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"),
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(x.st_arg.to_valid_uop(ctx.idxs)),)+x.src[1:])),
# reduce/view_const
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(Ops.STORE, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_store),
(UPat(Ops.WMMA, name="x"), fixup_wmma),
# axis fixups for WMMA
(UPat((Ops.CONTRACT, Ops.UNROLL), name="x"),
lambda ctx,x: x.replace(tag=1, arg=tuple([(ctx.idxs[a].arg[0], sz) for a,sz in x.arg])) if x.tag is None else None),
])

View File

@@ -0,0 +1,26 @@
# opt opinionatedly transforms an ast into an optimized ast using either heuristics or beam search
from __future__ import annotations
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.uop.ops import AxisType
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)
class Opt:
op: OptOps
axis: int|None = None
arg: int|tuple|None = None
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
class KernelOptError(Exception): pass
def check(cond:bool, msg:str=""):
if not cond: raise KernelOptError(msg)

View File

@@ -0,0 +1,188 @@
import itertools
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, AMX
from tinygrad.dtype import ImageDType
from tinygrad.uop.ops import Ops, resolve, AxisType
from tinygrad.codegen.opt.postrange import Scheduler
def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# first try the tensor cores
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
Keyword arguments:
use_tensor_cores -- controls how tensor cores are applied (default 1)
0: will disable any tensor core matching
1: enable tensor cores
2: apply tensor core shape but don't use UOp.WMMA
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
[0-N]: uses only the n'th tensor core available; useful for search
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
"""
# NOTE: unless TC_OPT is > 0, we only trigger tensor cores if there's only one reduce axis
if USE_TC > 0 and (len(k.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (TC_OPT.value >= 1)):
good_tc_opt = False
try: # check TC first and apply hand-coded opts if successful
tk = k.copy()
rngs = tk.apply_opt(Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, USE_TC.value)))
good_tc_opt = True
except KernelOptError:
pass
if good_tc_opt:
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
if rngs is not None and not AMX:
for tc_dim in [1,0]: # attempt to upcast M and N
szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]
if szs:
# set it to the replaced range
rngs[tc_dim] = tk.apply_opt(Opt(OptOps.UPCAST, tk.rngs.index(rngs[tc_dim]), szs[0]))[0]
if (szs := [sz for sz in [4,2] if rngs[0].src[0].divides(sz) is not None]): # attempt to local N
tk.apply_opt(Opt(OptOps.LOCAL, tk.rngs.index(rngs[0]), szs[0]))
return tk
# make a copy so it does not mutate the input
k = k.copy()
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
idx0, idx1 = mulop.src[0].src[0].src[1].get_idx(), mulop.src[1].src[0].src[1].get_idx()
first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0]
if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges):
for global_idx in k.axes_of(AxisType.GLOBAL):
if first_reduce_rng.src[0].divides(MV_THREADS_PER_ROW) is not None and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3:
print(f"MATVEC: {k.full_shape=} {first_reduce_rng.render()} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
return k
# are we grouping? (requires local shape support)
if resolve(prod(k.output_shape[i] for i in k.upcastable_dims) <= 2048, False):
for sz in [16]:
try:
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
break
except KernelOptError: pass
# upcast float4 images
for buf_index,buf in enumerate(k.bufs):
if isinstance(buf.src[0].dtype, ImageDType):
# part of real_strides
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
if len(unit_stride_axes_mul_4):
if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
elif axis in k.unrollable_dims:
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4))
# no more opt if we are grouping
if k.group_for_reduces: return k
# **** below this line need to be optional and benchmarked ****
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
to_upcast: list[int] = []
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
for axis in k.upcastable_dims:
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
is_masked = any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE)
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
# potentially do more upcasts of non reduce axes based on a heuristic
is_dsp = k.opts is not None and k.opts.device == "DSP"
upcasted_axis: set[int] = set()
while resolve(prod(k.output_shape[i] for i in k.upcastable_dims) >= 1024):
xb_choices = []
# consider all upcastable axes with 3 or 4 upcast (128 on the DSP)
for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
# if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
rng = k.rngs[axis]
if any(rng not in b.src[1].get_idx().parents and all(r2 in b.src[1].get_idx().parents
for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs):
num_strides, sum_strides = 0, 0
for b in k.bufs:
idx = b.src[1].get_idx()
if rng in idx.parents: num_strides += 1
for c in idx.split_uop(Ops.ADD):
if c is rng: sum_strides += 1
if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg
if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg
xb_choices.append((num_strides, sum_strides, axis, upcast_amount))
if xb_choices:
xb_choices = sorted(xb_choices)
if DEBUG >= 4: print(f"more upcast axis : {xb_choices}")
k.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
upcasted_axis.add(xb_choices[0][2])
else: break
# if last reduce dim is small(ish), loop unroll the reduce
# NOTE: this can fail on multireduce with mismatching dimensions, this is okay
try:
upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL))
if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64):
if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32:
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
# if it's small, upcast a second reduce dimension too
if k.unrollable_dims and s <= 3 and k.full_shape[k.unrollable_dims[-1]] <= 3:
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
else:
for splits in [4]:
if k.full_shape[axis:=k.unrollable_dims[-1]]%splits == 0:
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, splits))
break
except KernelOptError: pass
# if nothing at all is upcasted and it's easy to, do an upcast
for splits in [4]:
# TODO: somehow this never hits a reduce
if not k.upcasted and k.upcastable_dims and k.full_shape[k.upcastable_dims[-1]] % splits == 0:
k.apply_opt(Opt(OptOps.UPCAST, k.upcastable_dims[-1], splits))
# **** local groups ****
if k.opts.has_local:
if NOLOCALS:
k.apply_opt(Opt(OptOps.NOLOCALS))
else:
# prioritize making expand axes local
local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].get_idx().parents for b in k.bufs), axis) \
for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST]
to_local: list[tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
local_size = prod(sz for _, sz in to_local)
local_sz: int|None = next((x for x in ([32] * (axis == 0) + [16,8,4,3,2]) if k.full_shape[axis] % x == 0 and local_size * x <= 128), None)
if local_sz is not None: to_local.append((axis, local_sz))
deleted_shape = 0
for axis, local_sz in sorted(to_local[:3]):
axis = axis - deleted_shape
will_delete_shape = local_sz == k.full_shape[axis]
k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
if will_delete_shape: deleted_shape += 1
# **** threading ****
if k.opts.has_threads and k.opts.global_max is not None:
for threads in [32,16,12,8,6,5,4,3,2]:
# Skip is too many threads. Heuristic: use about 128K ops per thread
if threads > k.opts.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue
for axis in k.axes_of(AxisType.LOOP):
if k.full_shape[axis] % threads == 0:
k.apply_opt(Opt(OptOps.THREAD, axis, threads))
break
if k.applied_opts and k.applied_opts[-1].op is OptOps.THREAD: break
return k

View File

@@ -0,0 +1,334 @@
from __future__ import annotations
import math, itertools
from collections import defaultdict
from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
from tinygrad.device import Buffer
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
class Scheduler:
def __init__(self, ast:UOp, opts:Renderer):
self.ast, self.opts = ast, opts
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
@property
def rngs(self):
# always in order by axistype
return sorted([u for u in self.ast.parents if u.op is Ops.RANGE and u.vmax > 0], key=lambda x: (axis_to_pos[x.arg[-1]],) + x.arg[0:-1])
@property
def shape_len(self): return len(self.rngs)
@property
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
@property
def axis_types(self): return [x.arg[-1] for x in self.rngs]
@property
def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0)
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
def shape_str(self) -> list[str]:
ret: list[str] = []
cnt: dict[AxisType, int] = {}
for x in self.axis_types:
cnt[x] = (cnt[x] + 1) if x in cnt else 0
ret.append(f"{axis_letters[x]}{cnt[x]}")
return ret
def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
def copy(self):
ret = Scheduler(self.ast, self.opts)
ret.dont_use_locals = self.dont_use_locals
ret.applied_opts = self.applied_opts[:]
return ret
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
def get_optimized_ast(self, name_override:str|None=None):
if name_override is not None: name = name_override
else:
kernel_type = "r" if self.reduceop is not None else "E"
name = kernel_type + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), color) for x,color in zip(self.rngs, self.colors())])
Scheduler.kernel_cnt[(function_name := to_function_name(name))] += 1
num = f"n{Scheduler.kernel_cnt[function_name]-1}" if Scheduler.kernel_cnt[function_name] > 1 else ""
name += colored(num, 'BLACK')
self.ast = graph_rewrite(self.ast, pm_flatten_range, name="flatten range")
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
def _globalizable_rngs(self) -> list[UOp]:
store_rngs = self.ast.src[0].src[2:]
# filter any not in local stores
local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \
or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)]
for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else []
def convert_loop_to_global(self):
if not self.opts.has_local: return None
globalizible_rngs = self._globalizable_rngs()
rng = [x.replace(arg=(x.arg[0], AxisType.GLOBAL)) if x in globalizible_rngs else x for x in self.rngs]
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
def colors(self) -> list[str]: return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types]
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
if (old_sz:=rng.src[0].divides(amount)) is None:
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[0]} {amount} {str(new_type).split('.')[1].lower()}")
return replaced_rng, new_rng
def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type]
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in axis_type]
# copied from kernel.py
@property
def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
if isinstance(s:=self.full_shape[i], int) and s > 1]
@property
def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
if isinstance(s:=self.full_shape[i], int) and s > 1]
def real_axis(self, op:OptOps, axis:int|None):
try:
if axis is None or op is OptOps.TC: return -1
if op is OptOps.UNROLL: return self.unrollable_dims[axis]
if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}")
return axis
except IndexError as e: raise KernelOptError from e
def apply_opt(self, opt:Opt, append_opt:bool=True):
if opt.op is OptOps.NOLOCALS:
check(all(x not in {AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE} for x in self.axis_types), "no locals can't have locals")
if append_opt: self.applied_opts.append(opt)
self.dont_use_locals = True
return
if opt.op in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}:
check(self.opts.has_local, "locals needed for opt")
rng = self.rngs[real_axis] if (real_axis:=self.real_axis(opt.op, opt.axis)) >= 0 else UOp(Ops.NOOP)
opt_to_at = {
OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST,
OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE,
OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD}
ret = None
if opt.op in opt_to_at:
amt:int = int(rng.vmax+1) if opt.arg == 0 else cast(int, opt.arg)
# copied from kernel.py. prevents METAL compiler hangs
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
upcast_local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)])
smem_sz = amt*upcast_local_sz*self.reduceop.dtype.itemsize
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
if opt.op is OptOps.UNROLL:
check(amt <= 32, "don't unroll more than 32")
check(rng.arg[-1] in {AxisType.GROUP_REDUCE, AxisType.REDUCE}, "unroll is for GROUP_REDUCE/REDUCE")
if opt.op is OptOps.UPCAST:
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP}, f"upcast is for GLOBAL/LOCAL/LOOP, not {rng.arg[-1]}")
if opt.op is OptOps.LOCAL:
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals")
if opt.op is OptOps.THREAD:
check(self.opts is not None and self.opts.has_threads, "target does not support threads")
check(self.opts is not None and self.opts.global_max is not None and amt <= self.opts.global_max[0], "too many threads")
check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded")
check(rng in self._globalizable_rngs(), "can't apply range to this dim")
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}:
check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong?
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
elif opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
check(opt.axis is not None, "tensor core opts must have an axis")
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt)
except ValueError as e: raise KernelOptError(str(e))
check(ret is not None, "no tensor core available")
elif opt.op is OptOps.PADTO:
check(rng.src[0].op is Ops.CONST, "only pad const axes")
check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong?
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg))
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
replaced_rng = UOp.range(new_sz, *rng.arg)
replaces = {rng:replaced_rng}
valid = replaced_rng < rng.vmax+1
for b in self.bufs:
if rng in (i:=b.src[1].get_idx()).sparents:
replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}")
elif opt.op is OptOps.SWAP:
try:
altrng = self.rngs[opt.arg]
except IndexError:
raise KernelOptError
check(rng.arg[-1] == AxisType.GLOBAL and altrng.arg[-1] == AxisType.GLOBAL, "swap only for globals")
self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1),
altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)})
self.ast = graph_rewrite(self.ast, remove_tags)
else:
raise KernelOptError(f"unsupported opt {opt.op}")
if append_opt: self.applied_opts.append(opt)
return ret
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]:
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore")
reduceop = reduceops[0]
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
if mul.op is not Ops.MUL: return None
in0, in1 = mul.src
try:
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
except IndexError:
raise KernelOptError(f"invalid tensor core choice {tc_select}")
for tc in tensor_cores:
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
# tensor cores have three ranges. X, Y, and REDUCE
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0])
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
if DEBUG >= 3:
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue
# pick ranges
# NOTE: why are in1 and in0 switched?
axis_choices = list(itertools.product(in1_ranges, in0_ranges, red_ranges))
if not (axis < len(axis_choices)): continue
axes = list(axis_choices[axis])
# do optimizations and save the ranges
try:
for i,a in enumerate(axes):
idx = self.rngs.index(a)
if (a.vmax+1) % tc.dims[i] != 0:
if opt_level < 2: raise KernelOptError("tc padding requires opt_level >= 2")
# apply_opt should return the updated range?
self.apply_opt(Opt(OptOps.PADTO, idx, tc.dims[i]), append_opt=False) # PADTO might fail
axes[i] = self.rngs[idx]
except KernelOptError: continue
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
warp = UOp.range(tc.threads, -1, AxisType.WARP)
ne: list[UOp] = []
for opt in tc.opts:
if opt[0] == "l":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2)
warp //= 2
elif opt[0] == "u":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
ne.append(new_range)
for _, amt in tc.get_reduce_axes():
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
ne.append(new_range)
if use_tensor_cores != 2:
# fix the srcs
reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0]
tne = [x.replace(tag=1) for x in ne]
ret = reduceop.substitute(dict(zip(ne, tne)))
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
srcs = [x.substitute(dict(zip(tne, [ne[i] for i in argsort(p)]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))]
# get reduce/upcast axes for the tensor cores
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
# axes to range number (was done in lowerer)
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
# construct the op
# TODO: remove tc_upcast_axes from the arg
# do the reduce_axes always disappear? i think they don't
# they need to be moved into the WMMA srcs
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, ()) #, tc_reduce_axes)
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0], tag=1),
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1], tag=1),
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg, tag=1)
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2], tag=1)
# preserve extra reduces
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD)
self.ast = self.ast.substitute({reduceop: tc_uop})
return axes
return None
# helpers for hand_coded_optimizations
@property
def reduceop(self) -> UOp|None:
red = [x for x in self.ast.parents if x.op is Ops.REDUCE]
if not len(red): return None
return UOp(Ops.REDUCE_AXIS, red[0].dtype, red[0].src, (red[0].arg, ()))
@property
def bufs(self) -> list[UOp]: return [x for x in self.ast.toposort() if x.op is Ops.INDEX][::-1]
@property
def output_shape(self):
return [s if at not in {AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE} else 1 for s,at in zip(self.full_shape, self.axis_types)]
@property
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
@property
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
glbls = sorted([x for x in ast.parents if x.op is Ops.DEFINE_GLOBAL], key=lambda x: x.arg)
return [Buffer(dname, x.ptrdtype.size, x.dtype.base if not isinstance(x.dtype, ImageDType) else x.dtype) for x in glbls]
def apply_opts(ctx:Renderer, ast:UOp):
if ast.tag is not None: return None
k = Scheduler(ast, ctx)
k.convert_loop_to_global()
if ast.arg is not None and ast.arg.opts_to_apply is not None:
for opt in ast.arg.opts_to_apply: k.apply_opt(opt)
elif BEAM >= 1:
from tinygrad.codegen.opt.search import beam_search
rawbufs = bufs_from_ast(ast, ctx.device)
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
if all(len(u.src) == 1 for u in ast.parents if u.op is Ops.LOAD):
k = hand_coded_optimizations(k)
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
pm_postrange_opt = PatternMatcher([
(UPat(Ops.SINK, name="ast"), apply_opts),
])

View File

@@ -0,0 +1,183 @@
from typing import cast
import functools, math, time, multiprocessing, traceback, signal, atexit
from dataclasses import replace
from tinygrad.uop.ops import sym_infer, AxisType, pyrender
from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
from tinygrad.helpers import IGNORE_BEAM_CACHE
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
from tinygrad.tensor import Tensor
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.opt.postrange import Scheduler
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
if getenv("BEAM_PADTO", 0): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0, getenv("TC", 1)))]
# covers resnet kernels (3 global * 3 reduce)
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("TC", 1))) for axis in range(9)]
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
actions += [Opt(op=OptOps.THREAD, axis=axis, arg=amt) for amt in [2,3,4,5,8,12,16,24,32,64] for axis in range(3)]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
def get_test_global_size(global_size, max_global_size, var_vals):
test_global_size = [sym_infer(sz, var_vals) for sz in global_size]
input_size = prod(test_global_size)
while prod(test_global_size) > max_global_size:
for j in range(len(global_size)-1,-1,-1):
if test_global_size[j] > 16:
test_global_size[j] //= 2
break
return test_global_size, input_size / prod(test_global_size)
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
factor = 1
if allow_test_size and p.global_size is not None and max_global_size is not None:
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
p = replace(p, global_size=global_size)
try: car = CompiledRunner(p, precompiled=lib)
except AssertionError: return [math.inf] * cnt
tms = []
input_bufs = [rawbufs[i] for i in car.p.globals]
for _ in range(cnt):
if clear_l2:
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
else:
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
if early_stop is not None and early_stop < min(tms): break
return tms
class TimeoutException(Exception): pass
def timeout_handler(signum, frame):
if DEBUG >= 2: print("*** BEAM COMPILE TIMEOUT")
raise TimeoutException()
def _try_compile_linearized_w_idx(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]:
if hasattr(signal, "alarm"):
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
# set timeout
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
ret = None
try:
p = get_program(x[1].copy().get_optimized_ast(name_override="test"), x[1].opts)
assert p.uops is not None, "uop list wasn't generated?"
if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
raise RuntimeError("too many uops")
st = time.perf_counter()
prog = compiler.compile(p.src)
et = time.perf_counter() - st
ret = (p, prog, et)
except RuntimeError:
if DEBUG >= 4: traceback.print_exc()
except Exception as e:
if getenv("BEAM_STRICT_MODE"): raise e
finally:
if hasattr(signal, "alarm"): signal.alarm(0)
return x[0], ret
# workers should not open devices and should ignore ctrl c and should not launch VIZ
def _init_worker():
Context(ALLOW_DEVICE_USAGE=0, VIZ=0, TRACK_MATCH_STATS=0).__enter__()
signal.signal(signal.SIGINT, signal.SIG_IGN)
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() if buf is not None else buf for buf in bufs]
# *** external API ***
# get dictionary of all possible actions
def get_kernel_actions(lin:Scheduler, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Scheduler]:
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
kernel_actions = (actions if candidates is None else candidates).copy()
for i,a in enumerate(kernel_actions):
if a.axis is not None and a.op is not OptOps.TC:
try: ax = lin.real_axis(a.op, a.axis)
except KernelOptError: continue
if (ax >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, a.axis, 0) in kernel_actions): continue
lin2 = lin.copy()
try:
lin2.apply_opt(a)
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if hasattr(lin2, 'tensor_core') and (tc:=lin2.tensor_core) else 1
for s,c in zip(lin2.full_shape, lin2.axis_types):
if c in (AxisType.UPCAST, AxisType.UNROLL): up *= s
elif c in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE): lcl *= s
if up//tc_up > max_up or lcl > max_lcl:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}")
continue
acted_lins[i+1] = lin2
except KernelOptError: pass
return acted_lins
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value):
global beam_pool
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
return ret
beam: list[tuple[Scheduler, float]] = [(lin, float("inf"))]
seen_libs = set()
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
@atexit.register
def close_pool(): beam_pool.close()
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
if BEAM_DEBUG:
print("BEAM_SEARCH:")
print('\n'.join(pyrender(lin.ast.replace(arg=None))))
if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
try:
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: dict[str, int] = {k.expr:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
exiting, st = False, time.perf_counter()
dev = Device[lin.opts.device]
while not exiting:
acted_lins: list[Scheduler] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
timed_lins: list[tuple[Scheduler, float]] = []
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
least_compute_ops = math.inf
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
if proc is None: continue
p, lib, compile_et = proc
if lib in seen_libs: continue
# filter out kernels that use 1000x more compute than the smallest
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
if least_compute_ops*1000 < this_compute_ops: continue
seen_libs.add(lib)
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0,
allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'))
except Exception as e:
if BEAM_DEBUG: print(f"BEAM failed for opts: {acted_lins[i].applied_opts}\n{e}")
if isinstance(e, RuntimeError): continue
raise
timed_lins.append((acted_lins[i], min(tms)))
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
# done
opts = sorted(timed_lins, key=lambda x: x[1])
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
if not exiting: beam = opts[:amt]
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
except KeyboardInterrupt as e:
if beam_pool is not None: beam_pool.terminate()
raise e
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}")
return beam[0][0]

View File

@@ -0,0 +1,135 @@
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
from tinygrad.helpers import all_same, prod, unwrap, colored
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
from tinygrad.dtype import ImageDType, dtypes
merge_views = PatternMatcher([
# merge adjacent views
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
# replace MovementOps with VIEW
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
# remove NOOP views
(UPat.var("x").view(name="view"),
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
# only unmaksed VIEW on CONST replaces the ShapeTracker
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
lambda x,view: x.replace(src=(UOp(Ops.VIEW, x.dtype, x.src, view.arg),)) if all(v.mask is None for v in view.st.views) else None),
])
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
# contiguous, expand, and the same with ones removed
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
new_shape: list[sint] = []
new_reduce_axis = []
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
for i,pairs in enumerate(contraction):
new_shape_chunk = [view.shape[p] for p in pairs]
if i in r.arg[1]:
# if this is a reduce axis, we need a 1 in the view here to put it
assert len(new_shape_chunk) > 0
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
new_reduce_axis.append(len(new_shape)-1)
else:
# otherwise, pass through the new_shape_chunk
new_shape += new_shape_chunk
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
return ret
return None
view_left = merge_views+PatternMatcher([
# view before elementwise and buffer ops
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
# if there's ones added after reduce, put this before the reduce
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
])
view_left_through_load = PatternMatcher([
# view before load
(UPat(Ops.VIEW, src=(UPat(Ops.LOAD, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
])
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
# contiguous and same size can push to children
# if there's a reduce child, shapes match with ones removed
if unwrap(view.st).contiguous and view.size == r.size and \
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
return None
# swizzle the input
input_st = ShapeTracker.from_shape(src.shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
new_view = tmp + ShapeTracker(tuple(nv))
swizzled_input = apply_swizzle(src.view(new_view))
# create a new reduceop
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
return red.reshape(view.shape)
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
def elementwise_view_right(root:UOp):
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# place view after applying the elementwise op
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
# reshape to match downstream shapes
return root.replace(src=tuple(new_src)).reshape(root.shape)
# push VIEW to children
view_right = merge_views+PatternMatcher([
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
# remove view from sink
(UPat(Ops.VIEW, name="v").sink(name="sink"), lambda v,sink: v.src[0].sink(arg=sink.arg)),
])
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = view_left_through_load+PatternMatcher([
# add view to LOAD and STORE
(UPat(Ops.DEFINE_GLOBAL, name="g").load(), lambda g: g.view(g.st).load()),
(UPat(Ops.DEFINE_GLOBAL, name="g").store(UPat.var('x')), lambda g,x: g.view(g.st).store(x)),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
# no ImageDType after index
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW, Ops.INDEX}, name="x"),
lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
])

136
tinygrad/codegen/opt/tc.py Normal file
View File

@@ -0,0 +1,136 @@
import math, functools
from dataclasses import dataclass
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import getenv
@dataclass(frozen=True)
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
dims: tuple[int,int,int] # N, M, K
threads: int # number of threads that construct the warp
elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
dtype_in: DType # dtype for A and B
dtype_out: DType # dtype for C and D
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifying kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
# (local_swizzle, upcast_swizzle, reduce_swizzle)
# l<num> is the num axis of the locals, similar for u<num> and upcasts, r<num> and reduces
swizzle: tuple[tuple[tuple[str, ...], tuple[str, ...], tuple[str, ...]], tuple[tuple[str, ...], tuple[str, ...], tuple[str, ...]]]
@functools.cache # pylint: disable=method-cache-max-size-none
def _remaps(self) -> list[dict[str, str]]:
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
fwd_st = [f"l{i}" for i in range(local_axes)] + [f"u{i}" for i in range(upcast_axes)] + [f"r{i}" for i in range(reduce_axes)]
return [dict(zip(fwd_st, sum(s, ()))) for s in self.swizzle]
def permutes_for_shape_str(self, shape_str:list[str]) -> tuple[tuple[int, ...], tuple[int, ...]]:
ret = [[shape_str.index(remap[ss]) if ss in remap else i for i,ss in enumerate(shape_str)] for remap in self._remaps()]
return tuple(ret[0]), tuple(ret[1])
@functools.cache # pylint: disable=method-cache-max-size-none
def base_shape_str(self) -> list[str]:
ret = []
cnt = {'u': 0, 'l': 0}
for opt in self.opts:
ret.append(f"{opt[0]}{cnt[opt[0]]}")
cnt[opt[0]] += 1
# assumes you do the UNROLL after the opts
return ret + [f"r{i}" for i in range(len(self.get_reduce_axes()))]
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
def base_upcast_axes(self):
# this is defined in the swizzle. first we use the upcast axes, then the reduce
return ([f"r{i}" for i in range(len(self.get_reduce_axes()))] + [f"u{i}" for i in range(len(self.get_upcast_axes()))])[::-1]
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
def __post_init__(self):
# all axes have size 2, <local> <reduce> <upcast> is the order
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), \
f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})"
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
assert 2**upcast_axes == self.elements_per_thread[2], \
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}"
# check dims match opts
assert self.dims[0] == 2**len(gd:=[x for x in self.opts if x[1] == '0']), f"opts wrong on dims[0], {self.dims[0]} vs {gd}"
assert self.dims[1] == 2**len(gd:=[x for x in self.opts if x[1] == '1']), f"opts wrong on dims[1], {self.dims[1]} vs {gd}"
# NOTE: the K opts is implictly set by the dim
# check swizzle
assert len(self.swizzle[0]) == 3 and len(self.swizzle[1]) == 3, "swizzle has wrong part count"
assert len(self.swizzle[0][0]) == len(self.swizzle[1][0]) == local_axes, "local swizzle size is wrong"
assert len(self.swizzle[0][1]) == len(self.swizzle[1][1]) == upcast_axes, "upcast swizzle size is wrong"
assert len(self.swizzle[0][2]) == len(self.swizzle[1][2]) == reduce_axes, "reduce swizzle size is wrong"
assert all(len(s) == local_axes+upcast_axes+reduce_axes for s in self._remaps()), "remaps are the wrong size"
# check elements_per_thread
un, ln = 0, 0
zero_stride_0 = []
zero_stride_1 = []
for o in self.opts:
if o[1] == '0': zero_stride_0.append(o[0] + str(un if o[0] == 'u' else ln))
if o[1] == '1': zero_stride_1.append(o[0] + str(un if o[0] == 'u' else ln))
if o[0] == 'u': un += 1
if o[0] == 'l': ln += 1
# NOTE: all the zero_stride dims can be placed in any order in the swizzle
upcasted_0 = [x for x in (self.swizzle[0][1] + self.swizzle[0][2]) if x not in zero_stride_0 and x[0] != 'l']
upcasted_1 = [x for x in (self.swizzle[1][1] + self.swizzle[1][2]) if x not in zero_stride_1 and x[0] != 'l']
assert 2**len(upcasted_0) == self.elements_per_thread[0], f"mismatch in elements_per_thread[0], {upcasted_0} vs {self.elements_per_thread[0]}"
assert 2**len(upcasted_1) == self.elements_per_thread[1], f"mismatch in elements_per_thread[1], {upcasted_1} vs {self.elements_per_thread[1]}"
# ***** NVIDIA *****
cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
cuda_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('u1', 'r3'), ('l0', 'l1', 'u0', 'r0')),
(('r1', 'r2', 'u0', 'l0', 'l1'), ('r0', 'r3'), ('l2', 'l3', 'l4', 'u1'))))
for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float), (dtypes.half,dtypes.half)]]
cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('r0', 'u1'), ('l0', 'l1', 'u0')),
(('r1', 'r2', 'u0', 'l0', 'l1'), ('u1', 'r0'), ('l2', 'l3', 'l4'))))
for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')),
(('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))]
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16
if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32
cuda_sm75: list[TensorCore] = cuda_8168_f16
# ***** AMD *****
# https://gpuopen.com/learn/wmma_on_rdna3/
amd_rdna3 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
opts=("l0","l0","l0","l0","l1","u1","u1","u1"),
swizzle=((('l4', 'u0', 'u1', 'u2', 'l0'), ('r1', 'r2', 'r3'), ('l1', 'l2', 'l3', 'r0')),
(('l0', 'l1', 'l2', 'l3', 'l4'), ('r1', 'r2', 'r3'), ('u0', 'u1', 'u2', 'r0'))))
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float)]]
amd_rdna4 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(8,8,8), dtype_in=di, dtype_out=do,
opts=("l0","l0","l0","l0","u1","u1","u1","l1"),
swizzle=((('u0', 'u1', 'u2', 'l4', 'r2'), ('r0', 'r1', 'r3'), ('l0', 'l1', 'l2', 'l3')),
(('l0', 'l1', 'l2', 'l3', 'r2'), ('r0', 'r1', 'r3'), ('l4', 'u0', 'u1', 'u2'))))
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
# https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme
amd_cdna = [TensorCore(dims=(16,16,16), threads=64, elements_per_thread=(4,4,4), dtype_in=di, dtype_out=do,
opts=("l0","l0","l0","l0","u1","u1","l1","l1"),
swizzle=((('u0', 'u1', 'l4', 'l5', 'r2', 'r3'), ('r0', 'r1'), ('l0', 'l1', 'l2', 'l3')),
(('l0', 'l1', 'l2', 'l3', 'r2', 'r3'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1'))))
for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
# ***** Apple Metal *****
metal = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do,
opts=("u0","l0","l1","l1","l0","l1"),
swizzle=((('r1', 'l1', 'l2', 'r2', 'l4'), ('r0',), ('u0', 'l0', 'l3')),
(('l0', 'r0', 'r1', 'l3', 'r2'), ('u0',), ('l1', 'l2', 'l4'))))
for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
# ***** Apple AMX *****
amx = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
swizzle=(((), ('u0', 'u1', 'u2', 'u3', 'u4', 'u5', 'u6', 'u7'), ()),
((), ('u4', 'u5', 'u6', 'u7', 'u0', 'u1', 'u2', 'u3'), ())),
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
# ***** Intel ****
intel = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
opts=("l0","l0","l0","u1","u1","u1"),
swizzle=((('r1', 'r2', 'r3'), ('u0', 'u1', 'u2'), ('l0', 'l1', 'l2', 'r0')),
(('l0', 'l1', 'l2'), ('r1', 'r2', 'r3'), ('u0', 'u1', 'u2', 'r0'))))]

View File

@@ -0,0 +1,67 @@
from tinygrad.dtype import dtypes, least_upper_dtype
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
from tinygrad.uop.symbolic import symbolic
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
# this is badly tested and low quality. remove it?
FP = (1 << 15)
pm_quant = symbolic+PatternMatcher([
# cast after add/mul
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
# masked MUL after masked ADD
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
# MUL after reduce
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c.arg),
# CAST after reduce (doesn't work if it's a size change)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
# mul 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
# mul (with plus) 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
lambda ld,v,c1: ld*c1),
# const push through add
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
# fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
# where move
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
# where on two adds
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
# split REDUCE into multiple reduces (who remembers FOIL?)
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2"),), name="r"),
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
])

View File

@@ -0,0 +1,120 @@
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start
from tinygrad.uop.symbolic import symbolic_flat, sym
from tinygrad.helpers import partition
from tinygrad.dtype import dtypes
def flatten_range(r:UOp):
off = range_start[r.op]
rngs = r.src[off:]
if not len(rngs): return None
new_rngs = [x for x in UOp.sink(*rngs).toposort() if x.op is Ops.RANGE]
return r.replace(src=r.src[:off]+tuple(new_rngs))
pm_flatten_range = PatternMatcher([
# real ranges only
(UPat((Ops.REDUCE, Ops.STORE), name="r"), flatten_range),
])
def count_divmod(x:UOp): return len([u for u in x.toposort() if u.op in {Ops.IDIV, Ops.MOD}])
def simplify_merge_adjacent(u:UOp) -> UOp|None:
i = range_start[u.op]
while i < len(u.src)-1:
r0, r1 = u.src[i], u.src[i+1]
# check same type
if r0.arg[-1] == r1.arg[-1]:
s0, s1 = r0.src[0], r1.src[0]
# do the merge
new_range = r0.replace(src=(s0*s1,))
nidx = graph_rewrite(u, _substitute+symbolic_flat+pm_flatten_range, ctx={r0:new_range//s1, r1:new_range%s1},
name=f"check_merge_{r0.arg[0]}_{r1.arg[0]}")
# check if it simplifies
if count_divmod(nidx) <= count_divmod(u):
u = nidx
continue
i += 1
return u
pm_simplify_ranges = PatternMatcher([
(UPat((Ops.STORE, Ops.REDUCE), name="u"), simplify_merge_adjacent),
])
# **** reduce simplification ****
def reduce_rangeless(red:UOp):
# TODO: share code with reduce_unparented
if red.arg not in {Ops.ADD, Ops.MAX}: return None
if red.src[0].dtype != red.dtype: return None
if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None
ret = red.src[0]
if red.arg is Ops.ADD:
for r in red.src[1:]:
ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
pm_reduce_collapse = PatternMatcher([
# lift x+y out of reduce on lt
((UPat.var("x")+UPat.var("y")).or_casted() < UPat.var("c"), lambda x,y,c: (x < (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None),
# lift x*y out of reduce
((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
# lift x+y out of reduce on ne
((UPat.var("x")+UPat.var("y")).or_casted() != UPat.var("c"), lambda x,y,c: (x != (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None),
# fold the range
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
# REDUCE on ADD
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
# MUL casted bool
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
# WHERE on LOAD (works on max too)
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx, gate).load()),
(UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
# INDEX on RANGE / gated RANGE
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))),
# AND on WHERE
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
# remove REDUCEs that no longer have a RANGE in the src
(UPat(Ops.REDUCE, name="red"), reduce_rangeless),
])+sym
def reduce_collapse(red:UOp):
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
replaces: dict[UOp, UOp] = {}
for u in included:
for s in u.src:
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
collapse_fxn = red.substitute(replaces)
sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse")
if any(x.op is Ops.RANGE for x in sink.toposort()): return None
return sink.substitute({v:k for k,v in replaces.items()})
def reduce_unparented(red:UOp):
if red.arg not in {Ops.ADD, Ops.MAX, Ops.MUL}: return None
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
if len(reduce_unparented) == 0: return None
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
if red.arg is Ops.ADD:
for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
if red.arg is Ops.MUL:
for r in reduce_unparented: ret = ret ** r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
pm_reduce_simplify = PatternMatcher([
# remove any ranges from a REDUCE that aren't referenced in the reduce source
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
])