Files
maths-cs-ai-compendium-zh/images/mha_gqa_mqa.svg
T

113 lines
9.4 KiB
XML
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 260" width="800" height="260">
<text x="400" y="22" text-anchor="middle" font-family="Arial, sans-serif" font-size="14" font-weight="bold" fill="#333">MHA vs GQA vs MQA: KV Head Sharing</text>
<line x1="267" y1="35" x2="267" y2="245" stroke="#ccc" stroke-width="1" stroke-dasharray="4,3"/>
<line x1="533" y1="35" x2="533" y2="245" stroke="#ccc" stroke-width="1" stroke-dasharray="4,3"/>
<!-- MHA -->
<text x="134" y="52" text-anchor="middle" font-family="Arial, sans-serif" font-size="12" font-weight="bold" fill="#e74c3c">Multi-Head (MHA)</text>
<text x="134" y="66" text-anchor="middle" font-family="Arial, sans-serif" font-size="9" fill="#e74c3c">8 Q heads, 8 KV sets</text>
<!-- Q heads -->
<text x="35" y="88" text-anchor="middle" font-family="Arial, sans-serif" font-size="8" fill="#666">Q heads:</text>
<rect x="60" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="84" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="108" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="132" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="156" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="180" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="204" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="228" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<!-- KV sets (one per head) -->
<text x="35" y="128" text-anchor="middle" font-family="Arial, sans-serif" font-size="8" fill="#666">KV sets:</text>
<rect x="60" y="118" width="22" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1"/>
<rect x="84" y="118" width="22" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1"/>
<rect x="108" y="118" width="22" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1"/>
<rect x="132" y="118" width="22" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1"/>
<rect x="156" y="118" width="22" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1"/>
<rect x="180" y="118" width="22" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1"/>
<rect x="204" y="118" width="22" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1"/>
<rect x="228" y="118" width="22" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1"/>
<!-- Lines connecting each Q to its KV -->
<line x1="71" y1="96" x2="71" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="95" y1="96" x2="95" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="119" y1="96" x2="119" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="143" y1="96" x2="143" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="167" y1="96" x2="167" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="191" y1="96" x2="191" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="215" y1="96" x2="215" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="239" y1="96" x2="239" y2="118" stroke="#666" stroke-width="0.8"/>
<text x="134" y="155" text-anchor="middle" font-family="Arial, sans-serif" font-size="9" fill="#e74c3c">KV cache: 8× per layer</text>
<!-- GQA -->
<text x="400" y="52" text-anchor="middle" font-family="Arial, sans-serif" font-size="12" font-weight="bold" fill="#f39c12">Grouped-Query (GQA)</text>
<text x="400" y="66" text-anchor="middle" font-family="Arial, sans-serif" font-size="9" fill="#f39c12">8 Q heads, 2 KV sets</text>
<!-- Q heads -->
<rect x="296" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="320" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="344" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="368" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="404" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="428" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="452" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="476" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<!-- KV sets (2 shared) -->
<rect x="320" y="118" width="46" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1.5"/>
<text x="343" y="131" text-anchor="middle" font-family="Arial, sans-serif" font-size="7" fill="#e74c3c">KV₁</text>
<rect x="428" y="118" width="46" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="1.5"/>
<text x="451" y="131" text-anchor="middle" font-family="Arial, sans-serif" font-size="7" fill="#e74c3c">KV₂</text>
<!-- Lines: 4 heads → KV1, 4 heads → KV2 -->
<line x1="307" y1="96" x2="335" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="331" y1="96" x2="340" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="355" y1="96" x2="346" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="379" y1="96" x2="351" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="415" y1="96" x2="443" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="439" y1="96" x2="448" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="463" y1="96" x2="454" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="487" y1="96" x2="459" y2="118" stroke="#666" stroke-width="0.8"/>
<text x="400" y="155" text-anchor="middle" font-family="Arial, sans-serif" font-size="9" fill="#f39c12">KV cache: 2× per layer (4× smaller)</text>
<!-- MQA -->
<text x="666" y="52" text-anchor="middle" font-family="Arial, sans-serif" font-size="12" font-weight="bold" fill="#27ae60">Multi-Query (MQA)</text>
<text x="666" y="66" text-anchor="middle" font-family="Arial, sans-serif" font-size="9" fill="#27ae60">8 Q heads, 1 KV set</text>
<!-- Q heads -->
<rect x="570" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="594" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="618" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="642" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="666" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="690" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="714" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<rect x="738" y="78" width="22" height="18" rx="3" fill="#3498db" fill-opacity="0.3" stroke="#3498db" stroke-width="1"/>
<!-- Single KV set -->
<rect x="642" y="118" width="46" height="18" rx="3" fill="#e74c3c" fill-opacity="0.3" stroke="#e74c3c" stroke-width="2"/>
<text x="665" y="131" text-anchor="middle" font-family="Arial, sans-serif" font-size="7" font-weight="bold" fill="#e74c3c">KV</text>
<!-- All lines converge -->
<line x1="581" y1="96" x2="655" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="605" y1="96" x2="658" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="629" y1="96" x2="661" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="653" y1="96" x2="664" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="677" y1="96" x2="667" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="701" y1="96" x2="670" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="725" y1="96" x2="673" y2="118" stroke="#666" stroke-width="0.8"/>
<line x1="749" y1="96" x2="676" y2="118" stroke="#666" stroke-width="0.8"/>
<text x="666" y="155" text-anchor="middle" font-family="Arial, sans-serif" font-size="9" fill="#27ae60">KV cache: 1× per layer (8× smaller)</text>
<!-- Bottom summary -->
<rect x="50" y="175" width="700" height="60" rx="8" fill="#f5f5f5" stroke="#ddd" stroke-width="1"/>
<text x="400" y="195" text-anchor="middle" font-family="Arial, sans-serif" font-size="10" fill="#333">For a 70B model (80 layers, 64 heads) at sequence length 32K in FP16:</text>
<text x="200" y="215" text-anchor="middle" font-family="Arial, sans-serif" font-size="10" font-weight="bold" fill="#e74c3c">MHA: 40 GB</text>
<text x="400" y="215" text-anchor="middle" font-family="Arial, sans-serif" font-size="10" font-weight="bold" fill="#f39c12">GQA (8 groups): 5 GB</text>
<text x="600" y="215" text-anchor="middle" font-family="Arial, sans-serif" font-size="10" font-weight="bold" fill="#27ae60">MQA: 0.6 GB</text>
</svg>