TREEQN
def transition(zl):
# -- [batch_size x num_actions x hidden_dimension]
return zl.unsqueeze(1) + F.tanh(torch.einsum("bk,aki->bai", [zl, W]) + b)
def transition(zl):
# -- [batch_size x num_actions x hidden_dimension]
return zl.unsqueeze(1) + F.tanh(torch.einsum("bk,aki->bai", [zl, W]) + b)