Release 260111
This commit is contained in:
131
tinygrad/codegen/__init__.py
Normal file
131
tinygrad/codegen/__init__.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Any, Callable
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
|
||||
from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.lowerer import pm_lowerer, get_index
|
||||
from tinygrad.codegen.quantize import pm_quant
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
pm: PatternMatcher
|
||||
ctx: Callable[[UOp], Any]|None = None
|
||||
name: str|None = None
|
||||
bottom_up: bool = False
|
||||
def __call__(self, sink:UOp):
|
||||
return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
|
||||
|
||||
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
|
||||
|
||||
rewrites_for_views = [
|
||||
RewriteStep(view_left, name="Main View Left"),
|
||||
RewriteStep(view_right, name="Main View Right"),
|
||||
RewriteStep(view_left+fix_kernel_ops, bottom_up=True, name="Finalize Kernel"),
|
||||
]
|
||||
|
||||
rewrites_for_linearizer = [
|
||||
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
|
||||
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
|
||||
RewriteStep(block_merge, name="Linearizer: Merge Blocks"),
|
||||
RewriteStep(pm_finalize, name="Linearizer: Finalize")]
|
||||
|
||||
def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]:
|
||||
# cache with the values of the context vars
|
||||
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
|
||||
|
||||
@functools.cache
|
||||
def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
|
||||
if optimize:
|
||||
# view pushing
|
||||
ret.extend(rewrites_for_views)
|
||||
|
||||
# lowerer first
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
|
||||
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
|
||||
ret.append(RewriteStep(sym+pm_flatten_range, name="initial symbolic"))
|
||||
|
||||
# optimize (schedule) the AST
|
||||
ret.append(RewriteStep(pm_simplify_ranges, name="simplify ranges"))
|
||||
ret.append(RewriteStep(pm_reduce_simplify, name="simplify reduces"))
|
||||
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing, name="postopt symbolic"))
|
||||
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+pm_pre_expander+expander, name="expander"))
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
|
||||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
|
||||
|
||||
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||
ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims"))
|
||||
|
||||
# devectorize (TODO: does this need opts?)
|
||||
if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
||||
elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
||||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
||||
ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
|
||||
|
||||
supported_ops = tuple(opts.code_for_op.keys())
|
||||
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
ret.append(RewriteStep(pm_lower_index_dtype+load_store_indexing, lambda _: opts.device, name="lower all index dtypes"))
|
||||
|
||||
# optional pre matcher
|
||||
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
|
||||
|
||||
# decompositions
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, _TRANSCENDENTAL>=2)
|
||||
ret.append(RewriteStep(pm_decomp, lambda _: opts.device, name="decompositions"))
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
|
||||
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
|
||||
|
||||
# return the list (with optional linearizer)
|
||||
return ret + (rewrites_for_linearizer if linearizer else [])
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True, linearizer:bool=False) -> UOp:
|
||||
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize, linearizer))
|
||||
|
||||
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
Function to transform the Kernel UOp graph into a linearized program.
|
||||
|
||||
Args:
|
||||
sink: The Ops.SINK rooting the Kernel graph.
|
||||
opts: The Renderer (can change how things are processed, fix this).
|
||||
|
||||
Returns:
|
||||
Linear program in UOps.
|
||||
"""
|
||||
|
||||
lst = list(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True).arg.lst)
|
||||
if __debug__: type_verify(lst)
|
||||
return lst
|
||||
94
tinygrad/codegen/gpudims.py
Normal file
94
tinygrad/codegen/gpudims.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import math
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop
|
||||
from tinygrad.helpers import all_int, dedup
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.view import get_contraction
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
||||
# TODO: symbolic shape
|
||||
if not all_int(dims): return dims
|
||||
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
||||
for i,m in enumerate(max_sizes):
|
||||
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
|
||||
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
||||
break
|
||||
else: return None
|
||||
return dims
|
||||
|
||||
def _split_dims(dims, max_sizes):
|
||||
if all(d <= m for d,m in zip(dims, max_sizes)): return dims
|
||||
_dims = list(dims) + [1]*(3-len(dims))
|
||||
for i in range(len(_dims)):
|
||||
while _dims[i] > max_sizes[i]:
|
||||
div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
|
||||
if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
||||
_dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
|
||||
return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
|
||||
|
||||
def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
|
||||
if reverse: dims = dims[::-1]
|
||||
# try to group first: (a, b, c, d) -> (ab, c, d)
|
||||
limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
|
||||
# check if grouping failed
|
||||
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
||||
# try to split up dims: (a,) -> (b, c)
|
||||
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
|
||||
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.index, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)]
|
||||
if len(limited) < len(dims):
|
||||
ret = []
|
||||
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
|
||||
for idx, contraction_group in zip(raw_idxs, contraction):
|
||||
for c in contraction_group[:-1]:
|
||||
ret.append(idx % dims[c])
|
||||
idx //= dims[c]
|
||||
ret.append(idx)
|
||||
elif len(limited) > len(dims):
|
||||
a, b = len(limited), len(dims)
|
||||
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
|
||||
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
|
||||
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
|
||||
return ret[::-1] if reverse else ret
|
||||
|
||||
def add_gpudims(ctx:Renderer, s:UOp):
|
||||
if s.arg is None: return None
|
||||
s_topo = list(s.toposort())
|
||||
if any(x.op is Ops.SPECIAL for x in s_topo): return None
|
||||
|
||||
# get ranges
|
||||
all_ranges = {x.arg[0:-1]:x for x in s_topo if x.op is Ops.RANGE}
|
||||
|
||||
# extract global/local dims
|
||||
global_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.GLOBAL, AxisType.THREAD)]))
|
||||
local_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)]))
|
||||
if not global_dims and not local_dims: return None
|
||||
|
||||
# get global and local shape
|
||||
ranges = [all_ranges[r] for r in global_dims+local_dims if r in all_ranges]
|
||||
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0:-1] in global_dims])
|
||||
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0:-1] in local_dims])
|
||||
|
||||
# get the idxs
|
||||
ki: KernelInfo = s.arg
|
||||
if ki.dont_use_locals:
|
||||
assert not local_dims, "can't use locals if there's no local dims"
|
||||
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
|
||||
else:
|
||||
# define indexes for GPU-like execution
|
||||
idxs = get_grouped_dims("gidx", global_shape, ctx.global_max, reverse=True) + get_grouped_dims("lidx", local_shape, ctx.local_max)
|
||||
|
||||
# apply to multiple ranges
|
||||
subs = {}
|
||||
for r in s_topo:
|
||||
if r.op is not Ops.RANGE: continue
|
||||
try:
|
||||
ii = (global_dims+local_dims).index(r.arg[0:-1])
|
||||
if r.arg[1] == AxisType.REDUCE: continue
|
||||
subs[r] = idxs[ii]
|
||||
except ValueError: continue
|
||||
return s.substitute(subs)
|
||||
|
||||
pm_add_gpudims = PatternMatcher([
|
||||
# add gpudims must be last
|
||||
(UPat(Ops.SINK, name="s"), add_gpudims),
|
||||
])
|
||||
306
tinygrad/codegen/late/devectorizer.py
Normal file
306
tinygrad/codegen/late/devectorizer.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from typing import Any, cast
|
||||
import functools, operator, itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.index(UOp.invalid())
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
||||
|
||||
# wait for it to be image indexed before running simplification
|
||||
if start_idx.dtype.count != 2: return None
|
||||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
for stmt in valid.split_uop(Ops.AND):
|
||||
try: X, is_upper_bound, c = parse_valid(stmt)
|
||||
except ValueError: return None
|
||||
|
||||
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
||||
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in X.split_uop(Ops.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx)
|
||||
testidx = testidx.simplify()
|
||||
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
continue
|
||||
|
||||
# if X <= c, check if it's out of bound when X = c+1
|
||||
# if X >= c, check if it's out of bound when X = c-1
|
||||
test_value = c + 1 if is_upper_bound else c - 1
|
||||
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
|
||||
if i.is_increasing():
|
||||
rw = i.substitute({X:X.const_like(test_value)}).simplify()
|
||||
if rw.vmin >= b or rw.vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
break
|
||||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx, new_valid)
|
||||
|
||||
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
||||
if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None
|
||||
# remove the gate from the index
|
||||
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:])
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
# image load valid idx simplification
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# lower turn the invalid into a gate, must come before index dtype lowering
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate,),), lambda buf,x,cond,i: buf.index(x, cond)),
|
||||
# drop true gate
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
|
||||
# remove hanging cast
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
||||
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
|
||||
])
|
||||
|
||||
# ***** load/store grouping *****
|
||||
|
||||
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
|
||||
# generate the individual indexes
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
||||
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
|
||||
# extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
for i in range(vec.dtype.count):
|
||||
idx: Any = midx.src[i].src[1]
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
|
||||
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
||||
|
||||
# then rewrite everything we can into groups
|
||||
ret = []
|
||||
idxs: list[int|None] = [None]*vec.dtype.count
|
||||
global_offset = 0
|
||||
for offsets in offsets_rootsrc.values():
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
lidx = midx.src[offsets[grp[0]][0]]
|
||||
if len(grp) > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(len(grp)).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]: idxs[oo] = global_offset+i
|
||||
# add this lidx to the CAT
|
||||
ret.append(lidx)
|
||||
global_offset += len(grp)
|
||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||
# this base thing is for image, we want the CAT to be a normal pointer
|
||||
post_cat = UOp(Ops.PTRCAT, buf.ptrdtype.base.ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
|
||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
||||
|
||||
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
||||
# TODO: this is written in many places
|
||||
offset = 0
|
||||
ret: list[UOp] = []
|
||||
for s in cat.src:
|
||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
|
||||
offset += s.dtype.count
|
||||
return UOp(Ops.NOOP, src=tuple(ret))
|
||||
|
||||
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
||||
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
||||
# fake argsort. TODO: handle duplicates
|
||||
a = {}
|
||||
for i,x in enumerate(gep.arg): a[x] = i
|
||||
new_arg = tuple(x[1] for x in sorted(a.items()))
|
||||
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
|
||||
|
||||
load_store_folding = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"),
|
||||
UPat.var("mask"))), expand_index),
|
||||
# GEP after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
||||
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
||||
# GEP on data of STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
|
||||
# put PTRCAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||
# put PTRCAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
|
||||
])
|
||||
|
||||
# *** correct load/store ***
|
||||
|
||||
def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
# this splits loads and stores into multiple chunks
|
||||
|
||||
# if there's only one element to load/store, no splitting needed
|
||||
if (sz:=ls.src[0].dtype.count) == 1: return None
|
||||
buf = idx.src[0]
|
||||
|
||||
# determine fold lengths
|
||||
lengths = []
|
||||
must_divide = True
|
||||
if ctx is not None and ctx.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif buf.ptrdtype.addrspace == AddrSpace.REG:
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
elif ctx is not None and ctx.supports_float4:
|
||||
# TODO: a better way to get this than ctx
|
||||
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
|
||||
lengths.append(1) # worst case, it's not folded
|
||||
|
||||
# filter fold lengths that don't divide
|
||||
if must_divide: lengths = [x for x in lengths if idx.src[1].divides(x) is not None]
|
||||
|
||||
# split based on the fold lengths
|
||||
global_offset = 0
|
||||
ret = []
|
||||
while global_offset < sz:
|
||||
# with 1 at the end of the lengths list, this will always hit
|
||||
for fold_length in lengths:
|
||||
if global_offset+fold_length > sz: continue
|
||||
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
||||
global_offset += fold_length
|
||||
break
|
||||
|
||||
# if it wasn't split, we return None. otherwise we CAT them
|
||||
if len(ret) <= 1: return None
|
||||
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load or store, with the CAST from expand_index
|
||||
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
|
||||
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
|
||||
idx = ls.src[0].src[0]
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
||||
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
||||
return ls.replace(src=(idx,)+ls.src[1:])
|
||||
|
||||
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
|
||||
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
|
||||
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
|
||||
idx = ls.src[0]
|
||||
id4 = idx.src[1] % 4
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
||||
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
||||
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
|
||||
|
||||
return None
|
||||
|
||||
correct_load_store = PatternMatcher([
|
||||
# split LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
|
||||
# image indexing, including unfoldable images
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
|
||||
])
|
||||
|
||||
# *** uop expander ***
|
||||
|
||||
# TODO: there's a lot shared with gep_through_wmma here
|
||||
def no_vectorized_wmma(wmma:UOp):
|
||||
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
||||
if wmma.dtype.count == out_sz: return None
|
||||
tsrcs = []
|
||||
for s,sz in zip(wmma.src, wmma.arg[6]):
|
||||
ssz = prod(x[1] for x in sz)
|
||||
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
|
||||
wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
||||
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
||||
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
def no_vectorized_alu(alu:UOp):
|
||||
if alu.dtype.vcount == 1: return None
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
||||
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
def no_vectorized_buf(buf:UOp):
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
])
|
||||
|
||||
pm_render = PatternMatcher([
|
||||
# for rendering, we use explicit VECTORIZE
|
||||
(UPat(Ops.CONST, name='c'),
|
||||
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# give any loads that are masked an alt value
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
|
||||
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
|
||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
|
||||
len(store.src) <= 2 or store.src[2].op != Ops.IF else None),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
||||
@dataclass
|
||||
class ReduceContext:
|
||||
acc_num: int = 0
|
||||
|
||||
def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
|
||||
# if this has a horizontal reduction component, do that first
|
||||
if inp.dtype != out_dtype:
|
||||
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
|
||||
horizontal_amount = inp.dtype.count//out_dtype.count
|
||||
return [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
|
||||
return [inp]
|
||||
|
||||
def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
inp, reduce_range = red.src[0], red.src[1:]
|
||||
lst = horizontal_reduce(inp, red.dtype)
|
||||
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
|
||||
# if we have a range
|
||||
if len(reduce_range) != 0:
|
||||
topo = inp.toposort()
|
||||
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
|
||||
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
|
||||
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
|
||||
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
# tensor core built in accumulate
|
||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
])+sym
|
||||
162
tinygrad/codegen/late/expander.py
Normal file
162
tinygrad/codegen/late/expander.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# this converts a lowerer program into a vectorized program
|
||||
import functools, itertools, operator
|
||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, partition
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start
|
||||
from tinygrad.schedule.rangeify import BufferizeOpts
|
||||
|
||||
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
|
||||
idx, mul = 0, 1
|
||||
for axis,m in args[::-1]:
|
||||
idx += rpk[axis] * mul
|
||||
mul *= m
|
||||
return idx
|
||||
|
||||
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
|
||||
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
||||
|
||||
@functools.cache
|
||||
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
|
||||
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
||||
|
||||
def do_expand(root:UOp):
|
||||
expands = [x for x in root.src if x.op is Ops.UNROLL]
|
||||
if len(expands) == 0: return None
|
||||
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
|
||||
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
|
||||
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
|
||||
# if there's only one expand arg, it's okay to use it (optimization)
|
||||
expand_args = expands[0].arg
|
||||
else:
|
||||
# otherwise, we sort them and GEP
|
||||
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
|
||||
expand_sz = prod([x[1] for x in expand_args])
|
||||
new_srcs = []
|
||||
for i,src in enumerate(root.src):
|
||||
if src.op is Ops.UNROLL:
|
||||
if root.op is Ops.IF and i == 0:
|
||||
# IF means OR on first arg to IF
|
||||
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
|
||||
elif expand_args == src.arg:
|
||||
# just remove the expand
|
||||
new_srcs.append(src.src[0])
|
||||
else:
|
||||
lst = _swizzle_args(expand_args, src.arg, exclude_args)
|
||||
# if the base dtype is > 1, put those at the end
|
||||
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
|
||||
new_srcs.append(src.src[0].gep(tuple(lst)))
|
||||
else:
|
||||
# non-UNROLL input
|
||||
if root.op is Ops.IF or src.op is Ops.IF:
|
||||
# for the first arg of IF, just pass them through ignoring UNROLLS
|
||||
new_srcs.append(src)
|
||||
elif root.op in range_start and i >= range_start[root.op]:
|
||||
# for any range args of STORE/REDUCE, pass them through
|
||||
new_srcs.append(src)
|
||||
elif root.op is Ops.INDEX and i >= 1 and not isinstance(root.dtype, PtrDType):
|
||||
new_srcs.append(src)
|
||||
elif src.dtype.count > 1:
|
||||
# put any input dtype > 1 grouped together
|
||||
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
|
||||
else:
|
||||
# repeat the arg
|
||||
new_srcs.append(src.broadcast(expand_sz))
|
||||
|
||||
new_arg = root.arg
|
||||
if root.op is Ops.GEP:
|
||||
assert root.dtype.count == 1
|
||||
# is this right?
|
||||
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
||||
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
||||
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
|
||||
|
||||
def do_contract(con:UOp):
|
||||
ex = con.src[0]
|
||||
# CONTRACT without UNROLL repeats the element VECTORIZED
|
||||
if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
||||
# CONTRACT may remove several axes from UNROLL
|
||||
assert con.dtype == dtypes.void or con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
|
||||
idxs = []
|
||||
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
|
||||
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
|
||||
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
||||
|
||||
expander = PatternMatcher([
|
||||
# double expand
|
||||
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
||||
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# BARRIERs aren't actually expanded
|
||||
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
||||
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
|
||||
# empty UNROLL is NOOP
|
||||
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
|
||||
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
|
||||
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
||||
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
||||
])
|
||||
|
||||
def create_gate(root:UOp) -> UOp|None:
|
||||
@functools.cache
|
||||
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
||||
if u.op is Ops.BARRIER: return u
|
||||
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
|
||||
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg)
|
||||
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
|
||||
idx = root.src[0]
|
||||
if idx.op is Ops.CAST: idx = idx.src[0]
|
||||
return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
|
||||
|
||||
migrate_indexing = PatternMatcher([
|
||||
# create gate MUST BE BEFORE expander
|
||||
(UPat(Ops.STORE, name="root"), create_gate),
|
||||
])
|
||||
|
||||
# ****
|
||||
|
||||
def fix_reduce_unroll(x:UOp):
|
||||
reduce_range, reduce_expand = partition(x.src[1:], lambda y: y.op is Ops.RANGE)
|
||||
if len(reduce_expand) == 0: return None
|
||||
reduce_expand = [x for x in reduce_expand if x.op is not Ops.CONST]
|
||||
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand}"
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1)
|
||||
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
|
||||
return x.replace(src=(ret,)+tuple(reduce_range))
|
||||
|
||||
def fix_store_unroll(x:UOp):
|
||||
store_expand, store_range = partition(x.src[2:], lambda y: y.op is Ops.UNROLL)
|
||||
if len(store_expand) == 0: return None
|
||||
return UOp(Ops.CONTRACT, dtypes.void, (x.replace(src=x.src[:2]+tuple(store_range)),), tuple(flatten(x.arg for x in store_expand)), tag=1)
|
||||
|
||||
def fix_group_for_reduce(x:UOp):
|
||||
reduce_gfr, reduce_r = partition(x.src[1:], lambda u: u.op is Ops.RANGE and u.arg[1] == AxisType.GROUP_REDUCE)
|
||||
if len(reduce_gfr) == 0: return None
|
||||
|
||||
# NOTE: if there's other locals here, we need them in the buffer too
|
||||
upstream_locals = [u for u in x.toposort() if u.op is Ops.RANGE and u.arg[1] == AxisType.LOCAL]
|
||||
|
||||
# do only the non grouped reduces early
|
||||
ret = x.replace(src=(x.src[0],)+tuple(reduce_r))
|
||||
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr]
|
||||
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop)
|
||||
|
||||
# gate with an if on the store + do the final reduce
|
||||
buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf))
|
||||
return buf.reduce(*reduce_loop, arg=x.arg)
|
||||
|
||||
pm_pre_expander = PatternMatcher([
|
||||
# rewrite UPCAST/UNROLL range to something to be expanded
|
||||
(UPat(Ops.RANGE, name="r"),
|
||||
lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \
|
||||
if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None),
|
||||
# fix REDUCEs with UNROLLs
|
||||
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
|
||||
(UPat(Ops.STORE, name="x"), fix_store_unroll),
|
||||
# fix group for reduce
|
||||
(UPat(Ops.REDUCE, name="x"), fix_group_for_reduce),
|
||||
])
|
||||
243
tinygrad/codegen/late/linearize.py
Normal file
243
tinygrad/codegen/late/linearize.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from __future__ import annotations
|
||||
import heapq
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp, BottomUpGate
|
||||
from tinygrad.helpers import dedup, all_same, flatten, BLOCK_REORDER
|
||||
|
||||
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
||||
def block_reorder(lst:list[UOp]) -> list[UOp]:
|
||||
in_this_block = set(lst)
|
||||
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree:dict[UOp, int] = {}
|
||||
priorities:dict[UOp, int] = {}
|
||||
|
||||
# get local children and assign priorities
|
||||
# NOTE: this requires the lst be locally toposorted
|
||||
for u in reversed(lst):
|
||||
in_degree[u] = 0
|
||||
for s in u.src:
|
||||
if s in in_this_block:
|
||||
local_children[s].append(u)
|
||||
in_degree[u] += 1
|
||||
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
|
||||
priority = [0] + [priorities[x] for x in local_children[u]]
|
||||
if u.op is Ops.LOAD: priority.append(-1000)
|
||||
if u.op is Ops.BARRIER: priority.append(-1500)
|
||||
priorities[u] = min(priority)
|
||||
|
||||
# number the uops in "ideal" order
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
|
||||
|
||||
# then force then to be toposorted in as close to the ideal order as possible
|
||||
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
|
||||
newlst = []
|
||||
while heap:
|
||||
newlst.append(u:=heapq.heappop(heap)[1])
|
||||
for v in local_children[u]:
|
||||
in_degree[v] -= 1
|
||||
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
|
||||
|
||||
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
|
||||
return newlst
|
||||
|
||||
# ***** basic block *****
|
||||
|
||||
def disp(y:UOp) -> str:
|
||||
if y.op is Ops.IF: return f'IF{id(y)}'
|
||||
if y.op is Ops.RANGE: return str(y.arg)
|
||||
return "<NONE>"
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class BasicBlock:
|
||||
lst: tuple[UOp, ...]
|
||||
ctx: tuple[UOp, ...] = ()
|
||||
end: UOp|None = None
|
||||
cnt: int = 0
|
||||
child_ctx: tuple[UOp, ...]|None = None
|
||||
def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks")
|
||||
def __repr__(self):
|
||||
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
|
||||
f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
|
||||
f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
||||
def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx
|
||||
|
||||
def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize))
|
||||
|
||||
# ***** block context *****
|
||||
|
||||
@dataclass
|
||||
class BlockContext:
|
||||
child_count: dict[UOp, int]
|
||||
block_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
child_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u])
|
||||
@staticmethod
|
||||
def from_sink(sink:UOp) -> BlockContext:
|
||||
# get children and all block contexts
|
||||
ctx = BlockContext({}, {}, {})
|
||||
for u in sink.toposort(gate=lambda u:u.op is not Ops.SPECIAL):
|
||||
this_block_ctx: list[UOp] = []
|
||||
ctx.child_count[u] = 0
|
||||
|
||||
# get children and accumulate the last_ctx
|
||||
for s in u.src:
|
||||
if s.op is Ops.SPECIAL: continue
|
||||
# NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
|
||||
ctx.child_count[s] += 1
|
||||
this_block_ctx += ctx.last_ctx(s)
|
||||
|
||||
# save the block ctx. SINK never has anything
|
||||
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
|
||||
|
||||
# RANGE/IF add to the next ctx
|
||||
# STORE/ASSIGN subtract from the next ctx
|
||||
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
|
||||
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
|
||||
return ctx
|
||||
|
||||
# ***** make blocks *****
|
||||
|
||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
|
||||
|
||||
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
|
||||
ends_to_add = [z for z in new_ctx if z not in current_ctx]
|
||||
while len(ends_to_add):
|
||||
r:UOp = ends_to_add.pop(-1)
|
||||
new_ctx = tuple([z for z in new_ctx if z is not r])
|
||||
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
|
||||
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
|
||||
return base_block
|
||||
|
||||
def make_block_bottom_up(ctx:BlockContext, x:UOp):
|
||||
if x.op is Ops.BLOCKSTART:
|
||||
current_ctx, child_ctx = x.arg
|
||||
lst = list(x.src)
|
||||
child_count = 1
|
||||
else:
|
||||
current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None)
|
||||
lst = [x]
|
||||
|
||||
# count of times we've seen this block, or a seed for a new block if we can't merge it
|
||||
unmergable: defaultdict[UOp, int] = defaultdict(int)
|
||||
blockseeds = defaultdict(list)
|
||||
|
||||
# add the srcs of this to the frontier
|
||||
# NOTE: things may be in here multiple times, that's okay
|
||||
frontier_nodes = list(flatten(y.src[::-1] for y in lst))
|
||||
while len(frontier_nodes):
|
||||
u = frontier_nodes.pop(0)
|
||||
if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1:
|
||||
# count is correct
|
||||
if (newctx:=ctx.block_ctxs[u]) == current_ctx:
|
||||
# block has same context, merge it, and put the srcs on the frontier
|
||||
lst.append(u)
|
||||
frontier_nodes.extend(u.src[::-1])
|
||||
else:
|
||||
# block has different context, add it to blockseeds
|
||||
blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u)
|
||||
del unmergable[u]
|
||||
else:
|
||||
# count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
|
||||
unmergable[u] += 1
|
||||
|
||||
# add unmergables to sources
|
||||
srcs = []
|
||||
for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs.get(u,()), current_ctx, cnt=cnt)]*cnt
|
||||
|
||||
# add blockseeds, with blockends as needed
|
||||
for (new_ctx, new_child_ctx), v in blockseeds.items():
|
||||
base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx))
|
||||
srcs.append(add_blockends(base_block, new_ctx, current_ctx))
|
||||
|
||||
lst = lst[::-1]
|
||||
if BLOCK_REORDER: lst = block_reorder(lst)
|
||||
bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
|
||||
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
|
||||
|
||||
# we prevent the source of the SPECIAL from being linearized since its not part of the kernel
|
||||
def raise_bottom_up_gate(): raise BottomUpGate()
|
||||
|
||||
block_create = PatternMatcher([
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
|
||||
(UPat(Ops.SPECIAL), raise_bottom_up_gate)
|
||||
])
|
||||
|
||||
# ***** blockend merging ****
|
||||
|
||||
def merge_blockends(sink:UOp) -> UOp|None:
|
||||
# only run on the final BLOCK with the SINK in it
|
||||
if sink.arg.lst[-1].op is not Ops.SINK: return None
|
||||
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
|
||||
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
||||
for be in sink.toposort():
|
||||
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
||||
new_forks = {}
|
||||
for k,v in blockends_to_arg.items():
|
||||
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
||||
if len(v) > 1:
|
||||
bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
|
||||
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
|
||||
# NOTE: bb.ctx != u.arg.ctx can cause problems here
|
||||
for u in v: new_forks[u] = out
|
||||
if len(new_forks) == 0: return None
|
||||
return sink.substitute(new_forks)
|
||||
|
||||
pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
|
||||
|
||||
# ***** block merging ****
|
||||
|
||||
def merge_block(x:UOp):
|
||||
unmergable_blocks, mergable_blocks = [], []
|
||||
mergable_dict: defaultdict[UOp, int] = defaultdict(int)
|
||||
for y in x.src:
|
||||
if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1
|
||||
elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1
|
||||
else: unmergable_blocks.append(y)
|
||||
for k,v in mergable_dict.items():
|
||||
if v == k.arg.cnt: mergable_blocks.append(k)
|
||||
else: unmergable_blocks.extend([k]*v)
|
||||
if len(mergable_blocks) == 0: return None
|
||||
del mergable_dict
|
||||
|
||||
# create the block
|
||||
arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst)
|
||||
return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg)
|
||||
|
||||
def remove_blockend(x:UOp):
|
||||
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
|
||||
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
|
||||
|
||||
if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]):
|
||||
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
|
||||
parent_block = parent_blocks[0]
|
||||
assert len(parent_blocks) == parent_block.arg.cnt
|
||||
# NOTE: DEFINE_ACC doesn't have to be handled in any special way
|
||||
late_ops = list(x.arg.lst)
|
||||
# NOTE: we have to add a barrier at the start if barrier is used in the range
|
||||
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
|
||||
late_ops = [UOp(Ops.BARRIER)] + late_ops
|
||||
# peephole opt, remove any BARRIERs next to each other
|
||||
for i in range(len(late_ops)-1):
|
||||
if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP)
|
||||
arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
|
||||
return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
|
||||
# else the whole context ended by the blockend is already in this block and we can safely turn it into a block
|
||||
return UOp(Ops.BLOCK, src=x.src, arg=BasicBlock(x.arg.lst, tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt))
|
||||
|
||||
block_merge = PatternMatcher([
|
||||
(UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block),
|
||||
(UPat(Ops.BLOCKEND, name="x"), remove_blockend),
|
||||
])
|
||||
|
||||
# ****** finalize ******
|
||||
|
||||
def finalize(sink:UOp) -> UOp:
|
||||
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
|
||||
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
|
||||
|
||||
# place the early things
|
||||
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
||||
return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst)))
|
||||
|
||||
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
|
||||
86
tinygrad/codegen/lowerer.py
Normal file
86
tinygrad/codegen/lowerer.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve
|
||||
|
||||
# ***** indexing *****
|
||||
|
||||
@dataclass
|
||||
class IndexContext:
|
||||
axis_types: tuple[AxisType, ...]
|
||||
idxs: list[UOp]
|
||||
start: int = 0
|
||||
|
||||
def shape_to_idx(s, axis_types, start=0):
|
||||
return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
|
||||
|
||||
def get_index(ast:UOp) -> IndexContext:
|
||||
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
||||
if len(ast.full_shape) != len(axis_types) and ast.st is not None:
|
||||
axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
|
||||
return IndexContext(axis_types, [], 0)
|
||||
|
||||
# ***** lowering (given index) *****
|
||||
|
||||
def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp):
|
||||
lc = IndexContext(ctx.axis_types, full_new_idx, ctx.start+1000)
|
||||
ctx.start = lc.start
|
||||
return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True)
|
||||
|
||||
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
||||
full_new_idx = list(ctx.idxs)
|
||||
for a in x.axis_arg: full_new_idx[a] = new_idxs[a]
|
||||
ret = subblock(ctx, full_new_idx, x.src[0])
|
||||
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple([full_new_idx[i] for i in x.axis_arg]), x.arg[0])
|
||||
|
||||
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
# TODO: reenable after REDUCE_AXIS is fixed
|
||||
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
|
||||
|
||||
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
||||
idx = x.st_arg.to_valid_uop(new_idxs)
|
||||
used_idxs = [x for x in idx.toposort() if x in new_idxs]
|
||||
real_new_idxs = []
|
||||
for i in range(len(x.src[0].shape)):
|
||||
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
|
||||
else: real_new_idxs.append(ctx.idxs[i])
|
||||
|
||||
stored = subblock(ctx, real_new_idxs, x.src[1])
|
||||
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
|
||||
return buf.index(idx).store(stored, *used_ranges)
|
||||
|
||||
def fixup_wmma(ctx:IndexContext, x:UOp):
|
||||
if x.tag is not None: return None
|
||||
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
||||
full_new_idx = list(ctx.idxs)
|
||||
for a in x.arg[-1]: full_new_idx[a] = new_idxs[a]
|
||||
|
||||
srcs = subblock(ctx, full_new_idx, UOp.sink(*x.src)).src
|
||||
|
||||
# NOTE: this assumes these are expanded. which now shouldn't change anything
|
||||
new_x_arg_m2 = tuple([tuple([(full_new_idx[a].arg[0], sz) for a,sz in v]) for v in x.arg[-2]])
|
||||
new_x_arg_m1 = tuple([full_new_idx[a].arg[0] for a in x.arg[-1]])
|
||||
return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1)
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
# TODO: remove these hacks
|
||||
# hack for old style CONST(VIEW) (now it's just VIEW(CONST))
|
||||
(UPat((Ops.DEFINE_VAR, Ops.CONST), src=(UPat(Ops.VIEW, name="v"),), name="c"), lambda c,v: c.replace(src=()).view(v.arg)),
|
||||
# hack for old style VALID (now it's just VIEW(CONST))
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
|
||||
|
||||
# consts and loads
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"),
|
||||
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_valid_uop(ctx.idxs).get_valid().where(c, c.const_like(0))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"),
|
||||
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(x.st_arg.to_valid_uop(ctx.idxs)),)+x.src[1:])),
|
||||
|
||||
# reduce/view_const
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(Ops.STORE, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_store),
|
||||
(UPat(Ops.WMMA, name="x"), fixup_wmma),
|
||||
|
||||
# axis fixups for WMMA
|
||||
(UPat((Ops.CONTRACT, Ops.UNROLL), name="x"),
|
||||
lambda ctx,x: x.replace(tag=1, arg=tuple([(ctx.idxs[a].arg[0], sz) for a,sz in x.arg])) if x.tag is None else None),
|
||||
])
|
||||
26
tinygrad/codegen/opt/__init__.py
Normal file
26
tinygrad/codegen/opt/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# opt opinionatedly transforms an ast into an optimized ast using either heuristics or beam search
|
||||
from __future__ import annotations
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import AxisType
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Opt:
|
||||
op: OptOps
|
||||
axis: int|None = None
|
||||
arg: int|tuple|None = None
|
||||
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
||||
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
|
||||
class KernelOptError(Exception): pass
|
||||
def check(cond:bool, msg:str=""):
|
||||
if not cond: raise KernelOptError(msg)
|
||||
188
tinygrad/codegen/opt/heuristic.py
Normal file
188
tinygrad/codegen/opt/heuristic.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import itertools
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, AMX
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.uop.ops import Ops, resolve, AxisType
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
|
||||
def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
# first try the tensor cores
|
||||
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
||||
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
||||
|
||||
Keyword arguments:
|
||||
use_tensor_cores -- controls how tensor cores are applied (default 1)
|
||||
0: will disable any tensor core matching
|
||||
1: enable tensor cores
|
||||
2: apply tensor core shape but don't use UOp.WMMA
|
||||
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
||||
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
|
||||
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
|
||||
[0-N]: uses only the n'th tensor core available; useful for search
|
||||
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
||||
0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
|
||||
1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
|
||||
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
||||
"""
|
||||
# NOTE: unless TC_OPT is > 0, we only trigger tensor cores if there's only one reduce axis
|
||||
if USE_TC > 0 and (len(k.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (TC_OPT.value >= 1)):
|
||||
good_tc_opt = False
|
||||
try: # check TC first and apply hand-coded opts if successful
|
||||
tk = k.copy()
|
||||
rngs = tk.apply_opt(Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, USE_TC.value)))
|
||||
good_tc_opt = True
|
||||
except KernelOptError:
|
||||
pass
|
||||
if good_tc_opt:
|
||||
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
||||
if rngs is not None and not AMX:
|
||||
for tc_dim in [1,0]: # attempt to upcast M and N
|
||||
szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]
|
||||
if szs:
|
||||
# set it to the replaced range
|
||||
rngs[tc_dim] = tk.apply_opt(Opt(OptOps.UPCAST, tk.rngs.index(rngs[tc_dim]), szs[0]))[0]
|
||||
if (szs := [sz for sz in [4,2] if rngs[0].src[0].divides(sz) is not None]): # attempt to local N
|
||||
tk.apply_opt(Opt(OptOps.LOCAL, tk.rngs.index(rngs[0]), szs[0]))
|
||||
return tk
|
||||
|
||||
# make a copy so it does not mutate the input
|
||||
k = k.copy()
|
||||
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \
|
||||
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||||
idx0, idx1 = mulop.src[0].src[0].src[1].get_idx(), mulop.src[1].src[0].src[1].get_idx()
|
||||
first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0]
|
||||
if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges):
|
||||
for global_idx in k.axes_of(AxisType.GLOBAL):
|
||||
if first_reduce_rng.src[0].divides(MV_THREADS_PER_ROW) is not None and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3:
|
||||
print(f"MATVEC: {k.full_shape=} {first_reduce_rng.render()} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
||||
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
return k
|
||||
|
||||
# are we grouping? (requires local shape support)
|
||||
if resolve(prod(k.output_shape[i] for i in k.upcastable_dims) <= 2048, False):
|
||||
for sz in [16]:
|
||||
try:
|
||||
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
|
||||
break
|
||||
except KernelOptError: pass
|
||||
|
||||
# upcast float4 images
|
||||
for buf_index,buf in enumerate(k.bufs):
|
||||
if isinstance(buf.src[0].dtype, ImageDType):
|
||||
# part of real_strides
|
||||
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
|
||||
c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
|
||||
if len(unit_stride_axes_mul_4):
|
||||
if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
|
||||
elif axis in k.unrollable_dims:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4))
|
||||
|
||||
# no more opt if we are grouping
|
||||
if k.group_for_reduces: return k
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
to_upcast: list[int] = []
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in k.upcastable_dims:
|
||||
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
|
||||
is_masked = any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE)
|
||||
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
is_dsp = k.opts is not None and k.opts.device == "DSP"
|
||||
upcasted_axis: set[int] = set()
|
||||
while resolve(prod(k.output_shape[i] for i in k.upcastable_dims) >= 1024):
|
||||
xb_choices = []
|
||||
# consider all upcastable axes with 3 or 4 upcast (128 on the DSP)
|
||||
for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
|
||||
# if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
|
||||
rng = k.rngs[axis]
|
||||
if any(rng not in b.src[1].get_idx().parents and all(r2 in b.src[1].get_idx().parents
|
||||
for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs):
|
||||
num_strides, sum_strides = 0, 0
|
||||
for b in k.bufs:
|
||||
idx = b.src[1].get_idx()
|
||||
if rng in idx.parents: num_strides += 1
|
||||
for c in idx.split_uop(Ops.ADD):
|
||||
if c is rng: sum_strides += 1
|
||||
if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg
|
||||
xb_choices.append((num_strides, sum_strides, axis, upcast_amount))
|
||||
if xb_choices:
|
||||
xb_choices = sorted(xb_choices)
|
||||
if DEBUG >= 4: print(f"more upcast axis : {xb_choices}")
|
||||
k.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
|
||||
upcasted_axis.add(xb_choices[0][2])
|
||||
else: break
|
||||
|
||||
# if last reduce dim is small(ish), loop unroll the reduce
|
||||
# NOTE: this can fail on multireduce with mismatching dimensions, this is okay
|
||||
try:
|
||||
upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
||||
if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64):
|
||||
if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if k.unrollable_dims and s <= 3 and k.full_shape[k.unrollable_dims[-1]] <= 3:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0))
|
||||
else:
|
||||
for splits in [4]:
|
||||
if k.full_shape[axis:=k.unrollable_dims[-1]]%splits == 0:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, splits))
|
||||
break
|
||||
except KernelOptError: pass
|
||||
|
||||
# if nothing at all is upcasted and it's easy to, do an upcast
|
||||
for splits in [4]:
|
||||
# TODO: somehow this never hits a reduce
|
||||
if not k.upcasted and k.upcastable_dims and k.full_shape[k.upcastable_dims[-1]] % splits == 0:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, k.upcastable_dims[-1], splits))
|
||||
|
||||
# **** local groups ****
|
||||
|
||||
if k.opts.has_local:
|
||||
if NOLOCALS:
|
||||
k.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].get_idx().parents for b in k.bufs), axis) \
|
||||
for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST]
|
||||
to_local: list[tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
local_sz: int|None = next((x for x in ([32] * (axis == 0) + [16,8,4,3,2]) if k.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
||||
if local_sz is not None: to_local.append((axis, local_sz))
|
||||
deleted_shape = 0
|
||||
for axis, local_sz in sorted(to_local[:3]):
|
||||
axis = axis - deleted_shape
|
||||
will_delete_shape = local_sz == k.full_shape[axis]
|
||||
k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
||||
if will_delete_shape: deleted_shape += 1
|
||||
|
||||
# **** threading ****
|
||||
|
||||
if k.opts.has_threads and k.opts.global_max is not None:
|
||||
for threads in [32,16,12,8,6,5,4,3,2]:
|
||||
# Skip is too many threads. Heuristic: use about 128K ops per thread
|
||||
if threads > k.opts.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue
|
||||
for axis in k.axes_of(AxisType.LOOP):
|
||||
if k.full_shape[axis] % threads == 0:
|
||||
k.apply_opt(Opt(OptOps.THREAD, axis, threads))
|
||||
break
|
||||
if k.applied_opts and k.applied_opts[-1].op is OptOps.THREAD: break
|
||||
|
||||
return k
|
||||
334
tinygrad/codegen/opt/postrange.py
Normal file
334
tinygrad/codegen/opt/postrange.py
Normal file
@@ -0,0 +1,334 @@
|
||||
from __future__ import annotations
|
||||
import math, itertools
|
||||
from collections import defaultdict
|
||||
from typing import cast, Final
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod
|
||||
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, ast:UOp, opts:Renderer):
|
||||
self.ast, self.opts = ast, opts
|
||||
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
|
||||
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
|
||||
|
||||
@property
|
||||
def rngs(self):
|
||||
# always in order by axistype
|
||||
return sorted([u for u in self.ast.parents if u.op is Ops.RANGE and u.vmax > 0], key=lambda x: (axis_to_pos[x.arg[-1]],) + x.arg[0:-1])
|
||||
@property
|
||||
def shape_len(self): return len(self.rngs)
|
||||
@property
|
||||
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
|
||||
@property
|
||||
def axis_types(self): return [x.arg[-1] for x in self.rngs]
|
||||
@property
|
||||
def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0)
|
||||
|
||||
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
|
||||
def shape_str(self) -> list[str]:
|
||||
ret: list[str] = []
|
||||
cnt: dict[AxisType, int] = {}
|
||||
for x in self.axis_types:
|
||||
cnt[x] = (cnt[x] + 1) if x in cnt else 0
|
||||
ret.append(f"{axis_letters[x]}{cnt[x]}")
|
||||
return ret
|
||||
def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
|
||||
|
||||
def copy(self):
|
||||
ret = Scheduler(self.ast, self.opts)
|
||||
ret.dont_use_locals = self.dont_use_locals
|
||||
ret.applied_opts = self.applied_opts[:]
|
||||
return ret
|
||||
|
||||
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
|
||||
def get_optimized_ast(self, name_override:str|None=None):
|
||||
if name_override is not None: name = name_override
|
||||
else:
|
||||
kernel_type = "r" if self.reduceop is not None else "E"
|
||||
name = kernel_type + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), color) for x,color in zip(self.rngs, self.colors())])
|
||||
Scheduler.kernel_cnt[(function_name := to_function_name(name))] += 1
|
||||
num = f"n{Scheduler.kernel_cnt[function_name]-1}" if Scheduler.kernel_cnt[function_name] > 1 else ""
|
||||
name += colored(num, 'BLACK')
|
||||
self.ast = graph_rewrite(self.ast, pm_flatten_range, name="flatten range")
|
||||
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
|
||||
|
||||
def _globalizable_rngs(self) -> list[UOp]:
|
||||
store_rngs = self.ast.src[0].src[2:]
|
||||
|
||||
# filter any not in local stores
|
||||
local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \
|
||||
or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)]
|
||||
for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
|
||||
|
||||
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else []
|
||||
|
||||
def convert_loop_to_global(self):
|
||||
if not self.opts.has_local: return None
|
||||
|
||||
globalizible_rngs = self._globalizable_rngs()
|
||||
rng = [x.replace(arg=(x.arg[0], AxisType.GLOBAL)) if x in globalizible_rngs else x for x in self.rngs]
|
||||
|
||||
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
|
||||
|
||||
def colors(self) -> list[str]: return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types]
|
||||
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
|
||||
|
||||
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
|
||||
if (old_sz:=rng.src[0].divides(amount)) is None:
|
||||
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
|
||||
new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng
|
||||
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
|
||||
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
|
||||
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[0]} {amount} {str(new_type).split('.')[1].lower()}")
|
||||
return replaced_rng, new_rng
|
||||
|
||||
def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type]
|
||||
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in axis_type]
|
||||
|
||||
# copied from kernel.py
|
||||
@property
|
||||
def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
|
||||
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
||||
@property
|
||||
def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
|
||||
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
||||
|
||||
def real_axis(self, op:OptOps, axis:int|None):
|
||||
try:
|
||||
if axis is None or op is OptOps.TC: return -1
|
||||
if op is OptOps.UNROLL: return self.unrollable_dims[axis]
|
||||
if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
|
||||
check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}")
|
||||
return axis
|
||||
except IndexError as e: raise KernelOptError from e
|
||||
|
||||
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
||||
if opt.op is OptOps.NOLOCALS:
|
||||
check(all(x not in {AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE} for x in self.axis_types), "no locals can't have locals")
|
||||
if append_opt: self.applied_opts.append(opt)
|
||||
self.dont_use_locals = True
|
||||
return
|
||||
|
||||
if opt.op in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}:
|
||||
check(self.opts.has_local, "locals needed for opt")
|
||||
|
||||
rng = self.rngs[real_axis] if (real_axis:=self.real_axis(opt.op, opt.axis)) >= 0 else UOp(Ops.NOOP)
|
||||
|
||||
opt_to_at = {
|
||||
OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST,
|
||||
OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE,
|
||||
OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD}
|
||||
|
||||
ret = None
|
||||
if opt.op in opt_to_at:
|
||||
amt:int = int(rng.vmax+1) if opt.arg == 0 else cast(int, opt.arg)
|
||||
|
||||
# copied from kernel.py. prevents METAL compiler hangs
|
||||
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
||||
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
||||
upcast_local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)])
|
||||
smem_sz = amt*upcast_local_sz*self.reduceop.dtype.itemsize
|
||||
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
||||
|
||||
if opt.op is OptOps.UNROLL:
|
||||
check(amt <= 32, "don't unroll more than 32")
|
||||
check(rng.arg[-1] in {AxisType.GROUP_REDUCE, AxisType.REDUCE}, "unroll is for GROUP_REDUCE/REDUCE")
|
||||
if opt.op is OptOps.UPCAST:
|
||||
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
|
||||
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP}, f"upcast is for GLOBAL/LOCAL/LOOP, not {rng.arg[-1]}")
|
||||
if opt.op is OptOps.LOCAL:
|
||||
check(not self.dont_use_locals, "can't use locals")
|
||||
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals")
|
||||
if opt.op is OptOps.THREAD:
|
||||
check(self.opts is not None and self.opts.has_threads, "target does not support threads")
|
||||
check(self.opts is not None and self.opts.global_max is not None and amt <= self.opts.global_max[0], "too many threads")
|
||||
check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded")
|
||||
check(rng in self._globalizable_rngs(), "can't apply range to this dim")
|
||||
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}:
|
||||
check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong?
|
||||
check(not self.dont_use_locals, "can't use locals")
|
||||
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
|
||||
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
|
||||
elif opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
|
||||
check(opt.axis is not None, "tensor core opts must have an axis")
|
||||
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
|
||||
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
||||
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
||||
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
|
||||
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt)
|
||||
except ValueError as e: raise KernelOptError(str(e))
|
||||
check(ret is not None, "no tensor core available")
|
||||
elif opt.op is OptOps.PADTO:
|
||||
check(rng.src[0].op is Ops.CONST, "only pad const axes")
|
||||
check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong?
|
||||
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
|
||||
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
|
||||
new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg))
|
||||
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
|
||||
replaced_rng = UOp.range(new_sz, *rng.arg)
|
||||
replaces = {rng:replaced_rng}
|
||||
valid = replaced_rng < rng.vmax+1
|
||||
for b in self.bufs:
|
||||
if rng in (i:=b.src[1].get_idx()).sparents:
|
||||
replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
|
||||
self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}")
|
||||
elif opt.op is OptOps.SWAP:
|
||||
try:
|
||||
altrng = self.rngs[opt.arg]
|
||||
except IndexError:
|
||||
raise KernelOptError
|
||||
check(rng.arg[-1] == AxisType.GLOBAL and altrng.arg[-1] == AxisType.GLOBAL, "swap only for globals")
|
||||
self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1),
|
||||
altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)})
|
||||
self.ast = graph_rewrite(self.ast, remove_tags)
|
||||
else:
|
||||
raise KernelOptError(f"unsupported opt {opt.op}")
|
||||
|
||||
if append_opt: self.applied_opts.append(opt)
|
||||
return ret
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]:
|
||||
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
|
||||
if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore")
|
||||
reduceop = reduceops[0]
|
||||
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
|
||||
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
|
||||
if mul.op is not Ops.MUL: return None
|
||||
in0, in1 = mul.src
|
||||
try:
|
||||
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
|
||||
except IndexError:
|
||||
raise KernelOptError(f"invalid tensor core choice {tc_select}")
|
||||
for tc in tensor_cores:
|
||||
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
|
||||
# tensor cores have three ranges. X, Y, and REDUCE
|
||||
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0])
|
||||
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
|
||||
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
|
||||
if DEBUG >= 3:
|
||||
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
||||
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
||||
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue
|
||||
|
||||
# pick ranges
|
||||
# NOTE: why are in1 and in0 switched?
|
||||
axis_choices = list(itertools.product(in1_ranges, in0_ranges, red_ranges))
|
||||
if not (axis < len(axis_choices)): continue
|
||||
axes = list(axis_choices[axis])
|
||||
|
||||
# do optimizations and save the ranges
|
||||
try:
|
||||
for i,a in enumerate(axes):
|
||||
idx = self.rngs.index(a)
|
||||
if (a.vmax+1) % tc.dims[i] != 0:
|
||||
if opt_level < 2: raise KernelOptError("tc padding requires opt_level >= 2")
|
||||
# apply_opt should return the updated range?
|
||||
self.apply_opt(Opt(OptOps.PADTO, idx, tc.dims[i]), append_opt=False) # PADTO might fail
|
||||
axes[i] = self.rngs[idx]
|
||||
except KernelOptError: continue
|
||||
|
||||
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
|
||||
warp = UOp.range(tc.threads, -1, AxisType.WARP)
|
||||
ne: list[UOp] = []
|
||||
for opt in tc.opts:
|
||||
if opt[0] == "l":
|
||||
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2)
|
||||
warp //= 2
|
||||
elif opt[0] == "u":
|
||||
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
|
||||
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
|
||||
ne.append(new_range)
|
||||
|
||||
for _, amt in tc.get_reduce_axes():
|
||||
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
|
||||
ne.append(new_range)
|
||||
|
||||
if use_tensor_cores != 2:
|
||||
# fix the srcs
|
||||
reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0]
|
||||
tne = [x.replace(tag=1) for x in ne]
|
||||
ret = reduceop.substitute(dict(zip(ne, tne)))
|
||||
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
||||
srcs = [x.substitute(dict(zip(tne, [ne[i] for i in argsort(p)]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))]
|
||||
|
||||
# get reduce/upcast axes for the tensor cores
|
||||
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
|
||||
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
|
||||
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
|
||||
|
||||
# axes to range number (was done in lowerer)
|
||||
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
|
||||
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
|
||||
|
||||
# construct the op
|
||||
# TODO: remove tc_upcast_axes from the arg
|
||||
# do the reduce_axes always disappear? i think they don't
|
||||
# they need to be moved into the WMMA srcs
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, ()) #, tc_reduce_axes)
|
||||
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
||||
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0], tag=1),
|
||||
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1], tag=1),
|
||||
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg, tag=1)
|
||||
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2], tag=1)
|
||||
|
||||
# preserve extra reduces
|
||||
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
|
||||
if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD)
|
||||
self.ast = self.ast.substitute({reduceop: tc_uop})
|
||||
return axes
|
||||
return None
|
||||
|
||||
# helpers for hand_coded_optimizations
|
||||
@property
|
||||
def reduceop(self) -> UOp|None:
|
||||
red = [x for x in self.ast.parents if x.op is Ops.REDUCE]
|
||||
if not len(red): return None
|
||||
return UOp(Ops.REDUCE_AXIS, red[0].dtype, red[0].src, (red[0].arg, ()))
|
||||
@property
|
||||
def bufs(self) -> list[UOp]: return [x for x in self.ast.toposort() if x.op is Ops.INDEX][::-1]
|
||||
@property
|
||||
def output_shape(self):
|
||||
return [s if at not in {AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE} else 1 for s,at in zip(self.full_shape, self.axis_types)]
|
||||
@property
|
||||
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
||||
@property
|
||||
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
|
||||
|
||||
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
|
||||
glbls = sorted([x for x in ast.parents if x.op is Ops.DEFINE_GLOBAL], key=lambda x: x.arg)
|
||||
return [Buffer(dname, x.ptrdtype.size, x.dtype.base if not isinstance(x.dtype, ImageDType) else x.dtype) for x in glbls]
|
||||
|
||||
def apply_opts(ctx:Renderer, ast:UOp):
|
||||
if ast.tag is not None: return None
|
||||
k = Scheduler(ast, ctx)
|
||||
k.convert_loop_to_global()
|
||||
if ast.arg is not None and ast.arg.opts_to_apply is not None:
|
||||
for opt in ast.arg.opts_to_apply: k.apply_opt(opt)
|
||||
elif BEAM >= 1:
|
||||
from tinygrad.codegen.opt.search import beam_search
|
||||
rawbufs = bufs_from_ast(ast, ctx.device)
|
||||
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
|
||||
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
||||
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
|
||||
if all(len(u.src) == 1 for u in ast.parents if u.op is Ops.LOAD):
|
||||
k = hand_coded_optimizations(k)
|
||||
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
|
||||
|
||||
pm_postrange_opt = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="ast"), apply_opts),
|
||||
])
|
||||
183
tinygrad/codegen/opt/search.py
Normal file
183
tinygrad/codegen/opt/search.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from typing import cast
|
||||
import functools, math, time, multiprocessing, traceback, signal, atexit
|
||||
from dataclasses import replace
|
||||
from tinygrad.uop.ops import sym_infer, AxisType, pyrender
|
||||
from tinygrad.device import Device, Buffer, Compiler
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
|
||||
from tinygrad.helpers import IGNORE_BEAM_CACHE
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
|
||||
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
|
||||
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
||||
actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
|
||||
if getenv("BEAM_PADTO", 0): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
|
||||
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0, getenv("TC", 1)))]
|
||||
# covers resnet kernels (3 global * 3 reduce)
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("TC", 1))) for axis in range(9)]
|
||||
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
|
||||
actions += [Opt(op=OptOps.THREAD, axis=axis, arg=amt) for amt in [2,3,4,5,8,12,16,24,32,64] for axis in range(3)]
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
def get_test_global_size(global_size, max_global_size, var_vals):
|
||||
test_global_size = [sym_infer(sz, var_vals) for sz in global_size]
|
||||
input_size = prod(test_global_size)
|
||||
while prod(test_global_size) > max_global_size:
|
||||
for j in range(len(global_size)-1,-1,-1):
|
||||
if test_global_size[j] > 16:
|
||||
test_global_size[j] //= 2
|
||||
break
|
||||
return test_global_size, input_size / prod(test_global_size)
|
||||
|
||||
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
|
||||
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
|
||||
factor = 1
|
||||
if allow_test_size and p.global_size is not None and max_global_size is not None:
|
||||
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
|
||||
p = replace(p, global_size=global_size)
|
||||
try: car = CompiledRunner(p, precompiled=lib)
|
||||
except AssertionError: return [math.inf] * cnt
|
||||
tms = []
|
||||
input_bufs = [rawbufs[i] for i in car.p.globals]
|
||||
for _ in range(cnt):
|
||||
if clear_l2:
|
||||
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
|
||||
else:
|
||||
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
||||
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
||||
if early_stop is not None and early_stop < min(tms): break
|
||||
return tms
|
||||
|
||||
class TimeoutException(Exception): pass
|
||||
def timeout_handler(signum, frame):
|
||||
if DEBUG >= 2: print("*** BEAM COMPILE TIMEOUT")
|
||||
raise TimeoutException()
|
||||
|
||||
def _try_compile_linearized_w_idx(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]:
|
||||
if hasattr(signal, "alarm"):
|
||||
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
|
||||
# set timeout
|
||||
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
|
||||
ret = None
|
||||
try:
|
||||
p = get_program(x[1].copy().get_optimized_ast(name_override="test"), x[1].opts)
|
||||
assert p.uops is not None, "uop list wasn't generated?"
|
||||
if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
|
||||
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
|
||||
raise RuntimeError("too many uops")
|
||||
st = time.perf_counter()
|
||||
prog = compiler.compile(p.src)
|
||||
et = time.perf_counter() - st
|
||||
ret = (p, prog, et)
|
||||
except RuntimeError:
|
||||
if DEBUG >= 4: traceback.print_exc()
|
||||
except Exception as e:
|
||||
if getenv("BEAM_STRICT_MODE"): raise e
|
||||
finally:
|
||||
if hasattr(signal, "alarm"): signal.alarm(0)
|
||||
return x[0], ret
|
||||
|
||||
# workers should not open devices and should ignore ctrl c and should not launch VIZ
|
||||
def _init_worker():
|
||||
Context(ALLOW_DEVICE_USAGE=0, VIZ=0, TRACK_MATCH_STATS=0).__enter__()
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() if buf is not None else buf for buf in bufs]
|
||||
|
||||
# *** external API ***
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_kernel_actions(lin:Scheduler, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Scheduler]:
|
||||
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
||||
kernel_actions = (actions if candidates is None else candidates).copy()
|
||||
|
||||
for i,a in enumerate(kernel_actions):
|
||||
if a.axis is not None and a.op is not OptOps.TC:
|
||||
try: ax = lin.real_axis(a.op, a.axis)
|
||||
except KernelOptError: continue
|
||||
if (ax >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, a.axis, 0) in kernel_actions): continue
|
||||
lin2 = lin.copy()
|
||||
try:
|
||||
lin2.apply_opt(a)
|
||||
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if hasattr(lin2, 'tensor_core') and (tc:=lin2.tensor_core) else 1
|
||||
for s,c in zip(lin2.full_shape, lin2.axis_types):
|
||||
if c in (AxisType.UPCAST, AxisType.UNROLL): up *= s
|
||||
elif c in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE): lcl *= s
|
||||
if up//tc_up > max_up or lcl > max_lcl:
|
||||
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}")
|
||||
continue
|
||||
acted_lins[i+1] = lin2
|
||||
except KernelOptError: pass
|
||||
return acted_lins
|
||||
|
||||
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
||||
def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value):
|
||||
global beam_pool
|
||||
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
|
||||
ret = lin.copy()
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
return ret
|
||||
|
||||
beam: list[tuple[Scheduler, float]] = [(lin, float("inf"))]
|
||||
seen_libs = set()
|
||||
|
||||
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0
|
||||
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
||||
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
||||
@atexit.register
|
||||
def close_pool(): beam_pool.close()
|
||||
|
||||
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
|
||||
if BEAM_DEBUG:
|
||||
print("BEAM_SEARCH:")
|
||||
print('\n'.join(pyrender(lin.ast.replace(arg=None))))
|
||||
if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
|
||||
|
||||
try:
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals: dict[str, int] = {k.expr:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[lin.opts.device]
|
||||
while not exiting:
|
||||
acted_lins: list[Scheduler] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
timed_lins: list[tuple[Scheduler, float]] = []
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
||||
least_compute_ops = math.inf
|
||||
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
||||
if proc is None: continue
|
||||
p, lib, compile_et = proc
|
||||
if lib in seen_libs: continue
|
||||
# filter out kernels that use 1000x more compute than the smallest
|
||||
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
|
||||
if least_compute_ops*1000 < this_compute_ops: continue
|
||||
seen_libs.add(lib)
|
||||
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0,
|
||||
allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'))
|
||||
except Exception as e:
|
||||
if BEAM_DEBUG: print(f"BEAM failed for opts: {acted_lins[i].applied_opts}\n{e}")
|
||||
if isinstance(e, RuntimeError): continue
|
||||
raise
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
|
||||
# done
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
|
||||
if not exiting: beam = opts[:amt]
|
||||
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
||||
except KeyboardInterrupt as e:
|
||||
if beam_pool is not None: beam_pool.terminate()
|
||||
raise e
|
||||
|
||||
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
||||
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}")
|
||||
return beam[0][0]
|
||||
135
tinygrad/codegen/opt/swizzler.py
Normal file
135
tinygrad/codegen/opt/swizzler.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
||||
from tinygrad.helpers import all_same, prod, unwrap, colored
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
|
||||
merge_views = PatternMatcher([
|
||||
# merge adjacent views
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
||||
# replace MovementOps with VIEW
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
|
||||
# remove NOOP views
|
||||
(UPat.var("x").view(name="view"),
|
||||
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
||||
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
||||
lambda x,view: x.replace(src=(UOp(Ops.VIEW, x.dtype, x.src, view.arg),)) if all(v.mask is None for v in view.st.views) else None),
|
||||
])
|
||||
|
||||
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
||||
# contiguous, expand, and the same with ones removed
|
||||
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
|
||||
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
|
||||
new_shape: list[sint] = []
|
||||
new_reduce_axis = []
|
||||
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
|
||||
for i,pairs in enumerate(contraction):
|
||||
new_shape_chunk = [view.shape[p] for p in pairs]
|
||||
if i in r.arg[1]:
|
||||
# if this is a reduce axis, we need a 1 in the view here to put it
|
||||
assert len(new_shape_chunk) > 0
|
||||
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
|
||||
new_reduce_axis.append(len(new_shape)-1)
|
||||
else:
|
||||
# otherwise, pass through the new_shape_chunk
|
||||
new_shape += new_shape_chunk
|
||||
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
|
||||
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
|
||||
return ret
|
||||
return None
|
||||
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before elementwise and buffer ops
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
||||
# if there's ones added after reduce, put this before the reduce
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
|
||||
])
|
||||
|
||||
view_left_through_load = PatternMatcher([
|
||||
# view before load
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.LOAD, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
||||
])
|
||||
|
||||
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
|
||||
|
||||
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
||||
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
||||
# contiguous and same size can push to children
|
||||
# if there's a reduce child, shapes match with ones removed
|
||||
if unwrap(view.st).contiguous and view.size == r.size and \
|
||||
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
|
||||
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
|
||||
return None
|
||||
# swizzle the input
|
||||
input_st = ShapeTracker.from_shape(src.shape)
|
||||
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
||||
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
||||
strides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
||||
new_view = tmp + ShapeTracker(tuple(nv))
|
||||
swizzled_input = apply_swizzle(src.view(new_view))
|
||||
# create a new reduceop
|
||||
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
|
||||
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
|
||||
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
|
||||
return red.reshape(view.shape)
|
||||
|
||||
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
||||
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
||||
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
|
||||
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
|
||||
|
||||
def elementwise_view_right(root:UOp):
|
||||
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
|
||||
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
||||
# place view after applying the elementwise op
|
||||
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
|
||||
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
|
||||
# reshape to match downstream shapes
|
||||
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# push a non contiguous ShapeTracker through reduceop
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
||||
# apply view after reduceops
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
# apply view after elementwise ops
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
||||
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||
# remove view from sink
|
||||
(UPat(Ops.VIEW, name="v").sink(name="sink"), lambda v,sink: v.src[0].sink(arg=sink.arg)),
|
||||
])
|
||||
|
||||
def check_load_st(glbl:UOp, view:UOp):
|
||||
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
||||
# otherwise, it's not fine
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
|
||||
fix_kernel_ops = view_left_through_load+PatternMatcher([
|
||||
# add view to LOAD and STORE
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="g").load(), lambda g: g.view(g.st).load()),
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="g").store(UPat.var('x')), lambda g,x: g.view(g.st).store(x)),
|
||||
# VALID
|
||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||
# no ImageDType after index
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW, Ops.INDEX}, name="x"),
|
||||
lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
||||
])
|
||||
136
tinygrad/codegen/opt/tc.py
Normal file
136
tinygrad/codegen/opt/tc.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import math, functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
||||
dims: tuple[int,int,int] # N, M, K
|
||||
threads: int # number of threads that construct the warp
|
||||
elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
|
||||
dtype_in: DType # dtype for A and B
|
||||
dtype_out: DType # dtype for C and D
|
||||
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifying kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
|
||||
# (local_swizzle, upcast_swizzle, reduce_swizzle)
|
||||
# l<num> is the num axis of the locals, similar for u<num> and upcasts, r<num> and reduces
|
||||
swizzle: tuple[tuple[tuple[str, ...], tuple[str, ...], tuple[str, ...]], tuple[tuple[str, ...], tuple[str, ...], tuple[str, ...]]]
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def _remaps(self) -> list[dict[str, str]]:
|
||||
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
|
||||
fwd_st = [f"l{i}" for i in range(local_axes)] + [f"u{i}" for i in range(upcast_axes)] + [f"r{i}" for i in range(reduce_axes)]
|
||||
return [dict(zip(fwd_st, sum(s, ()))) for s in self.swizzle]
|
||||
def permutes_for_shape_str(self, shape_str:list[str]) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
ret = [[shape_str.index(remap[ss]) if ss in remap else i for i,ss in enumerate(shape_str)] for remap in self._remaps()]
|
||||
return tuple(ret[0]), tuple(ret[1])
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def base_shape_str(self) -> list[str]:
|
||||
ret = []
|
||||
cnt = {'u': 0, 'l': 0}
|
||||
for opt in self.opts:
|
||||
ret.append(f"{opt[0]}{cnt[opt[0]]}")
|
||||
cnt[opt[0]] += 1
|
||||
# assumes you do the UNROLL after the opts
|
||||
return ret + [f"r{i}" for i in range(len(self.get_reduce_axes()))]
|
||||
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
|
||||
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
|
||||
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
|
||||
def base_upcast_axes(self):
|
||||
# this is defined in the swizzle. first we use the upcast axes, then the reduce
|
||||
return ([f"r{i}" for i in range(len(self.get_reduce_axes()))] + [f"u{i}" for i in range(len(self.get_upcast_axes()))])[::-1]
|
||||
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
||||
def __post_init__(self):
|
||||
# all axes have size 2, <local> <reduce> <upcast> is the order
|
||||
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
|
||||
assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), \
|
||||
f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})"
|
||||
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
|
||||
assert 2**upcast_axes == self.elements_per_thread[2], \
|
||||
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}"
|
||||
# check dims match opts
|
||||
assert self.dims[0] == 2**len(gd:=[x for x in self.opts if x[1] == '0']), f"opts wrong on dims[0], {self.dims[0]} vs {gd}"
|
||||
assert self.dims[1] == 2**len(gd:=[x for x in self.opts if x[1] == '1']), f"opts wrong on dims[1], {self.dims[1]} vs {gd}"
|
||||
# NOTE: the K opts is implictly set by the dim
|
||||
# check swizzle
|
||||
assert len(self.swizzle[0]) == 3 and len(self.swizzle[1]) == 3, "swizzle has wrong part count"
|
||||
assert len(self.swizzle[0][0]) == len(self.swizzle[1][0]) == local_axes, "local swizzle size is wrong"
|
||||
assert len(self.swizzle[0][1]) == len(self.swizzle[1][1]) == upcast_axes, "upcast swizzle size is wrong"
|
||||
assert len(self.swizzle[0][2]) == len(self.swizzle[1][2]) == reduce_axes, "reduce swizzle size is wrong"
|
||||
assert all(len(s) == local_axes+upcast_axes+reduce_axes for s in self._remaps()), "remaps are the wrong size"
|
||||
# check elements_per_thread
|
||||
un, ln = 0, 0
|
||||
zero_stride_0 = []
|
||||
zero_stride_1 = []
|
||||
for o in self.opts:
|
||||
if o[1] == '0': zero_stride_0.append(o[0] + str(un if o[0] == 'u' else ln))
|
||||
if o[1] == '1': zero_stride_1.append(o[0] + str(un if o[0] == 'u' else ln))
|
||||
if o[0] == 'u': un += 1
|
||||
if o[0] == 'l': ln += 1
|
||||
# NOTE: all the zero_stride dims can be placed in any order in the swizzle
|
||||
upcasted_0 = [x for x in (self.swizzle[0][1] + self.swizzle[0][2]) if x not in zero_stride_0 and x[0] != 'l']
|
||||
upcasted_1 = [x for x in (self.swizzle[1][1] + self.swizzle[1][2]) if x not in zero_stride_1 and x[0] != 'l']
|
||||
assert 2**len(upcasted_0) == self.elements_per_thread[0], f"mismatch in elements_per_thread[0], {upcasted_0} vs {self.elements_per_thread[0]}"
|
||||
assert 2**len(upcasted_1) == self.elements_per_thread[1], f"mismatch in elements_per_thread[1], {upcasted_1} vs {self.elements_per_thread[1]}"
|
||||
|
||||
# ***** NVIDIA *****
|
||||
|
||||
cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
|
||||
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
|
||||
cuda_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
||||
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('u1', 'r3'), ('l0', 'l1', 'u0', 'r0')),
|
||||
(('r1', 'r2', 'u0', 'l0', 'l1'), ('r0', 'r3'), ('l2', 'l3', 'l4', 'u1'))))
|
||||
for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float), (dtypes.half,dtypes.half)]]
|
||||
cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
||||
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('r0', 'u1'), ('l0', 'l1', 'u0')),
|
||||
(('r1', 'r2', 'u0', 'l0', 'l1'), ('u1', 'r0'), ('l2', 'l3', 'l4'))))
|
||||
for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
|
||||
cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
||||
swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')),
|
||||
(('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))]
|
||||
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16
|
||||
if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32
|
||||
cuda_sm75: list[TensorCore] = cuda_8168_f16
|
||||
|
||||
# ***** AMD *****
|
||||
|
||||
# https://gpuopen.com/learn/wmma_on_rdna3/
|
||||
amd_rdna3 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","l1","u1","u1","u1"),
|
||||
swizzle=((('l4', 'u0', 'u1', 'u2', 'l0'), ('r1', 'r2', 'r3'), ('l1', 'l2', 'l3', 'r0')),
|
||||
(('l0', 'l1', 'l2', 'l3', 'l4'), ('r1', 'r2', 'r3'), ('u0', 'u1', 'u2', 'r0'))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float)]]
|
||||
amd_rdna4 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(8,8,8), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","u1","u1","u1","l1"),
|
||||
swizzle=((('u0', 'u1', 'u2', 'l4', 'r2'), ('r0', 'r1', 'r3'), ('l0', 'l1', 'l2', 'l3')),
|
||||
(('l0', 'l1', 'l2', 'l3', 'r2'), ('r0', 'r1', 'r3'), ('l4', 'u0', 'u1', 'u2'))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
||||
|
||||
# https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme
|
||||
amd_cdna = [TensorCore(dims=(16,16,16), threads=64, elements_per_thread=(4,4,4), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","u1","u1","l1","l1"),
|
||||
swizzle=((('u0', 'u1', 'l4', 'l5', 'r2', 'r3'), ('r0', 'r1'), ('l0', 'l1', 'l2', 'l3')),
|
||||
(('l0', 'l1', 'l2', 'l3', 'r2', 'r3'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1'))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
# ***** Apple Metal *****
|
||||
|
||||
metal = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do,
|
||||
opts=("u0","l0","l1","l1","l0","l1"),
|
||||
swizzle=((('r1', 'l1', 'l2', 'r2', 'l4'), ('r0',), ('u0', 'l0', 'l3')),
|
||||
(('l0', 'r0', 'r1', 'l3', 'r2'), ('u0',), ('l1', 'l2', 'l4'))))
|
||||
for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
|
||||
(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
||||
|
||||
# ***** Apple AMX *****
|
||||
|
||||
amx = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
|
||||
swizzle=(((), ('u0', 'u1', 'u2', 'u3', 'u4', 'u5', 'u6', 'u7'), ()),
|
||||
((), ('u4', 'u5', 'u6', 'u7', 'u0', 'u1', 'u2', 'u3'), ())),
|
||||
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
||||
|
||||
# ***** Intel ****
|
||||
|
||||
intel = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
|
||||
opts=("l0","l0","l0","u1","u1","u1"),
|
||||
swizzle=((('r1', 'r2', 'r3'), ('u0', 'u1', 'u2'), ('l0', 'l1', 'l2', 'r0')),
|
||||
(('l0', 'l1', 'l2'), ('r1', 'r2', 'r3'), ('u0', 'u1', 'u2', 'r0'))))]
|
||||
67
tinygrad/codegen/quantize.py
Normal file
67
tinygrad/codegen/quantize.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from tinygrad.dtype import dtypes, least_upper_dtype
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
|
||||
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
|
||||
# this is badly tested and low quality. remove it?
|
||||
|
||||
FP = (1 << 15)
|
||||
pm_quant = symbolic+PatternMatcher([
|
||||
# cast after add/mul
|
||||
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
|
||||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
|
||||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
|
||||
# masked MUL after masked ADD
|
||||
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
|
||||
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
|
||||
|
||||
# MUL after reduce
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c.arg),
|
||||
# CAST after reduce (doesn't work if it's a size change)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
|
||||
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
|
||||
|
||||
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
|
||||
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
|
||||
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
|
||||
# mul 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
||||
# mul (with plus) 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
|
||||
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
lambda ld,v,c1: ld*c1),
|
||||
|
||||
# const push through add
|
||||
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
|
||||
|
||||
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
|
||||
lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
||||
# fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
|
||||
lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
||||
|
||||
# where move
|
||||
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
|
||||
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
|
||||
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
|
||||
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
|
||||
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
|
||||
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
|
||||
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
|
||||
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
|
||||
|
||||
# where on two adds
|
||||
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
|
||||
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
|
||||
|
||||
# split REDUCE into multiple reduces (who remembers FOIL?)
|
||||
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2"),), name="r"),
|
||||
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
|
||||
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
|
||||
])
|
||||
120
tinygrad/codegen/simplify.py
Normal file
120
tinygrad/codegen/simplify.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start
|
||||
from tinygrad.uop.symbolic import symbolic_flat, sym
|
||||
from tinygrad.helpers import partition
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
def flatten_range(r:UOp):
|
||||
off = range_start[r.op]
|
||||
rngs = r.src[off:]
|
||||
if not len(rngs): return None
|
||||
new_rngs = [x for x in UOp.sink(*rngs).toposort() if x.op is Ops.RANGE]
|
||||
return r.replace(src=r.src[:off]+tuple(new_rngs))
|
||||
|
||||
pm_flatten_range = PatternMatcher([
|
||||
# real ranges only
|
||||
(UPat((Ops.REDUCE, Ops.STORE), name="r"), flatten_range),
|
||||
])
|
||||
|
||||
def count_divmod(x:UOp): return len([u for u in x.toposort() if u.op in {Ops.IDIV, Ops.MOD}])
|
||||
def simplify_merge_adjacent(u:UOp) -> UOp|None:
|
||||
i = range_start[u.op]
|
||||
while i < len(u.src)-1:
|
||||
r0, r1 = u.src[i], u.src[i+1]
|
||||
# check same type
|
||||
if r0.arg[-1] == r1.arg[-1]:
|
||||
s0, s1 = r0.src[0], r1.src[0]
|
||||
# do the merge
|
||||
new_range = r0.replace(src=(s0*s1,))
|
||||
nidx = graph_rewrite(u, _substitute+symbolic_flat+pm_flatten_range, ctx={r0:new_range//s1, r1:new_range%s1},
|
||||
name=f"check_merge_{r0.arg[0]}_{r1.arg[0]}")
|
||||
# check if it simplifies
|
||||
if count_divmod(nidx) <= count_divmod(u):
|
||||
u = nidx
|
||||
continue
|
||||
i += 1
|
||||
return u
|
||||
|
||||
pm_simplify_ranges = PatternMatcher([
|
||||
(UPat((Ops.STORE, Ops.REDUCE), name="u"), simplify_merge_adjacent),
|
||||
])
|
||||
|
||||
# **** reduce simplification ****
|
||||
|
||||
def reduce_rangeless(red:UOp):
|
||||
# TODO: share code with reduce_unparented
|
||||
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
if red.src[0].dtype != red.dtype: return None
|
||||
if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None
|
||||
ret = red.src[0]
|
||||
if red.arg is Ops.ADD:
|
||||
for r in red.src[1:]:
|
||||
ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
|
||||
|
||||
pm_reduce_collapse = PatternMatcher([
|
||||
# lift x+y out of reduce on lt
|
||||
((UPat.var("x")+UPat.var("y")).or_casted() < UPat.var("c"), lambda x,y,c: (x < (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None),
|
||||
# lift x*y out of reduce
|
||||
((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
|
||||
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
|
||||
# lift x+y out of reduce on ne
|
||||
((UPat.var("x")+UPat.var("y")).or_casted() != UPat.var("c"), lambda x,y,c: (x != (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None),
|
||||
# fold the range
|
||||
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
|
||||
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
|
||||
# REDUCE on ADD
|
||||
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
|
||||
# MUL casted bool
|
||||
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
|
||||
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
|
||||
# WHERE on LOAD (works on max too)
|
||||
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda buf,idx,gate: buf.index(idx, gate).load()),
|
||||
(UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
|
||||
# INDEX on RANGE / gated RANGE
|
||||
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
|
||||
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))),
|
||||
# AND on WHERE
|
||||
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
|
||||
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
|
||||
# remove REDUCEs that no longer have a RANGE in the src
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_rangeless),
|
||||
])+sym
|
||||
|
||||
def reduce_collapse(red:UOp):
|
||||
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
|
||||
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u in included:
|
||||
for s in u.src:
|
||||
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
|
||||
replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
|
||||
collapse_fxn = red.substitute(replaces)
|
||||
sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse")
|
||||
if any(x.op is Ops.RANGE for x in sink.toposort()): return None
|
||||
return sink.substitute({v:k for k,v in replaces.items()})
|
||||
|
||||
def reduce_unparented(red:UOp):
|
||||
if red.arg not in {Ops.ADD, Ops.MAX, Ops.MUL}: return None
|
||||
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
|
||||
if len(reduce_unparented) == 0: return None
|
||||
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
|
||||
if red.arg is Ops.ADD:
|
||||
for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
if red.arg is Ops.MUL:
|
||||
for r in reduce_unparented: ret = ret ** r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
pm_reduce_simplify = PatternMatcher([
|
||||
# remove any ranges from a REDUCE that aren't referenced in the reduce source
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
|
||||
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
|
||||
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
|
||||
])
|
||||
Reference in New Issue
Block a user