Slide 33
Slide 33 text
5. エッジAI推論のサンプル – 訓練スクリプト作成
33
def save(model: torch.nn.modules.Module, path: str) -> None:
suffix = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
path = os.path.join(path, f'model-{suffix}.pt')
# If you use `model.state_dict()`, SageMaker compilation will fail.
torch.save(model, path)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
# hyperparameters sent by the client are passed as command-line arguments to the script.
# input data and model directories
parser.add_argument('--model_dir', type=str)
parser.add_argument('--sm_model_dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN'))
parser.add_argument('--test', type=str, default=os.environ.get('SM_CHANNEL_TEST'))
args, _ = parser.parse_known_args()
return args
training.py