Release 260111
This commit is contained in:
113
tinygrad/runtime/graph/remote.py
Normal file
113
tinygrad/runtime/graph/remote.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import time, itertools
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
|
||||
from tinygrad.device import Device, Compiled, Buffer
|
||||
from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec
|
||||
from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait
|
||||
from tinygrad.helpers import unwrap, flatten, dedup
|
||||
from enum import Enum, auto
|
||||
from dataclasses import replace
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
|
||||
class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702
|
||||
|
||||
def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev)
|
||||
def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev
|
||||
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
|
||||
|
||||
class RemoteGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, rawbufs, var_vals)
|
||||
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
|
||||
c2d = {device.conn: device for device in devices}
|
||||
self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))}
|
||||
|
||||
self.template: list[RemoteRequest] = []
|
||||
|
||||
stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list)
|
||||
clobbered_buffers: set[Buffer] = set()
|
||||
cur_staging_type: StagingType = StagingType.NONE
|
||||
|
||||
def _flush(new_staging_type:StagingType, force_break:bool=False):
|
||||
nonlocal cur_staging_type
|
||||
if cur_staging_type == new_staging_type and not force_break: return
|
||||
# Pre-sync
|
||||
if cur_staging_type == StagingType.TRANSFER:
|
||||
for sdev,ddev in itertools.permutations(c2d.values(), 2):
|
||||
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
|
||||
self.template.append(Wait(event, session=ddev.session))
|
||||
# Flush
|
||||
for dev in devices:
|
||||
dk = dev_key(dev)
|
||||
staging = stagings[dk]
|
||||
if not staging: continue
|
||||
match cur_staging_type:
|
||||
case StagingType.GRAPH:
|
||||
bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk)
|
||||
dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals))
|
||||
self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session))
|
||||
case StagingType.TRANSFER:
|
||||
st = cast(list[Transfer], staging)
|
||||
for host in dedup(t.dsession.host for t in st):
|
||||
sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host]
|
||||
dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host]
|
||||
self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session))
|
||||
staging.clear()
|
||||
# Post-sync
|
||||
if cur_staging_type == StagingType.TRANSFER:
|
||||
for sdev,ddev in itertools.permutations(c2d.values(), 2):
|
||||
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
|
||||
self.template.append(Wait(event, session=ddev.session))
|
||||
cur_staging_type = new_staging_type
|
||||
clobbered_buffers.clear()
|
||||
|
||||
for ji in jit_cache:
|
||||
match ji.prg:
|
||||
case CompiledRunner():
|
||||
_flush(StagingType.GRAPH)
|
||||
gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs),
|
||||
tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs),
|
||||
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
|
||||
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
|
||||
stagings[dev_key(ji.prg.dev)].append(gi)
|
||||
case BufferXfer():
|
||||
dest, src = ji.bufs[0:2]
|
||||
dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device])
|
||||
assert dest is not None and src is not None, ji
|
||||
ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf)
|
||||
if dev_key(dest_dev) == dev_key(src_dev):
|
||||
_flush(StagingType.GRAPH)
|
||||
stagings[dev_key(src_dev)].append(ti)
|
||||
elif dest_dev.conn == src_dev.conn:
|
||||
_flush(StagingType.NONE)
|
||||
self.template.append(ti)
|
||||
else:
|
||||
_flush(StagingType.TRANSFER, force_break=src in clobbered_buffers)
|
||||
clobbered_buffers.add(dest)
|
||||
stagings[dev_key(src_dev)].append(ti)
|
||||
case _: raise NotImplementedError(ji.prg)
|
||||
_flush(StagingType.NONE)
|
||||
def __del__(self):
|
||||
for req in self.template:
|
||||
match req:
|
||||
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
|
||||
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
|
||||
if wait: st = time.perf_counter()
|
||||
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
|
||||
for req in self.template:
|
||||
match req:
|
||||
case GraphExec():
|
||||
req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait)
|
||||
case Transfer():
|
||||
if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1])
|
||||
if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1])
|
||||
case BatchTransfer():
|
||||
req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums])
|
||||
case Event()|Wait():
|
||||
pass # event number can be reused
|
||||
case _: raise NotImplementedError(req)
|
||||
RemoteConnection(unwrap(req.session).host).q(req)
|
||||
if wait:
|
||||
RemoteConnection(unwrap(req.session).host).batch_submit()
|
||||
return time.perf_counter() - st
|
||||
Reference in New Issue
Block a user