Slide 13
Slide 13 text
13
Transformer Pruning について
教師あり・教師なしの手法を実装
from textpruner import TransformerPruner, TransformerPruningConfig
transformer_pruning_config = TransformerPruningConfig(
target_ffn_size=1536,
target_num_of_heads=6,
pruning_method="iterative", # or "masks"
n_iters=2,
use_logits=False, # True for unsupervised mode
)
pruner = TransformerPruner(
model,
transformer_pruning_config=transformer_pruning_config,
)
# torch_ds = hf_ds.set_format(type='torch')
# dataloader = torch.utils.data.DataLoader(torch_ds, batch_size=BATCH_SIZE)
pruner.prune(dataloader=dataloader, save_model=True)
# dataloader: PyTorch dataloader
● attention head の数,FFN のサイズの枝刈りが可能
以下を準備
Ø transformers 形式の
モデル
Ø Pytorch の dataloader
引数について
Ø pruning_method
iterative: IS に基づく
枝刈りを反復
masks: 指定した箇所
を枝刈りする
Ø use_logits
True: 教師なし枝刈り
False: 教師あり枝刈り