Some frequent one-hot method in PyTorch
use scatter_ function
labels = [0, 1, 4, 7, 3, 2]
one_hot = torch.zeros(6, 8).scatter_(dim = 1, index = labels, value = 1)
use index_select() funtion
labels = [0, 1, 4, 7, 3, 2]
index = torch.eye(8)
one_hot = torch.index_select(index, dim = 0, index = labels)
use Embedding module
emb = nn.Embedding(8, 8)
emb.weight.data = torch.eye(8)
then we can get
emb(Variable(torch.LongTensor([1, 2], [3, 4])))
Variable containing:
(0,.,.) =
0 1 0 0 0 0 0 0
0 0 1 0 0 0 0 0
(1 ,.,.) =
0 0 0 1 0 0 0 0
0 0 0 0 1 0 0 0
create a module
class One_Hot(nn.Module):
def __init__(self, depth):
super(One_Hot,self).__init__()
self.depth = depth
self.ones = torch.sparse.torch.eye(depth)
def forward(self, X_in):
X_in = X_in.long()
return Variable(self.ones.index_select(0,X_in.data))
def __repr__(self):
return self.__class__.__name__ + "({})".format(self.depth)