Release 260111

This commit is contained in:
Comma Device
2026-01-11 18:23:29 +08:00
commit 3721ecbf8a
2601 changed files with 855070 additions and 0 deletions

121
tinygrad/uop/__init__.py Normal file
View File

@@ -0,0 +1,121 @@
from enum import auto, IntEnum, Enum
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
class FastEnum(IntEnum):
def __str__(self): return Enum.__str__(self)
@staticmethod
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
# track children
CHILD = auto(); CHILDREN = auto() # noqa: E702
# buffer ops
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
# create buffer
BUFFERIZE = auto()
# ops that adjust the behavior of the scheduler
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
REALIZE = auto()
# blocks in linearizer (only used there)
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
# movement ops! these only exist in the tensor graph
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
MULTI = auto() # MULTI is really a movement op
# view is what all movement ops become
VIEW = auto()
# TODO: remove VALID with the VIEW(CONST(DEVICE)) refactor
VALID = auto()
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
# this is for symbolic shapes
DEFINE_VAR = auto(); BIND = auto() # noqa: E702
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
SPECIAL = auto()
# reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
# optimization helper ops
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
# load/store before math
LOAD = auto(); STORE = auto() # noqa: E702
ASSIGN = auto() # TODO: ASSIGN is STORE, remove ASSIGN
# tensor core math op, not elementwise
WMMA = auto()
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
INDEX = auto()
# BinaryOps
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
CMPLT = auto(); CMPNE = auto(); CMPEQ = auto() # noqa: E702
XOR = auto(); OR = auto(); AND = auto() # noqa: E702
THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
# TernaryOps
WHERE = auto(); MULACC = auto() # noqa: E702
# control flow ops
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
# consts. VCONST is a vectorized const
VCONST = auto(); CONST = auto() # noqa: E702
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG, Ops.TRUNC}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)
# TODO: is BITCAST always Elementwise if it's shape changing?
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
# BinaryOps that can be flipped
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.CMPEQ, Ops.XOR, Ops.AND, Ops.OR}
# BinaryOps where f(f(a,b),c) = f(a,f(b,c))
Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
# These can change the dtype to bool
Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}
# do not preserve f(0) = 0
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
Meta = {Ops.COPY, Ops.BUFFER_VIEW}
All = set(Ops)

View File

@@ -0,0 +1,367 @@
from typing import Callable
import math, functools
from tinygrad.dtype import dtypes, DType, promo_lattice
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import polyN, DISABLE_FAST_IDIV
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher
TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
"""replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
# *** helper functions for bit manipulation ***
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
# **** utils ****
def shr(x:UOp, y:int) -> UOp: return x // (2**y)
def shl(x:UOp, y:int) -> UOp: return x * (2**y)
def rintk(d:UOp) -> UOp:
"""round d:float to int away from 0"""
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount)
return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
def pow2if(q:UOp, float_dtype:DType):
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype.scalar()}[q.dtype.scalar()].vec(q.dtype.vcount)
return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
def ilogb2k(d:UOp) -> UOp:
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount))
# -1 <= ilog2bk(d) <= 128
return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype)
def ldexp3k(d:UOp, e:UOp) -> UOp:
"""d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES and e.dtype.scalar() in TRANSCENDENTAL_DTYPES
dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.count)
m1 = d.bitcast(dtype)
m2 = shl(e.cast(dtype), mantissa_bits(d.dtype))
return (m1 + m2).bitcast(d.dtype).cast(d.dtype)
def ldexp2k(d:UOp, e:UOp) -> UOp:
"""d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES and e.dtype.scalar() in (dtypes.int16, dtypes.int32, dtypes.int64)
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
def frexp(v:UOp) -> tuple[UOp, UOp]:
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
assert v.dtype.scalar() in TRANSCENDENTAL_DTYPES
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype.scalar()]
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype.scalar()]
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype.scalar()].vec(v.dtype.count))
exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype)
# Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
mantissa = ((bits & m1) | m2).bitcast(v.dtype)
exp = exponent - exponent_bias(v.dtype) + 1
return mantissa, exp
# *** reduction algorithms for sine ***
def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
"""
Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
39800.0 <= d <= +Inf
Returns a tuple of `(r, q)`:
- `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
- `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
"""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
# https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
# 190 bits of 2/pi for Payne-Hanek style argument reduction
two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410]
intermediate_dtype = dtypes.float32.vec(d.dtype.count) if d.dtype.base.scalar() == dtypes.float16 else d.dtype
f, e = frexp(d)
ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(dtypes.uint64)
# extract 96 relevant bits of 2/pi based on magnitude of argument
i = shr(e.cast(dtypes.uint64), 5)
e = e.cast(dtypes.int32) & 31
offset = 32 - e
def _take(an:UOp, offset:int, count:int=0) -> UOp:
"""an = two_over_pi_f[i+offset]"""
if count+offset < len(two_over_pi_f) - 1:
an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
return an
def _shl_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
def _shr_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
a = [_take(UOp.const(dtypes.uint32.vec(d.dtype.count), 0), i) for i in range(4)]
# (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
# Note: e >= 1 for all numbers d >= 1.0. assume e != 0
hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset)
mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset)
lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset)
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
# compute x * 2/pi
p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)
# round quotient to nearest
q = shr(p, 62).cast(dtypes.int32)
p = p & 0x3fffffffffffffff
r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
# if fraction >= 0.5, r -= pi/2, q += 1
return (f<0.5).where(r, r - math.pi/2), (f<0.5).where(q, q + 1)
def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
"""
Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
0 <= abs(d) <= 39800.0
Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`.
"""
def _reduce_d(x:UOp, q:UOp):
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefdp.c#L789-L823
if x.dtype.scalar() == dtypes.float64:
# https://github.com/shibatch/sleef/blob/f6d8a841fbfddd26ce712834d4da220cd76048fb/src/common/misc.h#L77
PI_A, PI_B, PI_C, PI_D = 3.1415926218032836914, 3.1786509424591713469e-08, 1.2246467864107188502e-16, 1.2736634327021899816e-24
d = qdh * -PI_A + x
d = q * -PI_A + d
d = qdh * -PI_B + d
d = q * -PI_B + d
d = qdh * -PI_C + d
d = q * -PI_C + d
d = (qdh + q) * -PI_D + d
elif x.dtype.scalar() == dtypes.float16:
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
else:
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefsp.c#L464-L503
d = q * -3.1414794921875 + x
d = q * -0.00011315941810607910156 + d
d = q * -1.9841872589410058936e-09 + d
d = q * -1.2154201256553420762e-10 + d
return d
m_1_pi = 0.318309886183790671537767526745028724
qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64).cast(d.dtype) * (2.0**24)
quadrant = rintk(d * m_1_pi -qdh) if d.dtype.base.scalar() == dtypes.float64 else rintk(d * m_1_pi)
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32)
# *** approximate sine on small angle. ***
def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype.scalar() == dtypes.float64 else polyN(d*d, coeff32))
# approximate sine on [-pi/2, pi/2]
def sin_poly(d:UOp) -> UOp:
return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938, 1.0],
[-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10,
-2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815,
-0.166666666666666657414808, 1.0])
def _ifand(q:UOp, n:int): return (q & n).ne(0)
def sin_poly_small(d:UOp, q:UOp) -> UOp:
r = sin_poly(d)
return r * _ifand(q, 1).where(r.const_like(-1), r.const_like(1))
def sin_poly_large(d:UOp, q:UOp) -> UOp:
r = sin_poly(d + _ifand(q, 1).where(d.const_like(math.pi / 2), d.const_like(0)))
return r * _ifand(q, 2).where(r.const_like(-1), r.const_like(1))
# *** toplevel functions for xsin/xlog2/xexp2 ***
def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
"""
Implements a 1.0 ULP approximation for Ops.SIN.
- fast=True assumes x <= switch_over.
- switch_over is the threshold for switching to payne_hanek_reduction.
"""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
# mask +-inf/nan as zero
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
# x_sign = sign(x)
x_sign = x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
x_abs = x * x_sign
r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs)
if fast: result = sin_poly_small(r, q)
else:
# Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
r_small, q_small = cody_waite_reduction(x_abs)
result = (x_abs<switch_over).where(sin_poly_small(r_small, q_small), sin_poly_large(r, q))
# adjusts the sign for abs(x)
result = result * x_sign
# sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
return _lazy_map_numbers(d, d.const_like(math.nan), d.const_like(math.nan), d.const_like(math.nan), result)
def xexp2(d:UOp) -> UOp:
"""
Implements a 1.0 ULP approximation for Ops.EXP2
- Paper: https://arxiv.org/pdf/2001.09258
"""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
# mask +=inf/nan as zero.
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
q = rintk(x)
# s = d - round(d)
s = x - q.cast(x.dtype)
# a polynomial approximation with 13 non-zero terms in the range of [(log 2)/2,(log 2)/2].
if d.dtype.scalar() == dtypes.float64:
u = polyN(s, [0.4434359082926529454e-9, 0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4,
0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
0.6931471805599452862e+0, 0.1000000000000000000e+1])
else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
u = ldexp2k(u, q) # u*2^q
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype.scalar()]
# Replace x >= upper with +inf
u = (d >= upper).where(d.const_like(math.inf), u)
# Replace x < lower with zero.
u = (d<lower).where(d.const_like(0.0), u)
# exp2(NaN) = NaN
return d.ne(d).where(d.const_like(math.nan), u)
def xlog2(d:UOp) -> UOp:
"""
Implements a 1.0 ULP approximation for Ops.LOG2
Paper: https://arxiv.org/pdf/2001.09258 5.5
"""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
# TODO: float16 denormal need float32 to achieve precision
if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4)
is_denormal = d<FLT_MIN
a = is_denormal.where(d * (2 ** 64), d)
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
m = ldexp3k(a, -e)
e = is_denormal.where(e - 64, e)
x = (m - 1.0) / (m + 1.0)
x2 = x * x
if d.dtype.scalar() == dtypes.float64:
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
else:
t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
r = t * (x * x2) + (s_hi + s_lo)
# log2(Inf) = Inf
r = d.ne(math.inf).where(r, r.const_like(math.inf))
# log2(x) = NaN for x < 0
r = (d<-0.0).where(r.const_like(math.nan), r)
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
# log2_zero = the value of unmasked xlog2(0.0).
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()]
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
# log2(NaN) = NaN
r = d.ne(d).where(r.const_like(math.nan), r)
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
def xpow(base:UOp, exponent:UOp) -> UOp:
# start with b ** e = exp2(e * log2(b))
ret = (base < 0).where(-base, base).log2().mul(exponent).exp2()
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
non_int = exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)
adj = non_int.where(ret.const_like(math.nan),
(exponent < 0).where(-exponent, exponent).cast(dtypes.int32).mod(2).cast(dtypes.bool).where(ret.const_like(-1), ret.const_like(1)))
# fix 0 ** 0 = 1
return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))
# *** integer division ***
@functools.lru_cache(None)
def magicgu(vmax:int, d:int) -> tuple[int,int]:
# calculate m,s such that x//d == (x*m) >> s for all 0 <= x <= vmax, d>0; adapted from Hacker's Delight, Chapter 10
nc = (vmax+1)//(d) * d - 1
nbits = vmax.bit_length()
for s in range(0, 2*nbits + 1):
if 2**s > nc*(d - 1 - (2**s - 1) % d):
m = (2**s + d - 1 - (2**s - 1) % d)//d
return m, s
assert False
def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
# If d is a power of two this is not valid for signed ints!
is_unsigned = True if x.vmin>=0 or x.dtype in dtypes.uints else False
assert d>0, "Sign should have been taken out of divisor"
vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max)
m,s = magicgu(max(vmax, abs(vmin)), d)
if m*vmin >= dtypes.min(x.dtype) and m*vmax <= dtypes.max(x.dtype):
return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0)
# before we try casting to a larger dtype (slow), we see if there are powers of two in d we can shift to make x smaller
if (largest_factor_of_two_in_d := (d & -d)) > 1:
if (ret:=fast_idiv(device, x//largest_factor_of_two_in_d, d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret
if dont_cast: return None
# promo_lattice needs to return an unsigned type if the type is unsigned
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, None if device=='' else device):
if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype):
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
return None
# ***** threefry *****
def threefry2x32(x: UOp, key: UOp):
# split x and key from uint64 to two uint32
x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
ks = [key1, key0 ^ key1 ^ 0x1BD11BDA, key0]
xr:list[UOp] = [x0 + ks[-1], x1 + ks[0]]
for i in range(5):
for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
# ***** decomposition patterns *****
powers_of_two = {2**i:i for i in range(64)}
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental=False):
pat: list[tuple[UPat, Callable]] = []
for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)):
if op not in ops or force_transcendental:
pat += [(UPat(op, dtype=TRANSCENDENTAL_DTYPES, src=(UPat.var("d"),)), f),
(UPat(op, dtype=tuple(dt for dt in dtypes.floats if dt not in TRANSCENDENTAL_DTYPES), src=(UPat.var("d"),), name="x"),
lambda x,d: d.cast(dtypes.float32).alu(x.op).cast(x.dtype))]
# no real hardware supports THREEFRY, but NullRenderer does
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
# rewrite SQRT to xpow 0.5
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())]
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
if Ops.SHR in ops:
# no reason to check x<0 for uints
pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(
c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v
if not DISABLE_FAST_IDIV:
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
if Ops.CMPLT in ops:
# These are late rewrites because simplex expects equalities to be a certain format
pat += [
((UPat.var("x", dtypes.sints) < UPat.cvar("c", dtypes.sints)).logical_not(), lambda x,c: c-1<x),
((UPat.cvar("c", dtypes.sints) < UPat.var("x", dtypes.sints)).logical_not(), lambda x,c: x<c+1),
(UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*UPat.cvar("c"), lambda x,y,c: y*(-c)<x),
(UPat.var("x", dtypes.sints)*-1 < UPat.cvar("c"), lambda x,c:-c<x),
((UPat.cvar("c1",vec=False)<UPat.var("x", dtypes.sints)) & (UPat.var("x", dtypes.sints)<UPat.cvar("c2",vec=False)),
lambda x,c1,c2: x.eq(c1+1) if c1.arg+1==c2.arg-1 else None), # (c-1)<x & x<(c+1) -> x==c
]
if Ops.CMPEQ in ops: pat += [(UPat.var('x').ne(UPat.var('y')).logical_not(), lambda x,y: x.alu(Ops.CMPEQ, y))]
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
# some backends emit FDIV for RECIP, in that case: a*(1/b) -> a/b
if Ops.FDIV in ops:
pat += [(UPat.var("x").reciprocal(), lambda x: x.const_like(1).alu(Ops.FDIV, x))]
pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))]
return PatternMatcher(pat)

170
tinygrad/uop/mathtraits.py Normal file
View File

@@ -0,0 +1,170 @@
from tinygrad.uop import Ops
from tinygrad.helpers import T
from tinygrad.dtype import dtypes
class MathTrait:
# required to implement
def alu(self:T, op:Ops, *src) -> T: raise NotImplementedError
def const_like(self:T, b) -> T: raise NotImplementedError
# great functions you get!
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
def logical_not(self): return self.ne(True)
def neg(self):
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
def _check_dtype(self):
if (dtype:=getattr(self, 'dtype')) is not None:
if isinstance(dtype, tuple): dtype = dtype[0]
if not (dtypes.is_bool(dtype) or dtypes.is_int(dtype)): raise RuntimeError(f"{dtype} is not supported")
def add(self, x, reverse=False):
"""
Adds `self` and `x`.
Equivalent to `self + x`.
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.add(20).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.add(Tensor([[2.0], [3.5]])).numpy())
```
"""
return self._binop(Ops.ADD, x, reverse)
def mul(self, x, reverse=False):
"""
Multiplies `self` and `x`.
Equivalent to `self * x`.
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.mul(3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
```
"""
return self._binop(Ops.MUL, x, reverse)
def bitwise_and(self, x, reverse=False):
"""
Computes the bitwise AND of `self` and `x`.
Equivalent to `self & x`.
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
```
"""
self._check_dtype()
return self._binop(Ops.AND, x, reverse)
def bitwise_or(self, x, reverse=False):
"""
Computes the bitwise OR of `self` and `x`.
Equivalent to `self | x`.
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
```
"""
self._check_dtype()
return self._binop(Ops.OR, x, reverse)
def bitwise_xor(self, x, reverse=False):
"""
Computes bitwise xor of `self` and `x`.
Equivalent to `self ^ x`.
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, -2, 3]).bitwise_xor(Tensor([1, 0, 3])).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, True, False, False]).bitwise_xor(Tensor([True, False, True, False])).numpy())
```
"""
self._check_dtype()
return self._binop(Ops.XOR, x, reverse)
def idiv(self, x, reverse=False):
"""
Divides `self` by `x`.
Equivalent to `self // x`.
Supports broadcasting to a common shape, type promotion, and integer inputs.
`idiv` performs integer division (truncate towards zero).
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
```
"""
return self._binop(Ops.IDIV, x, reverse)
def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
def __neg__(self): return self.neg()
def __add__(self, x): return self.add(x)
def __sub__(self, x): return self.sub(x)
def __mul__(self, x): return self.mul(x)
def __truediv__(self, x): return self.div(x)
def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
def __mod__(self, x): return self.mod(x)
def __and__(self, x): return self.bitwise_and(x)
def __or__(self, x): return self.bitwise_or(x)
def __xor__(self, x): return self.bitwise_xor(x)
def __radd__(self, x): return self.add(x, True)
def __rsub__(self, x): return self.sub(x, True)
def __rmul__(self, x): return self.mul(x, True)
def __rtruediv__(self, x): return self.div(x, True)
def __rfloordiv__(self, x): return self.idiv(x, True)
def __rand__(self, x): return self.bitwise_and(x, True)
def __ror__(self, x): return self.bitwise_or(x, True)
def __rxor__(self, x): return self.bitwise_xor(x, True)
def __rmod__(self, x): return self.mod(x, True)
def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
def __ge__(self, x): return (self < x).logical_not()
def __le__(self, x): return (self > x).logical_not()
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
def eq(self, x): return self.ne(x).logical_not()
def __ne__(self, x): return self.ne(x)
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
def __lshift__(self, x): return self.lshift(x)
def __rshift__(self, x): return self.rshift(x)
def __rlshift__(self, x): return self.lshift(x, True)
def __rrshift__(self, x): return self.rshift(x, True)
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
def minimum(self, x): return -(-self).maximum(-x)
def where(self, x, y):
if type(self) is type(x): return self.alu(Ops.WHERE, x, x.ufix(y))
if type(self) is type(y): return self.alu(Ops.WHERE, y.ufix(x), y)
raise RuntimeError("where needs at least one UOp arg")
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
def reciprocal(self): return self.alu(Ops.RECIP)
def trunc(self): return self.alu(Ops.TRUNC)
def sqrt(self): return self.alu(Ops.SQRT)
def sin(self): return self.alu(Ops.SIN)
def log2(self): return self.alu(Ops.LOG2)
def exp2(self): return self.alu(Ops.EXP2)
def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
def __pow__(self, x): return self.pow(x)

1154
tinygrad/uop/ops.py Normal file

File diff suppressed because it is too large Load Diff

261
tinygrad/uop/spec.py Normal file
View File

@@ -0,0 +1,261 @@
from typing import cast, Callable
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, AxisType
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context, cpu_profile
from tinygrad.shape.shapetracker import ShapeTracker
try:
import z3
# older versions of z3 dont have some operators like & overloaded
if z3.get_version() < (4, 12, 4, 0): raise ImportError
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
def z3_xor(a,b):
if isinstance(a, z3.BoolRef): return a^b
assert a==-1 or b==-1, "xor can only be used in indexing if one of the aruments is -1"
return -a-1 if b==-1 else -b-1
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor,
Ops.MAX: lambda a,b: z3.If(a<b, b, a), Ops.TRUNC: lambda a: a if a.is_int() else z3.ToReal(z3.If(a >= 0, z3.ToInt(a), -z3.ToInt(-a)))}
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
s = z3.Int(name, ctx=solver.ctx)
solver.add(vmin <= s, s <= vmax)
return s
# ctx is (solver, load_number_dict)
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
# contexts can have the same hash but error on comparison
z3_renderer = PatternMatcher([
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
# float loads only become a variable when they get cast to int/bool
(UPat(Ops.LOAD, dtypes.ints, name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
# z3 can cast from bool to int automatically
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
# if the source of the cast is not a noop it means that it is a float and so we create a new variable
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))),
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))),
# A comparison between floats introduces a new bool variable
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))),
])
def uops_to_z3(solver, *uops: UOp) -> 'list[z3.ExprRef]':
with Context(TRACK_MATCH_STATS=0): # cant pickle z3 objects
return [s.arg[1] for s in graph_rewrite(uops[0].sink(*uops[1:]), z3_renderer, ctx=(solver, {})).src]
z3_imported = True
except (ImportError, AttributeError): z3_imported = False
# if you have z3 installed, by default we check the bounds
IGNORE_OOB = ContextVar("IGNORE_OOB", int(not z3_imported))
buffer_spec = PatternMatcher([
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
# allow VIEW here. TODO: what views specifically are allowed? does this mess with gradient?
(UPat(Ops.VIEW), lambda: True),
])
assign_spec = PatternMatcher([
# KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
# MSELECT chooses one of the multi buffers
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
# MSTACK combines buffers into multi
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
])
# *** this is the spec of a Tensor in UOp ***
tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)),
# naturally correct
lambda mv,x: (isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
# "make things that can't be images not images" can change the buffer dtype
# this is fine as long as it's a realized buffer and base dtypes match.
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)),
(UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}),
# Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
# Tensor const has a device and an unmasked ShapeTracker of stride 0
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
# TODO: remove after rangeify is default
(UPat(Ops.CONST, src=(UPat.any(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="st"),
UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND)), name="st")),)),
lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="root", src=(UPat.var("x"),), arg=None),
lambda root,x: root.dtype == x.dtype),
# CONTIGUOUS with a range
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat.var("x"),), allow_any_len=True, arg=None),
lambda root,x: root.dtype == x.dtype and all(u.op is Ops.RANGE for u in root.src[1:])),
# COPY/ALLREDUCE/MULTI
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), arg=None), lambda copy,x: copy.dtype == x.dtype),
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
])
# ***** uop type spec *****
def validate_index(idx:UOp, gate:UOp=UOp.const(dtypes.bool, True)):
# TODO: check for overflow
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
mask = idx.src[2]&gate if len(idx.src)==3 else gate
# WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
solver = z3.Solver(ctx=z3.Context())
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask)
solver.add(z3_mask)
with cpu_profile("validate index with z3", "TINY"):
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
print(f"idx={idx.src[1].render(simplify=False)}")
print(f"mask & gate={mask.render(simplify=False)}")
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
return False
return True
def validate_store(idx:UOp, val:UOp, gate:UOp=UOp.const(dtypes.bool, True)):
if gate.op is Ops.IF: gate = gate.src[0]
# we need to find the implicit gates, inverse of delete_redundant_gates
for u in val.toposort():
if u.op is Ops.IF: gate &= u.src[0]
return validate_index(idx, gate)
index_pat = UPat(Ops.INDEX, name="idx").or_casted()
# this is the matcher for the final rendered UOps
# matcher functions returns True or False (or None to not match)
spec = PatternMatcher([
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
(UPat(Ops.DEFINE_REG, src=()), lambda: True),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) == 2 and \
isinstance(rng.arg[0], int) and isinstance(rng.arg[1], AxisType)),
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
(UPat(Ops.VIEW, dtypes.void, src=(), name="x"), lambda x: isinstance(x.arg, ShapeTracker)),
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"),
lambda x,src: isinstance(x.arg, ShapeTracker) and src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
# early LOAD has a <bufview, store?>
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat(Ops.STORE))), lambda: True),
# early STORE has a <bufview, val>
(UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat())), lambda: True),
# **** new style load/store ****
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)),
# INDEX is used in new style load/store
# INDEX takes a <buf, alu, gate?>
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
# LOAD on STORE
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True),
# LOAD takes a <bufidx, alt?, barrier?>
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
(UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
# STORE takes a <bufidx, val, gate?>
(UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),
(UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
# and SHL/SHR, the shift distance can be an int
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
# WMMA has a <a, b, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# if has a <gate, barrier?>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
(UPat(Ops.BARRIER, dtypes.void), lambda: True), # BARRIERs can also happen at the end of loops
# NOTE: for testing, we let sinks be anything
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
(UPat(Ops.SINK, dtypes.void), lambda: True),
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
# PTX LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
])
# *** this is the UOp AST spec ***
ast_spec = PatternMatcher([
# VIEW can only exist in the edges
(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL),))), lambda: True),
(UPat(Ops.VIEW, name="view"), lambda view: len(view.src) == 0),
# all parent UOps must have the same shape
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
])
# ***** uop helpers *****
def type_verify(uops:list[UOp], extra_spec:PatternMatcher|None=None):
check_spec = (extra_spec+spec) if extra_spec is not None else spec
for i,u in enumerate(uops):
with Context(TRACK_MATCH_STATS=0): ret = check_spec.rewrite(u)
if cast(bool|None, ret) is not True:
if DEBUG >= 3: print_uops(uops)
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[(x.op, x.dtype, x.arg) for x in u.src]} {u.arg}")

528
tinygrad/uop/symbolic.py Normal file
View File

@@ -0,0 +1,528 @@
# all of symbolic lives here now
from typing import cast
import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
from tinygrad.uop.decompositions import xpow
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
def simplify_pow(x:UOp, c:UOp) -> UOp|None:
if c.arg < 0: return x.reciprocal().pow(-c)
if c.arg == 0: return x.const_like(1)
if int(c.arg-0.5)+0.5 == c.arg: return x.pow(c.const_like(c.arg-0.5)) * x.sqrt()
if int(c.arg) == c.arg: return (y := x.pow(c.const_like(c.arg//2))) * y * (x if c.arg%2 == 1 else 1)
return None
def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None
if c.dtype.itemsize != root.dtype.itemsize: return None
def convert(v:ConstType): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
invalid_pat = UPat.const(dtypes.index, Invalid).named("i")
invalid_gate = UPat.var("cond").where(UPat.var("x",dtype=dtypes.index), invalid_pat)
propagate_invalid = PatternMatcher([
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
# propagate invalid, push it past children
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i))
for op in GroupOp.Binary-GroupOp.Comparison),
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: x.alu(alu.op,y)) for op in GroupOp.Comparison),
# invalid + y -> y same for other ops
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)).named("alu"), lambda alu,i: i) for op in GroupOp.Binary-GroupOp.Comparison),
# i < y -> a_bool_value_that_will_never_be_used: we choose a random bool const
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)), lambda i: UOp.const(dtypes.bool, True)) for op in GroupOp.Comparison),
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
# order of gate&!cond matters!, and-clauses are only simplified left to right and we need to gate to be used to fold cond
(UPat.var("gate").where(invalid_gate, UPat.var("y")), lambda gate,cond,x,y,i: ((gate&cond.logical_not()).logical_not()).where(gate.where(x,y), i)),
# unswap the branches for the rule above
(UPat.var("gate").where(UPat.var("y"), invalid_gate).named("where"), lambda gate,cond,x,y,i: gate.logical_not().where(cond.where(x,i), y))
])
symbolic_simple = propagate_invalid + PatternMatcher([
# ** self folding **
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) ^ 0, lambda x: x), # x^0 -> x
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
# 4 variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x),
((UPat.var("y")+UPat.var("x")%UPat.cvar("c"))+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda y,x,c: y+x),
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"))+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"),
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)).trunc(), lambda x: x),
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# x*0 -> 0 or 0*x -> 0
# if x is nan or inf it should render the nan value.
# NOTE: this can be wrong for loaded NaN
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# ** constant folding **
# TODO: add const folding for Ops.THREEFRY
(UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))),
(UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"),
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))),
(UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"),
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))),
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
# *** cast/bitcast ***
(UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)),
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
# b.cast(a).cast(b) -> b if a preserves all values in b
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None),
# ** pow **
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
# positive const ** x
(UPat.cvar("c", vec=False).alu(Ops.POW, UPat.var("x")), lambda c,x: c if c.arg == 1 else (x*math.log2(c.arg)).exp2() if c.arg > 0 else None),
# rules for threefry
((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)&0xFFFFFFFF), # TODO: why is the and needed?
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
# hacks for threefry long removal when padded (TODO: genericize)
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
lambda x,y: y.where(x, 0).cast(dtypes.uint64) * (1<<32)),
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
lambda x,y: y.where(x.cast(dtypes.uint32), 0)),
# new decomp rules for threefry
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
(UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0))
])
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
def lt_folding(x:UOp, c:int) -> UOp|None:
p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1)
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
return None
def canonicalize_simplex(X:UOp) -> UOp|None:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
# returns x0 + x1 + ... in such case, or None if not
changed, ret = False, []
for u in X.split_uop(Ops.ADD):
# assumed the const is the last src of MUL
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
changed = True
u = u.src[0]
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
ret.append(u)
return functools.reduce(operator.add, ret) if changed else None
def cancel_divmod(d: UOp, x: UOp, y: UOp) -> UOp|None:
# simple cancel div/mod case when the range of the numerator lies within a single denominator interval
x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax
assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int)
if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}")
if y_min*y_max > 0 and (q:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max):
return x - q*y if d.op is Ops.MOD else d.const_like(q)
return None
def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None:
# remove nested mod in case the inner mod is a multiple of the outer mod
# example: (a%4 + b)%2 -> (a+b)%2
if ((c := y.arg) < 0) or x.vmin<0: return None
new_xs = []
something_changed = False
for u in x.split_uop(Ops.ADD):
if u.op is Ops.MOD:
if u.src[1].divides(c) is not None:
something_changed = True
u = u.src[0]
new_xs.append(u)
new_x: UOp = functools.reduce(operator.add, new_xs)
if something_changed and new_x.vmin>=0: return new_x % y
return None
def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we can fold if the expression has only one non-constant term and this term can only take on two values
if ((c := y.arg) < 0): return None
x,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore
return (y2-y1)*(v-v.vmin) + y1
return None
def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0): return None
x,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c!=rem.vmax//c: return None
if d.op is Ops.MOD: return rem - rem.vmin//c*c
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c
def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
# x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd)
gcd = UOp.gcd(*x.split_uop(Ops.ADD), y).simplify()
if gcd.op is Ops.CONST and gcd.arg==1: return None
ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd)))
return ret*gcd if d.op is Ops.MOD else ret
def gcd_with_remainder(d: UOp, x: UOp, y: UOp):
# (gcd*x+r)//(gcd*d) -> (x+(r%d)//gcd)//d + r//(gcd*d)
# (gcd*x+r)%(gcd*d) -> gcd*(x+(r%d)//gcd)%d + r%gcd
# These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x
if ((c := y.arg) < 0) or x.vmin<0: return None
x_no_const, const = x.pop_const()
gcd = UOp.gcd(*x_no_const.split_uop(Ops.ADD), y).simplify()
assert gcd.op is Ops.CONST
if gcd.arg==1: return None
new_x = unwrap(x_no_const.divide_exact(gcd)).simplify() + (const%c)//gcd
if new_x.vmin<0: return None
ret = new_x.alu(d.op, x.ufix(c//gcd.arg))
return ret*gcd + const%gcd.arg if d.op is Ops.MOD else ret+const//c
def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we try and nest the div and see if it allows the numerator to be simplified
if ((c := y.arg) < 0): return None
factors = [u.const_factor() for u in x.pop_const()[0].split_uop(Ops.ADD)]
# div is the smallest factor of the denominator (greater than 1) out of all "factors"
# TODO: there are better ways to pick `div`, this sometimes adds extra divisions
# TODO: add same optimization for mod
div = min([y.arg]+[abs(f) for f in factors if abs(f) > 1 and (c%f)==0])
if (1 < div < c) and (newxs:=(newx:=(x//div)).simplify()) is not newx and x.vmin>=0 and newx.vmin>=0: return newxs//(c//div)
return None
def factor_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None:
# (d*x+y)//d -> x+y//d or (d*x+y)%d
# for mod we go further and take the remainder of all factors to reduce their size
# These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x
if y.vmin<0 or x.vmin<0: return None
quo, rem = [], []
for u in x.split_uop(Ops.ADD):
if (q:=u.divide_exact(y)) is not None: quo.append(q)
# if this is mod and y is a const, we can make the remainder factor sm
elif d.op is Ops.MOD and y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c:
rem.append(u.divides(c)*(c%y.arg))
quo.append(u.const_like(0)) # we append this so we can check if something changed
else: rem.append(u)
new_x = sum(rem)+x.const_like(0)
if len(quo)==0 or new_x.vmin<0: return None
return new_x%y if d.op is Ops.MOD else new_x//y+sum(quo)
def gep_through_wmma(gep:UOp, wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
wmma_idxs = gep.arg[::out_sz]
for i in range(out_sz):
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
src_args = []
ssz = prod(x[1] for x in sz)
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
tsrcs.append(s.gep(tuple(src_args)))
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
gep_pushing = PatternMatcher([
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(Ops.GEP, name='g2').f(Ops.GEP, name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
(UPat(Ops.VECTORIZE, name='vec').f(Ops.GEP, name='gep'),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat.cvar("c", vec=False).f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.VCONST, name="c").f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# GEP on void is skipped
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
# GEP in order is removed
(UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None),
# push all GEPs through ALUs (fix arange stuff)
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) else None),
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
if not isinstance(x.dtype, PtrDType) else None),
# VECTORIZE on same GEP
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
# push some GEPs through WMMAs
(UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma),
])
commutative = PatternMatcher([
# ** COMMUTATIVE flipping (only for index) **
# NOTE: this can break merging vector math by only flipping some of them
(UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
])
symbolic = symbolic_simple+commutative+PatternMatcher([
# ** boolean algebra **
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
# TODO: make a more general or folder like simplify_valid
(UPat.var("x", dtype=dtypes.bool) | UPat.var("x").logical_not(), lambda x: x.const_like(True)), # x|!x -> True
# ** combine terms **
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
((UPat.var("y") + UPat.var("x") * UPat.cvar("c")) + UPat.var("x"), lambda x,y,c: y+x*(c+1)),
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
# a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)
if f.arg is not Invalid else None),
# alu of two where with same conds can combine, only do if true branch or false branch is const
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# if its a plus we add the associative variation too
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU/variable min==max -> CONST (slow!)
(UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# TODO: why does this rule break beautiful_mnist?
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
#((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
# ** two stage ALU folding **
*((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
# ** lt **
# c0*x<c1 for positive int c0,c1
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
# c0*x<c1 for negative int c0 and non-positive c1
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.index))<UPat.cvar("c1", vec=False),
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# x//d<c
((UPat.var("x", dtype=dtypes.index)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
lambda x,d,c: (x<(c.arg*d.arg) if c.arg > 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None),
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
# *** rules from symbolic ***
# generic lt folding
(UPat.var("x", dtypes.index)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
(UPat.var("x", dtypes.index)*-1 < UPat.var("y")*-1, lambda x,y: y<x),
# canonicalize a simplex with positive coefficients > 0
# not x < 1 -> X > 0
((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
# ** div **
# div folding
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
# a range mod its own upper bound is just the range
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r),
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod),
(UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), divide_by_gcd),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), gcd_with_remainder),
(UPat(Ops.MOD, dtypes.index, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod),
(UPat((Ops.IDIV), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), factor_remainder),
(UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax<=0 else None),
((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
lambda x,c,n,d: ((x+c.arg%d.arg)//d + c.arg//d.arg) if c.arg%d.arg!=c.arg and x.vmin>=0 and n.vmin>=0 and d.arg>0 else None),
((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
# ** mod **
# mod folding
(UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
(UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
# cast/long folding
# if the intermediate cast doesnt narrow we can do it in one cast
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_safe_cast(x.dtype, a.dtype) else None),
(UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"),
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
# try to do math in int instead of long
(UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
])+gep_pushing
symbolic_flat = symbolic+PatternMatcher([
# ** combine terms (opinionated) **
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
])
# ******** we take a small aside to "simplify_valid" to rewrite valids ********
def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
# if it's X <= c, returns X, True, c
# if it's X >= c, returns X, False, c
# (X < c).ne(True) -> X >= c
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin)
# X < c -> X <= c-1
if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1
raise ValueError(f"not able to parse {valid=}")
def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
for stmt in valid.split_uop(Ops.AND):
try: expr, is_upper, c = parse_valid(stmt)
except ValueError: continue # give up if we cannot parse the valid
bounds[expr][int(is_upper)] = c
# don't simplify any other gates, can lead to OOB, we substitute them back later
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
# simplify uop given that valid is True
for expr,v in bounds.items():
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
# some expr has lower bound > upper bound -> valid is an empty set and we return None
if v0 > v1: return None
# whole node became a const
if v0 == v1:
uop = uop.substitute({expr:expr.const_like(v0)}).simplify()
continue
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
# try checking the whole clause
if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
elif all_same(newuops): uop = newuops[0]
# put the loads back in
uop = uop.substitute({v:k for k,v in load_subs.items()})
return uop
def _valid_priority(v: UOp, valids:list[UOp]):
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
try: return sum(-1 if parse_valid(v)[0] in other.toposort() else 0 for other in valids)
except ValueError: return 0
def simplify_valid(valid:UOp) -> UOp|None:
ret:list[UOp] = []
something_changed = False
valids = list(valid.split_uop(Ops.AND))
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
# TODO: root cause this and test_simplify_valid_from_div
if stmt.op is Ops.CAST: return None
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
if ret[-1] is not stmt: something_changed = True
return functools.reduce(operator.and_, ret) if something_changed else None
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
def reduce_mul_chain(r:UOp):
if r.arg not in {Ops.ADD, Ops.MAX}: return None
if r.dtype != r.src[0].dtype: return None
inside, outside = [], []
for m in r.src[0].split_uop(Ops.MUL):
m_parents = m.toposort()
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
else: inside.append(m)
if len(outside) == 0: return None
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
# this is symbolic 2.0
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
sym = symbolic_flat+PatternMatcher([
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda cond,x,i: cond.where(newx, i) if
(newx:=uop_given_valid(cond, x)) is not x else None),
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
# VECTORIZE/CONST, VECTORIZE/GEP
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
# reorder ALU/VECTORIZE
(UPat(GroupOp.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(alu.op, alu.dtype.scalar(), (x,y)),)*alu.dtype.count)),
# VECTORIZE of a single element is just that element
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# VECTORIZE void is SINK
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
# tensor core with a 0 input is acc
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
# ** self folding **
# x!=0 -> (bool)x
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
# ** where **
# push cast to branches
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
# ** pow **
((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
# ** load/store folding **
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
lambda index, gate, alt, store: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt, *store.src[2:])),
# fold gated LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c")).or_casted(),), allow_any_len=True, name="l"), UPat.var("a")),
lambda c,idx,l,a: l.replace(src=(l.src[0], a)+l.src[1:])),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),),
allow_any_len=True, name="l")), lambda c,idx,l,a: l.replace(src=(l.src[0], a)+l.src[1:])),
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
(UPat(Ops.BARRIER, name="root"),
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)
if any(x.op in REMOVE_FROM_BARRIER for x in root.src) else None),
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_SINK else (x,) for x in root.src)), root.arg)
if any(x.op in REMOVE_FROM_SINK for x in root.src) else None),
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
# move const multiply after REDUCE (NOTE: the mul chain can do this, but only if it's a same dtype reduce)
((UPat.var("x")*UPat.cvar("c", vec=False)).reduce(arg=Ops.ADD, name="r", allow_any_len=True), lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
# reduce mul chain, move muls after the reduce
(UPat(Ops.MUL).reduce(name="r", allow_any_len=True), reduce_mul_chain),
])

163
tinygrad/uop/upat.py Normal file
View File

@@ -0,0 +1,163 @@
from typing import Any, Callable
import itertools, inspect, functools, types
from tinygrad.helpers import partition, dedup, Context
from tinygrad.uop.ops import UPat, UPatAny, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function
class UPatCompileError(Exception): pass
# **** UPat compiled ****
def _get_clause(self:UPat, base:UOp, depth=0) -> UOp:
if isinstance(self, UPatAny):
assert len(self.src) == 1
return UOp(Ops.AND, src=(UOp(Ops.OR, src=tuple(_get_clause(s, base, depth) for s in self.src[0])),))
# build the and_clause for acceptance
and_clause:list[UOp] = []
if self.op is not None:
if len(self.op) > 1: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(int(x) for x in self.op))), arg="{0}.op in {1}"))
else: and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg="{0}.op == "+str(self.op[0].value)))
if self.arg is not None:
if isinstance(self.arg, int): and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg="{0}.arg == "+str(int(self.arg))))
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.arg)), arg="{0}.arg == {1}"))
if self.strict_length or self.required_len > 0:
and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg=("len({0}.src)"+(" == " if self.strict_length else " >= ")+str(self.required_len))))
if self.name is not None: and_clause.append(UOp(Ops.STORE, src=(UOp(Ops.DEFINE_VAR, arg=self.name), base)))
if self.dtype is not None:
if len(self.dtype) > 1:
and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(self.dtype))), arg="({0}.dtype in {1} or {0}.dtype._scalar in {1})"))
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.dtype[0])), arg="({0}.dtype == {1} or {0}.dtype._scalar == {1})"))
if self.src is not None:
# single match
if len(self.src) == 1 and isinstance(self.src[0], tuple):
and_clause += [_get_clause(s, base.gep(i), depth) for i,s in enumerate(self.src[0])]
# repeat match
elif len(self.src) == 1 and isinstance(self.src[0], itertools.repeat):
it = UOp(Ops.NOOP, arg=f"ituop{depth}")
match = _get_clause(next(self.src[0]), it, depth+1)
and_clause.append(UOp(Ops.RANGE, src=(match, it, base), arg="all([{0} for {1} in {2}.src])"))
# multi match (fork)
elif len(self.src) > 1 and all(isinstance(x, tuple) for x in self.src):
fork_cond = [UOp(Ops.AND, src=tuple([_get_clause(s, base.gep(i), depth) for i,s in enumerate(ss)])) for ss in self.src]
and_clause.append(UOp(Ops.OR, src=tuple(fork_cond)))
else: raise RuntimeError("broken")
return UOp(Ops.AND, src=tuple(and_clause))
# *** pattern matcher ***
def do_process_and(a:UOp) -> UOp|None:
found = False
new_src:list[UOp] = []
or_clause:list[UOp] = []
# remove any nested ANDs, extract or clauses
for x in a.src:
if x.op is Ops.AND:
new_src.extend(x.src)
found = True
elif x.op is Ops.OR: or_clause.append(x)
else: new_src.append(x)
# too big to compile
if len(or_clause) >= 4: raise UPatCompileError("too big to compile")
# one or clause max
if len(or_clause) > 1:
# need the product of the or clauses
or_clause = [UOp(Ops.OR, src=tuple([UOp(Ops.AND, src=x) for x in itertools.product(*[x.src for x in or_clause])]))]
found = True
# handle stores
stores, new_src = partition(new_src, lambda x: x.op is Ops.STORE)
if len(stores):
if len(or_clause):
# push stores to the top if we have an or_clause
assert len(or_clause) == 1 and all(x.op is Ops.AND for x in or_clause[0].src)
or_clause = [UOp(Ops.OR, src=tuple([x.replace(src=x.src+tuple(stores)) for x in or_clause[0].src]))]
found = True
else:
# check for duplicate stores
dict_stores: dict[UOp, UOp] = {}
for a in stores:
if a.src[0] in dict_stores:
# duplicate store is a compare
new_src.append(UOp(Ops.CMPNE, src=(dict_stores[a.src[0]], a.src[1])))
found = True
else:
dict_stores[a.src[0]] = a.src[1]
# put the stores back
for k,v in dict_stores.items(): new_src.append(UOp(Ops.STORE, src=(k,v)))
# reassemble, if there's any deduping to do, do it
if len(dretand:=dedup(new_src+or_clause)) != len(new_src)+len(or_clause): found = True
return UOp(Ops.AND, src=tuple(dretand)) if found else None
# processor
pm_proc = PatternMatcher([(UPat(Ops.AND, name="a"), do_process_and)], compiled=False)
# renderer
def wrap(ctx, x) -> UOp:
ctx[ret:=f"a{len(ctx)}"] = x.arg
return UOp(Ops.NOOP, arg=ret)
pm_renderer = PatternMatcher([
(UPat(Ops.BIND, name="x"), wrap),
# CMPNE is actually equal
(UPat(Ops.CMPNE, name="x"), lambda x: UOp(Ops.CUSTOM, src=x.src, arg="{0} is {1}")),
# RANGE can't have OR inside it
(UPat(Ops.RANGE, src=(UPat(Ops.AND, src=UPat(Ops.NOOP), name="x"), UPat(), UPat()), name="r"),
lambda r,x: r.replace(op=Ops.CUSTOM, src=(UOp(Ops.NOOP, arg="(" + ' and '.join(y.arg for y in x.src) + ")"),)+r.src[1:])),
(UPat(Ops.CUSTOM, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg.format(*[y.arg for y in x.src]))),
(UPat(Ops.GEP, src=UPat(Ops.NOOP, name="x"), name="g"), lambda x,g: x.replace(arg=x.arg+f".src[{g.arg[0]}]"))
], compiled=False)
def _final_render(x:UOp, has_ctx:bool, depth=1) -> list[str]:
assert x.op is Ops.AND
and_pieces, store_pieces = [], []
or_pieces: list[str] = []
for s in x.src:
if s.op is Ops.OR:
assert len(or_pieces) == 0 and len(s.src) >= 1
for ss in s.src: or_pieces.extend(_final_render(ss, has_ctx, depth+1))
elif s.op is Ops.STORE:
assert s.src[0].op is Ops.DEFINE_VAR and s.src[1].op is Ops.NOOP
store_pieces.append(f"{s.src[0].arg}={s.src[1].arg}")
elif s.op is Ops.NOOP: and_pieces.append(s.arg)
else: raise UPatCompileError(f"can't compile this {s}")
# if we have an or, render it
if len(or_pieces):
assert len(store_pieces) == 0
and_clause = ' and '.join(and_pieces)
return [f"{' '*depth}if {and_clause if len(and_clause) else 'True'}:"] + or_pieces
# if we don't, this is a final return
store_clause = ', '.join((["ctx=ctx"] if has_ctx else [])+store_pieces)
and_clause = ' and '.join(and_pieces + [f"(_ret:=_fxn({store_clause})) is not None"])
return [f"{' '*depth}if {and_clause}: return _ret"]
def _get_code(self:UPat, has_ctx:bool):
ret = _get_clause(self, UOp(Ops.NOOP, arg="uop"))
try:
# TODO: this should be tracked in a "system" rewrite, not untracked or tracked with kernel
with Context(TRACK_MATCH_STATS=0):
ret = graph_rewrite(ret, pm_proc, name="process UPat")
dyn_lookup: dict[str, Any] = {}
out = graph_rewrite(ret, pm_renderer, ctx=dyn_lookup, name="compile UPat")
rendered = _final_render(out, has_ctx)
except UPatCompileError:
#print("FAILED", self, self.location)
return None
return '\n'.join([f"# match for {self.location}", "def compiled_match(uop, ctx):"] + rendered + [" return None"]), dyn_lookup
@functools.cache
def upat_compile(self:UPat, fxn) -> Callable|None:
real_fxn = types.FunctionType(*deconstruct_function(fxn))
code = _get_code(self, 'ctx' in inspect.signature(real_fxn).parameters)
if code is None: return None
code_str, dyn_lookup = code
globs = dyn_lookup.copy()
globs["_fxn"] = real_fxn
namespace: dict = {}
exec(code_str, globs, namespace) # pylint: disable=W0122
return namespace["compiled_match"]