Tile 運算與基本操作 (Tile Operations and Primitives)
重點總覽
| 項目 | 重點 |
|---|---|
| Element-wise 算術 | tile 支援標準逐元素算術;形狀相容但不同時,小的會 broadcast 對齊後再運算 |
| Broadcasting 規則 | 遵循 NumPy 語義:scalar 複製、singleton(長度1) 拉伸、低 rank 對齊到高 rank 的 trailing 維度(缺的 leading 維度當 singleton) |
| Broadcasting 例外 | 兩個對應維度都非 singleton 且不相等 → ill-formed |
| Arithmetic operators | 逐元素套用、產生 broadcast 形狀的新 tile;型別不同時偏好保留較多資訊者 |
| Scalar-tile 型別差異 | scalar 需 narrow 時:Python 提升型別、C++ 視為 ill-formed |
| Tile primitives | factory、load/store、算術皆屬語言內建 primitive;compiler 映射到硬體(含 tensor cores) |
| Matrix Multiply | matmul (a @ b) 與 mma (a @ b + acc);mma 的 accumulator 跨 K-tile 累加 |
| GEMM 慣用法 | 無論輸入精度都用 FP32 累加,store 時 cast 成輸出型別 |
| Reductions | 把 tile 收斂成 scalar 或一列 scalar;Python 預設丟掉 reduced 軸、C++ 永遠保留 |
| Scans | 沿軸的累積(running)版本,例如 prefix-sum/cumsum,輸出維度與輸入相同 |
| Transpose / Permute | 重排軸但不動資料;transpose 換前兩軸、permute 任意重排 |
| Element-wise Selection | tile 版條件式;Python ct.where(cond,x,y)、C++ ct::select(cond,lhs,rhs) |
| Mathematical Functions | ct namespace 提供逐元素數學函式,回傳同形狀 tile,也能作用於 scalar |
Element-wise Arithmetic and Broadcasting
Tile 支援標準的逐元素 (element-wise) 算術。當兩個運算元形狀相容但不同時,較小者會先 broadcast 對齊到較大者,再執行運算。
- broadcast 是「邏輯上拉伸」,不會真正複製整塊資料。
- 對齊後的形狀稱為 broadcast shape,是結果 tile 的形狀。
Broadcasting 規則
Broadcasting 遵循 NumPy 語義,三條規則:
- Scalar 複製:scalar 被複製到整個 tile 的每個元素。
- Singleton 拉伸:長度為 1 的維度(singleton)被拉伸以匹配另一運算元對應維度的長度。
- Rank 對齊:較低 rank 的運算元會對齊到較高 rank 運算元的尾端 (trailing) 維度,缺少的前導 (leading) 維度視為 singleton(即 rank promotion)。
若兩個對應維度都非 singleton 且彼此不相等,該運算為 ill-formed(無法成立)。例如 (3) 與 (4) 無法 broadcast。
下例同時觸發「singleton 拉伸」與「rank promotion」:rank-2 的 8x2 先被 rank-promote 成 1x8x2,再與 rank-3 的 4x1x2 broadcast 到共同形狀 4x8x2。
auto x = ct::iota<ct::tile<int, ct::shape<8, 2>>>(); // 8x2 (rank 2)
auto y = ct::iota<ct::tile<int, ct::shape<4, 1, 2>>>(); // 4x1x2 (rank 3)
auto z = x + y; // x 先 promote 成 1x8x2,再 broadcast 成 4x8x2
x = ct.full((8, 2), 3, dtype=ct.int32) # 8x2 (rank 2)
y = ct.full((4, 1, 2), 5, dtype=ct.int32) # 4x1x2 (rank 3)
z = x + y # x 先 promote 成 1x8x2,再 broadcast 成 4x8x2
對齊與拉伸的過程:
x: (8, 2) -- rank 2
↓ rank promotion (補 leading singleton)
x: (1, 8, 2)
y: (4, 1, 2) -- rank 3
─────────────────────────────────────
逐維對齊 (trailing 對齊):
dim0: 1 vs 4 -> 拉伸 1 → 4
dim1: 8 vs 1 -> 拉伸 1 → 8
dim2: 2 vs 2 -> 相等
─────────────────────────────────────
broadcast shape = (4, 8, 2)
Arithmetic Operators
所有支援的算術運算子都逐元素套用於 tile,並產生一個 broadcast 形狀的新 tile。scalar 與 tile 結合時,scalar 被 broadcast 到每個元素。
當運算元型別不同時,偏好保留較多資訊(精度或範圍較大)的型別:
- Tile 結合 tile:結果為精度/範圍較大者:
int + float→floatint16 + int32→int32
- Scalar 結合 tile:
- 當 scalar 型別可精確表示於 tile 的元素型別(如整數字面值
2配 int tile、2.0f配 float tile)→ 直接以 tile 的元素型別運算。 - 當 scalar 需收窄 (narrow) 才能塞進 tile 的元素型別(如字面值
2.5配 int tile)→ 兩語言行為不同:
- 當 scalar 型別可精確表示於 tile 的元素型別(如整數字面值
| 語言 | scalar 需 narrow 時的行為 |
|---|---|
| Python | 把結果提升 (promote) 成能同時容納兩者的型別 |
| C++ | 視該表達式為 ill-formed(編譯拒絕) |
using i32x8 = ct::tile<int, ct::shape<8>>;
i32x8 x = ct::full<i32x8>(3);
x + 2; // OK - int 字面值與 int tile 元素型別相符
x + 2.5; // ill-formed - 2.5 會 narrow 成 int
x = ct.full((8,), 3, dtype=ct.int32)
x + 2 # int32 - int 字面值與 int32 tile dtype 相符
x + 2.5 # float32 - 結果被 promote 以同時容納兩者
盡量用 tile 元素型別寫 scalar 字面值(如 float tile 就寫 2.0f),要不同精度時用顯式轉換。這些規則在 kernel 內對 loaded tile 同樣適用。
kernel 內逐元素運算(2.0f 與 float tile 相符,無 narrowing;scalar 先 broadcast,再逐元素 +):
auto x = aView.load(bx);
auto y = bView.load(bx);
auto z = 2.0f * x + y; // scalar broadcast → 逐元素乘加
cView.store(z, bx);
x = ct.load(a, index=(bid,), shape=(TILE,))
y = ct.load(b, index=(bid,), shape=(TILE,))
z = 2.0 * x + y # 2.0 為鬆散型別 float 常數,配 float tile 仍為 float
ct.store(c, index=(bid,), tile=z)
若需要明確控制 rounding mode 或 subnormal handling,CUDA Tile API 提供可接受這些參數的數學函式(如 ct.add / ct::add),而非用運算子。
Tile Primitives 概觀
Factory 函式(建立 tile)、load/store、以及逐元素算術,全都是 tile primitive:屬於語言本身的一部分。
- 程式設計師以 tile 粒度 (granularity) 撰寫這些運算。
- compiler 負責把它們映射到硬體,在可用時包含 tensor cores。
- 以下小節介紹其他可用的 primitive。
Matrix Multiply
兩個 tile 的矩陣相乘是實作陣列矩陣乘法的基礎運算。CUDA Tile 提供兩種形式:
| 形式 | 寫法 | 說明 |
|---|---|---|
| matmul(純乘) | a @ b |
單純矩陣相乘 |
| mma(乘加) | a @ b + acc |
accumulator 把部分積 (partial products) 從一個 K-tile 帶到下一個,適合 tiled matmul 的內層迴圈 |
- 兩者都支援 2D 矩陣乘法與 3D batched 乘法,並可混合運算元與 accumulator 的資料型別(精度)。
- rank 與元素型別的限制見 API reference。
慣用 GEMM 模式:無論輸入精度為何,都以 FP32 累加,store 時再 cast 成輸出元素型別。
- Python:
ct.mma(a, b, acc),acc為 FP32 型別。 - C++:
ct::mma(a, b, acc),accumulator 為顯式 FP32 型別。 - K-loop 迭代
ceil(K / tk)次,以涵蓋 A 的右緣與 B 的下緣。 - 部分 K-tile 在 load 時 zero-pad(Python
PaddingMode.ZERO、C++.load_masked())。 - C 側部分 M/N 邊緣 tile 用 store 端 OOB-discard 處理(Python
ct.store、C++.store_masked())。
using f32_acc = ct::tile<float, ct::shape<32, 32>>;
auto acc = ct::full<f32_acc>(0.0f); // FP32 accumulator
std::size_t num_k = (K + tk - 1) / tk; // = ceil(K / tk)
for (auto k : ct::irangesize_t{0}, num_k) {
acc = ct::mma(aView.load_masked(bx, k), // 部分 K-tile zero-pad
bView.load_masked(k, by),
acc); // acc += a @ b
}
cView.store_masked(acc, bx, by); // 丟棄 OOB 邊緣 lane
acc = ct.full((tm, tn), 0, dtype=ct.float32) # FP32 accumulator
for k in range(num_k):
a = ct.load(A, index=(bx, k), shape=(tm, tk),
padding_mode=ct.PaddingMode.ZERO) # 部分 K-tile zero-pad
b = ct.load(B, index=(k, by), shape=(tk, tn),
padding_mode=ct.PaddingMode.ZERO)
acc = ct.mma(a, b, acc) # acc += a @ b
ct.store(C, index=(bx, by), tile=acc.astype(C.dtype)) # cast + store
K-loop 結構 (M×K · K×N → M×N):
for k in 0 .. ceil(K/tk)-1:
┌────────┐ ┌────────┐
│ A[bx,k]│ @ │ B[k,by]│ -> 部分積累加進 acc (FP32)
└────────┘ └────────┘
迴圈結束: store_masked / store(astype) 寫回 C,cast 成輸出型別
Reductions and Scans
Reduction 把一個 tile 收斂成 scalar 或一列 scalar。典型用途:softmax 的分母、layer norm 的 mean/variance、attention scoring 的 max。
最該記住的一點是結果的形狀:
| 語言 | reduced 軸的處理 |
|---|---|
| Python | 預設丟掉 (drop) 該軸;傳 keepdims=True 才保留為長度 1 |
| C++ | 永遠保留該軸,維持 tile 的 rank |
下例對 2x4 tile 沿 axis 1 做 sum:
using i32x2x4 = ct::tile<int, ct::shape<2, 4>>;
auto x = ct::iota<i32x2x4>(); // [[0,1,2,3],[4,5,6,7]]
auto row_sums = ct::sum(x, 1_ic); // shape (2,1) 軸保留; [[6],[22]]
x = ct.arange(8, dtype=ct.int32).reshape((2, 4)) # [[0,1,2,3],[4,5,6,7]]
s = ct.sum(x, axis=1) # shape (2,) 軸丟掉; [6, 22]
s_k = ct.sum(x, axis=1, keepdims=True) # shape (2,1) 軸保留; [[6],[22]]
input 2x4: [[0,1,2,3],
[4,5,6,7]] --- sum 沿 axis=1 --->
C++ : shape (2,1) [[6],[22]] (永遠保留)
Python 預設 : shape (2,) [6, 22] (丟掉軸)
Python keepdims: shape (2,1) [[6],[22]] (保留為長度 1)
Scan 是 reduction 的「running(累積)」版本,沿某軸產生累積結果。例如 prefix-sum (cumsum) 輸出維度與輸入相同,每個位置的值是沿指定軸到該位置(含)為止所有元素之和。
Transpose and Permutation
兩個相關 primitive 重排 tile 的軸但不觸碰資料:
- transpose:交換前兩個軸。
- permute:做任意順序的重排。
常用於 tile 邏輯佈局需改變之處:把 matmul 運算元 materialize 成轉置、attention block 中交換 row/column、或在 broadcast 前對齊軸。
| 語言 | transpose | permute |
|---|---|---|
| Python | ct.transpose(x):rank-2 換兩軸;更高 rank 需給 axis0/axis1 |
ct.permute(x, axes):傳軸索引 tuple |
| C++ | ct::transpose(x):交換前兩維(trailing 維度保留) |
ct::permute(x, map):傳 ct::dimension_map 描述新順序 |
using t2d = ct::tile<int, ct::shape<2, 4>>;
auto tx = ct::iota<t2d>();
auto ty = ct::transpose(tx); // shape (4, 2)
auto tz = ct::iota<ct::tile<int, ct::shape<2, 2, 2>>>();
auto tw = ct::permute(tz, ct::dimension_map{2_ic, 0_ic, 1_ic}); // (0,1,2)->(2,0,1)
tx = ct.arange(8, dtype=ct.int32).reshape((2, 4))
ty = ct.transpose(tx) # shape (4, 2)
tz = ct.arange(8, dtype=ct.int32).reshape((2, 2, 2))
tw = ct.permute(tz, (2, 0, 1)) # 軸 (0,1,2) -> (2,0,1)
Element-wise Selection
逐元素 selection 是 tile 版的條件式:給一個 boolean tile 與兩個運算元 tile,每個輸出元素依對應的 boolean 從其一挑選。
- 條件會被 broadcast 到運算元形狀;運算元型別必須相容。
- Python 寫成
ct.where(cond, x, y);C++ 寫成ct::select(cond, lhs, rhs)。
auto cond = ct::iota<ct::tile<int, ct::shape<4>>>() < 2; // {T, T, F, F}
auto t = ct::full<ct::tile<float, ct::shape<4>>>( 1.0f);
auto f = ct::full<ct::tile<float, ct::shape<4>>>(-1.0f);
auto r = ct::select(cond, t, f); // {1, 1, -1, -1}
cond = ct.arange(4, dtype=ct.int32) < 2 # [T, T, F, F]
x_true = ct.full((4,), 1.0, dtype=ct.float32)
x_false = ct.full((4,), -1.0, dtype=ct.float32)
result = ct.where(cond, x_true, x_false) # [1, 1, -1, -1]
兩語言皆「cond 為 true → 取 x/lhs;false → 取 y/rhs」,與 NumPy where 一致。
Mathematical Functions
ct namespace 提供常見的逐元素數學運算函式。每個函式對輸入 tile 逐元素套用、回傳同形狀 tile,且也能作用於 tile code 中的 scalar。
| 類別 | 函式 |
|---|---|
| 基本算術 | add, sub, mul |
| 除法 | truediv, floordiv, cdiv, mod |
| 冪/指對數 | pow, exp, exp2, log, log2 |
| 開根 | sqrt, rsqrt |
| 三角 | sin, cos, tan |
| 雙曲 | sinh, cosh, tanh |
| 極值/取負 | minimum, maximum, negative |
| 取整 | floor, ceil |
用函式版(如 ct::add)而非運算子,可在需要時傳入 rounding mode / subnormal 等參數;完整清單見 cuTile Python 與 CUDA Tile C++ 的 Math Operations API reference。
考試/測驗重點
| 情境/關鍵字 | 答案 |
|---|---|
| broadcasting 遵循什麼語義 | NumPy 語義:scalar 複製、singleton 拉伸、低 rank 對齊 trailing 維度 |
| 兩維都非 singleton 且不相等 | ill-formed(無法 broadcast) |
8x2 與 4x1x2 相加結果形狀 |
8x2 先 promote 成 1x8x2,再 broadcast 成 4x8x2 |
| 缺少的維度補在哪 | 補在 leading(前導),視為 singleton(rank promotion) |
int + float 結果型別 |
float(保留較多資訊) |
int16 + int32 結果型別 |
int32 |
int tile + 2.5 在 C++ |
ill-formed(2.5 會 narrow 成 int) |
int32 tile + 2.5 在 Python |
float32(結果被 promote) |
| matmul vs mma | matmul = a @ b;mma = a @ b + acc(accumulator 跨 K-tile 累加) |
| GEMM 慣用累加精度 | 一律 FP32 累加,store 時 cast 成輸出型別 |
| K-loop 迭代次數 | ceil(K / tk),等價 (K + tk - 1) / tk |
| 部分 K-tile 怎麼處理 | load 時 zero-pad(Python PaddingMode.ZERO、C++ .load_masked()) |
| 部分 M/N 邊緣 tile 怎麼處理 | store 端 OOB-discard(Python ct.store、C++ .store_masked()) |
| reduction 軸:C++ vs Python | C++ 永遠保留軸;Python 預設丟掉,keepdims=True 才保留 |
2x4 沿 axis 1 sum 的 Python 預設形狀 |
(2,)(軸丟掉);C++ 為 (2,1) |
| scan / cumsum 的輸出維度 | 與輸入相同,每位置 = 沿軸到該位置(含)之累積 |
| transpose 換哪些軸 | 前兩個軸(C++ trailing 維度保留) |
| permute 怎麼指定順序 | Python 傳軸索引 tuple;C++ 傳 ct::dimension_map |
| selection:Python vs C++ 寫法 | ct.where(cond, x, y) vs ct::select(cond, lhs, rhs) |
| selection 的 cond 形狀 | 會 broadcast 到運算元形狀,運算元型別需相容 |
| tensor cores 何時用到 | compiler 把 tile primitive(尤其 matmul/mma)映射時,於可用硬體自動採用 |