这是一篇文献上附上的代码,文献的目的是根据深度学习探究材料的原子性质与其他性质的关系。
我截取这一部分的目的应该是想要把原子特征和晶体特征联系起来,但是我并不能理解这一部分的代码,谢谢大家
https://github.com/txie-93/cgcnn/blob/master/cgcnn/model.py完整代码在这里,目标部分在第168行到最后
def pooling(self, atom_fea, crystal_atom_idx): """ Pooling the atom features to crystal features N: Total number of atoms in the batch N0: Total number of crystals in the batch Parameters ---------- atom_fea: Variable(torch.Tensor) shape (N, atom_fea_len) Atom feature vectors of the batch crystal_atom_idx: list of torch.LongTensor of length N0 Mapping from the crystal idx to atom idx """ assert sum([len(idx_map) for idx_map in crystal_atom_idx]) ==\ atom_fea.data.shape[0] summed_fea = [torch.mean(atom_fea[idx_map], dim=0, keepdim=True) for idx_map in crystal_atom_idx] return torch.cat(summed_fea, dim=0)