Tile 運算與基本操作 (Tile Operations and Primitives)

重點總覽

項目 重點
Element-wise 算術 tile 支援標準逐元素算術;形狀相容但不同時,小的會 broadcast 對齊後再運算
Broadcasting 規則 遵循 NumPy 語義:scalar 複製、singleton(長度1) 拉伸、低 rank 對齊到高 rank 的 trailing 維度(缺的 leading 維度當 singleton)
Broadcasting 例外 兩個對應維度都非 singleton 且不相等 → ill-formed
Arithmetic operators 逐元素套用、產生 broadcast 形狀的新 tile;型別不同時偏好保留較多資訊者
Scalar-tile 型別差異 scalar 需 narrow 時:Python 提升型別、C++ 視為 ill-formed
Tile primitives factory、load/store、算術皆屬語言內建 primitive;compiler 映射到硬體(含 tensor cores
Matrix Multiply matmul (a @ b) 與 mma (a @ b + acc);mma 的 accumulator 跨 K-tile 累加
GEMM 慣用法 無論輸入精度都用 FP32 累加,store 時 cast 成輸出型別
Reductions 把 tile 收斂成 scalar 或一列 scalar;Python 預設丟掉 reduced 軸、C++ 永遠保留
Scans 沿軸的累積(running)版本,例如 prefix-sum/cumsum,輸出維度與輸入相同
Transpose / Permute 重排軸但不動資料;transpose 換前兩軸、permute 任意重排
Element-wise Selection tile 版條件式;Python ct.where(cond,x,y)、C++ ct::select(cond,lhs,rhs)
Mathematical Functions ct namespace 提供逐元素數學函式,回傳同形狀 tile,也能作用於 scalar

Element-wise Arithmetic and Broadcasting

Tile 支援標準的逐元素 (element-wise) 算術。當兩個運算元形狀相容但不同時,較小者會先 broadcast 對齊到較大者,再執行運算。


Broadcasting 規則

Broadcasting 遵循 NumPy 語義,三條規則:

  1. Scalar 複製:scalar 被複製到整個 tile 的每個元素。
  2. Singleton 拉伸:長度為 1 的維度(singleton)被拉伸以匹配另一運算元對應維度的長度。
  3. Rank 對齊:較低 rank 的運算元會對齊到較高 rank 運算元的尾端 (trailing) 維度,缺少的前導 (leading) 維度視為 singleton(即 rank promotion)。
例外(ill-formed)

若兩個對應維度都非 singleton 且彼此不相等,該運算為 ill-formed(無法成立)。例如 (3)(4) 無法 broadcast。

下例同時觸發「singleton 拉伸」與「rank promotion」:rank-2 的 8x2 先被 rank-promote 成 1x8x2,再與 rank-3 的 4x1x2 broadcast 到共同形狀 4x8x2

auto x = ct::iota<ct::tile<int, ct::shape<8, 2>>>();     // 8x2   (rank 2)
auto y = ct::iota<ct::tile<int, ct::shape<4, 1, 2>>>();  // 4x1x2 (rank 3)
auto z = x + y;  // x 先 promote 成 1x8x2,再 broadcast 成 4x8x2
x = ct.full((8, 2), 3, dtype=ct.int32)     # 8x2   (rank 2)
y = ct.full((4, 1, 2), 5, dtype=ct.int32)  # 4x1x2 (rank 3)
z = x + y  # x 先 promote 成 1x8x2,再 broadcast 成 4x8x2

對齊與拉伸的過程:

   x: (8, 2)            -- rank 2
   ↓ rank promotion (補 leading singleton)
   x: (1, 8, 2)
   y: (4, 1, 2)         -- rank 3
  ─────────────────────────────────────
  逐維對齊 (trailing 對齊):
    dim0:  1 vs 4   -> 拉伸 1 → 4
    dim1:  8 vs 1   -> 拉伸 1 → 8
    dim2:  2 vs 2   -> 相等
  ─────────────────────────────────────
  broadcast shape = (4, 8, 2)

Arithmetic Operators

所有支援的算術運算子都逐元素套用於 tile,並產生一個 broadcast 形狀的新 tile。scalar 與 tile 結合時,scalar 被 broadcast 到每個元素。

當運算元型別不同時,偏好保留較多資訊(精度或範圍較大)的型別

語言 scalar 需 narrow 時的行為
Python 把結果提升 (promote) 成能同時容納兩者的型別
C++ 視該表達式為 ill-formed(編譯拒絕)
using i32x8 = ct::tile<int, ct::shape<8>>;
i32x8 x = ct::full<i32x8>(3);
x + 2;    // OK       - int 字面值與 int tile 元素型別相符
x + 2.5;  // ill-formed - 2.5 會 narrow 成 int
x = ct.full((8,), 3, dtype=ct.int32)
x + 2     # int32   - int 字面值與 int32 tile dtype 相符
x + 2.5   # float32 - 結果被 promote 以同時容納兩者
實務建議

盡量用 tile 元素型別寫 scalar 字面值(如 float tile 就寫 2.0f),要不同精度時用顯式轉換。這些規則在 kernel 內對 loaded tile 同樣適用。

kernel 內逐元素運算(2.0f 與 float tile 相符,無 narrowing;scalar 先 broadcast,再逐元素 +):

auto x = aView.load(bx);
auto y = bView.load(bx);
auto z = 2.0f * x + y;   // scalar broadcast → 逐元素乘加
cView.store(z, bx);
x = ct.load(a, index=(bid,), shape=(TILE,))
y = ct.load(b, index=(bid,), shape=(TILE,))
z = 2.0 * x + y          # 2.0 為鬆散型別 float 常數,配 float tile 仍為 float
ct.store(c, index=(bid,), tile=z)
控制 rounding / subnormal

若需要明確控制 rounding modesubnormal handling,CUDA Tile API 提供可接受這些參數的數學函式(如 ct.add / ct::add),而非用運算子。


Tile Primitives 概觀

Factory 函式(建立 tile)、load/store、以及逐元素算術,全都是 tile primitive:屬於語言本身的一部分。


Matrix Multiply

兩個 tile 的矩陣相乘是實作陣列矩陣乘法的基礎運算。CUDA Tile 提供兩種形式:

形式 寫法 說明
matmul(純乘) a @ b 單純矩陣相乘
mma(乘加) a @ b + acc accumulator 把部分積 (partial products) 從一個 K-tile 帶到下一個,適合 tiled matmul 的內層迴圈

慣用 GEMM 模式:無論輸入精度為何,都以 FP32 累加,store 時再 cast 成輸出元素型別。

using f32_acc = ct::tile<float, ct::shape<32, 32>>;
auto acc = ct::full<f32_acc>(0.0f);              // FP32 accumulator
std::size_t num_k = (K + tk - 1) / tk;           // = ceil(K / tk)
for (auto k : ct::irangesize_t{0}, num_k) {
  acc = ct::mma(aView.load_masked(bx, k),        // 部分 K-tile zero-pad
                bView.load_masked(k, by),
                acc);                            // acc += a @ b
}
cView.store_masked(acc, bx, by);                 // 丟棄 OOB 邊緣 lane
acc = ct.full((tm, tn), 0, dtype=ct.float32)     # FP32 accumulator
for k in range(num_k):
    a = ct.load(A, index=(bx, k), shape=(tm, tk),
                padding_mode=ct.PaddingMode.ZERO) # 部分 K-tile zero-pad
    b = ct.load(B, index=(k, by), shape=(tk, tn),
                padding_mode=ct.PaddingMode.ZERO)
    acc = ct.mma(a, b, acc)                       # acc += a @ b
ct.store(C, index=(bx, by), tile=acc.astype(C.dtype))  # cast + store
 K-loop 結構 (M×K · K×N → M×N):
   for k in 0 .. ceil(K/tk)-1:
        ┌────────┐   ┌────────┐
        │ A[bx,k]│ @ │ B[k,by]│  -> 部分積累加進 acc (FP32)
        └────────┘   └────────┘
   迴圈結束: store_masked / store(astype) 寫回 C,cast 成輸出型別

Reductions and Scans

Reduction 把一個 tile 收斂成 scalar 或一列 scalar。典型用途:softmax 的分母、layer norm 的 mean/variance、attention scoring 的 max。

最該記住的一點是結果的形狀

語言 reduced 軸的處理
Python 預設丟掉 (drop) 該軸;傳 keepdims=True 才保留為長度 1
C++ 永遠保留該軸,維持 tile 的 rank

下例對 2x4 tile 沿 axis 1 做 sum:

using i32x2x4 = ct::tile<int, ct::shape<2, 4>>;
auto x = ct::iota<i32x2x4>();        // [[0,1,2,3],[4,5,6,7]]
auto row_sums = ct::sum(x, 1_ic);    // shape (2,1) 軸保留; [[6],[22]]
x   = ct.arange(8, dtype=ct.int32).reshape((2, 4))  # [[0,1,2,3],[4,5,6,7]]
s   = ct.sum(x, axis=1)                  # shape (2,)  軸丟掉; [6, 22]
s_k = ct.sum(x, axis=1, keepdims=True)   # shape (2,1) 軸保留; [[6],[22]]
 input 2x4:  [[0,1,2,3],
              [4,5,6,7]]   --- sum 沿 axis=1 --->

   C++           : shape (2,1)  [[6],[22]]   (永遠保留)
   Python 預設   : shape (2,)   [6, 22]      (丟掉軸)
   Python keepdims: shape (2,1) [[6],[22]]   (保留為長度 1)

Scan 是 reduction 的「running(累積)」版本,沿某軸產生累積結果。例如 prefix-sum (cumsum) 輸出維度與輸入相同,每個位置的值是沿指定軸到該位置(含)為止所有元素之和。


Transpose and Permutation

兩個相關 primitive 重排 tile 的軸但不觸碰資料

常用於 tile 邏輯佈局需改變之處:把 matmul 運算元 materialize 成轉置、attention block 中交換 row/column、或在 broadcast 前對齊軸。

語言 transpose permute
Python ct.transpose(x):rank-2 換兩軸;更高 rank 需給 axis0/axis1 ct.permute(x, axes):傳軸索引 tuple
C++ ct::transpose(x):交換前兩維(trailing 維度保留) ct::permute(x, map):傳 ct::dimension_map 描述新順序
using t2d = ct::tile<int, ct::shape<2, 4>>;
auto tx = ct::iota<t2d>();
auto ty = ct::transpose(tx);                              // shape (4, 2)
auto tz = ct::iota<ct::tile<int, ct::shape<2, 2, 2>>>();
auto tw = ct::permute(tz, ct::dimension_map{2_ic, 0_ic, 1_ic}); // (0,1,2)->(2,0,1)
tx = ct.arange(8, dtype=ct.int32).reshape((2, 4))
ty = ct.transpose(tx)              # shape (4, 2)
tz = ct.arange(8, dtype=ct.int32).reshape((2, 2, 2))
tw = ct.permute(tz, (2, 0, 1))     # 軸 (0,1,2) -> (2,0,1)

Element-wise Selection

逐元素 selection 是 tile 版的條件式:給一個 boolean tile 與兩個運算元 tile,每個輸出元素依對應的 boolean 從其一挑選

auto cond = ct::iota<ct::tile<int, ct::shape<4>>>() < 2;  // {T, T, F, F}
auto t = ct::full<ct::tile<float, ct::shape<4>>>( 1.0f);
auto f = ct::full<ct::tile<float, ct::shape<4>>>(-1.0f);
auto r = ct::select(cond, t, f);                          // {1, 1, -1, -1}
cond    = ct.arange(4, dtype=ct.int32) < 2       # [T, T, F, F]
x_true  = ct.full((4,), 1.0,  dtype=ct.float32)
x_false = ct.full((4,), -1.0, dtype=ct.float32)
result  = ct.where(cond, x_true, x_false)        # [1, 1, -1, -1]
cond 為 true 取第一個運算元

兩語言皆「cond 為 true → 取 x/lhs;false → 取 y/rhs」,與 NumPy where 一致。


Mathematical Functions

ct namespace 提供常見的逐元素數學運算函式。每個函式對輸入 tile 逐元素套用、回傳同形狀 tile,且也能作用於 tile code 中的 scalar

類別 函式
基本算術 add, sub, mul
除法 truediv, floordiv, cdiv, mod
冪/指對數 pow, exp, exp2, log, log2
開根 sqrt, rsqrt
三角 sin, cos, tan
雙曲 sinh, cosh, tanh
極值/取負 minimum, maximum, negative
取整 floor, ceil
函式 vs 運算子

用函式版(如 ct::add)而非運算子,可在需要時傳入 rounding mode / subnormal 等參數;完整清單見 cuTile Python 與 CUDA Tile C++ 的 Math Operations API reference。


考試/測驗重點

情境/關鍵字 答案
broadcasting 遵循什麼語義 NumPy 語義:scalar 複製、singleton 拉伸、低 rank 對齊 trailing 維度
兩維都非 singleton 且不相等 ill-formed(無法 broadcast)
8x24x1x2 相加結果形狀 8x2 先 promote 成 1x8x2,再 broadcast 成 4x8x2
缺少的維度補在哪 補在 leading(前導),視為 singleton(rank promotion)
int + float 結果型別 float(保留較多資訊)
int16 + int32 結果型別 int32
int tile + 2.5 在 C++ ill-formed(2.5 會 narrow 成 int)
int32 tile + 2.5 在 Python float32(結果被 promote)
matmul vs mma matmul = a @ bmma = a @ b + acc(accumulator 跨 K-tile 累加)
GEMM 慣用累加精度 一律 FP32 累加,store 時 cast 成輸出型別
K-loop 迭代次數 ceil(K / tk),等價 (K + tk - 1) / tk
部分 K-tile 怎麼處理 load 時 zero-pad(Python PaddingMode.ZERO、C++ .load_masked()
部分 M/N 邊緣 tile 怎麼處理 store 端 OOB-discard(Python ct.store、C++ .store_masked()
reduction 軸:C++ vs Python C++ 永遠保留軸Python 預設丟掉keepdims=True 才保留
2x4 沿 axis 1 sum 的 Python 預設形狀 (2,)(軸丟掉);C++ 為 (2,1)
scan / cumsum 的輸出維度 與輸入相同,每位置 = 沿軸到該位置(含)之累積
transpose 換哪些軸 前兩個軸(C++ trailing 維度保留)
permute 怎麼指定順序 Python 傳軸索引 tuple;C++ 傳 ct::dimension_map
selection:Python vs C++ 寫法 ct.where(cond, x, y) vs ct::select(cond, lhs, rhs)
selection 的 cond 形狀 broadcast 到運算元形狀,運算元型別需相容
tensor cores 何時用到 compiler 把 tile primitive(尤其 matmul/mma)映射時,於可用硬體自動採用