Slide 40
Slide 40 text
class SimpleClassifier(Model):
def forward(self,
text: TextFieldTensors,
label: torch.Tensor) -> Dict[str, torch.Tensor]:
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(text)
# Shape: (batch_size, num_tokens)
mask = util.get_text_field_mask(text)
# Shape: (batch_size, encoding_dim)
encoded_text = self.encoder(embedded_text, mask)
# Shape: (batch_size, num_labels)
logits = self.classifier(encoded_text)
# Shape: (batch_size, num_labels)
probs = torch.nn.functional.softmax(logits)
# Shape: (1,)
loss = torch.nn.functional.cross_entropy(logits, label)
self.accuracy(logits, label)
return {'loss': loss, 'probs': probs}
AllenNLP学習予測: モデルの評価 (2)
40
● allennlp evaluate コマンドで評価(のみ)を実行
○ allennlp train コマンドの最後にも評価が行われる
@Model.register('simple_classifier')
class SimpleClassifier(Model):
def __init__(self,
vocab: Vocabulary,
embedder: TextFieldEmbedder,
encoder: Seq2VecEncoder):
super().__init__(vocab)
self.embedder = embedder
self.encoder = encoder
num_labels = vocab.get_vocab_size("labels")
self.classifier = torch.nn.Linear(
encoder.get_output_dim(), num_labels)
self.accuracy = CategoricalAccuracy()
評価結果は指定した serialization_dir
(allennlp train コマンド時),
output_file (allennlp evaluate
コマンド時) に json 形式で保存される
{'accuracy': 0.855, 'loss': 0.3686505307257175}
AllenNLP について > ことはじめ > 学習・予測 > API 化 > Registrable 機構 > オレオレ Subcommand