请问这个池化层的功能是如何实现的?

CSDN问答 2021-12-30 11:06:17 阅读数:171

功能 实现 请问 问这

这是一篇文献上附上的代码,文献的目的是根据深度学习探究材料的原子性质与其他性质的关系。
我截取这一部分的目的应该是想要把原子特征和晶体特征联系起来,但是我并不能理解这一部分的代码,谢谢大家

https://github.com/txie-93/cgcnn/blob/master/cgcnn/model.py完整代码在这里,目标部分在第168行到最后

def pooling(self, atom_fea, crystal_atom_idx): """ Pooling the atom features to crystal features N: Total number of atoms in the batch N0: Total number of crystals in the batch Parameters ---------- atom_fea: Variable(torch.Tensor) shape (N, atom_fea_len) Atom feature vectors of the batch crystal_atom_idx: list of torch.LongTensor of length N0 Mapping from the crystal idx to atom idx """ assert sum([len(idx_map) for idx_map in crystal_atom_idx]) ==\ atom_fea.data.shape[0] summed_fea = [torch.mean(atom_fea[idx_map], dim=0, keepdim=True) for idx_map in crystal_atom_idx] return torch.cat(summed_fea, dim=0)



采纳答案:

assert 是一个断言判断,主要代码是变量summed_fea ,它等同下面代码

summed_fea=[]for idx_map in crystal_atom_idx: r = torch.mean(atom_fea[idx_map], dim=0, keepdim=True) summed_fea.append(r)

这段代码要分清楚 torch.mean、torch.cat以及各个参数是什么就行了
如果对你有帮助,可以点击我这个回答右上方的【采纳】按钮,给我个采纳吗,谢谢


版权声明:本文为[CSDN问答]所创,转载请带上原文链接,感谢。 https://ask.csdn.net/questions/7619864