Copyright © Fixstars Group
問題解決事例 - Non Maximum Suppression(4)
● ONNXのカスタムオペレータとして出力
35
@parse_args('v', 'v', 'i', 'f', 'f')
def symbolic_efficient_nms_standard(g, boxes, scores, num_boxes, score_threshold, iou_threshold):
num_detections, detection_boxes, detection_scores, detection_classes = g.op('tensorrt::EfficientNMS_TRT',
boxes, scores,
outputs=4,
score_threshold_f=score_threshold,
iou_threshold_f=iou_threshold,
max_output_boxes_i=num_boxes,
background_class_i=-1,
score_activation_i=0,
box_coding_i=0,
)
return num_detections, detection_boxes, detection_scores, detection_classes
def forward(self, input):
…
torch.ops.load_library('custom_ops.so')
return torch.ops.custom_ops.efficient_nms_standard(transformed_anchors, confs, num_boxes, self.threshold, self.iou_threshold)
register_custom_op_symbolic('custom_ops::efficient_nms_standard‘, symbolic_efficient_nms_standard, 11)
「ドメイン名::プラグイン名」で
ONNXのカスタムオペレータ出力
PyTorchの推論(ONNXエクスポート)時
はカスタムオペレータのDLLを呼び出す