Release 260111
This commit is contained in:
306
tinygrad/codegen/late/devectorizer.py
Normal file
306
tinygrad/codegen/late/devectorizer.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from typing import Any, cast
|
||||
import functools, operator, itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.index(UOp.invalid())
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
||||
|
||||
# wait for it to be image indexed before running simplification
|
||||
if start_idx.dtype.count != 2: return None
|
||||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
for stmt in valid.split_uop(Ops.AND):
|
||||
try: X, is_upper_bound, c = parse_valid(stmt)
|
||||
except ValueError: return None
|
||||
|
||||
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
||||
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in X.split_uop(Ops.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx)
|
||||
testidx = testidx.simplify()
|
||||
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
continue
|
||||
|
||||
# if X <= c, check if it's out of bound when X = c+1
|
||||
# if X >= c, check if it's out of bound when X = c-1
|
||||
test_value = c + 1 if is_upper_bound else c - 1
|
||||
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
|
||||
if i.is_increasing():
|
||||
rw = i.substitute({X:X.const_like(test_value)}).simplify()
|
||||
if rw.vmin >= b or rw.vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
break
|
||||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx, new_valid)
|
||||
|
||||
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
||||
if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None
|
||||
# remove the gate from the index
|
||||
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:])
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
# image load valid idx simplification
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# lower turn the invalid into a gate, must come before index dtype lowering
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate,),), lambda buf,x,cond,i: buf.index(x, cond)),
|
||||
# drop true gate
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
|
||||
# remove hanging cast
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
||||
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
|
||||
])
|
||||
|
||||
# ***** load/store grouping *****
|
||||
|
||||
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
|
||||
# generate the individual indexes
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
||||
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
|
||||
# extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
for i in range(vec.dtype.count):
|
||||
idx: Any = midx.src[i].src[1]
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
|
||||
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
||||
|
||||
# then rewrite everything we can into groups
|
||||
ret = []
|
||||
idxs: list[int|None] = [None]*vec.dtype.count
|
||||
global_offset = 0
|
||||
for offsets in offsets_rootsrc.values():
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
lidx = midx.src[offsets[grp[0]][0]]
|
||||
if len(grp) > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(len(grp)).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]: idxs[oo] = global_offset+i
|
||||
# add this lidx to the CAT
|
||||
ret.append(lidx)
|
||||
global_offset += len(grp)
|
||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||
# this base thing is for image, we want the CAT to be a normal pointer
|
||||
post_cat = UOp(Ops.PTRCAT, buf.ptrdtype.base.ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
|
||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
||||
|
||||
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
||||
# TODO: this is written in many places
|
||||
offset = 0
|
||||
ret: list[UOp] = []
|
||||
for s in cat.src:
|
||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
|
||||
offset += s.dtype.count
|
||||
return UOp(Ops.NOOP, src=tuple(ret))
|
||||
|
||||
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
||||
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
||||
# fake argsort. TODO: handle duplicates
|
||||
a = {}
|
||||
for i,x in enumerate(gep.arg): a[x] = i
|
||||
new_arg = tuple(x[1] for x in sorted(a.items()))
|
||||
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
|
||||
|
||||
load_store_folding = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"),
|
||||
UPat.var("mask"))), expand_index),
|
||||
# GEP after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
||||
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
||||
# GEP on data of STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
|
||||
# put PTRCAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||
# put PTRCAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
|
||||
])
|
||||
|
||||
# *** correct load/store ***
|
||||
|
||||
def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
# this splits loads and stores into multiple chunks
|
||||
|
||||
# if there's only one element to load/store, no splitting needed
|
||||
if (sz:=ls.src[0].dtype.count) == 1: return None
|
||||
buf = idx.src[0]
|
||||
|
||||
# determine fold lengths
|
||||
lengths = []
|
||||
must_divide = True
|
||||
if ctx is not None and ctx.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif buf.ptrdtype.addrspace == AddrSpace.REG:
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
elif ctx is not None and ctx.supports_float4:
|
||||
# TODO: a better way to get this than ctx
|
||||
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
|
||||
lengths.append(1) # worst case, it's not folded
|
||||
|
||||
# filter fold lengths that don't divide
|
||||
if must_divide: lengths = [x for x in lengths if idx.src[1].divides(x) is not None]
|
||||
|
||||
# split based on the fold lengths
|
||||
global_offset = 0
|
||||
ret = []
|
||||
while global_offset < sz:
|
||||
# with 1 at the end of the lengths list, this will always hit
|
||||
for fold_length in lengths:
|
||||
if global_offset+fold_length > sz: continue
|
||||
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
||||
global_offset += fold_length
|
||||
break
|
||||
|
||||
# if it wasn't split, we return None. otherwise we CAT them
|
||||
if len(ret) <= 1: return None
|
||||
return UOp(Ops.CAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp(Ops.NOOP, src=tuple(ret))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load or store, with the CAST from expand_index
|
||||
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
|
||||
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
|
||||
idx = ls.src[0].src[0]
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
||||
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
||||
return ls.replace(src=(idx,)+ls.src[1:])
|
||||
|
||||
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
|
||||
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
|
||||
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
|
||||
idx = ls.src[0]
|
||||
id4 = idx.src[1] % 4
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
||||
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
||||
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
|
||||
|
||||
return None
|
||||
|
||||
correct_load_store = PatternMatcher([
|
||||
# split LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
|
||||
# image indexing, including unfoldable images
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
|
||||
])
|
||||
|
||||
# *** uop expander ***
|
||||
|
||||
# TODO: there's a lot shared with gep_through_wmma here
|
||||
def no_vectorized_wmma(wmma:UOp):
|
||||
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
||||
if wmma.dtype.count == out_sz: return None
|
||||
tsrcs = []
|
||||
for s,sz in zip(wmma.src, wmma.arg[6]):
|
||||
ssz = prod(x[1] for x in sz)
|
||||
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
|
||||
wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
||||
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
||||
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
def no_vectorized_alu(alu:UOp):
|
||||
if alu.dtype.vcount == 1: return None
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
||||
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
def no_vectorized_buf(buf:UOp):
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
])
|
||||
|
||||
pm_render = PatternMatcher([
|
||||
# for rendering, we use explicit VECTORIZE
|
||||
(UPat(Ops.CONST, name='c'),
|
||||
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# give any loads that are masked an alt value
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
|
||||
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
|
||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
|
||||
len(store.src) <= 2 or store.src[2].op != Ops.IF else None),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
||||
@dataclass
|
||||
class ReduceContext:
|
||||
acc_num: int = 0
|
||||
|
||||
def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
|
||||
# if this has a horizontal reduction component, do that first
|
||||
if inp.dtype != out_dtype:
|
||||
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
|
||||
horizontal_amount = inp.dtype.count//out_dtype.count
|
||||
return [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
|
||||
return [inp]
|
||||
|
||||
def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
inp, reduce_range = red.src[0], red.src[1:]
|
||||
lst = horizontal_reduce(inp, red.dtype)
|
||||
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
|
||||
# if we have a range
|
||||
if len(reduce_range) != 0:
|
||||
topo = inp.toposort()
|
||||
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
|
||||
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
|
||||
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
|
||||
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
# tensor core built in accumulate
|
||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
])+sym
|
||||
162
tinygrad/codegen/late/expander.py
Normal file
162
tinygrad/codegen/late/expander.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# this converts a lowerer program into a vectorized program
|
||||
import functools, itertools, operator
|
||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, partition
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start
|
||||
from tinygrad.schedule.rangeify import BufferizeOpts
|
||||
|
||||
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
|
||||
idx, mul = 0, 1
|
||||
for axis,m in args[::-1]:
|
||||
idx += rpk[axis] * mul
|
||||
mul *= m
|
||||
return idx
|
||||
|
||||
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
|
||||
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
||||
|
||||
@functools.cache
|
||||
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
|
||||
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
||||
|
||||
def do_expand(root:UOp):
|
||||
expands = [x for x in root.src if x.op is Ops.UNROLL]
|
||||
if len(expands) == 0: return None
|
||||
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
|
||||
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
|
||||
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
|
||||
# if there's only one expand arg, it's okay to use it (optimization)
|
||||
expand_args = expands[0].arg
|
||||
else:
|
||||
# otherwise, we sort them and GEP
|
||||
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
|
||||
expand_sz = prod([x[1] for x in expand_args])
|
||||
new_srcs = []
|
||||
for i,src in enumerate(root.src):
|
||||
if src.op is Ops.UNROLL:
|
||||
if root.op is Ops.IF and i == 0:
|
||||
# IF means OR on first arg to IF
|
||||
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
|
||||
elif expand_args == src.arg:
|
||||
# just remove the expand
|
||||
new_srcs.append(src.src[0])
|
||||
else:
|
||||
lst = _swizzle_args(expand_args, src.arg, exclude_args)
|
||||
# if the base dtype is > 1, put those at the end
|
||||
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
|
||||
new_srcs.append(src.src[0].gep(tuple(lst)))
|
||||
else:
|
||||
# non-UNROLL input
|
||||
if root.op is Ops.IF or src.op is Ops.IF:
|
||||
# for the first arg of IF, just pass them through ignoring UNROLLS
|
||||
new_srcs.append(src)
|
||||
elif root.op in range_start and i >= range_start[root.op]:
|
||||
# for any range args of STORE/REDUCE, pass them through
|
||||
new_srcs.append(src)
|
||||
elif root.op is Ops.INDEX and i >= 1 and not isinstance(root.dtype, PtrDType):
|
||||
new_srcs.append(src)
|
||||
elif src.dtype.count > 1:
|
||||
# put any input dtype > 1 grouped together
|
||||
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
|
||||
else:
|
||||
# repeat the arg
|
||||
new_srcs.append(src.broadcast(expand_sz))
|
||||
|
||||
new_arg = root.arg
|
||||
if root.op is Ops.GEP:
|
||||
assert root.dtype.count == 1
|
||||
# is this right?
|
||||
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
||||
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
||||
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
|
||||
|
||||
def do_contract(con:UOp):
|
||||
ex = con.src[0]
|
||||
# CONTRACT without UNROLL repeats the element VECTORIZED
|
||||
if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
||||
# CONTRACT may remove several axes from UNROLL
|
||||
assert con.dtype == dtypes.void or con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
|
||||
idxs = []
|
||||
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
|
||||
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
|
||||
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
||||
|
||||
expander = PatternMatcher([
|
||||
# double expand
|
||||
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
||||
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# BARRIERs aren't actually expanded
|
||||
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
||||
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
|
||||
# empty UNROLL is NOOP
|
||||
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
|
||||
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
|
||||
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
||||
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
||||
])
|
||||
|
||||
def create_gate(root:UOp) -> UOp|None:
|
||||
@functools.cache
|
||||
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
||||
if u.op is Ops.BARRIER: return u
|
||||
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
|
||||
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg)
|
||||
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
|
||||
idx = root.src[0]
|
||||
if idx.op is Ops.CAST: idx = idx.src[0]
|
||||
return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
|
||||
|
||||
migrate_indexing = PatternMatcher([
|
||||
# create gate MUST BE BEFORE expander
|
||||
(UPat(Ops.STORE, name="root"), create_gate),
|
||||
])
|
||||
|
||||
# ****
|
||||
|
||||
def fix_reduce_unroll(x:UOp):
|
||||
reduce_range, reduce_expand = partition(x.src[1:], lambda y: y.op is Ops.RANGE)
|
||||
if len(reduce_expand) == 0: return None
|
||||
reduce_expand = [x for x in reduce_expand if x.op is not Ops.CONST]
|
||||
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand}"
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1)
|
||||
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
|
||||
return x.replace(src=(ret,)+tuple(reduce_range))
|
||||
|
||||
def fix_store_unroll(x:UOp):
|
||||
store_expand, store_range = partition(x.src[2:], lambda y: y.op is Ops.UNROLL)
|
||||
if len(store_expand) == 0: return None
|
||||
return UOp(Ops.CONTRACT, dtypes.void, (x.replace(src=x.src[:2]+tuple(store_range)),), tuple(flatten(x.arg for x in store_expand)), tag=1)
|
||||
|
||||
def fix_group_for_reduce(x:UOp):
|
||||
reduce_gfr, reduce_r = partition(x.src[1:], lambda u: u.op is Ops.RANGE and u.arg[1] == AxisType.GROUP_REDUCE)
|
||||
if len(reduce_gfr) == 0: return None
|
||||
|
||||
# NOTE: if there's other locals here, we need them in the buffer too
|
||||
upstream_locals = [u for u in x.toposort() if u.op is Ops.RANGE and u.arg[1] == AxisType.LOCAL]
|
||||
|
||||
# do only the non grouped reduces early
|
||||
ret = x.replace(src=(x.src[0],)+tuple(reduce_r))
|
||||
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr]
|
||||
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop)
|
||||
|
||||
# gate with an if on the store + do the final reduce
|
||||
buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf))
|
||||
return buf.reduce(*reduce_loop, arg=x.arg)
|
||||
|
||||
pm_pre_expander = PatternMatcher([
|
||||
# rewrite UPCAST/UNROLL range to something to be expanded
|
||||
(UPat(Ops.RANGE, name="r"),
|
||||
lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \
|
||||
if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None),
|
||||
# fix REDUCEs with UNROLLs
|
||||
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
|
||||
(UPat(Ops.STORE, name="x"), fix_store_unroll),
|
||||
# fix group for reduce
|
||||
(UPat(Ops.REDUCE, name="x"), fix_group_for_reduce),
|
||||
])
|
||||
243
tinygrad/codegen/late/linearize.py
Normal file
243
tinygrad/codegen/late/linearize.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from __future__ import annotations
|
||||
import heapq
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp, BottomUpGate
|
||||
from tinygrad.helpers import dedup, all_same, flatten, BLOCK_REORDER
|
||||
|
||||
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
||||
def block_reorder(lst:list[UOp]) -> list[UOp]:
|
||||
in_this_block = set(lst)
|
||||
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree:dict[UOp, int] = {}
|
||||
priorities:dict[UOp, int] = {}
|
||||
|
||||
# get local children and assign priorities
|
||||
# NOTE: this requires the lst be locally toposorted
|
||||
for u in reversed(lst):
|
||||
in_degree[u] = 0
|
||||
for s in u.src:
|
||||
if s in in_this_block:
|
||||
local_children[s].append(u)
|
||||
in_degree[u] += 1
|
||||
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
|
||||
priority = [0] + [priorities[x] for x in local_children[u]]
|
||||
if u.op is Ops.LOAD: priority.append(-1000)
|
||||
if u.op is Ops.BARRIER: priority.append(-1500)
|
||||
priorities[u] = min(priority)
|
||||
|
||||
# number the uops in "ideal" order
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
|
||||
|
||||
# then force then to be toposorted in as close to the ideal order as possible
|
||||
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
|
||||
newlst = []
|
||||
while heap:
|
||||
newlst.append(u:=heapq.heappop(heap)[1])
|
||||
for v in local_children[u]:
|
||||
in_degree[v] -= 1
|
||||
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
|
||||
|
||||
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
|
||||
return newlst
|
||||
|
||||
# ***** basic block *****
|
||||
|
||||
def disp(y:UOp) -> str:
|
||||
if y.op is Ops.IF: return f'IF{id(y)}'
|
||||
if y.op is Ops.RANGE: return str(y.arg)
|
||||
return "<NONE>"
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class BasicBlock:
|
||||
lst: tuple[UOp, ...]
|
||||
ctx: tuple[UOp, ...] = ()
|
||||
end: UOp|None = None
|
||||
cnt: int = 0
|
||||
child_ctx: tuple[UOp, ...]|None = None
|
||||
def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks")
|
||||
def __repr__(self):
|
||||
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
|
||||
f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
|
||||
f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
||||
def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx
|
||||
|
||||
def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize))
|
||||
|
||||
# ***** block context *****
|
||||
|
||||
@dataclass
|
||||
class BlockContext:
|
||||
child_count: dict[UOp, int]
|
||||
block_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
child_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u])
|
||||
@staticmethod
|
||||
def from_sink(sink:UOp) -> BlockContext:
|
||||
# get children and all block contexts
|
||||
ctx = BlockContext({}, {}, {})
|
||||
for u in sink.toposort(gate=lambda u:u.op is not Ops.SPECIAL):
|
||||
this_block_ctx: list[UOp] = []
|
||||
ctx.child_count[u] = 0
|
||||
|
||||
# get children and accumulate the last_ctx
|
||||
for s in u.src:
|
||||
if s.op is Ops.SPECIAL: continue
|
||||
# NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
|
||||
ctx.child_count[s] += 1
|
||||
this_block_ctx += ctx.last_ctx(s)
|
||||
|
||||
# save the block ctx. SINK never has anything
|
||||
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
|
||||
|
||||
# RANGE/IF add to the next ctx
|
||||
# STORE/ASSIGN subtract from the next ctx
|
||||
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
|
||||
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
|
||||
return ctx
|
||||
|
||||
# ***** make blocks *****
|
||||
|
||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
|
||||
|
||||
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
|
||||
ends_to_add = [z for z in new_ctx if z not in current_ctx]
|
||||
while len(ends_to_add):
|
||||
r:UOp = ends_to_add.pop(-1)
|
||||
new_ctx = tuple([z for z in new_ctx if z is not r])
|
||||
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
|
||||
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
|
||||
return base_block
|
||||
|
||||
def make_block_bottom_up(ctx:BlockContext, x:UOp):
|
||||
if x.op is Ops.BLOCKSTART:
|
||||
current_ctx, child_ctx = x.arg
|
||||
lst = list(x.src)
|
||||
child_count = 1
|
||||
else:
|
||||
current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None)
|
||||
lst = [x]
|
||||
|
||||
# count of times we've seen this block, or a seed for a new block if we can't merge it
|
||||
unmergable: defaultdict[UOp, int] = defaultdict(int)
|
||||
blockseeds = defaultdict(list)
|
||||
|
||||
# add the srcs of this to the frontier
|
||||
# NOTE: things may be in here multiple times, that's okay
|
||||
frontier_nodes = list(flatten(y.src[::-1] for y in lst))
|
||||
while len(frontier_nodes):
|
||||
u = frontier_nodes.pop(0)
|
||||
if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1:
|
||||
# count is correct
|
||||
if (newctx:=ctx.block_ctxs[u]) == current_ctx:
|
||||
# block has same context, merge it, and put the srcs on the frontier
|
||||
lst.append(u)
|
||||
frontier_nodes.extend(u.src[::-1])
|
||||
else:
|
||||
# block has different context, add it to blockseeds
|
||||
blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u)
|
||||
del unmergable[u]
|
||||
else:
|
||||
# count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
|
||||
unmergable[u] += 1
|
||||
|
||||
# add unmergables to sources
|
||||
srcs = []
|
||||
for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs.get(u,()), current_ctx, cnt=cnt)]*cnt
|
||||
|
||||
# add blockseeds, with blockends as needed
|
||||
for (new_ctx, new_child_ctx), v in blockseeds.items():
|
||||
base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx))
|
||||
srcs.append(add_blockends(base_block, new_ctx, current_ctx))
|
||||
|
||||
lst = lst[::-1]
|
||||
if BLOCK_REORDER: lst = block_reorder(lst)
|
||||
bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
|
||||
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
|
||||
|
||||
# we prevent the source of the SPECIAL from being linearized since its not part of the kernel
|
||||
def raise_bottom_up_gate(): raise BottomUpGate()
|
||||
|
||||
block_create = PatternMatcher([
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
|
||||
(UPat(Ops.SPECIAL), raise_bottom_up_gate)
|
||||
])
|
||||
|
||||
# ***** blockend merging ****
|
||||
|
||||
def merge_blockends(sink:UOp) -> UOp|None:
|
||||
# only run on the final BLOCK with the SINK in it
|
||||
if sink.arg.lst[-1].op is not Ops.SINK: return None
|
||||
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
|
||||
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
||||
for be in sink.toposort():
|
||||
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
||||
new_forks = {}
|
||||
for k,v in blockends_to_arg.items():
|
||||
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
||||
if len(v) > 1:
|
||||
bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
|
||||
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
|
||||
# NOTE: bb.ctx != u.arg.ctx can cause problems here
|
||||
for u in v: new_forks[u] = out
|
||||
if len(new_forks) == 0: return None
|
||||
return sink.substitute(new_forks)
|
||||
|
||||
pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
|
||||
|
||||
# ***** block merging ****
|
||||
|
||||
def merge_block(x:UOp):
|
||||
unmergable_blocks, mergable_blocks = [], []
|
||||
mergable_dict: defaultdict[UOp, int] = defaultdict(int)
|
||||
for y in x.src:
|
||||
if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1
|
||||
elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1
|
||||
else: unmergable_blocks.append(y)
|
||||
for k,v in mergable_dict.items():
|
||||
if v == k.arg.cnt: mergable_blocks.append(k)
|
||||
else: unmergable_blocks.extend([k]*v)
|
||||
if len(mergable_blocks) == 0: return None
|
||||
del mergable_dict
|
||||
|
||||
# create the block
|
||||
arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst)
|
||||
return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg)
|
||||
|
||||
def remove_blockend(x:UOp):
|
||||
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
|
||||
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
|
||||
|
||||
if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]):
|
||||
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
|
||||
parent_block = parent_blocks[0]
|
||||
assert len(parent_blocks) == parent_block.arg.cnt
|
||||
# NOTE: DEFINE_ACC doesn't have to be handled in any special way
|
||||
late_ops = list(x.arg.lst)
|
||||
# NOTE: we have to add a barrier at the start if barrier is used in the range
|
||||
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
|
||||
late_ops = [UOp(Ops.BARRIER)] + late_ops
|
||||
# peephole opt, remove any BARRIERs next to each other
|
||||
for i in range(len(late_ops)-1):
|
||||
if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP)
|
||||
arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
|
||||
return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
|
||||
# else the whole context ended by the blockend is already in this block and we can safely turn it into a block
|
||||
return UOp(Ops.BLOCK, src=x.src, arg=BasicBlock(x.arg.lst, tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt))
|
||||
|
||||
block_merge = PatternMatcher([
|
||||
(UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block),
|
||||
(UPat(Ops.BLOCKEND, name="x"), remove_blockend),
|
||||
])
|
||||
|
||||
# ****** finalize ******
|
||||
|
||||
def finalize(sink:UOp) -> UOp:
|
||||
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
|
||||
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
|
||||
|
||||
# place the early things
|
||||
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
||||
return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst)))
|
||||
|
||||
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
|
||||
Reference in New Issue
Block a user