vacation std = pyro.sample('std', dist.Categorical(torch.Tensor([0.25, 0.25, 0.25, 0.25]))) print("std", std) points = 20 po = pyro.sample('po', dist.Bernoulli(torch.Tensor([0.7]))) sm = pyro.sample('sm', dist.Bernoulli(torch.Tensor([0.5]))) po_factor = 0.8 if po else 1.2 sm_factor = 0.6 if sm else 1 print("po", po_factor) print("sm", sm_factor) for i in range(9): team_factor = pyro.sample('team_factor', dist.Uniform(1, 2)) print("team_factor", team_factor) mean = points * team_factor * po_factor * sm_factor days = pyro.sample('days', dist.Normal(mean, std.float())) print(days)