Release 260111
This commit is contained in:
0
tinygrad/runtime/graph/__init__.py
Normal file
0
tinygrad/runtime/graph/__init__.py
Normal file
75
tinygrad/runtime/graph/cuda.py
Normal file
75
tinygrad/runtime/graph/cuda.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import ctypes
|
||||
from typing import Any, cast
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, dedup
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
|
||||
class CUDAGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()])
|
||||
self.updatable_nodes: dict[int, tuple[Any, Any, Any, bool]] = {} # dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
|
||||
|
||||
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
global_size, local_size = ji.prg.p.launch_dims(var_vals)
|
||||
|
||||
new_node = cuda.CUgraphNode()
|
||||
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars])
|
||||
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
|
||||
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
||||
|
||||
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
|
||||
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
src_dev = cast(CUDADevice, Device[src.device])
|
||||
node_from = cuda.CUgraphNode()
|
||||
deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
||||
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
||||
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
||||
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
|
||||
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
|
||||
|
||||
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Update rawbuffers in the c_args struct.
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
else:
|
||||
if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
|
||||
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
|
||||
|
||||
# Update var_vals in the c_args struct.
|
||||
for j, i, v in self.updated_vars(var_vals): setattr(self.updatable_nodes[j][2], f'v{i}', v)
|
||||
|
||||
# Update launch dims in the kern_params struct.
|
||||
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
||||
node = self.updatable_nodes[j][1]
|
||||
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_dims, *global_dims # type: ignore[misc]
|
||||
|
||||
# Update graph nodes with the updated structs.
|
||||
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
|
||||
if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
|
||||
else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
|
||||
|
||||
return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
|
||||
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
|
||||
245
tinygrad/runtime/graph/hcq.py
Normal file
245
tinygrad/runtime/graph/hcq.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import collections, time
|
||||
from typing import Any, cast
|
||||
from tinygrad.helpers import round_up, PROFILE, merge_dicts, getenv, dedup
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, BufferCopy
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
||||
|
||||
# CPU Device is always last
|
||||
self.devices = sorted(self.devices, key=lambda x: 1 if x._is_cpu() else 0)
|
||||
|
||||
# Replace input buffers with variables.
|
||||
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
|
||||
self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
|
||||
|
||||
for (j,i), input_idx in self.input_replace.items():
|
||||
x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
|
||||
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
|
||||
|
||||
# Allocate kernel args.
|
||||
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in jit_cache:
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
|
||||
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
self.ji_args: dict[int, HCQArgsState] = {}
|
||||
|
||||
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
|
||||
argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
|
||||
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, argsbuf)
|
||||
|
||||
# Schedule Dependencies.
|
||||
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
||||
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
|
||||
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
|
||||
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
|
||||
self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {}
|
||||
|
||||
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
||||
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
|
||||
|
||||
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()},
|
||||
**{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}}
|
||||
self.kickoff_value: int = 0
|
||||
self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
|
||||
|
||||
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
|
||||
# TODO: This logic might allocate a few extra signals...
|
||||
self.prof_signals: list[HCQSignal] = []
|
||||
self.prof_graph_deps: list[list[int]] = []
|
||||
self.prof_graph_entries: list[ProfileGraphEntry] = []
|
||||
|
||||
last_j: dict[HWQueue, int|None] = collections.defaultdict(lambda: None)
|
||||
queue_access: dict[HWQueue, dict[HWQueue, int|None]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
|
||||
dev_access: dict[HWQueue, set[HCQCompiled]] = collections.defaultdict(set)
|
||||
|
||||
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
||||
|
||||
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
|
||||
self.fixedvars: dict[HCQCompiled, dict[str, int]] = {}
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
|
||||
else:
|
||||
# For copy ops prioritize enqeueuing on the dest device, so reverse the buffers.
|
||||
for b in cast(list[Buffer], ji.bufs[::-1]):
|
||||
if (enqueue_dev:=cast(HCQCompiled, Device[b.device])).hw_copy_queue_t is not None: break
|
||||
|
||||
# set any fixedvars on the device
|
||||
self.fixedvars[enqueue_dev] = merge_dicts([self.fixedvars.get(enqueue_dev, {}), ji.fixedvars])
|
||||
|
||||
if is_exec_prg:
|
||||
enqueue_queue = self.comp_queues[enqueue_dev]
|
||||
else:
|
||||
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
||||
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
||||
|
||||
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
|
||||
|
||||
# Get dependencies based on input and output buffers.
|
||||
rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
|
||||
|
||||
# Update dependencies to include previous kernel in queue. This is required for timeline signals.
|
||||
opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])
|
||||
|
||||
# Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already
|
||||
# synced with the current queue.
|
||||
for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True):
|
||||
if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
|
||||
opt_deps.append((self.signals[dep_queue], dep_val))
|
||||
queue_access[enqueue_queue][dep_queue] = dep_val
|
||||
dev_access[enqueue_queue].update(dev_access[dep_queue])
|
||||
|
||||
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
|
||||
sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
|
||||
dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
|
||||
|
||||
# Remove self-dependency for compute and copy queues.
|
||||
# For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
|
||||
# eliminating dependency need.
|
||||
dname = enqueue_dev.device.split(":", 1)[0]
|
||||
can_opt = dname in {"AMD", "QCOM"} or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal))
|
||||
if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
|
||||
|
||||
# Enable necessary signals in the schedule by setting the signal value.
|
||||
for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,)
|
||||
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
|
||||
|
||||
# Collect profile information if profiling is enabled.
|
||||
if PROFILE:
|
||||
# When execution are chained, we can reuse the end timestamp from the previous command as the start timestamp for the current command.
|
||||
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2
|
||||
|
||||
# Description based on the command.
|
||||
prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
||||
|
||||
self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
|
||||
self.prof_graph_deps.append([d - 1 for _, d in rdeps])
|
||||
|
||||
last_j[enqueue_queue] = j
|
||||
|
||||
# Check which signals are used in the profile graph.
|
||||
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(jit_cache) * 2)]
|
||||
|
||||
# Build hardware queues.
|
||||
self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
||||
|
||||
# Create variable timeline signals for each device.
|
||||
timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{self.dev_name(dev)}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
|
||||
self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{self.dev_name(dev)}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
|
||||
self.virt_timeline_signals = {dev: dev.signal_t(HCQBuffer(timeline_sigaddrs[dev], 16), owner=dev, is_timeline=True) for dev in self.devices}
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
|
||||
.wait(self.signals['KICK'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
||||
|
||||
# Lazy allocate signals
|
||||
if PROFILE: self.prof_signals += [enqueue_dev.new_signal(value=0) for _ in range(2)]
|
||||
|
||||
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
||||
|
||||
# Encode waits and start profile timestamp (if needed).
|
||||
if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2])
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
|
||||
elif isinstance(ji.prg, (BufferXfer, BufferCopy)):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
for bufid, src in enumerate(cast(list[Buffer], ji.bufs)):
|
||||
if (inprep_idx:=self.input_replace.get((j, bufid))) is not None: self.input_replace_map[enqueue_dev].add(inprep_idx)
|
||||
else: cast(HCQAllocator, enqueue_dev.allocator).map(self.hcq_bufs[j][bufid])
|
||||
enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes)
|
||||
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
||||
|
||||
# Encode finish profile timestamp (if needed).
|
||||
if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1])
|
||||
|
||||
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
|
||||
|
||||
for dev in self.devices:
|
||||
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
||||
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
|
||||
|
||||
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
|
||||
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
|
||||
|
||||
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
for sig in self.queue_signals_to_reset: sig.value = 0
|
||||
self.signals['KICK'].value = self.kickoff_value
|
||||
|
||||
for dev in self.devices:
|
||||
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
|
||||
|
||||
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
||||
|
||||
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
|
||||
**{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
||||
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
|
||||
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local)
|
||||
|
||||
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
|
||||
|
||||
if wait:
|
||||
st = time.perf_counter()
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
return time.perf_counter() - st
|
||||
return None
|
||||
|
||||
def collect_timestamps(self):
|
||||
# NOTE: Append to any device is fine...
|
||||
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prof_graph_deps, [s.timestamp for s in self.prof_signals])]
|
||||
|
||||
def dev_name(self, dev) -> str: return dev.device.replace(":", "_")
|
||||
|
||||
def __del__(self):
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
|
||||
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
|
||||
|
||||
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
|
||||
|
||||
@staticmethod
|
||||
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
|
||||
# Check if all devices are HCQ
|
||||
all_devs = cast(list[HCQCompiled], dedup(devs + [Device[b.device] for b in ei.bufs if b]))
|
||||
if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False
|
||||
|
||||
# If all of devices are mapped into CPU address space, can use CPU inside the peer group.
|
||||
cpu_support = all(isinstance(d.timeline_signal.base_buf.view, MMIOInterface) for d in all_devs)
|
||||
|
||||
# Check if all devices are within the same peer group. If CPU is supported, don't count it as a separate peer group.
|
||||
if len(set(d.peer_group for d in all_devs if cpu_support and not d._is_cpu())) > 1: return False
|
||||
|
||||
# MOCKGPU is not supported, since it can't execute commands in parallel
|
||||
copy = (isinstance(ei.prg, BufferCopy) and cast(HCQCompiled, devs[0]).hw_copy_queue_t is not None) and not getenv("MOCKGPU")
|
||||
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) or copy
|
||||
116
tinygrad/runtime/graph/metal.py
Normal file
116
tinygrad/runtime/graph/metal.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from typing import Any, cast
|
||||
import ctypes, re, decimal
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
|
||||
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
||||
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
|
||||
|
||||
class MTLIndirectCommandType:
|
||||
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
|
||||
|
||||
class MTLResourceUsage:
|
||||
MTLResourceUsageRead = 0b01
|
||||
MTLResourceUsageWrite = 0b10
|
||||
|
||||
class MetalGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
# create metal batch exec
|
||||
icb_descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"))
|
||||
msg("setCommandTypes:")(icb_descriptor, MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
|
||||
msg("setInheritBuffers:")(icb_descriptor, False)
|
||||
msg("setInheritPipelineState:")(icb_descriptor, False)
|
||||
msg("setMaxKernelBufferBindCount:")(icb_descriptor, 31)
|
||||
|
||||
self.icb = msg("newIndirectCommandBufferWithDescriptor:maxCommandCount:options:", objc_instance)(self.dev.sysdevice,
|
||||
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache)
|
||||
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
icb_label = bytes(msg("UTF8String", ctypes.c_char_p)(msg("description", objc_instance)(self.icb))).decode()
|
||||
self.needs_icb_fix = int((m := re.search(r'AGXG(\d+)XFamily', icb_label)) is None or int(m.group(1)) < 15) # not required on M3+
|
||||
|
||||
self.fixedvars = merge_dicts([ji.fixedvars for ji in jit_cache])
|
||||
self.varlist = self.vars + list(self.fixedvars.keys())
|
||||
if len(self.varlist): self.int_buf = self.dev.allocator.alloc(len(self.varlist)*dtypes.int32.itemsize)
|
||||
|
||||
all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.varlist) else []
|
||||
for j,ji in enumerate(jit_cache):
|
||||
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
||||
icb_command = msg("indirectComputeCommandAtIndex:", objc_instance)(self.icb, j)
|
||||
all_pipelines.append(prg._prg.pipeline_state)
|
||||
msg("setComputePipelineState:")(icb_command, prg._prg.pipeline_state)
|
||||
for i,b in enumerate(ji.bufs):
|
||||
if b is not None and b not in input_rawbuffers:
|
||||
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i,v in enumerate(prg.p.vars):
|
||||
msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
|
||||
|
||||
global_size, local_size = prg.p.launch_dims(var_vals)
|
||||
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
|
||||
msg("setBarrier")(icb_command)
|
||||
|
||||
self.all_resources = dedup(all_resources)
|
||||
self.all_pipelines = dedup(all_pipelines)
|
||||
self.command_buffer: Any = None
|
||||
if len(self.varlist): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
|
||||
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
|
||||
self.range = to_struct(0, len(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
# NOTE: old command buffer may not be inflight anymore
|
||||
if self.command_buffer is not None and PROFILE: self.collect_timestamps()
|
||||
|
||||
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
||||
msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
|
||||
|
||||
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
||||
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
||||
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(computeCommand, to_struct(*global_dims), to_struct(*local_dims))
|
||||
for var in self.vars: self.int_buf_view[self.varlist.index(var)] = var_vals[var]
|
||||
|
||||
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
|
||||
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
|
||||
msg("useResources:count:usage:")(encoder, (objc_id * len(all_resources))(*all_resources), len(all_resources),
|
||||
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
|
||||
|
||||
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
|
||||
# this is a O(n) hack to get them used. what should work is:
|
||||
#encoder.useResources_count_usage_(self.all_pipelines, len(self.all_pipelines), Metal.MTLResourceUsageRead)
|
||||
# but it fails with "Invalid Resource (00000009:kIOGPUCommandBufferCallbackErrorInvalidResource)"
|
||||
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
|
||||
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
|
||||
for ps in self.all_pipelines:
|
||||
msg("setComputePipelineState:")(encoder, ps)
|
||||
msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(0,0,0), to_struct(0,0,0))
|
||||
|
||||
msg("executeCommandsInBuffer:withRange:")(encoder, self.icb, self.range)
|
||||
msg("endEncoding")(encoder)
|
||||
msg("setLabel:")(command_buffer, to_ns_str(f"batched {len(self.jit_cache)}"))
|
||||
msg("commit")(command_buffer)
|
||||
self.command_buffer = command_buffer
|
||||
|
||||
self.dev.mtl_buffers_in_flight.append(command_buffer)
|
||||
if wait:
|
||||
wait_check(command_buffer)
|
||||
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
|
||||
return None
|
||||
|
||||
def collect_timestamps(self):
|
||||
# create a graph event and evenly space each program
|
||||
st, en = decimal.Decimal(cmdbuf_st_time(self.command_buffer)) * 1000000, decimal.Decimal(cmdbuf_en_time(self.command_buffer)) * 1000000
|
||||
ents = [ProfileGraphEntry(self.device, cast(CompiledRunner, ji.prg)._prg.name, i, i+1, is_copy=False) for i,ji in enumerate(self.jit_cache)]
|
||||
step = (en-st)/len(ents)
|
||||
self.dev.profile_events += [ProfileGraphEvent(ents, [], [st+step*i for i in range(len(ents)+1)])]
|
||||
|
||||
def __del__(self):
|
||||
if PROFILE and self.command_buffer is not None:
|
||||
wait_check(self.command_buffer)
|
||||
self.collect_timestamps()
|
||||
113
tinygrad/runtime/graph/remote.py
Normal file
113
tinygrad/runtime/graph/remote.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import time, itertools
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
|
||||
from tinygrad.device import Device, Compiled, Buffer
|
||||
from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec
|
||||
from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait
|
||||
from tinygrad.helpers import unwrap, flatten, dedup
|
||||
from enum import Enum, auto
|
||||
from dataclasses import replace
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
|
||||
class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702
|
||||
|
||||
def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev)
|
||||
def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev
|
||||
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
|
||||
|
||||
class RemoteGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, rawbufs, var_vals)
|
||||
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
|
||||
c2d = {device.conn: device for device in devices}
|
||||
self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))}
|
||||
|
||||
self.template: list[RemoteRequest] = []
|
||||
|
||||
stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list)
|
||||
clobbered_buffers: set[Buffer] = set()
|
||||
cur_staging_type: StagingType = StagingType.NONE
|
||||
|
||||
def _flush(new_staging_type:StagingType, force_break:bool=False):
|
||||
nonlocal cur_staging_type
|
||||
if cur_staging_type == new_staging_type and not force_break: return
|
||||
# Pre-sync
|
||||
if cur_staging_type == StagingType.TRANSFER:
|
||||
for sdev,ddev in itertools.permutations(c2d.values(), 2):
|
||||
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
|
||||
self.template.append(Wait(event, session=ddev.session))
|
||||
# Flush
|
||||
for dev in devices:
|
||||
dk = dev_key(dev)
|
||||
staging = stagings[dk]
|
||||
if not staging: continue
|
||||
match cur_staging_type:
|
||||
case StagingType.GRAPH:
|
||||
bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk)
|
||||
dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals))
|
||||
self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session))
|
||||
case StagingType.TRANSFER:
|
||||
st = cast(list[Transfer], staging)
|
||||
for host in dedup(t.dsession.host for t in st):
|
||||
sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host]
|
||||
dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host]
|
||||
self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session))
|
||||
staging.clear()
|
||||
# Post-sync
|
||||
if cur_staging_type == StagingType.TRANSFER:
|
||||
for sdev,ddev in itertools.permutations(c2d.values(), 2):
|
||||
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
|
||||
self.template.append(Wait(event, session=ddev.session))
|
||||
cur_staging_type = new_staging_type
|
||||
clobbered_buffers.clear()
|
||||
|
||||
for ji in jit_cache:
|
||||
match ji.prg:
|
||||
case CompiledRunner():
|
||||
_flush(StagingType.GRAPH)
|
||||
gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs),
|
||||
tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs),
|
||||
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
|
||||
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
|
||||
stagings[dev_key(ji.prg.dev)].append(gi)
|
||||
case BufferXfer():
|
||||
dest, src = ji.bufs[0:2]
|
||||
dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device])
|
||||
assert dest is not None and src is not None, ji
|
||||
ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf)
|
||||
if dev_key(dest_dev) == dev_key(src_dev):
|
||||
_flush(StagingType.GRAPH)
|
||||
stagings[dev_key(src_dev)].append(ti)
|
||||
elif dest_dev.conn == src_dev.conn:
|
||||
_flush(StagingType.NONE)
|
||||
self.template.append(ti)
|
||||
else:
|
||||
_flush(StagingType.TRANSFER, force_break=src in clobbered_buffers)
|
||||
clobbered_buffers.add(dest)
|
||||
stagings[dev_key(src_dev)].append(ti)
|
||||
case _: raise NotImplementedError(ji.prg)
|
||||
_flush(StagingType.NONE)
|
||||
def __del__(self):
|
||||
for req in self.template:
|
||||
match req:
|
||||
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
|
||||
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
|
||||
if wait: st = time.perf_counter()
|
||||
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
|
||||
for req in self.template:
|
||||
match req:
|
||||
case GraphExec():
|
||||
req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait)
|
||||
case Transfer():
|
||||
if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1])
|
||||
if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1])
|
||||
case BatchTransfer():
|
||||
req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums])
|
||||
case Event()|Wait():
|
||||
pass # event number can be reused
|
||||
case _: raise NotImplementedError(req)
|
||||
RemoteConnection(unwrap(req.session).host).q(req)
|
||||
if wait:
|
||||
RemoteConnection(unwrap(req.session).host).batch_submit()
|
||||
return time.perf_counter() - st
|
||||
Reference in New Issue
Block a user