Release 260111

This commit is contained in:
Comma Device
2026-01-11 18:23:29 +08:00
commit 3721ecbf8a
2601 changed files with 855070 additions and 0 deletions

52
tinygrad/viz/README Normal file
View File

@@ -0,0 +1,52 @@
viz is a replacement for:
GRAPH=1
JITGRAPH=1 (this restricts the graph...no need if we can select the schedules)
GRAPHUOPS=1
most uses of DEBUG >= 3
tiny-tools
and a viewer for:
TRACK_MATCH_STATS=2
ProfileEvents
to use:
1. Run tinygrad with VIZ=1 (this saves the pkls and launches the server (new process please!))
2. That's it!
This should be able to:
1. See all schedules (VIZ=1)
2. See all graphs and how they were rewritten (VIZ=1)
3. See generated code (VIZ=1)
4. See profile (click on 'profiler')
bunch of dev rules:
* everything must be responsive to keyboard smashing! lag should never happen
* no requirement to use any of these libraries, but in general libraries are bad
* pure python server + browser ready JS
* serialization is very annoying! UOps are fine...others think carefully
* NOTE: we don't have to save very much
* anything pure functional can be regen by the server (stable tinygrad APIs only!)
user story: viewing code
* tinygrad ran 3 schedules: init the model + first train step, train step, test step
* schedule 1 (123) = main.py:97
* schedule 2 (97) = main.py:97
* schedule 3 (10) = main.py:145
* click "schedule 1", get list of kernels (like DEBUG=2)
* kernel 1 "E_34_34" -- 'sin'
* kernel 2 "R_4545"
* click "E_34_34"
* pre-rewritten UOp graph (step through rewrite here)
* post-rewritten UOp graph
* UOp list
* generated code
user story: debugging scheduler
* tinygrad ran 3 schedules: init the model + first train step, train step, test step
* ...
* click "schedule 1 graph", get a graph of the schedule in UOps
* step through rewrite rules
* see how things are broken into kernels
* see why two kernels didn't fuse
this needs to be tested, both as the server and as the frontend

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,47 @@
/*! `cpp` grammar compiled for Highlight.js 11.10.0 */
(()=>{var e=(()=>{"use strict";return e=>{const t=e.regex,a=e.COMMENT("//","$",{
contains:[{begin:/\\\n/}]
}),n="decltype\\(auto\\)",r="[a-zA-Z_]\\w*::",i="(?!struct)("+n+"|"+t.optional(r)+"[a-zA-Z_]\\w*"+t.optional("<[^<>]+>")+")",s={
className:"type",begin:"\\b[a-z\\d_]*_t\\b"},c={className:"string",variants:[{
begin:'(u8?|U|L)?"',end:'"',illegal:"\\n",contains:[e.BACKSLASH_ESCAPE]},{
begin:"(u8?|U|L)?'(\\\\(x[0-9A-Fa-f]{2}|u[0-9A-Fa-f]{4,8}|[0-7]{3}|\\S)|.)",
end:"'",illegal:"."},e.END_SAME_AS_BEGIN({
begin:/(?:u8?|U|L)?R"([^()\\ ]{0,16})\(/,end:/\)([^()\\ ]{0,16})"/})]},o={
className:"number",variants:[{
begin:"[+-]?(?:(?:[0-9](?:'?[0-9])*\\.(?:[0-9](?:'?[0-9])*)?|\\.[0-9](?:'?[0-9])*)(?:[Ee][+-]?[0-9](?:'?[0-9])*)?|[0-9](?:'?[0-9])*[Ee][+-]?[0-9](?:'?[0-9])*|0[Xx](?:[0-9A-Fa-f](?:'?[0-9A-Fa-f])*(?:\\.(?:[0-9A-Fa-f](?:'?[0-9A-Fa-f])*)?)?|\\.[0-9A-Fa-f](?:'?[0-9A-Fa-f])*)[Pp][+-]?[0-9](?:'?[0-9])*)(?:[Ff](?:16|32|64|128)?|(BF|bf)16|[Ll]|)"
},{
begin:"[+-]?\\b(?:0[Bb][01](?:'?[01])*|0[Xx][0-9A-Fa-f](?:'?[0-9A-Fa-f])*|0(?:'?[0-7])*|[1-9](?:'?[0-9])*)(?:[Uu](?:LL?|ll?)|[Uu][Zz]?|(?:LL?|ll?)[Uu]?|[Zz][Uu]|)"
}],relevance:0},l={className:"meta",begin:/#\s*[a-z]+\b/,end:/$/,keywords:{
keyword:"if else elif endif define undef warning error line pragma _Pragma ifdef ifndef include"
},contains:[{begin:/\\\n/,relevance:0},e.inherit(c,{className:"string"}),{
className:"string",begin:/<.*?>/},a,e.C_BLOCK_COMMENT_MODE]},u={
className:"title",begin:t.optional(r)+e.IDENT_RE,relevance:0
},d=t.optional(r)+e.IDENT_RE+"\\s*\\(",p={
type:["bool","char","char16_t","char32_t","char8_t","double","float","int","long","short","void","wchar_t","unsigned","signed","const","static"],
keyword:["alignas","alignof","and","and_eq","asm","atomic_cancel","atomic_commit","atomic_noexcept","auto","bitand","bitor","break","case","catch","class","co_await","co_return","co_yield","compl","concept","const_cast|10","consteval","constexpr","constinit","continue","decltype","default","delete","do","dynamic_cast|10","else","enum","explicit","export","extern","false","final","for","friend","goto","if","import","inline","module","mutable","namespace","new","noexcept","not","not_eq","nullptr","operator","or","or_eq","override","private","protected","public","reflexpr","register","reinterpret_cast|10","requires","return","sizeof","static_assert","static_cast|10","struct","switch","synchronized","template","this","thread_local","throw","transaction_safe","transaction_safe_dynamic","true","try","typedef","typeid","typename","union","using","virtual","volatile","while","xor","xor_eq"],
literal:["NULL","false","nullopt","nullptr","true"],built_in:["_Pragma"],
_type_hints:["any","auto_ptr","barrier","binary_semaphore","bitset","complex","condition_variable","condition_variable_any","counting_semaphore","deque","false_type","future","imaginary","initializer_list","istringstream","jthread","latch","lock_guard","multimap","multiset","mutex","optional","ostringstream","packaged_task","pair","promise","priority_queue","queue","recursive_mutex","recursive_timed_mutex","scoped_lock","set","shared_future","shared_lock","shared_mutex","shared_timed_mutex","shared_ptr","stack","string_view","stringstream","timed_mutex","thread","true_type","tuple","unique_lock","unique_ptr","unordered_map","unordered_multimap","unordered_multiset","unordered_set","variant","vector","weak_ptr","wstring","wstring_view"]
},_={className:"function.dispatch",relevance:0,keywords:{
_hint:["abort","abs","acos","apply","as_const","asin","atan","atan2","calloc","ceil","cerr","cin","clog","cos","cosh","cout","declval","endl","exchange","exit","exp","fabs","floor","fmod","forward","fprintf","fputs","free","frexp","fscanf","future","invoke","isalnum","isalpha","iscntrl","isdigit","isgraph","islower","isprint","ispunct","isspace","isupper","isxdigit","labs","launder","ldexp","log","log10","make_pair","make_shared","make_shared_for_overwrite","make_tuple","make_unique","malloc","memchr","memcmp","memcpy","memset","modf","move","pow","printf","putchar","puts","realloc","scanf","sin","sinh","snprintf","sprintf","sqrt","sscanf","std","stderr","stdin","stdout","strcat","strchr","strcmp","strcpy","strcspn","strlen","strncat","strncmp","strncpy","strpbrk","strrchr","strspn","strstr","swap","tan","tanh","terminate","to_underlying","tolower","toupper","vfprintf","visit","vprintf","vsprintf"]
},
begin:t.concat(/\b/,/(?!decltype)/,/(?!if)/,/(?!for)/,/(?!switch)/,/(?!while)/,e.IDENT_RE,t.lookahead(/(<[^<>]+>|)\s*\(/))
},m=[_,l,s,a,e.C_BLOCK_COMMENT_MODE,o,c],f={variants:[{begin:/=/,end:/;/},{
begin:/\(/,end:/\)/},{beginKeywords:"new throw return else",end:/;/}],
keywords:p,contains:m.concat([{begin:/\(/,end:/\)/,keywords:p,
contains:m.concat(["self"]),relevance:0}]),relevance:0},g={className:"function",
begin:"("+i+"[\\*&\\s]+)+"+d,returnBegin:!0,end:/[{;=]/,excludeEnd:!0,
keywords:p,illegal:/[^\w\s\*&:<>.]/,contains:[{begin:n,keywords:p,relevance:0},{
begin:d,returnBegin:!0,contains:[u],relevance:0},{begin:/::/,relevance:0},{
begin:/:/,endsWithParent:!0,contains:[c,o]},{relevance:0,match:/,/},{
className:"params",begin:/\(/,end:/\)/,keywords:p,relevance:0,
contains:[a,e.C_BLOCK_COMMENT_MODE,c,o,s,{begin:/\(/,end:/\)/,keywords:p,
relevance:0,contains:["self",a,e.C_BLOCK_COMMENT_MODE,c,o,s]}]
},s,a,e.C_BLOCK_COMMENT_MODE,l]};return{name:"C++",
aliases:["cc","c++","h++","hpp","hh","hxx","cxx"],keywords:p,illegal:"</",
classNameAliases:{"function.dispatch":"built_in"},
contains:[].concat(f,g,_,m,[l,{
begin:"\\b(deque|list|queue|priority_queue|pair|stack|vector|map|set|bitset|multiset|multimap|unordered_map|unordered_set|unordered_multiset|unordered_multimap|array|tuple|optional|variant|function)\\s*<(?!<)",
end:">",keywords:p,contains:["self",s]},{begin:e.IDENT_RE+"::",keywords:p},{
match:[/\b(?:enum(?:\s+(?:class|struct))?|class|struct|union)/,/\s+/,/\w+/],
className:{1:"keyword",3:"title.class"}}])}}})();hljs.registerLanguage("cpp",e)
})();

View File

@@ -0,0 +1,42 @@
/*! `python` grammar compiled for Highlight.js 11.10.0 */
(()=>{var e=(()=>{"use strict";return e=>{
const n=e.regex,a=/[\p{XID_Start}_]\p{XID_Continue}*/u,s=["and","as","assert","async","await","break","case","class","continue","def","del","elif","else","except","finally","for","from","global","if","import","in","is","lambda","match","nonlocal|10","not","or","pass","raise","return","try","while","with","yield"],t={
$pattern:/[A-Za-z]\w+|__\w+__/,keyword:s,
built_in:["__import__","abs","all","any","ascii","bin","bool","breakpoint","bytearray","bytes","callable","chr","classmethod","compile","complex","delattr","dict","dir","divmod","enumerate","eval","exec","filter","float","format","frozenset","getattr","globals","hasattr","hash","help","hex","id","input","int","isinstance","issubclass","iter","len","list","locals","map","max","memoryview","min","next","object","oct","open","ord","pow","print","property","range","repr","reversed","round","set","setattr","slice","sorted","staticmethod","str","sum","super","tuple","type","vars","zip"],
literal:["__debug__","Ellipsis","False","None","NotImplemented","True"],
type:["Any","Callable","Coroutine","Dict","List","Literal","Generic","Optional","Sequence","Set","Tuple","Type","Union"]
},i={className:"meta",begin:/^(>>>|\.\.\.) /},r={className:"subst",begin:/\{/,
end:/\}/,keywords:t,illegal:/#/},l={begin:/\{\{/,relevance:0},o={
className:"string",contains:[e.BACKSLASH_ESCAPE],variants:[{
begin:/([uU]|[bB]|[rR]|[bB][rR]|[rR][bB])?'''/,end:/'''/,
contains:[e.BACKSLASH_ESCAPE,i],relevance:10},{
begin:/([uU]|[bB]|[rR]|[bB][rR]|[rR][bB])?"""/,end:/"""/,
contains:[e.BACKSLASH_ESCAPE,i],relevance:10},{
begin:/([fF][rR]|[rR][fF]|[fF])'''/,end:/'''/,
contains:[e.BACKSLASH_ESCAPE,i,l,r]},{begin:/([fF][rR]|[rR][fF]|[fF])"""/,
end:/"""/,contains:[e.BACKSLASH_ESCAPE,i,l,r]},{begin:/([uU]|[rR])'/,end:/'/,
relevance:10},{begin:/([uU]|[rR])"/,end:/"/,relevance:10},{
begin:/([bB]|[bB][rR]|[rR][bB])'/,end:/'/},{begin:/([bB]|[bB][rR]|[rR][bB])"/,
end:/"/},{begin:/([fF][rR]|[rR][fF]|[fF])'/,end:/'/,
contains:[e.BACKSLASH_ESCAPE,l,r]},{begin:/([fF][rR]|[rR][fF]|[fF])"/,end:/"/,
contains:[e.BACKSLASH_ESCAPE,l,r]},e.APOS_STRING_MODE,e.QUOTE_STRING_MODE]
},b="[0-9](_?[0-9])*",c=`(\\b(${b}))?\\.(${b})|\\b(${b})\\.`,d="\\b|"+s.join("|"),g={
className:"number",relevance:0,variants:[{
begin:`(\\b(${b})|(${c}))[eE][+-]?(${b})[jJ]?(?=${d})`},{begin:`(${c})[jJ]?`},{
begin:`\\b([1-9](_?[0-9])*|0+(_?0)*)[lLjJ]?(?=${d})`},{
begin:`\\b0[bB](_?[01])+[lL]?(?=${d})`},{begin:`\\b0[oO](_?[0-7])+[lL]?(?=${d})`
},{begin:`\\b0[xX](_?[0-9a-fA-F])+[lL]?(?=${d})`},{begin:`\\b(${b})[jJ](?=${d})`
}]},p={className:"comment",begin:n.lookahead(/# type:/),end:/$/,keywords:t,
contains:[{begin:/# type:/},{begin:/#/,end:/\b\B/,endsWithParent:!0}]},m={
className:"params",variants:[{className:"",begin:/\(\s*\)/,skip:!0},{begin:/\(/,
end:/\)/,excludeBegin:!0,excludeEnd:!0,keywords:t,
contains:["self",i,g,o,e.HASH_COMMENT_MODE]}]};return r.contains=[o,g,i],{
name:"Python",aliases:["py","gyp","ipython"],unicodeRegex:!0,keywords:t,
illegal:/(<\/|\?)|=>/,contains:[i,g,{scope:"variable.language",match:/\bself\b/
},{beginKeywords:"if",relevance:0},{match:/\bor\b/,scope:"keyword"
},o,p,e.HASH_COMMENT_MODE,{match:[/\bdef/,/\s+/,a],scope:{1:"keyword",
3:"title.function"},contains:[m]},{variants:[{
match:[/\bclass/,/\s+/,a,/\s*/,/\(\s*/,a,/\s*\)/]},{match:[/\bclass/,/\s+/,a]}],
scope:{1:"keyword",3:"title.class",6:"title.class.inherited"}},{
className:"meta",begin:/^[\t ]*@/,end:/(?=#)|$/,contains:[g,m,o]}]}}})()
;hljs.registerLanguage("python",e)})();

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,9 @@
/*!
Theme: Default
Description: Original highlight.js style
Author: (c) Ivan Sagalaev <maniac@softwaremaniacs.org>
Maintainer: @highlightjs/core-team
Website: https://highlightjs.org/
License: see project LICENSE
Touched: 2021
*/pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{background:#f3f3f3;color:#444}.hljs-comment{color:#697070}.hljs-punctuation,.hljs-tag{color:#444a}.hljs-tag .hljs-attr,.hljs-tag .hljs-name{color:#444}.hljs-attribute,.hljs-doctag,.hljs-keyword,.hljs-meta .hljs-keyword,.hljs-name,.hljs-selector-tag{font-weight:700}.hljs-deletion,.hljs-number,.hljs-quote,.hljs-selector-class,.hljs-selector-id,.hljs-string,.hljs-template-tag,.hljs-type{color:#800}.hljs-section,.hljs-title{color:#800;font-weight:700}.hljs-link,.hljs-operator,.hljs-regexp,.hljs-selector-attr,.hljs-selector-pseudo,.hljs-symbol,.hljs-template-variable,.hljs-variable{color:#ab5656}.hljs-literal{color:#695}.hljs-addition,.hljs-built_in,.hljs-bullet,.hljs-code{color:#397300}.hljs-meta{color:#1f7199}.hljs-meta .hljs-string{color:#38a}.hljs-emphasis{font-style:italic}.hljs-strong{font-weight:700}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,8 @@
pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}/*!
Theme: Tokyo-night-Dark
origin: https://github.com/enkia/tokyo-night-vscode-theme
Description: Original highlight.js style
Author: (c) Henri Vandersleyen <hvandersleyen@gmail.com>
License: see project LICENSE
Touched: 2022
*/.hljs-comment,.hljs-meta{color:#565f89}.hljs-deletion,.hljs-doctag,.hljs-regexp,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-selector-pseudo,.hljs-tag,.hljs-template-tag,.hljs-variable.language_{color:#f7768e}.hljs-link,.hljs-literal,.hljs-number,.hljs-params,.hljs-template-variable,.hljs-type,.hljs-variable{color:#ff9e64}.hljs-attribute,.hljs-built_in{color:#e0af68}.hljs-keyword,.hljs-property,.hljs-subst,.hljs-title,.hljs-title.class_,.hljs-title.class_.inherited__,.hljs-title.function_{color:#7dcfff}.hljs-selector-tag{color:#73daca}.hljs-addition,.hljs-bullet,.hljs-quote,.hljs-string,.hljs-symbol{color:#9ece6a}.hljs-code,.hljs-formula,.hljs-section{color:#7aa2f7}.hljs-attr,.hljs-char.escape_,.hljs-keyword,.hljs-name,.hljs-operator{color:#bb9af7}.hljs-punctuation{color:#c0caf5}.hljs{background:#1a1b26;color:#9aa5ce}.hljs-emphasis{font-style:italic}.hljs-strong{font-weight:700}

15
tinygrad/viz/fetch_assets.sh Executable file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
fetch() {
echo "fetch $1"
mkdir -p assets/$1
rmdir assets/$1
curl -o assets/$1 https://$1
}
fetch "d3js.org/d3.v7.min.js"
fetch "dagrejs.github.io/project/dagre/latest/dagre.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"
fetch "unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css"

364
tinygrad/viz/index.html Normal file
View File

@@ -0,0 +1,364 @@
<!DOCTYPE html>
<html>
<head>
<title>tinygrad viz</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
<script src="assets/d3js.org/d3.v7.min.js" charset="utf-8"></script>
<script src="assets/dagrejs.github.io/project/dagre/latest/dagre.min.js"></script>
<link rel="stylesheet" href="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"></script>
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js"></script>
<link rel="stylesheet" href="assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css" />
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
html, body {
color: #f0f0f5;
margin: 0;
padding: 0;
width: 100%;
height: 100%;
font-family: sans-serif;
font-optical-sizing: auto;
font-weight: 400;
font-style: normal;
font-variation-settings: "wdth" 100;
font-size: 14px;
overflow: hidden;
background-color: #08090e;
}
a {
color: #4a90e2;
}
ul {
padding: 0;
opacity: 0.6;
white-space: nowrap;
cursor: pointer;
}
ul.active {
opacity: 1;
}
ul > ul {
display: none;
}
ul.expanded > ul {
display: block;
}
ul.disabled {
opacity: 0.4;
pointer-events: none;
}
label {
display: inline-flex;
align-items: center;
gap: 4px;
}
.graph svg {
width: 100%;
height: 100%;
}
svg * {
cursor: default;
user-select: none;
}
g.clickable * {
cursor: pointer;
user-select: auto;
}
g.tag circle {
fill: #FFD700;
stroke: #B8860B;
}
g.port circle {
fill: #b3dcc2;
}
g.tag circle, #edge-labels circle {
stroke-width: 0.8;
}
g.tag text, #edge-labels text {
text-anchor: middle;
font-size: 6px;
fill: #08090e;
}
.label :is(text, p) {
font-weight: 350;
}
g.node rect {
stroke-width: 1.4;
stroke: #4a4b57;
}
g.overlay rect {
fill: rgba(26, 27, 38, 0.5);
}
.edgePath {
stroke: #4a4b57;
fill: none;
stroke-width: 1.4px;
}
g.node.highlight rect, .edgePath.highlight, g.port circle {
stroke: #89C9A2;
}
g.highlight.child rect, .edgePath.highlight.child {
stroke: #C888B0;
}
#edge-labels g.port.highlight {
display: block
}
#edge-labels g.port {
display: none
}
#arrowhead {
fill: #4a4b57;
}
.main-container {
display: flex;
width: 100%;
height: 100%;
position: relative;
}
.container {
flex: 0 0 auto;
background-color: #0f1018;
padding: 20px;
z-index: 2;
position: relative;
height: 100%;
}
.metadata > * + *, .rewrite-container > * + *, .ctx-list > * + * {
margin-top: 12px;
}
.ctx-list > ul > * + * {
margin-top: 4px;
}
.graph {
position: absolute;
inset: 0;
z-index: 1;
}
.profiler, .disasm {
flex: 1 1 auto;
min-width: 0;
width: 100%;
height: calc(100% - 50px);
margin-top: 50px;
overflow-y: auto;
overflow-x: hidden;
scrollbar-gutter: stable;
}
.ctx-list-parent {
width: 15%;
padding-top: 50px;
border-right: 1px solid #4a4b56;
}
.ctx-list, .metadata {
width: 100%;
height: 100%;
overflow-y: auto;
scrollbar-gutter: stable;
}
.metadata-parent {
width: 20%;
border-left: 1px solid #4a4b56;
margin-left: auto;
}
.resize-handle {
position: absolute;
top: 0;
bottom: 0;
width: 20px;
height: 100%;
cursor: col-resize;
z-index: 3;
background-color: transparent;
}
.floating-container {
position: fixed;
top: 10px;
left: 20px;
z-index: 4;
display: flex;
flex-direction: row;
gap: 8px;
}
.btn {
outline: none;
background-color: #1a1b26;
border: 1px solid #4a4b56;
color: #f0f0f5;
border-radius: 4px;
padding: 6px;
cursor: pointer;
height: 32px;
display: flex;
align-items: center;
justify-content: center;
text-decoration: none;
}
.btn:hover {
background-color: #2a2b36;
}
.collapsed .container {
display: none;
}
.rewrite-list {
display: flex;
flex-wrap: wrap;
}
.rewrite-list > ul {
padding: 2px;
}
.wrap {
word-wrap: break-word;
white-space: pre-wrap;
}
pre code.hljs {
overflow-y: auto;
max-height: 30vh;
padding: 8px;
}
#progress-message {
position: absolute;
z-index: 2;
left: 50%;
top: 2%;
color: #ffd230;
display: none;
}
#tooltip {
position: absolute;
z-index: 4;
background-color: #1e2029;
padding: 4px 8px;
max-width: 164px;
border-radius: 4px;
pointer-events: none;
display: none;
font-size: 10px;
}
#device-list > div {
min-height: 32px;
width: 132px;
overflow-x: auto;
overflow-y: hidden;
white-space: nowrap;
display: flex;
}
#device-list > div:hover {
background-color: rgba(20, 23, 35, 0.3);
}
#device-list {
height: fit-content;
}
.raw-text {
padding: 0 8px;
width: 100%;
height: 100%;
max-height: 100vh;
overflow-x: auto;
}
.raw-text code {
max-height: none !important;
}
table {
width: 100%;
border-collapse: separate;
border-spacing: 0;
background-color: #1a1b26;
color: #f0f0f5;
font-size: 0.95em;
}
table td {
border-bottom: 1px solid #4a4b56;
vertical-align: top;
}
table tr:last-child > td {
border-bottom: none;
}
tr.main-row:hover {
background-color: #2a2d3a;
}
tr.sub-row {
max-width: 150px;
}
tr.main-row > td, tr.sub-row > td {
padding: 8px 12px;
}
tr.code-row > td:first-child {
font-family: monospace;
}
td.pct-row > div {
height: 12px;
width: 100%;
display: flex;
}
td.pct-row > div > div {
height: 100%;
}
thead {
position: sticky;
top: 0;
z-index: 10;
background-color: #20222e;
}
thead th {
text-align: left;
padding: 10px 12px;
font-weight: 600;
border-bottom: 1px solid #4a4b56;
font-size: 0.95em;
letter-spacing: 0.03em;
}
.legend {
display: flex;
align-items: center;
}
.legend > div {
width: 0.95em;
height: 0.95em;
margin-right: 4px;
}
</style>
</head>
<body>
<div class="main-container">
<div class="floating-container">
<button class="btn collapse-btn">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" width="20"><path d="M15 19l-7-7 7-7"/></svg>
</button>
<button class="btn" id="zoom-to-fit-btn" aria-label="Fit graph">
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" width="20">
<path stroke-linecap="round" stroke-linejoin="round" d="M7.5 3.75H6A2.25 2.25 0 0 0 3.75 6v1.5M16.5 3.75H18A2.25 2.25 0 0 1 20.25 6v1.5m0 9V18A2.25 2.25 0 0 1 18 20.25h-1.5m-9 0H6A2.25 2.25 0 0 1 3.75 18v-1.5M15 12a3 3 0 1 1-6 0 3 3 0 0 1 6 0Z" />
</svg>
</button>
</div>
<div id="progress-message"></div>
<div class="container ctx-list-parent"><div class="ctx-list"></div></div>
<div class="view profiler"></div>
<div class="view disasm"></div>
<div class="view graph">
<svg id="graph-svg" preserveAspectRatio="xMidYMid meet">
<g id="render">
<g id="edges"></g>
<g id="nodes"></g>
<g id="edge-labels"></g> <!-- NOTE: this ensures edge labels are always on top -->
</g>
<defs>
<marker id="arrowhead" viewBox="0 -5 10 10" refX="10" refY="0" markerWidth="6" markerHeight="6" orient="auto">
<path d="M0,-5L10,0L0,5" fill="context-stroke"></path>
</marker>
</defs>
</svg>
</div>
<div class="container metadata-parent"><div class="metadata"></div></div>
</div>
<div id="tooltip" class="wrap"></div>
<script src="/js/index.js"></script>
</body>
</html>

769
tinygrad/viz/js/index.js Normal file
View File

@@ -0,0 +1,769 @@
// ** graph helpers
const displayGraph = (cls) => {
for (const e of document.getElementsByClassName("view")) e.style.display = e.classList.contains(cls) ? "flex" : "none";
}
const darkenHex = (h, p = 0) =>
`#${(
c = parseInt(h.slice(1), 16),
f = 1 - p / 100,
((c >> 16 & 255) * f | 0) << 16 |
((c >> 8 & 255) * f | 0) << 8 |
((c & 255) * f | 0)
).toString(16).padStart(6, '0')}`;
const ANSI_COLORS = ["#b3b3b3", "#ff6666", "#66b366", "#ffff66", "#6666ff", "#ff66ff", "#66ffff", "#ffffff"];
const parseColors = (name, defaultColor="#ffffff") => Array.from(name.matchAll(/(?:\u001b\[(\d+)m([\s\S]*?)\u001b\[0m)|([^\u001b]+)/g),
([_, code, colored_st, st]) => ({ st: colored_st ?? st, color: code != null ? ANSI_COLORS[(parseInt(code)-30+60)%60] : defaultColor }));
const rect = (s) => (typeof s === "string" ? document.querySelector(s) : s).getBoundingClientRect();
let timeout = null;
const updateProgress = ({ start }) => {
clearTimeout(timeout);
const msg = document.getElementById("progress-message");
msg.style.display = "none";
if (start) {
msg.innerText = "Rendering new graph...";
timeout = setTimeout(() => { msg.style.display = "block"; }, 2000);
}
}
// ** UOp graph
function intersectRect(r1, r2) {
const dx = r2.x-r1.x;
const dy = r2.y-r1.y;
if (dx === 0 && dy === 0) throw new Error("Invalid node coordinates, rects must not overlap");
const scaleX = dx !== 0 ? (r1.width/2)/Math.abs(dx) : Infinity;
const scaleY = dy !== 0 ? (r1.height/2)/Math.abs(dy) : Infinity;
const scale = Math.min(scaleX, scaleY);
return {x:r1.x+dx*scale, y:r1.y+dy*scale};
}
function addTags(root) {
root.selectAll("circle").data(d => [d]).join("circle").attr("r", 5);
root.selectAll("text").data(d => [d]).join("text").text(d => d).attr("dy", "0.35em");
}
let [workerUrl, worker] = [null, null];
async function initWorker() {
const resp = await Promise.all(["/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js","/js/worker.js"].map(u => fetch(u)));
workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" }));
}
function renderDag(graph, additions, recenter) {
// start calculating the new layout (non-blocking)
updateProgress({ start:true });
if (worker != null) worker.terminate();
worker = new Worker(workerUrl);
worker.postMessage({graph, additions});
worker.onmessage = (e) => {
displayGraph("graph");
updateProgress({ start:false });
const g = dagre.graphlib.json.read(e.data);
// draw nodes
const STROKE_WIDTH = 1.4;
d3.select("#graph-svg").on("click", () => d3.selectAll(".highlight").classed("highlight", false));
const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g").attr("class", d => d.className ?? "node")
.attr("transform", d => `translate(${d.x},${d.y})`).classed("clickable", d => d.ref != null).on("click", (e,d) => {
if (d.ref != null) return setCtxWithHistory(d.ref);
const parents = g.predecessors(d.id);
const children = g.successors(d.id);
if (parents == null && children == null) return;
const src = [...parents, ...children, d.id];
nodes.classed("highlight", n => src.includes(n.id)).classed("child", n => children.includes(n.id));
const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : "";
d3.select("#edges").selectAll("path.edgePath").attr("class", e => matchEdge(e.v, e.w)+"edgePath");
d3.select("#edge-labels").selectAll("g.port").attr("class", (_, i, n) => matchEdge(...n[i].id.split("-"))+"port");
e.stopPropagation();
});
nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color)
.attr("x", d => -d.width/2).attr("y", d => -d.height/2);
nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => {
const x = (d.width-d.padding*2)/2;
const y = (d.height-d.padding*2)/2+STROKE_WIDTH;
return `translate(-${x}, -${y})`;
}).selectAll("text").data(d => {
const ret = [[]];
for (const { st, color } of parseColors(d.label, defaultColor="initial")) {
for (const [i, l] of st.split("\n").entries()) {
if (i > 0) ret.push([]);
ret.at(-1).push({ st:l, color });
}
}
return [ret];
}).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")
.attr("fill", d => darkenHex(d.color, 25)).text(d => d.st).attr("xml:space", "preserve");
addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag")
.attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag));
// draw edges
const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis), edges = g.edges();
d3.select("#edges").selectAll("path.edgePath").data(edges).join("path").attr("class", "edgePath").attr("d", (e) => {
const edge = g.edge(e);
const points = edge.points.slice(1, edge.points.length-1);
points.unshift(intersectRect(g.node(e.v), points[0]));
points.push(intersectRect(g.node(e.w), points[points.length-1]));
return line(points);
}).attr("marker-end", "url(#arrowhead)");
addTags(d3.select("#edge-labels").selectAll("g").data(edges).join("g").attr("transform", (e) => {
// get a point near the end
const [p1, p2] = g.edge(e).points.slice(-2);
const dx = p2.x-p1.x;
const dy = p2.y-p1.y;
// normalize to the unit vector
const len = Math.sqrt(dx*dx + dy*dy);
const ux = dx / len;
const uy = dy / len;
// avoid overlap with the arrowhead
const offset = 17;
const x = p2.x - ux * offset;
const y = p2.y - uy * offset;
return `translate(${x}, ${y})`
}).attr("class", e => g.edge(e).label.type).attr("id", e => `${e.v}-${e.w}`).datum(e => g.edge(e).label.text));
if (recenter) document.getElementById("zoom-to-fit-btn").click();
};
}
// ** profiler graph
function formatTime(ts, dur=ts) {
if (dur<=1e3) return `${ts.toFixed(2)}us`;
if (dur<=1e6) return `${(ts*1e-3).toFixed(2)}ms`;
return `${(ts*1e-6).toFixed(2)}s`;
}
const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit;
const colorScheme = {TINY:["#1b5745", "#354f52", "#354f52", "#1d2e62", "#63b0cd"],
DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
BUFFER:["#3A57B7","#5066C1","#6277CD","#7488D8","#8A9BE3","#A3B4F2"],
CATEGORICAL:["#ff8080", "#F4A261", "#C8F9D4", "#8D99AE", "#F4A261", "#ffffa2", "#ffffc0", "#87CEEB"],}
const cycleColors = (lst, i) => lst[i%lst.length];
const rescaleTrack = (source, tid, k) => {
for (const e of source.shapes) {
for (let i=0; i<e.y0.length; i++) {
e.y0[i] = e.y0[i]*k;
e.y1[i] = e.y1[i]*k;
}
}
const change = (source.height*k)-source.height;
const div = document.getElementById(tid);
div.style.height = rect(div).height+change+"px";
source.height = source.height*k;
return change;
}
const drawLine = (ctx, x, y, opts) => {
ctx.beginPath();
ctx.moveTo(x[0], y[0]);
ctx.lineTo(x[1], y[1]);
ctx.fillStyle = ctx.strokeStyle = opts?.color || "#f0f0f5";
ctx.stroke();
}
var data, focusedDevice, canvasZoom, zoomLevel = d3.zoomIdentity;
async function renderProfiler() {
displayGraph("profiler");
d3.select(".metadata").html("");
// layout once!
if (data != null) return updateProgress({ start:false });
const profiler = d3.select(".profiler").html("");
const buf = await (await fetch("/get_profile")).arrayBuffer();
const view = new DataView(buf);
let offset = 0;
const u8 = () => { const ret = view.getUint8(offset); offset += 1; return ret; }
const u32 = () => { const ret = view.getUint32(offset, true); offset += 4; return ret; }
const u64 = () => { const ret = new Number(view.getBigUint64(offset, true)); offset += 8; return ret; }
const f32 = () => { const ret = view.getFloat32(offset, true); offset += 4; return ret; }
const optional = (i) => i === 0 ? null : i-1;
const dur = u32(), tracePeak = u64(), indexLen = u32(), layoutsLen = u32();
const textDecoder = new TextDecoder("utf-8");
const { strings, dtypeSize, markers } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen;
// place devices on the y axis and set vertical positions
const [tickSize, padding] = [10, 8];
const deviceList = profiler.append("div").attr("id", "device-list").style("padding-top", tickSize+padding+"px");
const canvas = profiler.append("canvas").attr("id", "timeline").node();
// NOTE: scrolling via mouse can only zoom the graph
canvas.addEventListener("wheel", e => (e.stopPropagation(), e.preventDefault()), { passive:false });
const ctx = canvas.getContext("2d");
const canvasTop = rect(canvas).top;
// color by key (name/device)
const colorMap = new Map();
data = {tracks:new Map(), axes:{}};
const heightScale = d3.scaleLinear().domain([0, tracePeak]).range([4,maxheight=100]);
for (let i=0; i<layoutsLen; i++) {
const nameLen = view.getUint8(offset, true); offset += 1;
const k = textDecoder.decode(new Uint8Array(buf, offset, nameLen)); offset += nameLen;
const div = deviceList.append("div").attr("id", k).text(k).style("padding", padding+"px");
const { y:baseY, height:baseHeight } = rect(div.node());
const offsetY = baseY-canvasTop+padding/2;
const shapes = [], visible = [];
const EventTypes = {TIMELINE:0, MEMORY:1};
const eventType = u8(), eventsLen = u32();
if (eventType === EventTypes.TIMELINE) {
const levelHeight = baseHeight-padding;
const levels = [];
data.tracks.set(k, { shapes, visible, offsetY });
let colorKey, ref;
for (let j=0; j<eventsLen; j++) {
const e = {name:strings[u32()], ref:optional(u32()), st:u32(), dur:f32(), info:strings[u32()] || null};
// find a free level to put the event
let depth = levels.findIndex(levelEt => e.st >= levelEt);
const et = e.st+Math.trunc(e.dur);
if (depth === -1) {
depth = levels.length;
levels.push(et);
} else levels[depth] = et;
if (depth === 0) colorKey = e.name.split(" ")[0];
if (!colorMap.has(colorKey)) colorMap.set(colorKey, d3.rgb(cycleColors(colorScheme[k.split(":")[0]] ?? colorScheme.DEFAULT, colorMap.size)));
const base = colorMap.get(colorKey), s = Math.min(Math.pow(1/0.7, depth), 240 / Math.max(base.r, base.g, base.b));
const fillColor = d3.rgb(base.r*s, base.g*s, base.b*s).toString();
const label = parseColors(e.name).map(({ color, st }) => ({ color, st, width:ctx.measureText(st).width }));
if (e.ref != null) ref = {ctx:e.ref, step:0};
else if (ref != null) {
const start = ref.step>0 ? ref.step+1 : 0;
const stepIdx = ctxs[ref.ctx+1].steps.findIndex((s, i) => i >= start && s.name == e.name);
ref = stepIdx === -1 ? null : {ctx:ref.ctx, step:stepIdx};
}
const htmlLabel = label.map(({color, st}) => `<span style="color:${color}">${st}</span>`).join('');
const arg = { tooltipText:htmlLabel+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), ...ref };
// offset y by depth
shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label, fillColor });
}
div.style("height", levelHeight*levels.length+padding+"px").style("pointerEvents", "none");
} else {
const peak = u64();
let x = 0, y = 0;
const buf_shapes = new Map(), temp = new Map();
const timestamps = [];
for (let j=0; j<eventsLen; j++) {
const alloc = u8(), ts = u32(), key = u32();
if (alloc) {
const dtype = strings[u32()], sz = u64(), nbytes = dtypeSize[dtype]*sz;
const shape = {x:[x], y:[y], dtype, sz, nbytes, key};
buf_shapes.set(key, shape); temp.set(key, shape);
timestamps.push(ts);
x += 1; y += nbytes;
} else {
const free = buf_shapes.get(key);
timestamps.push(ts);
x += 1; y -= free.nbytes;
free.x.push(x);
free.y.push(free.y.at(-1));
temp.delete(key);
for (const [k, v] of temp) {
if (k <= key) continue;
v.x.push(x, x);
v.y.push(v.y.at(-1), v.y.at(-1)-free.nbytes);
}
}
}
for (const [_, v] of temp) {
v.x.push(x);
v.y.push(v.y.at(-1));
}
timestamps.push(dur);
const height = heightScale(peak);
const yscale = d3.scaleLinear().domain([0, peak]).range([height, 0]);
for (const [num, {dtype, sz, nbytes, y, x:steps}] of buf_shapes) {
const x = steps.map(s => timestamps[s]);
const dur = x.at(-1)-x[0];
const arg = {tooltipText:`${dtype} len:${formatUnit(sz)}\n${formatUnit(nbytes, "B")}\nnum:${num}\nalive for ${formatTime(dur)}`};
shapes.push({ x, y0:y.map(yscale), y1:y.map(y0 => yscale(y0+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, shapes.length) });
}
data.tracks.set(k, { shapes, visible, offsetY, height, peak, scaleFactor:maxheight*4/height });
div.style("height", height+padding+"px").style("cursor", "pointer").on("click", (e) => {
const newFocus = e.currentTarget.id === focusedDevice ? null : e.currentTarget.id;
let offset = 0;
for (const [tid, track] of data.tracks) {
track.offsetY += offset;
if (tid === newFocus) offset += rescaleTrack(track, tid, track.scaleFactor);
else if (tid === focusedDevice) offset += rescaleTrack(track, tid, 1/track.scaleFactor);
}
data.axes.y = newFocus != null ? { domain:[0, (t=data.tracks.get(newFocus)).peak], range:[t.offsetY+t.height, t.offsetY], fmt:"B" } : null;
focusedDevice = newFocus;
return resize();
});
}
}
updateProgress({ start:false });
// draw events on a timeline
const dpr = window.devicePixelRatio || 1;
const ellipsisWidth = ctx.measureText("...").width;
function render(transform) {
zoomLevel = transform;
ctx.clearRect(0, 0, canvas.clientWidth, canvas.clientHeight);
// rescale to match current zoom
const xscale = d3.scaleLinear().domain([0, dur]).range([0, canvas.clientWidth]);
const visibleX = xscale.range().map(zoomLevel.invertX, zoomLevel).map(xscale.invert, xscale);
const st = visibleX[0], et = visibleX[1];
xscale.domain(visibleX);
// draw shapes
for (const [_, { offsetY, shapes, visible }] of data.tracks) {
visible.length = 0;
for (const e of shapes) {
// generic polygon
if (e.width == null) {
if (e.x[0]>et || e.x.at(-1)<st) continue;
const x = e.x.map(xscale);
ctx.beginPath();
ctx.moveTo(x[0], offsetY+e.y0[0]);
for (let i=1; i<x.length; i++) {
ctx.lineTo(x[i], offsetY+e.y0[i]);
visible.push({ x0:x[i-1], x1:x[i], y0:offsetY+e.y1[i-1], y1:offsetY+e.y0[i], arg:e.arg });
}
for (let i=x.length-1; i>=0; i--) ctx.lineTo(x[i], offsetY+e.y1[i]);
ctx.closePath();
ctx.fillStyle = e.fillColor; ctx.fill();
continue;
}
// contiguous rect
if (e.x>et || e.x+e.width<st) continue;
const x = xscale(e.x);
const y = offsetY+e.y;
const width = xscale(e.x+e.width)-x;
ctx.fillStyle = e.fillColor; ctx.fillRect(x, y, width, e.height);
visible.push({ y0:y, y1:y+e.height, x0:x, x1:x+width, arg:e.arg });
// add label
if (e.label == null) continue;
ctx.textAlign = "left";
ctx.textBaseline = "middle";
let labelX = x+2, labelWidth = 0;
const labelY = y+e.height/2;
for (const [i,l] of e.label.entries()) {
if (labelWidth+l.width+(i===e.label.length-1 ? 0 : ellipsisWidth)+2 > width) {
if (labelWidth !== 0) ctx.fillText("...", labelX, labelY);
break;
}
ctx.fillStyle = l.color;
ctx.fillText(l.st, labelX, labelY);
labelWidth += l.width;
labelX += l.width;
}
}
}
// draw axes
drawLine(ctx, xscale.range(), [0, 0]);
for (const tick of xscale.ticks()) {
// tick line
const x = xscale(tick);
drawLine(ctx, [x, x], [0, tickSize])
// tick label
ctx.textBaseline = "top";
ctx.fillText(formatTime(tick, dur), x+ctx.lineWidth+2, tickSize);
}
if (data.axes.y != null) {
drawLine(ctx, [0, 0], data.axes.y.range);
const yscale = d3.scaleLinear().domain(data.axes.y.domain).range(data.axes.y.range);
for (const tick of yscale.ticks()) {
const y = yscale(tick);
drawLine(ctx, [0, tickSize], [y, y]);
ctx.textBaseline = "middle";
ctx.fillText(formatUnit(tick, data.axes.y.fmt), tickSize+2, y);
}
}
// draw markers
for (const m of markers) {
const x = xscale(m.ts);
drawLine(ctx, [x, x], [0, canvas.clientHeight], { color:m.color });
ctx.fillText(m.name, x+2, 1);
}
}
function resize() {
const profiler = document.querySelector(".profiler");
const sideRect = rect("#device-list");
const width = profiler.clientWidth-(sideRect.width+padding), height = Math.round(sideRect.height);
if (canvas.width === width*dpr && canvas.height === height*dpr) return;
canvas.width = width*dpr;
canvas.height = height*dpr;
canvas.style.height = `${height}px`;
canvas.style.width = `${width}px`;
ctx.scale(dpr, dpr);
d3.select(canvas).call(canvasZoom.transform, zoomLevel);
}
canvasZoom = d3.zoom().filter(vizZoomFilter).scaleExtent([1, Infinity]).translateExtent([[0,0], [Infinity,0]]).on("zoom", e => render(e.transform));
d3.select(canvas).call(canvasZoom);
document.addEventListener("contextmenu", e => e.ctrlKey && e.preventDefault());
new ResizeObserver(([e]) => e.contentRect.width > 0 && resize()).observe(profiler.node());
function findRectAtPosition(x, y) {
let tid = null;
for (const k of data.tracks.keys()) {
const r = rect(document.getElementById(k));
if (y >= r.y && y <= r.y+r.height) { tid = k; break; }
}
if (tid == null) return;
const { top, left, width, height } = rect(canvas);
const X = ((x-left) * (canvas.width/width))/dpr;
const Y = ((y-top) * (canvas.height/height))/dpr;
for (const r of data.tracks.get(tid).visible) {
if (Y>=r.y0 && Y<=r.y1 && X>=r.x0 && X<=r.x1) return r.arg;
}
}
canvas.addEventListener("click", e => {
e.preventDefault();
const foundRect = findRectAtPosition(e.clientX, e.clientY);
if (foundRect?.step != null) return setCtxWithHistory(foundRect.ctx, foundRect.step);
});
canvas.addEventListener("mousemove", e => {
const foundRect = findRectAtPosition(e.clientX, e.clientY);
if (foundRect?.tooltipText != null) {
const tooltip = document.getElementById("tooltip");
tooltip.style.display = "block";
tooltip.style.left = (e.pageX+10)+"px";
tooltip.style.top = (e.pageY)+"px";
tooltip.innerHTML = foundRect.tooltipText;
} else tooltip.style.display = "none";
});
canvas.addEventListener("mouseleave", () => document.getElementById("tooltip").style.display = "none");
}
// ** zoom and recentering
const vizZoomFilter = e => (!e.ctrlKey || e.type === 'wheel' || e.type === 'mousedown') && !e.button && e.type !== 'dblclick';
const svgZoom = d3.zoom().filter(vizZoomFilter).on("zoom", (e) => d3.select("#render").attr("transform", e.transform));
d3.select("#graph-svg").call(svgZoom);
// zoom to fit into view
document.getElementById("zoom-to-fit-btn").addEventListener("click", () => {
const canvas = d3.select("#timeline");
if (!canvas.empty() && rect(canvas.node()).width !== 0) {
return canvas.call(canvasZoom.transform, d3.zoomIdentity);
}
const svg = d3.select("#graph-svg");
svg.call(svgZoom.transform, d3.zoomIdentity);
const mainRect = rect(".main-container");
const x0 = rect(".ctx-list-parent").right;
const x1 = rect(".metadata-parent").left;
const pad = 16;
const R = { x: x0+pad, y: mainRect.top+pad, width: (x1>0 ? x1-x0 : mainRect.width)-2*pad, height: mainRect.height-2*pad };
const r = rect("#render");
if (r.width === 0) return;
const scale = Math.min(R.width/r.width, R.height/r.height);
const [tx, ty] = [R.x+(R.width-r.width*scale)/2-r.left*scale, R.y+(R.height-r.height*scale)/2];
svg.call(svgZoom.transform, d3.zoomIdentity.translate(tx, ty).scale(scale));
});
// **** main VIZ interfacae
function codeBlock(st, language, { loc, wrap }={}) {
const code = document.createElement("code");
code.innerHTML = hljs.highlight(st, { language }).value;
code.className = "hljs";
const ret = document.createElement("pre");
if (wrap) ret.className = "wrap";
if (loc != null) {
const link = ret.appendChild(document.createElement("a"));
link.href = "vscode://file/"+loc.join(":");
link.textContent = `${loc[0].split("/").at(-1)}:${loc[1]}`+"\n\n";
}
ret.appendChild(code);
return ret;
}
function appendTd(tr, value, unit=null) {
const fmt = (typeof value === "number" && !Number.isInteger(value)) ? value.toFixed(2) : value;
tr.appendChild(document.createElement("td")).innerText = unit == "us" ? formatTime(value) : fmt+(unit ?? "");
}
function appendRow(table, name, value, unit=null, cls="main-row") {
const tr = table.appendChild(document.createElement("tr"));
tr.className = cls;
tr.appendChild(document.createElement("td")).innerText = name;
appendTd(tr, value, unit);
return tr;
}
function setActive(e) {
if (e == null) return;
e.classList.add("active");
requestAnimationFrame(() => e.scrollIntoView({ behavior: "auto", block: "nearest" }));
}
// ** hljs extra definitions for UOps and float4
hljs.registerLanguage("python", (hljs) => ({
...hljs.getLanguage("python"),
case_insensitive: false,
contains: [
{ begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-]*(\\.[a-zA-Z_][a-zA-Z0-9_-]*)*' + '(?=[.\\s\\n[:,(])', className: "type" },
{ begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-].vec*' + '(?=[.\\s\\n[:,(])', className: "type" },
{ begin: '[a-zA-Z_][a-zA-Z0-9_-]*\\.[a-zA-Z_][a-zA-Z0-9_-]*' + '(?=[.\\s\\n[:,()])', className: "operator" },
{ begin: '[A-Z][a-zA-Z0-9_]*(?=\\()', className: "section", ignoreEnd: true },
...hljs.getLanguage("python").contains,
]
}));
hljs.registerLanguage("cpp", (hljs) => ({
...hljs.getLanguage('cpp'),
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
}));
var ret = [];
var cache = {};
var ctxs = null;
const evtSources = [];
// VIZ displays graph rewrites in 3 levels, from bottom-up:
// rewrite: a single UOp transformation
// step: collection of rewrites
// context: collection of steps
const state = {currentCtx:-1, currentStep:0, currentRewrite:0, expandSteps:false};
function setState(ns) {
const { currentCtx:prevCtx, currentStep:prevStep } = state;
Object.assign(state, ns);
// update element styles if needed
document.getElementById(`ctx-${state.currentCtx}`)?.classList.toggle("expanded", state.expandSteps);
if (state.currentCtx !== prevCtx) {
document.getElementById(`ctx-${prevCtx}`)?.classList.remove("active", "expanded");
setActive(document.getElementById(`ctx-${state.currentCtx}`));
}
if (state.currentCtx !== prevCtx || state.currentStep !== prevStep) {
document.getElementById(`step-${prevCtx}-${prevStep}`)?.classList.remove("active");
setActive(document.getElementById(`step-${state.currentCtx}-${state.currentStep}`));
}
// re-render
main();
}
// set a new context and keep the old one in browser history
function setCtxWithHistory(newCtx, step=0) {
// NOTE: browser does a structured clone, passing a mutable object is safe.
history.replaceState(state, "");
history.pushState(state, "");
setState({ expandSteps:true, currentCtx:newCtx+1, currentStep:step, currentRewrite:0 });
}
window.addEventListener("popstate", (e) => {
if (e.state != null) setState(e.state);
});
async function main() {
// ** left sidebar context list
if (ctxs == null) {
ctxs = [{ name:"Profiler", steps:[] }];
for (const r of (await (await fetch("/ctxs")).json())) ctxs.push(r);
const ctxList = document.querySelector(".ctx-list");
for (const [i,{name, steps}] of ctxs.entries()) {
const ul = ctxList.appendChild(document.createElement("ul"));
ul.id = `ctx-${i}`;
const p = ul.appendChild(document.createElement("p"));
p.innerHTML = parseColors(name).map(c => `<span style="color: ${c.color}">${c.st}</span>`).join("");
p.onclick = () => {
setState(i === state.currentCtx ? { expandSteps:!state.expandSteps } : { expandSteps:true, currentCtx:i, currentStep:0, currentRewrite:0 });
}
for (const [j,u] of steps.entries()) {
const inner = ul.appendChild(document.createElement("ul"));
inner.id = `step-${i}-${j}`;
inner.innerText = `${u.name ?? u.loc[0].replaceAll("\\", "/").split("/").pop()+':'+u.loc[1]}`+(u.match_count ? ` - ${u.match_count}` : '');
inner.style.marginLeft = `${8*u.depth}px`;
inner.onclick = (e) => {
e.stopPropagation();
setState({ currentStep:j, currentCtx:i, currentRewrite:0 });
}
}
}
return setState({ currentCtx:-1 });
}
// ** center graph
const { currentCtx, currentStep, currentRewrite, expandSteps } = state;
if (currentCtx == -1) return;
const ctx = ctxs[currentCtx];
const step = ctx.steps[currentStep];
const ckey = step?.query;
// close any pending event sources
let activeSrc = null;
for (const e of evtSources) {
const url = new URL(e.url);
if (url.pathname+url.search !== ckey) e.close();
else if (e.readyState === EventSource.OPEN) activeSrc = e;
}
if (ctx.name === "Profiler") return renderProfiler();
if (workerUrl == null) await initWorker();
if (ckey in cache) {
ret = cache[ckey];
}
// ** Disassembly view
if (ckey.startsWith("/disasm")) {
if (!(ckey in cache)) cache[ckey] = ret = await (await fetch(ckey)).json();
displayGraph("disasm");
const root = document.createElement("div");
root.className = "raw-text";
const metadata = document.querySelector(".metadata");
metadata.innerHTML = "";
// detailed assembly view
if (ret.cols != null) {
const asm = root.appendChild(document.createElement("table"));
const thead = asm.appendChild(document.createElement("thead"));
for (const c of ret.cols) thead.appendChild(document.createElement("th")).innerText = c.title ?? c;
for (const r of ret.rows) {
const tr = asm.appendChild(document.createElement("tr"));
tr.className = "main-row code-row";
for (const [i,value] of r.entries()) {
// string format scalar values
if (!Array.isArray(value)) appendTd(tr, value);
// display arrays in a bar graph
else {
const segmentsTd = tr.appendChild(document.createElement("td"));
segmentsTd.className = "pct-row";
const usageBar = segmentsTd.appendChild(document.createElement("div"));
for (const [k, v, width] of value) {
const seg = usageBar.appendChild(document.createElement("div"));
seg.style.width = width+"%";
seg.title = `${ret.cols[i].labels[k]} ${v}`;
seg.style.background = cycleColors(colorScheme.CATEGORICAL, parseInt(k));
}
}
}
}
const summary = metadata.appendChild(document.createElement("table"));
for (const s of ret.summary) {
const tr = summary.appendChild(document.createElement("tr"));
tr.className = "main-row";
const td = tr.appendChild(document.createElement("td"));
const div = td.appendChild(document.createElement("div"));
div.className = "legend";
div.appendChild(document.createElement("div")).style.background = cycleColors(colorScheme.CATEGORICAL, s.idx);
div.appendChild(document.createElement("p")).textContent = s.label;
appendTd(tr, s.value);
}
} else root.appendChild(codeBlock(ret.src, "x86asm"));
return document.querySelector(".disasm").replaceChildren(root);
}
// ** UOp view (default)
// if we don't have a complete cache yet we start streaming rewrites in this step
if (!(ckey in cache) || (cache[ckey].length !== step.match_count+1 && activeSrc == null)) {
ret = [];
cache[ckey] = ret;
const eventSource = new EventSource(ckey);
evtSources.push(eventSource);
eventSource.onmessage = (e) => {
if (e.data === "END") return eventSource.close();
const chunk = JSON.parse(e.data);
ret.push(chunk);
// if it's the first one render this new rgaph
if (ret.length === 1) return main();
// otherwise just enable the graph selector
const ul = document.getElementById(`rewrite-${ret.length-1}`);
if (ul != null) ul.classList.remove("disabled");
};
}
if (ret.length === 0) return;
renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes ?? [], currentRewrite === 0);
// ** right sidebar code blocks
const metadata = document.querySelector(".metadata");
const [code, lang] = ctx.fmt != null ? [ctx.fmt, "cpp"] : [ret[currentRewrite].uop, "python"];
metadata.replaceChildren(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }), codeBlock(code, lang, { wrap:false }));
// ** rewrite steps
if (step.match_count >= 1) {
const rewriteList = metadata.appendChild(document.createElement("div"));
rewriteList.className = "rewrite-list";
for (let s=0; s<=step.match_count; s++) {
const ul = rewriteList.appendChild(document.createElement("ul"));
ul.innerText = s;
ul.id = `rewrite-${s}`;
ul.onclick = () => setState({ currentRewrite:s });
ul.className = s > ret.length-1 ? "disabled" : s === currentRewrite ? "active" : "";
if (s > 0 && s === currentRewrite) {
const { upat, diff } = ret[s];
metadata.appendChild(codeBlock(upat[1], "python", { loc:upat[0], wrap:true }));
const diffCode = metadata.appendChild(document.createElement("pre")).appendChild(document.createElement("code"));
for (const line of diff) {
const span = diffCode.appendChild(document.createElement("span"));
span.style.color = line.startsWith("+") ? "#3aa56d" : line.startsWith("-") ? "#d14b4b" : "#f0f0f5";
span.innerText = line;
diffCode.appendChild(document.createElement("br"));
}
diffCode.className = "wrap";
}
}
}
}
// **** collapse/expand
let isCollapsed = false;
document.querySelector(".collapse-btn").addEventListener("click", (e) => {
isCollapsed = !isCollapsed;
document.querySelector(".main-container").classList.toggle("collapsed", isCollapsed);
e.currentTarget.blur();
e.currentTarget.style.transform = isCollapsed ? "rotate(180deg)" : "rotate(0deg)";
window.dispatchEvent(new Event("resize"));
});
// **** resizer
function appendResizer(element, { minWidth, maxWidth }, left=false) {
const handle = Object.assign(document.createElement("div"), { className: "resize-handle", style: left ? "right: 0" : "left: 0; margin-top: 0" });
element.appendChild(handle);
const resize = (e) => {
const change = e.clientX - element.dataset.startX;
let newWidth = ((Number(element.dataset.startWidth)+(left ? change : -change))/Number(element.dataset.containerWidth))*100;
element.style.width = `${Math.max(minWidth, Math.min(maxWidth, newWidth))}%`;
};
handle.addEventListener("mousedown", (e) => {
e.preventDefault();
element.dataset.startX = e.clientX;
element.dataset.containerWidth = rect(".main-container").width;
element.dataset.startWidth = element.getBoundingClientRect().width;
document.documentElement.addEventListener("mousemove", resize, false);
document.documentElement.addEventListener("mouseup", () => {
document.documentElement.removeEventListener("mousemove", resize, false);
element.style.userSelect = "initial";
}, { once: true });
});
}
appendResizer(document.querySelector(".ctx-list-parent"), { minWidth: 15, maxWidth: 50 }, left=true);
appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWidth: 50 });
// **** keyboard shortcuts
document.addEventListener("keydown", (event) => {
const { currentCtx, currentStep, currentRewrite, expandSteps } = state;
// up and down change the step or context from the list
const changeStep = expandSteps && ctxs[currentCtx].steps?.length;
if (event.key == "ArrowUp") {
event.preventDefault();
if (changeStep) {
return setState({ currentRewrite:0, currentStep:Math.max(0, currentStep-1) });
}
return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.max(0, currentCtx-1), expandSteps:false });
}
if (event.key == "ArrowDown") {
event.preventDefault();
if (changeStep) {
const totalUOps = ctxs[currentCtx].steps.length-1;
return setState({ currentRewrite:0, currentStep:Math.min(totalUOps, currentStep+1) });
}
return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.min(ctxs.length-1, currentCtx+1), expandSteps:false });
}
// enter toggles focus on a single rewrite stage
if (event.key == "Enter") {
event.preventDefault()
if (currentCtx === -1) {
return setState({ currentCtx:0, expandSteps:true });
}
return setState({ expandSteps:!expandSteps });
}
// left and right go through rewrites in a single UOp
if (event.key == "ArrowLeft") {
event.preventDefault()
return setState({ currentRewrite:Math.max(0, currentRewrite-1) });
}
if (event.key == "ArrowRight") {
event.preventDefault()
const totalRewrites = ret.length-1;
return setState({ currentRewrite:Math.min(totalRewrites, currentRewrite+1) });
}
// space recenters the graph
if (event.key == " ") {
event.preventDefault()
document.getElementById("zoom-to-fit-btn").click();
}
});
main()

29
tinygrad/viz/js/worker.js Normal file
View File

@@ -0,0 +1,29 @@
const NODE_PADDING = 10;
const LINE_HEIGHT = 14;
const canvas = new OffscreenCanvas(0, 0);
const ctx = canvas.getContext("2d");
ctx.font = `${LINE_HEIGHT}px sans-serif`;
onmessage = (e) => {
const { graph, additions } = e.data;
const g = new dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
if (additions.length !== 0) g.setNode("addition", {label:"", className:"overlay", padding:0});
for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
// adjust node dims by label size (excluding escape codes) + add padding
let [width, height] = [0, 0];
for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
width = Math.max(width, ctx.measureText(line).width);
height += LINE_HEIGHT;
}
g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, padding:NODE_PADDING, label, ref, id:k, ...rest});
// add edges
const edgeCounts = {}
for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
}
dagre.layout(g);
postMessage(dagre.graphlib.json.write(g));
self.close();
}

323
tinygrad/viz/serve.py Executable file
View File

@@ -0,0 +1,323 @@
#!/usr/bin/env python3
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct
import subprocess, ctypes, pathlib
from contextlib import redirect_stdout
from decimal import Decimal
from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, Generator
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
from tinygrad.codegen.opt import axis_colors
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.REALIZE: "#C1C14D",
Ops.CHILDREN: "#80ffc0", Ops.CHILD: "#80fff0", Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e"}
# VIZ API
# ** Metadata for a track_rewrites scope
ref_map:dict[Any, int] = {}
traces:dict[int, tuple] = {}
def get_metadata(trace_bufs:list[tuple]) -> list[dict]:
ret = []
for keys,contexts,uop_fields in trace_bufs:
for k,v in zip(keys, contexts):
traces[i:=len(traces)] = (k, v, uop_fields)
steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc),
"query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)]
ret.append(r:={"name":k.display_name, "steps":steps})
# program spec metadata
if isinstance(k.ret, ProgramSpec):
steps.append({"name":"View Disassembly", "query":f"/disasm?ctx={i}"})
r["fmt"] = k.ret.src
for key in k.keys: ref_map[key] = i
return ret
# ** Complete rewrite details for a graph_rewrite call
class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step
uop: str # strigified UOp for this rewrite step
diff: list[str]|None # diff of the single UOp that changed
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
def uop_to_json(x:UOp) -> dict[int, dict]:
assert isinstance(x, UOp)
graph: dict[int, dict] = {}
excluded: set[UOp] = set()
for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
# only exclude CONST VIEW source if it has no other children in the graph
if u.op is Ops.CONST and len(u.src) != 0 and all(cr.op is Ops.CONST for c in u.src[0].children if (cr:=c()) is not None and cr in toposort):
excluded.update(u.src)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")
if u.op is Ops.VIEW:
argst = ("\n".join([f"{shape_to_str(v.shape)} / {shape_to_str(v.strides)}"+("" if v.offset == 0 else f" / {srender(v.offset)}")+
(f"\nMASK {mask_to_str(v.mask)}" if v.mask is not None else "") for v in unwrap(u.st).views]))
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.arg)
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src):
if x in excluded:
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
try:
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
label += f"\n{shape_to_str(u.shape)}"
elif len(rngs:=u.ranges):
label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata)
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src) if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"),
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
return graph
@functools.cache
def _reconstruct(a:int, i:int):
op, dtype, src, arg, *rest = traces[i][2][a]
arg = type(arg)(_reconstruct(arg.ast, i), arg.metadata) if op is Ops.KERNEL else arg
return UOp(op, dtype, tuple(_reconstruct(s, i) for s in src), arg, *rest)
def get_details(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink, i)), "uop":str(next_sink), "changed_nodes":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num, i)] = u1 = _reconstruct(u1_num, i)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat_loc, printable(upat_loc))}
if not ctx.bottom_up: next_sink = new_sink
# encoder helpers
def enum_str(s, cache:dict[str, int]) -> int:
if (cret:=cache.get(s)) is not None: return cret
cache[s] = ret = len(cache)
return ret
def option(s:int|None) -> int: return 0 if s is None else s+1
# Profiler API
device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {}
def cpu_ts_diff(device:str, thread=0) -> Decimal: return device_ts_diffs.get(device, (Decimal(0),))[thread]
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
for e in profile:
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(e.device, e.is_copy)), (e.en if e.en is not None else e.st)+diff, e)
elif isinstance(e, ProfilePointEvent): yield (e.ts, e.ts, e)
elif isinstance(e, ProfileGraphEvent):
cpu_ts = []
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(ent.device, ent.is_copy)), e.sigs[ent.en_id]+diff]
yield (st:=min(cpu_ts)), (et:=max(cpu_ts)), ProfileRangeEvent(f"{e.ents[0].device.split(':')[0]} Graph", f"batched {len(e.ents)}", st, et)
for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
# normalize event timestamps and attach kernel metadata
def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
events:list[bytes] = []
exec_points:dict[str, dict] = {}
for st,et,dur,e in dev_events:
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.key] = e.arg
if dur == 0: continue
name, info = e.name, None
if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"]
if isinstance(p:=traces[ref][0].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \
f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}"
elif isinstance(e.name, TracingKey):
name = e.name.display_name
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
events.append(struct.pack("<IIIfI", enum_str(name, scache), option(ref), st-start_ts, dur, enum_str(info or "", scache)))
return struct.pack("<BI", 0, len(events))+b"".join(events) if events else None
def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_ts:int, peaks:list[int], dtype_size:dict[str, int],
scache:dict[str, int]) -> bytes|None:
peak, mem = 0, 0
temp:dict[int, int] = {}
events:list[bytes] = []
for st,_,_,e in dev_events:
if not isinstance(e, ProfilePointEvent): continue
if e.name == "alloc":
events.append(struct.pack("<BIIIQ", 1, int(e.ts)-start_ts, e.key, enum_str(e.arg["dtype"].name, scache), e.arg["sz"]))
dtype_size.setdefault(e.arg["dtype"].name, e.arg["dtype"].itemsize)
temp[e.key] = nbytes = e.arg["sz"]*e.arg["dtype"].itemsize
mem += nbytes
if mem > peak: peak = mem
if e.name == "free":
events.append(struct.pack("<BII", 0, int(e.ts)-start_ts, e.key))
mem -= temp.pop(e.key)
peaks.append(peak)
return struct.pack("<BIQ", 1, len(events), peak)+b"".join(events) if events else None
def get_profile(profile:list[ProfileEvent]) -> bytes|None:
# start by getting the time diffs
for ev in profile:
if isinstance(ev,ProfileDeviceEvent): device_ts_diffs[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff)
# map events per device
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
markers:list[ProfilePointEvent] = []
start_ts:int|None = None
end_ts:int|None = None
for ts,en,e in flatten_events(profile):
dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e))
if start_ts is None or st < start_ts: start_ts = st
if end_ts is None or et > end_ts: end_ts = et
if isinstance(e, ProfilePointEvent) and e.name == "marker": markers.append(e)
if start_ts is None: return None
# return layout of per device events
layout:dict[str, bytes|None] = {}
scache:dict[str, int] = {}
peaks:list[int] = []
dtype_size:dict[str, int] = {}
for k,v in dev_events.items():
v.sort(key=lambda e:e[0])
layout[k] = timeline_layout(v, start_ts, scache)
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), v]) for k,v in layout.items() if v is not None]
index = json.dumps({"strings":list(scache), "dtypeSize":dtype_size, "markers":[{"ts":int(e.ts-start_ts), **e.arg} for e in markers]}).encode()
return struct.pack("<IQII", unwrap(end_ts)-start_ts, max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)
# ** Assembly analyzers
def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict:
target_args = [f"-mtriple={mtriple}", f"-mcpu={mcpu}"]
# disassembly output can include headers / metadata, skip if llvm-mca can't parse those lines
data = json.loads(subprocess.check_output(["llvm-mca","-skip-unsupported-instructions=parse-failure","--json","-"]+target_args, input=asm.encode()))
cr = data["CodeRegions"][0]
resource_labels = data["TargetInfo"]["Resources"]
rows:list = [[instr] for instr in cr["Instructions"]]
# add scheduler estimates
for info in cr["InstructionInfoView"]["InstructionList"]: rows[info["Instruction"]].append(info["Latency"])
# map per instruction resource usage
instr_usage:dict[int, dict[int, int]] = {}
for d in cr["ResourcePressureView"]["ResourcePressureInfo"]:
instr_usage.setdefault(i:=d["InstructionIndex"], {}).setdefault(r:=d["ResourceIndex"], 0)
instr_usage[i][r] += d["ResourceUsage"]
# last row is the usage summary
summary = [{"idx":k, "label":resource_labels[k], "value":v} for k,v in instr_usage.pop(len(rows), {}).items()]
max_usage = max([sum(v.values()) for i,v in instr_usage.items() if i<len(rows)], default=0)
for i,usage in instr_usage.items(): rows[i].append([[k, v, (v/max_usage)*100] for k,v in usage.items()])
return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary}
def get_disassembly(ctx:list[str]):
if not isinstance(prg:=traces[int(ctx[0])][0].ret, ProgramSpec): return
lib = (compiler:=Device[prg.device].compiler).compile(prg.src)
with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib)
disasm_str = buf.getvalue()
from tinygrad.runtime.support.compiler_cpu import llvm, LLVMCompiler
if isinstance(compiler, LLVMCompiler):
mtriple = ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode()
mcpu = ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode()
ret = get_llvm_mca(disasm_str, mtriple, mcpu)
else: ret = {"src":disasm_str}
return json.dumps(ret).encode()
# ** HTTP server
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
ret, status_code, content_type = b"", 200, "text/html"
if (url:=urlparse(self.path)).path == "/":
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path:
try:
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
if url.path.endswith(".js"): content_type = "application/javascript"
if url.path.endswith(".css"): content_type = "text/css"
except FileNotFoundError: status_code = 404
elif (query:=parse_qs(url.query)):
if url.path == "/disasm": ret, content_type = get_disassembly(**query), "application/json"
else: return self.stream_json(get_details(traces[i:=int(query["ctx"][0])][1][int(query["idx"][0])], i))
elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json"
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
else: status_code = 404
# send response
self.send_response(status_code)
self.send_header('Content-Type', content_type)
self.send_header('Content-Length', str(len(ret)))
self.end_headers()
return self.wfile.write(ret)
def stream_json(self, source:Generator):
try:
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
for r in source:
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
self.wfile.flush()
self.wfile.write("data: END\n\n".encode("utf-8"))
# pass if client closed connection
except (BrokenPipeError, ConnectionResetError): return
# ** main loop
def reloader():
mtime = os.stat(__file__).st_mtime
while not stop_reloader.is_set():
if mtime != os.stat(__file__).st_mtime:
print("reloading server...")
os.execv(sys.executable, [sys.executable] + sys.argv)
time.sleep(0.1)
def load_pickle(path:pathlib.Path|None) -> list:
if path is None or not path.exists(): return []
with path.open("rb") as f: return pickle.load(f)
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True)))
args = parser.parse_args()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0:
raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.")
stop_reloader = threading.Event()
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
st = time.perf_counter()
print("*** viz is starting")
ctxs = get_metadata(load_pickle(args.kernels))
profile_ret = get_profile(load_pickle(args.profile))
server = TCPServerWithReuse(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
print(f"*** started viz on {HOST}:{PORT}")
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"), flush=True)
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}")
try: server.serve_forever()
except KeyboardInterrupt:
print("*** viz is shutting down...")
stop_reloader.set()