| | import torch
|
| |
|
| | def gather(x, indices):
|
| | indices = indices.view(-1, indices.shape[-1]).tolist()
|
| | out = torch.cat([x[i] for i in indices])
|
| |
|
| | return out
|
| |
|
| | def gather_nd(x, indices):
|
| | newshape = indices.shape[:-1] + x.shape[indices.shape[-1]:]
|
| | indices = indices.view(-1, indices.shape[-1]).tolist()
|
| | out = torch.cat([x[tuple(i)] for i in indices])
|
| |
|
| | return out.reshape(newshape)
|
| |
|
| | def gen_node_indices(size_list):
|
| | '''generate node index for extraction of nodes of each graph from batched data'''
|
| | node_num = []
|
| | node_range = []
|
| | size_list = [int(i) for i in size_list]
|
| | for i, n in enumerate(size_list):
|
| | node_num.extend([i]*n)
|
| | node_range.extend(list(range(n)))
|
| |
|
| | node_num = torch.tensor(node_num)
|
| | node_range = torch.tensor(node_range)
|
| | indices = torch.stack([node_num, node_range], axis=1)
|
| | return indices, node_num, node_range
|
| |
|
| | def segment_max(x, size_list):
|
| | size_list = [int(i) for i in size_list]
|
| | return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)])
|
| |
|
| | def segment_sum(x, size_list):
|
| | size_list = [int(i) for i in size_list]
|
| | return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)])
|
| |
|
| | def segment_softmax(gate, size_list):
|
| | segmax = segment_max(gate, size_list)
|
| |
|
| | segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| | subtract = gate - segmax_expand
|
| | exp = torch.exp(subtract)
|
| | segsum = segment_sum(exp, size_list)
|
| |
|
| | segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| | attention = exp / (segsum_expand + 1e-16)
|
| |
|
| | return attention
|
| |
|
| | def pad_V(V, max_n):
|
| | N, C = V.shape
|
| | if max_n > N:
|
| | zeros = torch.zeros(max_n-N, C)
|
| | V = torch.cat([V, zeros], dim=0)
|
| | return V
|
| |
|
| | def pad_A(A, max_n):
|
| | N, L, _ = A.shape
|
| | if max_n > N:
|
| | zeros = torch.zeros(N, L, max_n-N)
|
| | A = torch.cat([A, zeros], dim=-1)
|
| | zeros = torch.zeros(max_n-N, L, max_n)
|
| | A = torch.cat([A, zeros], dim=0)
|
| |
|
| | return A
|
| |
|
| | def pad_prot(P, max_n):
|
| | N, = P.shape
|
| | if max_n > N:
|
| | zeros = torch.zeros(max_n-N)
|
| | P = torch.cat([P, zeros], dim=0)
|
| |
|
| | return P.type(torch.IntTensor)
|
| |
|
| | def create_batch(input, pad=False, device=torch.device('cpu')):
|
| | vl = []
|
| | al = []
|
| | gsl = []
|
| | msl = []
|
| | ssl = []
|
| | lbl = []
|
| | idxs = []
|
| | smis = []
|
| |
|
| | for d in input:
|
| | vl.append(d['V'])
|
| | al.append(d['A'])
|
| | gsl.append(d['G'])
|
| | msl.append(d['mol_size'])
|
| | ssl.append(d['subgraph_size'])
|
| | lbl.append(d['label'])
|
| | idxs.append(d['index'])
|
| | smis.append(d['smiles'])
|
| |
|
| | if gsl[0] is not None:
|
| | gsl = torch.stack(gsl, dim=0).to(device)
|
| |
|
| | if pad:
|
| | max_n = max(map(lambda x:x.shape[0], vl))
|
| | vl1 = []
|
| | for v in vl:
|
| | vl1.append(pad_V(v, max_n))
|
| | al1 = []
|
| | for a in al:
|
| | al1.append(pad_A(a, max_n))
|
| |
|
| | return {'V': torch.stack(vl1, dim=0).to(device),
|
| | 'A': torch.stack(al1, dim=0).to(device),
|
| | 'G': gsl,
|
| | 'mol_size': torch.cat(msl, dim=0).to(device),
|
| | 'subgraph_size': torch.stack(ssl, dim=0).to(device),
|
| | 'label': torch.stack(lbl, dim=0).to(device),
|
| | 'index': idxs,
|
| | 'smiles': smis}
|
| |
|
| | return {'V': torch.stack(vl, dim=0).to(device),
|
| | 'A': torch.stack(al, dim=0).to(device),
|
| | 'G': gsl,
|
| | 'mol_size': torch.cat(msl, dim=0).to(device),
|
| | 'subgraph_size': torch.stack(ssl, dim=0).to(device),
|
| | 'label': torch.stack(lbl, dim=0).to(device),
|
| | 'index': idxs,
|
| | 'smiles': smis}
|
| |
|
| | def create_mol_protein_batch(input, pad=False, device=torch.device('cpu'), pr=True):
|
| | vl = []
|
| | al = []
|
| | gsl = []
|
| | msl = []
|
| | ssl = []
|
| | prot = []
|
| | seq = []
|
| | lbl = []
|
| | idxs = []
|
| | smis = []
|
| | fpl = []
|
| |
|
| | for d in input:
|
| | vl.append(d['V'])
|
| | al.append(d['A'])
|
| | gsl.append(d['G'])
|
| | msl.append(d['mol_size'])
|
| | ssl.append(d['subgraph_size'])
|
| | prot.append(d['protein_seq'])
|
| | seq.append(d['protein'])
|
| | lbl.append(d['label'])
|
| | idxs.append(d['index'])
|
| | smis.append(d['smiles'])
|
| | if 'fp' in d:
|
| | fpl.append(d['fp'])
|
| |
|
| | if gsl[0] is not None:
|
| | if pad:
|
| | gsl = torch.stack(gsl, dim=0).to(device)
|
| | else:
|
| | gsl = [torch.unsqueeze(g, 0) for g in gsl]
|
| |
|
| | if pad:
|
| | max_n = max(map(lambda x:x.shape[0], vl))
|
| | vl1 = []
|
| | if pr:
|
| | print('\tPadding V to max_n:', max_n)
|
| | for v in vl:
|
| | vl1.append(pad_V(v, max_n))
|
| |
|
| | al1 = []
|
| | if pr:
|
| | print('\tPadding A to max_n:', max_n)
|
| | for a in al:
|
| | al1.append(pad_A(a, max_n))
|
| |
|
| | max_prot = max(map(lambda x:x.shape[0], prot))
|
| | prot1 = []
|
| | if pr:
|
| | print('\tPadding protein_seq to max_n:', max_prot)
|
| | for p in prot:
|
| | prot1.append(pad_prot(p, max_prot))
|
| |
|
| | fpt = None
|
| | if fpl:
|
| | fpt = torch.stack(fpl, dim=0).to(device)
|
| |
|
| | return {'V': torch.stack(vl1, dim=0).to(device),
|
| | 'A': torch.stack(al1, dim=0).to(device),
|
| | 'G': gsl,
|
| | 'fp': fpt,
|
| | 'mol_size': torch.cat(msl, dim=0).to(device),
|
| | 'subgraph_size': torch.stack(ssl, dim=0).to(device),
|
| | 'protein_seq': torch.stack(prot1, dim=0).to(device),
|
| | 'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
| | 'index': idxs,
|
| | 'smiles': smis,
|
| | 'protein': seq}
|
| |
|
| | return {'V': [torch.unsqueeze(v, 0) for v in vl],
|
| | 'A': [torch.unsqueeze(a, 0) for a in al],
|
| | 'G': gsl,
|
| | 'fp': fpt,
|
| | 'mol_size': torch.cat(msl, dim=0).to(device),
|
| | 'subgraph_size': [torch.unsqueeze(s, 0) for s in ssl],
|
| | 'protein_seq': [torch.unsqueeze(p, 0) for p in prot],
|
| | 'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
| | 'index': idxs,
|
| | 'smiles': smis,
|
| | 'protein': seq}
|
| |
|
| | def create_mol_protein_fp_batch(input, pad=False, device=torch.device('cpu'), pr=True):
|
| | fp = []
|
| | prot = []
|
| | lbl = []
|
| | idxs = []
|
| | smis = []
|
| |
|
| | for d in input:
|
| | fp.append(d['fp'])
|
| | prot.append(d['protein_seq'])
|
| | lbl.append(d['label'])
|
| | idxs.append(d['index'])
|
| | smis.append(d['smiles'])
|
| |
|
| | if pad:
|
| | max_prot = max(map(lambda x:x.shape[0], prot))
|
| | prot1 = []
|
| | if pr:
|
| | print('\tPadding protein_seq to max_n:', max_prot)
|
| | for p in prot:
|
| | prot1.append(pad_prot(p, max_prot))
|
| |
|
| | return {'fp': torch.stack(fp, dim=0).to(device),
|
| | 'protein_seq': torch.stack(prot1, dim=0).to(device),
|
| | 'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
| | 'index': idxs,
|
| | 'smiles': smis}
|
| |
|
| | return {'fp': [torch.unsqueeze(f, 0) for f in fp],
|
| | 'protein_seq': [torch.unsqueeze(p, 0) for p in prot],
|
| | 'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
| | 'index': idxs,
|
| | 'smiles': smis} |