Release 260111

This commit is contained in:
Comma Device
2026-01-11 18:23:29 +08:00
commit 3721ecbf8a
2601 changed files with 855070 additions and 0 deletions

View File

View File

@@ -0,0 +1,75 @@
import ctypes
from typing import Any, cast
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, dedup
from tinygrad.device import Buffer, Device
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner, GraphException
class CUDAGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
# Check all jit items are compatible.
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()])
self.updatable_nodes: dict[int, tuple[Any, Any, Any, bool]] = {} # dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, CompiledRunner):
global_size, local_size = ji.prg.p.launch_dims(var_vals)
new_node = cuda.CUgraphNode()
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars])
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
src_dev = cast(CUDADevice, Device[src.device])
node_from = cuda.CUgraphNode()
deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
WidthInBytes=dest.nbytes, Height=1, Depth=1)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Update rawbuffers in the c_args struct.
for (j,i),input_idx in self.input_replace.items():
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
else:
if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
# Update var_vals in the c_args struct.
for j, i, v in self.updated_vars(var_vals): setattr(self.updatable_nodes[j][2], f'v{i}', v)
# Update launch dims in the kern_params struct.
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
node = self.updatable_nodes[j][1]
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_dims, *global_dims # type: ignore[misc]
# Update graph nodes with the updated structs.
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
def __del__(self):
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))

View File

@@ -0,0 +1,245 @@
import collections, time
from typing import Any, cast
from tinygrad.helpers import round_up, PROFILE, merge_dicts, getenv, dedup
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Variable
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, BufferCopy
from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
# CPU Device is always last
self.devices = sorted(self.devices, key=lambda x: 1 if x._is_cpu() else 0)
# Replace input buffers with variables.
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
for (j,i), input_idx in self.input_replace.items():
x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
# Allocate kernel args.
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
for ji in jit_cache:
if not isinstance(ji.prg, CompiledRunner): continue
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()}
# Fill initial arguments.
self.ji_args: dict[int, HCQArgsState] = {}
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()}
for j,ji in enumerate(jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, argsbuf)
# Schedule Dependencies.
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the devices
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {}
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()},
**{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}}
self.kickoff_value: int = 0
self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
# TODO: This logic might allocate a few extra signals...
self.prof_signals: list[HCQSignal] = []
self.prof_graph_deps: list[list[int]] = []
self.prof_graph_entries: list[ProfileGraphEntry] = []
last_j: dict[HWQueue, int|None] = collections.defaultdict(lambda: None)
queue_access: dict[HWQueue, dict[HWQueue, int|None]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
dev_access: dict[HWQueue, set[HCQCompiled]] = collections.defaultdict(set)
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
self.fixedvars: dict[HCQCompiled, dict[str, int]] = {}
for j,ji in enumerate(jit_cache):
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
else:
# For copy ops prioritize enqeueuing on the dest device, so reverse the buffers.
for b in cast(list[Buffer], ji.bufs[::-1]):
if (enqueue_dev:=cast(HCQCompiled, Device[b.device])).hw_copy_queue_t is not None: break
# set any fixedvars on the device
self.fixedvars[enqueue_dev] = merge_dicts([self.fixedvars.get(enqueue_dev, {}), ji.fixedvars])
if is_exec_prg:
enqueue_queue = self.comp_queues[enqueue_dev]
else:
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
# Get dependencies based on input and output buffers.
rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
# Update dependencies to include previous kernel in queue. This is required for timeline signals.
opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])
# Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already
# synced with the current queue.
for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True):
if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
opt_deps.append((self.signals[dep_queue], dep_val))
queue_access[enqueue_queue][dep_queue] = dep_val
dev_access[enqueue_queue].update(dev_access[dep_queue])
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
# Remove self-dependency for compute and copy queues.
# For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
# eliminating dependency need.
dname = enqueue_dev.device.split(":", 1)[0]
can_opt = dname in {"AMD", "QCOM"} or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal))
if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
# Enable necessary signals in the schedule by setting the signal value.
for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,)
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
# Collect profile information if profiling is enabled.
if PROFILE:
# When execution are chained, we can reuse the end timestamp from the previous command as the start timestamp for the current command.
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2
# Description based on the command.
prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
self.prof_graph_deps.append([d - 1 for _, d in rdeps])
last_j[enqueue_queue] = j
# Check which signals are used in the profile graph.
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(jit_cache) * 2)]
# Build hardware queues.
self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
# Create variable timeline signals for each device.
timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{self.dev_name(dev)}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{self.dev_name(dev)}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
self.virt_timeline_signals = {dev: dev.signal_t(HCQBuffer(timeline_sigaddrs[dev], 16), owner=dev, is_timeline=True) for dev in self.devices}
for dev in self.devices:
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
.wait(self.signals['KICK'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
for j,ji in enumerate(jit_cache):
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
# Lazy allocate signals
if PROFILE: self.prof_signals += [enqueue_dev.new_signal(value=0) for _ in range(2)]
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
# Encode waits and start profile timestamp (if needed).
if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2])
# Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner):
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
elif isinstance(ji.prg, (BufferXfer, BufferCopy)):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
for bufid, src in enumerate(cast(list[Buffer], ji.bufs)):
if (inprep_idx:=self.input_replace.get((j, bufid))) is not None: self.input_replace_map[enqueue_dev].add(inprep_idx)
else: cast(HCQAllocator, enqueue_dev.allocator).map(self.hcq_bufs[j][bufid])
enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes)
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
# Encode finish profile timestamp (if needed).
if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1])
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
for dev in self.devices:
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
for sig in self.queue_signals_to_reset: sig.value = 0
self.signals['KICK'].value = self.kickoff_value
for dev in self.devices:
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
**{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
# Update rawbuffers
for (j,i),input_idx in self.input_replace.items():
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr
for dev in self.devices:
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local)
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
if wait:
st = time.perf_counter()
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
return time.perf_counter() - st
return None
def collect_timestamps(self):
# NOTE: Append to any device is fine...
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prof_graph_deps, [s.timestamp for s in self.prof_signals])]
def dev_name(self, dev) -> str: return dev.device.replace(":", "_")
def __del__(self):
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
@staticmethod
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
# Check if all devices are HCQ
all_devs = cast(list[HCQCompiled], dedup(devs + [Device[b.device] for b in ei.bufs if b]))
if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False
# If all of devices are mapped into CPU address space, can use CPU inside the peer group.
cpu_support = all(isinstance(d.timeline_signal.base_buf.view, MMIOInterface) for d in all_devs)
# Check if all devices are within the same peer group. If CPU is supported, don't count it as a separate peer group.
if len(set(d.peer_group for d in all_devs if cpu_support and not d._is_cpu())) > 1: return False
# MOCKGPU is not supported, since it can't execute commands in parallel
copy = (isinstance(ei.prg, BufferCopy) and cast(HCQCompiled, devs[0]).hw_copy_queue_t is not None) and not getenv("MOCKGPU")
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) or copy

View File

@@ -0,0 +1,116 @@
from typing import Any, cast
import ctypes, re, decimal
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
class MTLIndirectCommandType:
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
class MTLResourceUsage:
MTLResourceUsageRead = 0b01
MTLResourceUsageWrite = 0b10
class MetalGraph(GraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
# create metal batch exec
icb_descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"))
msg("setCommandTypes:")(icb_descriptor, MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
msg("setInheritBuffers:")(icb_descriptor, False)
msg("setInheritPipelineState:")(icb_descriptor, False)
msg("setMaxKernelBufferBindCount:")(icb_descriptor, 31)
self.icb = msg("newIndirectCommandBufferWithDescriptor:maxCommandCount:options:", objc_instance)(self.dev.sysdevice,
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache)
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
icb_label = bytes(msg("UTF8String", ctypes.c_char_p)(msg("description", objc_instance)(self.icb))).decode()
self.needs_icb_fix = int((m := re.search(r'AGXG(\d+)XFamily', icb_label)) is None or int(m.group(1)) < 15) # not required on M3+
self.fixedvars = merge_dicts([ji.fixedvars for ji in jit_cache])
self.varlist = self.vars + list(self.fixedvars.keys())
if len(self.varlist): self.int_buf = self.dev.allocator.alloc(len(self.varlist)*dtypes.int32.itemsize)
all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.varlist) else []
for j,ji in enumerate(jit_cache):
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
icb_command = msg("indirectComputeCommandAtIndex:", objc_instance)(self.icb, j)
all_pipelines.append(prg._prg.pipeline_state)
msg("setComputePipelineState:")(icb_command, prg._prg.pipeline_state)
for i,b in enumerate(ji.bufs):
if b is not None and b not in input_rawbuffers:
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars):
msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
global_size, local_size = prg.p.launch_dims(var_vals)
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
msg("setBarrier")(icb_command)
self.all_resources = dedup(all_resources)
self.all_pipelines = dedup(all_pipelines)
self.command_buffer: Any = None
if len(self.varlist): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
self.range = to_struct(0, len(jit_cache))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
# NOTE: old command buffer may not be inflight anymore
if self.command_buffer is not None and PROFILE: self.collect_timestamps()
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
for (j,i),input_idx in self.input_replace.items():
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(computeCommand, to_struct(*global_dims), to_struct(*local_dims))
for var in self.vars: self.int_buf_view[self.varlist.index(var)] = var_vals[var]
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
msg("useResources:count:usage:")(encoder, (objc_id * len(all_resources))(*all_resources), len(all_resources),
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
# this is a O(n) hack to get them used. what should work is:
#encoder.useResources_count_usage_(self.all_pipelines, len(self.all_pipelines), Metal.MTLResourceUsageRead)
# but it fails with "Invalid Resource (00000009:kIOGPUCommandBufferCallbackErrorInvalidResource)"
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
for ps in self.all_pipelines:
msg("setComputePipelineState:")(encoder, ps)
msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(0,0,0), to_struct(0,0,0))
msg("executeCommandsInBuffer:withRange:")(encoder, self.icb, self.range)
msg("endEncoding")(encoder)
msg("setLabel:")(command_buffer, to_ns_str(f"batched {len(self.jit_cache)}"))
msg("commit")(command_buffer)
self.command_buffer = command_buffer
self.dev.mtl_buffers_in_flight.append(command_buffer)
if wait:
wait_check(command_buffer)
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
return None
def collect_timestamps(self):
# create a graph event and evenly space each program
st, en = decimal.Decimal(cmdbuf_st_time(self.command_buffer)) * 1000000, decimal.Decimal(cmdbuf_en_time(self.command_buffer)) * 1000000
ents = [ProfileGraphEntry(self.device, cast(CompiledRunner, ji.prg)._prg.name, i, i+1, is_copy=False) for i,ji in enumerate(self.jit_cache)]
step = (en-st)/len(ents)
self.dev.profile_events += [ProfileGraphEvent(ents, [], [st+step*i for i in range(len(ents)+1)])]
def __del__(self):
if PROFILE and self.command_buffer is not None:
wait_check(self.command_buffer)
self.collect_timestamps()

View File

@@ -0,0 +1,113 @@
import time, itertools
from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec
from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait
from tinygrad.helpers import unwrap, flatten, dedup
from enum import Enum, auto
from dataclasses import replace
from collections import defaultdict
from typing import cast
class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702
def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev)
def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
class RemoteGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, rawbufs, var_vals)
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
c2d = {device.conn: device for device in devices}
self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))}
self.template: list[RemoteRequest] = []
stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list)
clobbered_buffers: set[Buffer] = set()
cur_staging_type: StagingType = StagingType.NONE
def _flush(new_staging_type:StagingType, force_break:bool=False):
nonlocal cur_staging_type
if cur_staging_type == new_staging_type and not force_break: return
# Pre-sync
if cur_staging_type == StagingType.TRANSFER:
for sdev,ddev in itertools.permutations(c2d.values(), 2):
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
self.template.append(Wait(event, session=ddev.session))
# Flush
for dev in devices:
dk = dev_key(dev)
staging = stagings[dk]
if not staging: continue
match cur_staging_type:
case StagingType.GRAPH:
bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk)
dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals))
self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session))
case StagingType.TRANSFER:
st = cast(list[Transfer], staging)
for host in dedup(t.dsession.host for t in st):
sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host]
dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host]
self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session))
staging.clear()
# Post-sync
if cur_staging_type == StagingType.TRANSFER:
for sdev,ddev in itertools.permutations(c2d.values(), 2):
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
self.template.append(Wait(event, session=ddev.session))
cur_staging_type = new_staging_type
clobbered_buffers.clear()
for ji in jit_cache:
match ji.prg:
case CompiledRunner():
_flush(StagingType.GRAPH)
gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs),
tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs),
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
stagings[dev_key(ji.prg.dev)].append(gi)
case BufferXfer():
dest, src = ji.bufs[0:2]
dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device])
assert dest is not None and src is not None, ji
ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf)
if dev_key(dest_dev) == dev_key(src_dev):
_flush(StagingType.GRAPH)
stagings[dev_key(src_dev)].append(ti)
elif dest_dev.conn == src_dev.conn:
_flush(StagingType.NONE)
self.template.append(ti)
else:
_flush(StagingType.TRANSFER, force_break=src in clobbered_buffers)
clobbered_buffers.add(dest)
stagings[dev_key(src_dev)].append(ti)
case _: raise NotImplementedError(ji.prg)
_flush(StagingType.NONE)
def __del__(self):
for req in self.template:
match req:
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
if wait: st = time.perf_counter()
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
for req in self.template:
match req:
case GraphExec():
req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait)
case Transfer():
if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1])
if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1])
case BatchTransfer():
req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums])
case Event()|Wait():
pass # event number can be reused
case _: raise NotImplementedError(req)
RemoteConnection(unwrap(req.session).host).q(req)
if wait:
RemoteConnection(unwrap(req.session).host).batch_submit()
return time.perf_counter() - st