-> None: suffix = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') path = os.path.join(path, f'model-{suffix}.pt') # If you use `model.state_dict()`, SageMaker compilation will fail. torch.save(model, path) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # hyperparameters sent by the client are passed as command-line arguments to the script. # input data and model directories parser.add_argument('--model_dir', type=str) parser.add_argument('--sm_model_dir', type=str, default=os.environ.get('SM_MODEL_DIR')) parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN')) parser.add_argument('--test', type=str, default=os.environ.get('SM_CHANNEL_TEST')) args, _ = parser.parse_known_args() return args training.py