Release 260111

This commit is contained in:
Comma Device
2026-01-11 18:23:29 +08:00
commit 3721ecbf8a
2601 changed files with 855070 additions and 0 deletions

351
tinygrad/nn/__init__.py Normal file
View File

@@ -0,0 +1,351 @@
from __future__ import annotations
import math
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import prod, make_tuple, flatten
from tinygrad.nn import optim, state, datasets # noqa: F401
class BatchNorm:
"""
Applies Batch Normalization over a 2D or 3D input.
- Paper: https://arxiv.org/abs/1502.03167v3
See: `Tensor.batchnorm`
```python exec="true" session="tensor"
from tinygrad import Tensor, dtypes, nn
import numpy as np
np.set_printoptions(precision=4)
```
```python exec="true" source="above" session="tensor" result="python"
norm = nn.BatchNorm(3)
t = Tensor.rand(2, 3, 4, 4)
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = norm(t)
print(t.mean().item(), t.std().item())
```
"""
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
self.weight: Tensor|None = Tensor.ones(sz) if affine else None
self.bias: Tensor|None = Tensor.zeros(sz) if affine else None
self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
def calc_stats(self, x:Tensor) -> tuple[Tensor, Tensor]:
shape_mask: list[int] = [1, -1, *([1]*(x.ndim-2))]
if self.track_running_stats and not Tensor.training: return self.running_mean, self.running_var.reshape(shape=shape_mask).expand(x.shape)
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
batch_mean = x.mean(axis=(reduce_axes:=tuple(x for x in range(x.ndim) if x != 1)))
y = (x - batch_mean.detach().reshape(shape=shape_mask)) # d(var)/d(mean) = 0
batch_var = (y*y).mean(axis=reduce_axes)
return batch_mean, batch_var
def __call__(self, x:Tensor) -> Tensor:
batch_mean, batch_var = self.calc_stats(x)
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats and Tensor.training:
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * x.numel()/(x.numel()-x.shape[1]) * batch_var.detach())
self.num_batches_tracked += 1
return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
BatchNorm2d = BatchNorm3d = BatchNorm
def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:int|str=0, dilation=1, groups=1, bias=True) -> Conv2d:
"""
Applies a 1D convolution over an input signal composed of several input planes.
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d
```python exec="true" source="above" session="tensor" result="python"
conv = nn.Conv1d(1, 1, 3)
t = Tensor.rand(1, 1, 4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = conv(t)
print(t.numpy())
```
"""
return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
class Conv2d:
"""
Applies a 2D convolution over an input signal composed of several input planes.
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d
```python exec="true" source="above" session="tensor" result="python"
conv = nn.Conv2d(1, 1, 3)
t = Tensor.rand(1, 1, 4, 4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = conv(t)
print(t.numpy())
```
"""
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride=1, padding:int|tuple[int, ...]|str=0,
dilation=1, groups=1, bias=True):
self.kernel_size = make_tuple(kernel_size, 2)
if isinstance(padding, str):
if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
pad = [(d*(k-1)//2, d*(k-1) - d*(k-1)//2) for d,k in zip(make_tuple(dilation, len(self.kernel_size)), self.kernel_size[::-1])]
padding = tuple(flatten(pad))
self.stride, self.dilation, self.groups, self.padding = stride, dilation, groups, padding
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
self.bias: Tensor|None = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
def __call__(self, x:Tensor) -> Tensor: return x.conv2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding)
def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
groups=1, bias=True) -> ConvTranspose2d:
"""
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d
```python exec="true" source="above" session="tensor" result="python"
conv = nn.ConvTranspose1d(1, 1, 3)
t = Tensor.rand(1, 1, 4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = conv(t)
print(t.numpy())
```
"""
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
class ConvTranspose2d(Conv2d):
"""
Applies a 2D transposed convolution operator over an input image.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d
```python exec="true" source="above" session="tensor" result="python"
conv = nn.ConvTranspose2d(1, 1, 3)
t = Tensor.rand(1, 1, 4, 4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = conv(t)
print(t.numpy())
```
"""
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride=1, padding=0, output_padding=0,
dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
self.output_padding = output_padding
def __call__(self, x:Tensor) -> Tensor:
return x.conv_transpose2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding, self.output_padding)
class Linear:
"""
Applies a linear transformation to the incoming data.
See: https://pytorch.org/docs/stable/generated/torch.nn.Linear
```python exec="true" source="above" session="tensor" result="python"
lin = nn.Linear(3, 4)
t = Tensor.rand(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = lin(t)
print(t.numpy())
```
"""
def __init__(self, in_features:int, out_features:int, bias=True):
bound = 1 / math.sqrt(in_features)
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
def __call__(self, x:Tensor) -> Tensor: return x.linear(self.weight.transpose(), self.bias)
class GroupNorm:
"""
Applies Group Normalization over a mini-batch of inputs.
- Paper: https://arxiv.org/abs/1803.08494v3
```python exec="true" source="above" session="tensor" result="python"
norm = nn.GroupNorm(2, 12)
t = Tensor.rand(2, 12, 4, 4) * 2 + 1
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = norm(t)
print(t.mean().item(), t.std().item())
```
"""
def __init__(self, num_groups:int, num_channels:int, eps=1e-5, affine=True):
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
self.weight: Tensor|None = Tensor.ones(num_channels) if affine else None
self.bias: Tensor|None = Tensor.zeros(num_channels) if affine else None
def __call__(self, x:Tensor) -> Tensor:
# reshape for layernorm to work as group norm
# subtract mean and divide stddev
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
# elementwise_affine on channels
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
class InstanceNorm:
"""
Applies Instance Normalization over a mini-batch of inputs.
- Paper: https://arxiv.org/abs/1607.08022v3
```python exec="true" source="above" session="tensor" result="python"
norm = nn.InstanceNorm(3)
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = norm(t)
print(t.mean().item(), t.std().item())
```
"""
def __init__(self, num_features:int, eps=1e-5, affine=True):
self.num_features, self.eps = num_features, eps
self.weight: Tensor|None = Tensor.ones(num_features) if affine else None
self.bias: Tensor|None = Tensor.zeros(num_features) if affine else None
def __call__(self, x:Tensor) -> Tensor:
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
class LayerNorm:
"""
Applies Layer Normalization over a mini-batch of inputs.
- Paper: https://arxiv.org/abs/1607.06450v1
```python exec="true" source="above" session="tensor" result="python"
norm = nn.LayerNorm(3)
t = Tensor.rand(2, 5, 3) * 2 + 1
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = norm(t)
print(t.mean().item(), t.std().item())
```
"""
def __init__(self, normalized_shape:int|tuple[int, ...], eps=1e-5, elementwise_affine=True):
self.normalized_shape: tuple[int, ...] = make_tuple(normalized_shape, 1)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
self.weight: Tensor|None = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
self.bias: Tensor|None = Tensor.zeros(*self.normalized_shape) if elementwise_affine else None
def __call__(self, x:Tensor) -> Tensor:
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
x = x.layernorm(eps=self.eps, axis=self.axis)
if not self.elementwise_affine: return x
return x * self.weight + self.bias
class LayerNorm2d(LayerNorm):
"""
Applies Layer Normalization over a mini-batch of 2D inputs.
See: `LayerNorm`
```python exec="true" source="above" session="tensor" result="python"
norm = nn.LayerNorm2d(3)
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = norm(t)
print(t.mean().item(), t.std().item())
```
"""
def __call__(self, x: Tensor) -> Tensor: return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class RMSNorm:
"""
Applies Root Mean Square Normalization to input.
- Paper: https://arxiv.org/abs/1910.07467
```python exec="true" source="above" session="tensor" result="python"
norm = nn.RMSNorm(4)
t = Tensor.arange(12, dtype=dtypes.float).reshape(3, 4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(norm(t).numpy())
```
"""
def __init__(self, dim:int, eps=1e-6, elementwise_affine=True):
self.eps = eps
self.weight = Tensor.ones(dim) if elementwise_affine else None
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
def __call__(self, x:Tensor) -> Tensor:
x = self._norm(x.float()).cast(x.dtype)
return x if self.weight is None else x * self.weight
class Embedding:
"""
A simple lookup table that stores embeddings of a fixed dictionary and size.
See: https://pytorch.org/docs/stable/generated/torch.nn.Embedding
```python exec="true" source="above" session="tensor" result="python"
emb = nn.Embedding(10, 3)
print(emb(Tensor([1, 2, 3, 1])).numpy())
```
"""
def __init__(self, vocab_size:int, embed_size:int):
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
class LSTMCell:
"""
A long short-term memory (LSTM) cell.
Args:
input_size: The number of expected features in the input `x`
hidden_size: The number of features in the hidden state `h`
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`
"""
def __init__(self, input_size:int, hidden_size:int, bias:bool=True):
stdv = 1.0 / math.sqrt(hidden_size)
self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
self.bias_ih: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
self.bias_hh: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
def __call__(self, x:Tensor, hc:tuple[Tensor, Tensor]|None=None) -> tuple[Tensor, Tensor]:
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
i, f, g, o = gates.chunk(4, dim=1)
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
new_c = f * hc[1] + i * g
new_h = o * new_c.tanh()
return (new_h.contiguous(), new_c.contiguous())

14
tinygrad/nn/datasets.py Normal file
View File

@@ -0,0 +1,14 @@
from tinygrad.tensor import Tensor
from tinygrad.nn.state import tar_extract
def mnist(device=None, fashion=False):
base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/"
def _mnist(file): return Tensor.from_url(base_url+file, gunzip=True)
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
def cifar(device=None):
tt = tar_extract(Tensor.from_url('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]

177
tinygrad/nn/optim.py Normal file
View File

@@ -0,0 +1,177 @@
# sorted in order of increasing complexity
import itertools
from tinygrad.helpers import dedup, flatten, getenv, unwrap, FUSE_OPTIM
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, least_upper_dtype
class Optimizer:
"""
Base class for all optimizers.
"""
def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM):
# if it's None, but being put into an optimizer, set it to True
for x in params:
if x.requires_grad is None: x.requires_grad = True
self.params: list[Tensor] = dedup([x for x in params if x.requires_grad])
assert len(self.params) != 0, "optimizer must have at least one param"
self.device = self.params[0].device
self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
self.fused = fused
# store lr in at least float32 precision
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
def _new_optim_param(self) -> list[Tensor]:
param_dtype = getenv("OPTIM_DTYPE", "float32")
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
return [Tensor.zeros(*t.shape, dtype=param_dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
def zero_grad(self):
"""
Zeroes the gradients of all the parameters.
"""
for param in self.params: param.grad = None
def step(self):
"""
Performs a single optimization step.
"""
Tensor.realize(*self.schedule_step())
def schedule_step(self) -> list[Tensor]:
"""
Returns the tensors that need to be realized to perform a single optimization step.
"""
if not Tensor.training: raise RuntimeError(
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
if self.fused:
# optimizer fusion just concatenates all the buffers, runs the _step, then splits them back up
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
[Tensor.cat(*[unwrap(t.grad).flatten() for t in self.params], dim=0)])
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
else:
updated_params, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
for i, tt in enumerate(self.params): tt.assign(updated_params[i])
return extra+self.params+self.buffers
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: raise NotImplementedError
class OptimizerGroup(Optimizer):
"""
Combines multiple optimizers into one.
"""
def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called
self.optimizers = optimizers
self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
def __getitem__(self, i): return self.optimizers[i]
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()]
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 it's just standard SGD.
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, fused=FUSE_OPTIM):
"""
Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
`classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
"""
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, fused=fused)
# Muon applies the newton schulz algorithm on gradient. also can include momentum, nesterov, and weight decay
def Muon(params: list[Tensor], lr=0.02, momentum=0.95, weight_decay=0.0, ns_steps=5, ns_params=(3.4445, -4.775, 2.0315),
nesterov=True, fused=FUSE_OPTIM):
"""
SGD with newton-schulz iteration and post momentum weight decay.
- Described: https://kellerjordan.github.io/posts/muon/
- Paper: https://arxiv.org/pdf/2502.16982
"""
assert not fused, "FUSE_OPTIM not allowed for Muon optimizer"
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_params, nesterov, classic=False, pre_wd=False, tcoef=0.0, fused=fused)
class LARS(Optimizer):
"""
Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.
- Paper: https://arxiv.org/abs/1708.03888v3
"""
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, ns_steps=0, ns_params=None,
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, fused=FUSE_OPTIM):
super().__init__(params, lr, fused)
self.momentum, self.wd, self.ns_steps, self.ns_params = momentum, weight_decay, ns_steps, ns_params
self.nesterov, self.classic, self.pre_wd, self.tcoef = nesterov, classic, pre_wd, tcoef
self.b = self._new_optim_param() if self.momentum else []
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
ret = []
for i, (t, g) in enumerate(zip(params, grads)):
if self.tcoef != 0:
r1 = t.detach().square().sum().sqrt()
r2 = g.square().sum().sqrt()
r:Tensor|float = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
else: r = 1.0
if self.pre_wd and self.wd > 0: g = g + self.wd * t.detach()
# classic momentum does post learning rate update
if self.classic: g = g * r * self.lr
if self.momentum:
# TODO: this contiguous is required for correctness because self.b[i] becomes a non contiguous view
# the scheduler should detect this and just insert contiguous
self.b[i].assign(self.momentum * self.b[i].contiguous() + g) # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
if self.ns_params: g = g.reshape(g.shape[0], -1).newton_schulz(self.ns_steps, self.ns_params).reshape(g.shape)
# muon does post momentum weight decay
if not self.pre_wd and self.wd > 0: t = t.detach() * (1.0 - self.wd * self.lr)
# popular momentum does pre learning rate update
if not self.classic: g = g * r * self.lr
ret.append((t.detach() - g).cast(t.dtype))
return ret, self.b
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 it's just Adam/W.
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, fused=FUSE_OPTIM):
"""
AdamW optimizer with optional weight decay.
- Paper: https://arxiv.org/abs/1711.05101v3
"""
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, fused=fused)
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, fused=FUSE_OPTIM):
"""
Adam optimizer.
- Paper: https://arxiv.org/abs/1412.6980
"""
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, fused=fused)
class LAMB(Optimizer):
"""
LAMB optimizer with optional weight decay.
- Paper: https://arxiv.org/abs/1904.00962
"""
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, fused=FUSE_OPTIM):
super().__init__(params, lr, fused)
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
self.m = self._new_optim_param()
self.v = self._new_optim_param()
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
ret = []
self.b1_t *= self.b1
self.b2_t *= self.b2
for i, (t, g) in enumerate(zip(params, grads)):
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
m_hat = self.m[i] / (1.0 - self.b1_t)
v_hat = self.v[i] / (1.0 - self.b2_t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
if not self.adam:
r1 = t.detach().square().sum().sqrt()
r2 = up.square().sum().sqrt()
r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
else:
r = 1.0
ret.append((t.detach() - self.lr * r * up).cast(t.dtype))
return ret, [self.b1_t, self.b2_t] + self.m + self.v

351
tinygrad/nn/state.py Normal file
View File

@@ -0,0 +1,351 @@
import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
from collections import OrderedDict
from typing import Any, Callable, BinaryIO, Iterable
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
from tinygrad.shape.view import strides_for_shape
class TensorIO(io.RawIOBase, BinaryIO):
def __init__(self, t: Tensor):
if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
self._position, self._tensor = 0, t
def readable(self) -> bool: return True
def read(self, size: int = -1) -> bytes:
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never)
return buf
def readinto(self, buffer: Any) -> int:
data = self._tensor[self._position:self._position+len(buffer)].data()
buffer[:len(data)] = data
self._position += len(data)
return len(data)
def seekable(self) -> bool: return True
def seek(self, offset: int, whence: int = 0) -> int:
self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence]))
return self._position
# required to correctly implement BinaryIO
def __enter__(self): return self
def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported")
def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported")
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Tensor|str|pathlib.Path], T]:
@functools.wraps(func)
def wrapper(fn: Tensor|str|pathlib.Path) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
return wrapper
@accept_filename
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
"""
Loads a .safetensor file, returning the source tensor, data start position, and metadata.
"""
data_start = int.from_bytes(t[0:8].data(), "little") + 8
return t, data_start, json.loads(t[8:data_start].data().tobytes())
def safe_load(fn:Tensor|str|pathlib.Path) -> dict[str, Tensor]:
"""
Loads a .safetensor file, returning the `state_dict`.
```python
state_dict = nn.state.safe_load("test.safetensor")
```
"""
t, data_start, metadata = safe_load_metadata(fn)
data = t[data_start:]
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
for k, v in metadata.items() if k != "__metadata__" }
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=None):
"""
Saves a `state_dict` to disk in a .safetensor file with optional metadata.
```python
t = Tensor([1, 2, 3])
nn.state.safe_save({'t':t}, "test.safetensor")
```
"""
headers, offset = {}, 0
if metadata: headers['__metadata__'] = metadata
for k,v in tensors.items():
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
offset += v.nbytes()
j = json.dumps(headers, separators=(',', ':'))
j += "\x20"*(round_up(len(j),8)-len(j))
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:8].bitcast(dtypes.int64).assign([len(j)])
t[8:8+len(j)].assign(list(j.encode('utf-8')))
for k,v in safe_load(t).items(): v.assign(tensors[k])
# state dict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
"""
Returns a `state_dict` of the object, with optional prefix.
```python exec="true" source="above" session="tensor" result="python"
class Net:
def __init__(self):
self.l1 = nn.Linear(4, 5)
self.l2 = nn.Linear(5, 6)
net = Net()
print(nn.state.get_state_dict(net).keys())
```
"""
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> list[Tensor]:
"""
```python exec="true" source="above" session="tensor" result="python"
class Net:
def __init__(self):
self.l1 = nn.Linear(4, 5)
self.l2 = nn.Linear(5, 6)
net = Net()
print(len(nn.state.get_parameters(net)))
```
"""
return list(get_state_dict(obj).values())
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> list[Tensor]:
"""
Loads a `state_dict` into a model. Return the loaded Tensors.
```python
class Net:
def __init__(self):
self.l1 = nn.Linear(4, 5)
self.l2 = nn.Linear(5, 6)
net = Net()
state_dict = nn.state.get_state_dict(net)
nn.state.load_state_dict(net, state_dict)
```
"""
start_mem_used = GlobalCounters.mem_used
ret = []
with Timing("loaded weights in ",
lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose):
model_state_dict = get_state_dict(model)
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
if k not in state_dict and not strict:
if DEBUG >= 1: print(f"WARNING: not loading {k}")
continue
if v.shape != state_dict[k].shape:
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
if isinstance(v.device, tuple):
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
else: v.replace(state_dict[k].shard(v.device, v.uop.axis))
else: v.replace(state_dict[k].to(v.device))
if realize: v.realize()
if consume: del state_dict[k]
ret.append(v)
return ret
@accept_filename
def tar_extract(t: Tensor) -> dict[str, Tensor]:
"""
```python
tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
```
Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
```python
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
```
"""
with tarfile.open(fileobj=TensorIO(t), mode="r") as tar:
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
# torch support!
@accept_filename
def torch_load(t:Tensor) -> dict[str, Tensor]:
"""
```python
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
```
Loads a torch .pth file, returning the `state_dict`.
```python
state_dict = nn.state.torch_load("test.pth")
```
"""
offsets: dict[str|int, int] = {}
lens: dict[str|int, int] = {}
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
assert storage[1] != dtypes.bfloat16, "can't permute BF16"
# TODO: find a nice way to support all shapetracker on disktensors
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
return ret.reshape(size)
class Parameter:
def __setstate__(self, state): self.tensor = state[0]
deserialized_objects: dict[str, Any] = {}
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
class Dummy: pass
class TorchPickle(pickle.Unpickler):
def find_class(self, module, name):
module_root = module.split(".")[0]
if module_root not in whitelist:
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
return Dummy
return intercept[name] if module_root == "torch" else super().find_class(module, name)
def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
fobj = io.BufferedReader(TensorIO(t))
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
myzip = zipfile.ZipFile(fobj, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
for n in myzip.namelist():
if n.startswith(f'{base_name}/data/'):
with myzip.open(n) as myfile:
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
with myzip.open(f'{base_name}/data.pkl') as myfile:
return TorchPickle(myfile).load()
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
with tarfile.open(fileobj=fobj, mode="r") as tar:
storages_offset = tar.getmember('storages').offset_data
f = unwrap(tar.extractfile('storages'))
for i in range(TorchPickle(f).load()): # num_storages
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
offsets[key] = storages_offset + f.tell()
f.seek(sz*storage_type.itemsize, 1)
f = unwrap(tar.extractfile('tensors'))
for _ in range(TorchPickle(f).load()): # num_tensors
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset = struct.unpack('<q', f.read(8))[0]
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
else:
pkl = TorchPickle(fobj)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
for i in ids:
offsets[i] = base_offset + 8
base_offset += 8 + lens[i]
fobj.seek(rwd)
return TorchPickle(fobj).load()
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
"""
Converts ggml tensor data to a tinygrad tensor.
Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14), MXFP4 (id: 39)
"""
# https://github.com/ggerganov/ggml/blob/323951f1bdcdfbd5b5ff3a9a7c3770e63b1a560e/include/ggml.h#L356
# native types
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
return t[:dtype.itemsize * n].bitcast(dtype)
def q_to_uint8(t: Tensor, b: int) -> Tensor:
# TODO: rewrite with arange?
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
# map to (number of elements, number of bytes)
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34), 39: (32, 17) }.get(ggml_type)) is not None:
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 3:
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
if ggml_type == 39:
e_int = blocks[:, 0].cast(dtypes.int32)
d = ((e_int >= 2).cast(dtypes.float32) * (e_int.cast(dtypes.float32) - 128).exp2() +
(e_int == 1).cast(dtypes.float32) * 2.0**(-127) +
(e_int == 0).cast(dtypes.float32) * 2.0**(-128)).unsqueeze(-1)
codes = q_to_uint8(blocks[:, 1:17], 4)
sign = 1.0 - codes.rshift(3).cast(dtypes.float32) * 2.0
exp, mant = codes.rshift(1).bitwise_and(0x3).cast(dtypes.float32), codes.bitwise_and(0x1).cast(dtypes.float32)
fp4_val = sign * ((exp != 0).cast(dtypes.float32) * (1.0 + 0.5 * mant) * (exp - 1.0).exp2() +
(exp == 0).cast(dtypes.float32) * 0.5 * mant)
return (fp4_val * d).flatten(-2)[:n]
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
@accept_filename
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
"""
Loads a .gguf file, returning the `kv_data` and `state_dict`.
```python
gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
```
NOTE: The provided tensor must be on a device that supports execution.
"""
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
def read_str(): return str(reader.read(read_uint64()), "utf-8")
def read_arr():
reader, n = readers[read_int32()], read_uint64()
return [ reader() for _ in range(n) ]
readers: dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \
[ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64()
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
for _ in range(n_kv):
k, typ = read_str(), read_int32()
kv_data[k] = readers[typ]()
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
alignment, pos = kv_data.get("general.alignment", 32), reader.tell()
data_start = round_up(pos, alignment)
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
return kv_data, state_dict