Release 260111
This commit is contained in:
351
tinygrad/nn/__init__.py
Normal file
351
tinygrad/nn/__init__.py
Normal file
@@ -0,0 +1,351 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import prod, make_tuple, flatten
|
||||
from tinygrad.nn import optim, state, datasets # noqa: F401
|
||||
|
||||
class BatchNorm:
|
||||
"""
|
||||
Applies Batch Normalization over a 2D or 3D input.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1502.03167v3
|
||||
|
||||
See: `Tensor.batchnorm`
|
||||
|
||||
```python exec="true" session="tensor"
|
||||
from tinygrad import Tensor, dtypes, nn
|
||||
import numpy as np
|
||||
np.set_printoptions(precision=4)
|
||||
```
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
norm = nn.BatchNorm(3)
|
||||
t = Tensor.rand(2, 3, 4, 4)
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = norm(t)
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
|
||||
self.weight: Tensor|None = Tensor.ones(sz) if affine else None
|
||||
self.bias: Tensor|None = Tensor.zeros(sz) if affine else None
|
||||
|
||||
self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
|
||||
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
||||
|
||||
def calc_stats(self, x:Tensor) -> tuple[Tensor, Tensor]:
|
||||
shape_mask: list[int] = [1, -1, *([1]*(x.ndim-2))]
|
||||
if self.track_running_stats and not Tensor.training: return self.running_mean, self.running_var.reshape(shape=shape_mask).expand(x.shape)
|
||||
# This requires two full memory accesses to x
|
||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||
batch_mean = x.mean(axis=(reduce_axes:=tuple(x for x in range(x.ndim) if x != 1)))
|
||||
y = (x - batch_mean.detach().reshape(shape=shape_mask)) # d(var)/d(mean) = 0
|
||||
batch_var = (y*y).mean(axis=reduce_axes)
|
||||
return batch_mean, batch_var
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
batch_mean, batch_var = self.calc_stats(x)
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats and Tensor.training:
|
||||
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
|
||||
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * x.numel()/(x.numel()-x.shape[1]) * batch_var.detach())
|
||||
self.num_batches_tracked += 1
|
||||
return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
|
||||
BatchNorm2d = BatchNorm3d = BatchNorm
|
||||
|
||||
def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:int|str=0, dilation=1, groups=1, bias=True) -> Conv2d:
|
||||
"""
|
||||
Applies a 1D convolution over an input signal composed of several input planes.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
conv = nn.Conv1d(1, 1, 3)
|
||||
t = Tensor.rand(1, 1, 4)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = conv(t)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
|
||||
|
||||
class Conv2d:
|
||||
"""
|
||||
Applies a 2D convolution over an input signal composed of several input planes.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
conv = nn.Conv2d(1, 1, 3)
|
||||
t = Tensor.rand(1, 1, 4, 4)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = conv(t)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride=1, padding:int|tuple[int, ...]|str=0,
|
||||
dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = make_tuple(kernel_size, 2)
|
||||
if isinstance(padding, str):
|
||||
if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
|
||||
if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
|
||||
pad = [(d*(k-1)//2, d*(k-1) - d*(k-1)//2) for d,k in zip(make_tuple(dilation, len(self.kernel_size)), self.kernel_size[::-1])]
|
||||
padding = tuple(flatten(pad))
|
||||
self.stride, self.dilation, self.groups, self.padding = stride, dilation, groups, padding
|
||||
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
||||
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
||||
self.bias: Tensor|None = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.conv2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding)
|
||||
|
||||
def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
|
||||
groups=1, bias=True) -> ConvTranspose2d:
|
||||
"""
|
||||
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
conv = nn.ConvTranspose1d(1, 1, 3)
|
||||
t = Tensor.rand(1, 1, 4)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = conv(t)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
|
||||
|
||||
class ConvTranspose2d(Conv2d):
|
||||
"""
|
||||
Applies a 2D transposed convolution operator over an input image.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
conv = nn.ConvTranspose2d(1, 1, 3)
|
||||
t = Tensor.rand(1, 1, 4, 4)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = conv(t)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride=1, padding=0, output_padding=0,
|
||||
dilation=1, groups=1, bias=True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
||||
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
||||
self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
||||
self.output_padding = output_padding
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv_transpose2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding, self.output_padding)
|
||||
|
||||
class Linear:
|
||||
"""
|
||||
Applies a linear transformation to the incoming data.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Linear
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
lin = nn.Linear(3, 4)
|
||||
t = Tensor.rand(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = lin(t)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, in_features:int, out_features:int, bias=True):
|
||||
bound = 1 / math.sqrt(in_features)
|
||||
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
|
||||
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.linear(self.weight.transpose(), self.bias)
|
||||
|
||||
class GroupNorm:
|
||||
"""
|
||||
Applies Group Normalization over a mini-batch of inputs.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1803.08494v3
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
norm = nn.GroupNorm(2, 12)
|
||||
t = Tensor.rand(2, 12, 4, 4) * 2 + 1
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = norm(t)
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __init__(self, num_groups:int, num_channels:int, eps=1e-5, affine=True):
|
||||
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
||||
self.weight: Tensor|None = Tensor.ones(num_channels) if affine else None
|
||||
self.bias: Tensor|None = Tensor.zeros(num_channels) if affine else None
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
# reshape for layernorm to work as group norm
|
||||
# subtract mean and divide stddev
|
||||
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
|
||||
if self.weight is None or self.bias is None: return x
|
||||
# elementwise_affine on channels
|
||||
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
|
||||
|
||||
class InstanceNorm:
|
||||
"""
|
||||
Applies Instance Normalization over a mini-batch of inputs.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1607.08022v3
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
norm = nn.InstanceNorm(3)
|
||||
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = norm(t)
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __init__(self, num_features:int, eps=1e-5, affine=True):
|
||||
self.num_features, self.eps = num_features, eps
|
||||
self.weight: Tensor|None = Tensor.ones(num_features) if affine else None
|
||||
self.bias: Tensor|None = Tensor.zeros(num_features) if affine else None
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
if self.weight is None or self.bias is None: return x
|
||||
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
|
||||
|
||||
class LayerNorm:
|
||||
"""
|
||||
Applies Layer Normalization over a mini-batch of inputs.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1607.06450v1
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
norm = nn.LayerNorm(3)
|
||||
t = Tensor.rand(2, 5, 3) * 2 + 1
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = norm(t)
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __init__(self, normalized_shape:int|tuple[int, ...], eps=1e-5, elementwise_affine=True):
|
||||
self.normalized_shape: tuple[int, ...] = make_tuple(normalized_shape, 1)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
||||
self.weight: Tensor|None = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
|
||||
self.bias: Tensor|None = Tensor.zeros(*self.normalized_shape) if elementwise_affine else None
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
||||
x = x.layernorm(eps=self.eps, axis=self.axis)
|
||||
if not self.elementwise_affine: return x
|
||||
return x * self.weight + self.bias
|
||||
|
||||
class LayerNorm2d(LayerNorm):
|
||||
"""
|
||||
Applies Layer Normalization over a mini-batch of 2D inputs.
|
||||
|
||||
See: `LayerNorm`
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
norm = nn.LayerNorm2d(3)
|
||||
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = norm(t)
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __call__(self, x: Tensor) -> Tensor: return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
class RMSNorm:
|
||||
"""
|
||||
Applies Root Mean Square Normalization to input.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1910.07467
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
norm = nn.RMSNorm(4)
|
||||
t = Tensor.arange(12, dtype=dtypes.float).reshape(3, 4)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(norm(t).numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, dim:int, eps=1e-6, elementwise_affine=True):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.ones(dim) if elementwise_affine else None
|
||||
|
||||
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self._norm(x.float()).cast(x.dtype)
|
||||
return x if self.weight is None else x * self.weight
|
||||
|
||||
class Embedding:
|
||||
"""
|
||||
A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Embedding
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
emb = nn.Embedding(10, 3)
|
||||
print(emb(Tensor([1, 2, 3, 1])).numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, vocab_size:int, embed_size:int):
|
||||
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
|
||||
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
|
||||
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
|
||||
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
||||
|
||||
class LSTMCell:
|
||||
"""
|
||||
A long short-term memory (LSTM) cell.
|
||||
|
||||
Args:
|
||||
input_size: The number of expected features in the input `x`
|
||||
hidden_size: The number of features in the hidden state `h`
|
||||
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`
|
||||
"""
|
||||
def __init__(self, input_size:int, hidden_size:int, bias:bool=True):
|
||||
stdv = 1.0 / math.sqrt(hidden_size)
|
||||
self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
|
||||
self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
|
||||
self.bias_ih: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
|
||||
self.bias_hh: Tensor|None = Tensor.zeros(hidden_size*4) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor, hc:tuple[Tensor, Tensor]|None=None) -> tuple[Tensor, Tensor]:
|
||||
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
|
||||
gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
|
||||
i, f, g, o = gates.chunk(4, dim=1)
|
||||
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
||||
new_c = f * hc[1] + i * g
|
||||
new_h = o * new_c.tanh()
|
||||
return (new_h.contiguous(), new_c.contiguous())
|
||||
14
tinygrad/nn/datasets.py
Normal file
14
tinygrad/nn/datasets.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import tar_extract
|
||||
|
||||
def mnist(device=None, fashion=False):
|
||||
base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
||||
def _mnist(file): return Tensor.from_url(base_url+file, gunzip=True)
|
||||
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
|
||||
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
|
||||
|
||||
def cifar(device=None):
|
||||
tt = tar_extract(Tensor.from_url('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
|
||||
train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
|
||||
test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
|
||||
return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]
|
||||
177
tinygrad/nn/optim.py
Normal file
177
tinygrad/nn/optim.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# sorted in order of increasing complexity
|
||||
import itertools
|
||||
from tinygrad.helpers import dedup, flatten, getenv, unwrap, FUSE_OPTIM
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes, least_upper_dtype
|
||||
|
||||
class Optimizer:
|
||||
"""
|
||||
Base class for all optimizers.
|
||||
"""
|
||||
def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM):
|
||||
# if it's None, but being put into an optimizer, set it to True
|
||||
for x in params:
|
||||
if x.requires_grad is None: x.requires_grad = True
|
||||
|
||||
self.params: list[Tensor] = dedup([x for x in params if x.requires_grad])
|
||||
assert len(self.params) != 0, "optimizer must have at least one param"
|
||||
self.device = self.params[0].device
|
||||
self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
||||
self.fused = fused
|
||||
# store lr in at least float32 precision
|
||||
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
|
||||
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
|
||||
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
|
||||
|
||||
def _new_optim_param(self) -> list[Tensor]:
|
||||
param_dtype = getenv("OPTIM_DTYPE", "float32")
|
||||
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
|
||||
return [Tensor.zeros(*t.shape, dtype=param_dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
||||
|
||||
def zero_grad(self):
|
||||
"""
|
||||
Zeroes the gradients of all the parameters.
|
||||
"""
|
||||
for param in self.params: param.grad = None
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
"""
|
||||
Tensor.realize(*self.schedule_step())
|
||||
|
||||
def schedule_step(self) -> list[Tensor]:
|
||||
"""
|
||||
Returns the tensors that need to be realized to perform a single optimization step.
|
||||
"""
|
||||
if not Tensor.training: raise RuntimeError(
|
||||
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
|
||||
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
|
||||
if self.fused:
|
||||
# optimizer fusion just concatenates all the buffers, runs the _step, then splits them back up
|
||||
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
|
||||
[Tensor.cat(*[unwrap(t.grad).flatten() for t in self.params], dim=0)])
|
||||
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
|
||||
else:
|
||||
updated_params, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
|
||||
for i, tt in enumerate(self.params): tt.assign(updated_params[i])
|
||||
return extra+self.params+self.buffers
|
||||
|
||||
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: raise NotImplementedError
|
||||
|
||||
class OptimizerGroup(Optimizer):
|
||||
"""
|
||||
Combines multiple optimizers into one.
|
||||
"""
|
||||
def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-called
|
||||
self.optimizers = optimizers
|
||||
self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
|
||||
def __getitem__(self, i): return self.optimizers[i]
|
||||
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
|
||||
def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()]
|
||||
|
||||
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 it's just standard SGD.
|
||||
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, fused=FUSE_OPTIM):
|
||||
"""
|
||||
Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
|
||||
|
||||
`classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
|
||||
"""
|
||||
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, fused=fused)
|
||||
|
||||
# Muon applies the newton schulz algorithm on gradient. also can include momentum, nesterov, and weight decay
|
||||
def Muon(params: list[Tensor], lr=0.02, momentum=0.95, weight_decay=0.0, ns_steps=5, ns_params=(3.4445, -4.775, 2.0315),
|
||||
nesterov=True, fused=FUSE_OPTIM):
|
||||
"""
|
||||
SGD with newton-schulz iteration and post momentum weight decay.
|
||||
|
||||
- Described: https://kellerjordan.github.io/posts/muon/
|
||||
- Paper: https://arxiv.org/pdf/2502.16982
|
||||
"""
|
||||
assert not fused, "FUSE_OPTIM not allowed for Muon optimizer"
|
||||
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_params, nesterov, classic=False, pre_wd=False, tcoef=0.0, fused=fused)
|
||||
|
||||
class LARS(Optimizer):
|
||||
"""
|
||||
Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1708.03888v3
|
||||
"""
|
||||
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, ns_steps=0, ns_params=None,
|
||||
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, fused)
|
||||
self.momentum, self.wd, self.ns_steps, self.ns_params = momentum, weight_decay, ns_steps, ns_params
|
||||
self.nesterov, self.classic, self.pre_wd, self.tcoef = nesterov, classic, pre_wd, tcoef
|
||||
self.b = self._new_optim_param() if self.momentum else []
|
||||
|
||||
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
ret = []
|
||||
for i, (t, g) in enumerate(zip(params, grads)):
|
||||
if self.tcoef != 0:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
r2 = g.square().sum().sqrt()
|
||||
r:Tensor|float = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
|
||||
else: r = 1.0
|
||||
if self.pre_wd and self.wd > 0: g = g + self.wd * t.detach()
|
||||
# classic momentum does post learning rate update
|
||||
if self.classic: g = g * r * self.lr
|
||||
if self.momentum:
|
||||
# TODO: this contiguous is required for correctness because self.b[i] becomes a non contiguous view
|
||||
# the scheduler should detect this and just insert contiguous
|
||||
self.b[i].assign(self.momentum * self.b[i].contiguous() + g) # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
if self.ns_params: g = g.reshape(g.shape[0], -1).newton_schulz(self.ns_steps, self.ns_params).reshape(g.shape)
|
||||
# muon does post momentum weight decay
|
||||
if not self.pre_wd and self.wd > 0: t = t.detach() * (1.0 - self.wd * self.lr)
|
||||
# popular momentum does pre learning rate update
|
||||
if not self.classic: g = g * r * self.lr
|
||||
ret.append((t.detach() - g).cast(t.dtype))
|
||||
return ret, self.b
|
||||
|
||||
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 it's just Adam/W.
|
||||
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, fused=FUSE_OPTIM):
|
||||
"""
|
||||
AdamW optimizer with optional weight decay.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1711.05101v3
|
||||
"""
|
||||
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, fused=fused)
|
||||
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, fused=FUSE_OPTIM):
|
||||
"""
|
||||
Adam optimizer.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1412.6980
|
||||
"""
|
||||
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, fused=fused)
|
||||
|
||||
class LAMB(Optimizer):
|
||||
"""
|
||||
LAMB optimizer with optional weight decay.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1904.00962
|
||||
"""
|
||||
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, fused)
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
|
||||
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
|
||||
self.m = self._new_optim_param()
|
||||
self.v = self._new_optim_param()
|
||||
|
||||
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
ret = []
|
||||
self.b1_t *= self.b1
|
||||
self.b2_t *= self.b2
|
||||
for i, (t, g) in enumerate(zip(params, grads)):
|
||||
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
|
||||
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
|
||||
m_hat = self.m[i] / (1.0 - self.b1_t)
|
||||
v_hat = self.v[i] / (1.0 - self.b2_t)
|
||||
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
||||
if not self.adam:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
r2 = up.square().sum().sqrt()
|
||||
r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
||||
else:
|
||||
r = 1.0
|
||||
ret.append((t.detach() - self.lr * r * up).cast(t.dtype))
|
||||
return ret, [self.b1_t, self.b2_t] + self.m + self.v
|
||||
351
tinygrad/nn/state.py
Normal file
351
tinygrad/nn/state.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, BinaryIO, Iterable
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
||||
class TensorIO(io.RawIOBase, BinaryIO):
|
||||
def __init__(self, t: Tensor):
|
||||
if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
|
||||
self._position, self._tensor = 0, t
|
||||
|
||||
def readable(self) -> bool: return True
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never)
|
||||
return buf
|
||||
def readinto(self, buffer: Any) -> int:
|
||||
data = self._tensor[self._position:self._position+len(buffer)].data()
|
||||
buffer[:len(data)] = data
|
||||
self._position += len(data)
|
||||
return len(data)
|
||||
|
||||
def seekable(self) -> bool: return True
|
||||
def seek(self, offset: int, whence: int = 0) -> int:
|
||||
self._position = min(len(self._tensor), max(0, [offset, self._position+offset, len(self._tensor)+offset][whence]))
|
||||
return self._position
|
||||
|
||||
# required to correctly implement BinaryIO
|
||||
def __enter__(self): return self
|
||||
def write(self, s: Any): raise io.UnsupportedOperation("TensorIO.write not supported")
|
||||
def writelines(self, lines: Iterable[Any]): raise io.UnsupportedOperation("TensorIO.writelines not supported")
|
||||
|
||||
safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
|
||||
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Tensor|str|pathlib.Path], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(fn: Tensor|str|pathlib.Path) -> T: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
|
||||
return wrapper
|
||||
|
||||
@accept_filename
|
||||
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
|
||||
"""
|
||||
Loads a .safetensor file, returning the source tensor, data start position, and metadata.
|
||||
"""
|
||||
data_start = int.from_bytes(t[0:8].data(), "little") + 8
|
||||
return t, data_start, json.loads(t[8:data_start].data().tobytes())
|
||||
|
||||
def safe_load(fn:Tensor|str|pathlib.Path) -> dict[str, Tensor]:
|
||||
"""
|
||||
Loads a .safetensor file, returning the `state_dict`.
|
||||
|
||||
```python
|
||||
state_dict = nn.state.safe_load("test.safetensor")
|
||||
```
|
||||
"""
|
||||
t, data_start, metadata = safe_load_metadata(fn)
|
||||
data = t[data_start:]
|
||||
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
|
||||
for k, v in metadata.items() if k != "__metadata__" }
|
||||
|
||||
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=None):
|
||||
"""
|
||||
Saves a `state_dict` to disk in a .safetensor file with optional metadata.
|
||||
|
||||
```python
|
||||
t = Tensor([1, 2, 3])
|
||||
nn.state.safe_save({'t':t}, "test.safetensor")
|
||||
```
|
||||
"""
|
||||
headers, offset = {}, 0
|
||||
if metadata: headers['__metadata__'] = metadata
|
||||
for k,v in tensors.items():
|
||||
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
||||
offset += v.nbytes()
|
||||
j = json.dumps(headers, separators=(',', ':'))
|
||||
j += "\x20"*(round_up(len(j),8)-len(j))
|
||||
pathlib.Path(fn).unlink(missing_ok=True)
|
||||
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
t[0:8].bitcast(dtypes.int64).assign([len(j)])
|
||||
t[8:8+len(j)].assign(list(j.encode('utf-8')))
|
||||
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
||||
|
||||
# state dict
|
||||
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
|
||||
"""
|
||||
Returns a `state_dict` of the object, with optional prefix.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
class Net:
|
||||
def __init__(self):
|
||||
self.l1 = nn.Linear(4, 5)
|
||||
self.l2 = nn.Linear(5, 6)
|
||||
|
||||
net = Net()
|
||||
print(nn.state.get_state_dict(net).keys())
|
||||
```
|
||||
"""
|
||||
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
||||
if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
|
||||
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
||||
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
|
||||
state_dict = {}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
|
||||
elif isinstance(obj, dict):
|
||||
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
||||
return state_dict
|
||||
|
||||
def get_parameters(obj) -> list[Tensor]:
|
||||
"""
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
class Net:
|
||||
def __init__(self):
|
||||
self.l1 = nn.Linear(4, 5)
|
||||
self.l2 = nn.Linear(5, 6)
|
||||
|
||||
net = Net()
|
||||
print(len(nn.state.get_parameters(net)))
|
||||
```
|
||||
"""
|
||||
return list(get_state_dict(obj).values())
|
||||
|
||||
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> list[Tensor]:
|
||||
"""
|
||||
Loads a `state_dict` into a model. Return the loaded Tensors.
|
||||
|
||||
```python
|
||||
class Net:
|
||||
def __init__(self):
|
||||
self.l1 = nn.Linear(4, 5)
|
||||
self.l2 = nn.Linear(5, 6)
|
||||
|
||||
net = Net()
|
||||
state_dict = nn.state.get_state_dict(net)
|
||||
nn.state.load_state_dict(net, state_dict)
|
||||
```
|
||||
"""
|
||||
start_mem_used = GlobalCounters.mem_used
|
||||
ret = []
|
||||
with Timing("loaded weights in ",
|
||||
lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose):
|
||||
model_state_dict = get_state_dict(model)
|
||||
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
||||
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
||||
for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
|
||||
t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
|
||||
if k not in state_dict and not strict:
|
||||
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
||||
continue
|
||||
if v.shape != state_dict[k].shape:
|
||||
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
|
||||
if isinstance(v.device, tuple):
|
||||
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
|
||||
else: v.replace(state_dict[k].shard(v.device, v.uop.axis))
|
||||
else: v.replace(state_dict[k].to(v.device))
|
||||
if realize: v.realize()
|
||||
if consume: del state_dict[k]
|
||||
ret.append(v)
|
||||
return ret
|
||||
|
||||
@accept_filename
|
||||
def tar_extract(t: Tensor) -> dict[str, Tensor]:
|
||||
"""
|
||||
```python
|
||||
tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
|
||||
```
|
||||
|
||||
Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
|
||||
|
||||
```python
|
||||
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
|
||||
```
|
||||
"""
|
||||
with tarfile.open(fileobj=TensorIO(t), mode="r") as tar:
|
||||
return {member.name:t[member.offset_data:member.offset_data+member.size] for member in tar if member.type == tarfile.REGTYPE}
|
||||
|
||||
# torch support!
|
||||
|
||||
@accept_filename
|
||||
def torch_load(t:Tensor) -> dict[str, Tensor]:
|
||||
"""
|
||||
```python
|
||||
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
|
||||
```
|
||||
|
||||
Loads a torch .pth file, returning the `state_dict`.
|
||||
|
||||
```python
|
||||
state_dict = nn.state.torch_load("test.pth")
|
||||
```
|
||||
"""
|
||||
offsets: dict[str|int, int] = {}
|
||||
lens: dict[str|int, int] = {}
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
lens[storage[2]] = storage[4] * storage[1].itemsize
|
||||
if storage[2] not in offsets: return None
|
||||
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||||
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
|
||||
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
|
||||
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
||||
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
||||
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
||||
if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
|
||||
assert storage[1] != dtypes.bfloat16, "can't permute BF16"
|
||||
# TODO: find a nice way to support all shapetracker on disktensors
|
||||
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
|
||||
|
||||
return ret.reshape(size)
|
||||
|
||||
class Parameter:
|
||||
def __setstate__(self, state): self.tensor = state[0]
|
||||
|
||||
deserialized_objects: dict[str, Any] = {}
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
|
||||
"IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
|
||||
"LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
||||
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
||||
class Dummy: pass
|
||||
class TorchPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
module_root = module.split(".")[0]
|
||||
if module_root not in whitelist:
|
||||
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
|
||||
return Dummy
|
||||
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
||||
def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
|
||||
|
||||
fobj = io.BufferedReader(TensorIO(t))
|
||||
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
|
||||
|
||||
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
|
||||
myzip = zipfile.ZipFile(fobj, 'r')
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
for n in myzip.namelist():
|
||||
if n.startswith(f'{base_name}/data/'):
|
||||
with myzip.open(n) as myfile:
|
||||
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
return TorchPickle(myfile).load()
|
||||
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
|
||||
with tarfile.open(fileobj=fobj, mode="r") as tar:
|
||||
storages_offset = tar.getmember('storages').offset_data
|
||||
f = unwrap(tar.extractfile('storages'))
|
||||
for i in range(TorchPickle(f).load()): # num_storages
|
||||
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
|
||||
offsets[key] = storages_offset + f.tell()
|
||||
f.seek(sz*storage_type.itemsize, 1)
|
||||
f = unwrap(tar.extractfile('tensors'))
|
||||
for _ in range(TorchPickle(f).load()): # num_tensors
|
||||
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
|
||||
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
|
||||
storage_offset = struct.unpack('<q', f.read(8))[0]
|
||||
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
|
||||
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
|
||||
else:
|
||||
pkl = TorchPickle(fobj)
|
||||
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
|
||||
for i in ids:
|
||||
offsets[i] = base_offset + 8
|
||||
base_offset += 8 + lens[i]
|
||||
fobj.seek(rwd)
|
||||
return TorchPickle(fobj).load()
|
||||
|
||||
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
||||
"""
|
||||
Converts ggml tensor data to a tinygrad tensor.
|
||||
|
||||
Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
|
||||
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14), MXFP4 (id: 39)
|
||||
"""
|
||||
# https://github.com/ggerganov/ggml/blob/323951f1bdcdfbd5b5ff3a9a7c3770e63b1a560e/include/ggml.h#L356
|
||||
|
||||
# native types
|
||||
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
|
||||
return t[:dtype.itemsize * n].bitcast(dtype)
|
||||
|
||||
def q_to_uint8(t: Tensor, b: int) -> Tensor:
|
||||
# TODO: rewrite with arange?
|
||||
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
|
||||
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
|
||||
|
||||
# map to (number of elements, number of bytes)
|
||||
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34), 39: (32, 17) }.get(ggml_type)) is not None:
|
||||
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
|
||||
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
|
||||
if ggml_type == 3:
|
||||
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
|
||||
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
|
||||
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
|
||||
if ggml_type == 14:
|
||||
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
|
||||
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
|
||||
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
|
||||
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
|
||||
if ggml_type == 39:
|
||||
e_int = blocks[:, 0].cast(dtypes.int32)
|
||||
d = ((e_int >= 2).cast(dtypes.float32) * (e_int.cast(dtypes.float32) - 128).exp2() +
|
||||
(e_int == 1).cast(dtypes.float32) * 2.0**(-127) +
|
||||
(e_int == 0).cast(dtypes.float32) * 2.0**(-128)).unsqueeze(-1)
|
||||
codes = q_to_uint8(blocks[:, 1:17], 4)
|
||||
sign = 1.0 - codes.rshift(3).cast(dtypes.float32) * 2.0
|
||||
exp, mant = codes.rshift(1).bitwise_and(0x3).cast(dtypes.float32), codes.bitwise_and(0x1).cast(dtypes.float32)
|
||||
fp4_val = sign * ((exp != 0).cast(dtypes.float32) * (1.0 + 0.5 * mant) * (exp - 1.0).exp2() +
|
||||
(exp == 0).cast(dtypes.float32) * 0.5 * mant)
|
||||
return (fp4_val * d).flatten(-2)[:n]
|
||||
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
|
||||
|
||||
@accept_filename
|
||||
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
|
||||
"""
|
||||
Loads a .gguf file, returning the `kv_data` and `state_dict`.
|
||||
|
||||
```python
|
||||
gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
|
||||
kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
|
||||
```
|
||||
|
||||
NOTE: The provided tensor must be on a device that supports execution.
|
||||
"""
|
||||
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
|
||||
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
|
||||
def read_str(): return str(reader.read(read_uint64()), "utf-8")
|
||||
def read_arr():
|
||||
reader, n = readers[read_int32()], read_uint64()
|
||||
return [ reader() for _ in range(n) ]
|
||||
|
||||
readers: dict[int, Callable[[], Any]] = { 8: read_str, 9: read_arr, **{ t: functools.partial(read_unpack, "<"+f, nb) for t,f,nb in \
|
||||
[ (0,"c",1), (1,"b",1), (2,"H",2), (3,"h",2), (4,"I",4), (5,"i",4), (6,"f",4), (7,"?",1), (10,"Q",8), (11,"q",8), (12,"d",8) ] } }
|
||||
read_uint32, read_int32, read_uint64, read_int64 = readers[4], readers[5], readers[10], readers[11]
|
||||
|
||||
magic, version, n_tensors, n_kv = reader.read(4), read_int32(), read_int64(), read_int64()
|
||||
if magic != b"GGUF" or version not in [2, 3]: raise ValueError("Invalid GGUF format!")
|
||||
for _ in range(n_kv):
|
||||
k, typ = read_str(), read_int32()
|
||||
kv_data[k] = readers[typ]()
|
||||
|
||||
t_infos = [ (read_str(), tuple(read_uint64() for _ in range(read_uint32())), read_int32(), read_uint64()) for _ in range(n_tensors) ]
|
||||
alignment, pos = kv_data.get("general.alignment", 32), reader.tell()
|
||||
data_start = round_up(pos, alignment)
|
||||
|
||||
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
|
||||
|
||||
return kv_data, state_dict
|
||||
Reference in New Issue
Block a user