| import copy
|
| import json
|
| import math
|
| import numpy as np
|
| import pandas as pd
|
| import torch
|
| from scipy.spatial import cKDTree
|
| from rdkit import Chem
|
| from rdkit.Chem import RWMol
|
| from rdkit.Chem import Draw, AllChem
|
| from rdkit.Chem import rdDepictor
|
| import matplotlib.pyplot as plt
|
| import re
|
|
|
| from typing import List
|
| import matplotlib.pyplot as plt
|
| from matplotlib.patches import Rectangle, Circle
|
|
|
|
|
| COLORS = {
|
| u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0',
|
| u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75'
|
| }
|
|
|
|
|
| def view_box_center(bond_bbox,heavy_centers):
|
| fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
| for box in bond_bbox:
|
| x1, y1, x2, y2 = box
|
| width = x2 - x1
|
| height = y2 - y1
|
| rect = Rectangle((x1, y1), width, height, linewidth=1, edgecolor='blue', facecolor='none')
|
| ax.add_patch(rect)
|
|
|
|
|
| for center in heavy_centers:
|
| x, y = center
|
| circle = Circle((x, y), radius=5, edgecolor='red', facecolor='none', linewidth=1)
|
| ax.add_patch(circle)
|
|
|
|
|
| x_min = min(bond_bbox[:, 0].min(), heavy_centers[:, 0].min()) - 10
|
| x_max = max(bond_bbox[:, 2].max(), heavy_centers[:, 0].max()) + 10
|
| y_min = min(bond_bbox[:, 1].min(), heavy_centers[:, 1].min()) - 10
|
| y_max = max(bond_bbox[:, 3].max(), heavy_centers[:, 1].max()) + 10
|
| ax.set_xlim(x_min, x_max)
|
| ax.set_ylim(y_min, y_max)
|
|
|
|
|
| ax.set_title("Boxes and Centers")
|
| ax.set_xlabel("X")
|
| ax.set_ylabel("Y")
|
|
|
| plt.gca().set_aspect('equal', adjustable='box')
|
| plt.grid(True, linestyle='--', alpha=0.7)
|
|
|
| def molIDX(mol):
|
| for i, atom in enumerate(mol.GetAtoms()):
|
| atom.SetAtomMapNum(i)
|
|
|
| return mol
|
|
|
| def molIDX_del(mol):
|
| for i, atom in enumerate(mol.GetAtoms()):
|
| atom.SetAtomMapNum(0)
|
| print(i)
|
| return mol
|
| from det_engine import ABBREVIATIONS
|
|
|
|
|
|
|
| def Val_extract_atom_info(error_message):
|
| """
|
| 从错误信息中提取 atomid, atomType 和 valence。
|
| :param error_message: 错误信息字符串
|
| :return: (atomid, atomType, valence) 元组
|
| """
|
|
|
| pattern = r"Explicit valence for atom # (\d+) (\w), (\d+)"
|
| pattern2 =r"Explicit valence for atom # (\d+) (\w) "
|
|
|
| if not isinstance(error_message, type('strs')):
|
| error_message=str(error_message)
|
| match = re.search(pattern, error_message)
|
| match2 = re.search(pattern2, error_message)
|
| if match:
|
|
|
| atomid = int(match.group(1))
|
| atomType = match.group(2)
|
| valence = int(match.group(3))
|
| return atomid, atomType, valence
|
| elif match2:
|
| atomid = int(match2.group(1))
|
| atomType = match2.group(2)
|
|
|
| return atomid, atomType, None
|
|
|
| else:
|
| raise ValueError("无法从错误信息中提取原子信息")
|
|
|
| def calculate_charge_adjustment(atom_symbol, current_valence):
|
| """
|
| 计算需要调整的电荷,根据反馈的原子符号和当前价态。
|
| :param atom_symbol: 原子符号(如 "C")
|
| :param current_valence: 当前价态(如 5)
|
| :return: 需要添加的电荷数(正数表示负电荷,负数表示正电荷)
|
| """
|
| if atom_symbol not in VALENCES:
|
| raise ValueError(f"未知的原子符号: {atom_symbol}")
|
|
|
|
|
| max_valence = max(VALENCES[atom_symbol])
|
| if current_valence is None:
|
| current_valence=max_valence
|
|
|
| if current_valence > max_valence:
|
|
|
| charge_adjustment = current_valence - max_valence
|
| return charge_adjustment
|
| else:
|
|
|
| return 0
|
|
|
| from rdkit.Chem import rdchem, RWMol, CombineMols
|
|
|
| def expandABB(mol,ABBREVIATIONS, placeholder_atoms):
|
| mols = [mol]
|
|
|
|
|
| for idx in sorted(placeholder_atoms.keys(), reverse=True):
|
| group = placeholder_atoms[idx]
|
|
|
| submol = Chem.MolFromSmiles(ABBREVIATIONS[group].smiles)
|
| submol_rw = RWMol(submol)
|
| anchor_atom_idx = 0
|
|
|
| new_mol = RWMol(mol)
|
|
|
| placeholder_idx = idx
|
|
|
| neighbors = [nb.GetIdx() for nb in new_mol.GetAtomWithIdx(placeholder_idx).GetNeighbors()]
|
|
|
| bonds_to_remove = []
|
| for bond in new_mol.GetBonds():
|
| if bond.GetBeginAtomIdx() == placeholder_idx or bond.GetEndAtomIdx() == placeholder_idx:
|
| bonds_to_remove.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()))
|
| for bond in bonds_to_remove:
|
| new_mol.RemoveBond(bond[0], bond[1])
|
|
|
| new_mol.RemoveAtom(placeholder_idx)
|
|
|
| new_neighbors = []
|
| for neighbor in neighbors:
|
| if neighbor < placeholder_idx:
|
| new_neighbors.append(neighbor)
|
| else:
|
| new_neighbors.append(neighbor - 1)
|
|
|
| new_mol = RWMol(CombineMols(new_mol, submol_rw))
|
|
|
|
|
| new_anchor_idx = new_mol.GetNumAtoms() - len(submol_rw.GetAtoms()) + anchor_atom_idx
|
|
|
|
|
| for neighbor in new_neighbors:
|
|
|
| new_mol.AddBond(neighbor, new_anchor_idx, Chem.BondType.SINGLE)
|
| a1=new_mol.GetAtomWithIdx(neighbor)
|
| a2=new_mol.GetAtomWithIdx(new_anchor_idx)
|
| a1.SetNumRadicalElectrons(0)
|
| a2.SetNumRadicalElectrons(0)
|
|
|
| mol = new_mol
|
| mols.append(mol)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| Chem.SanitizeMol(mols[-1])
|
|
|
| modified_smiles = Chem.MolToSmiles(mols[-1])
|
|
|
| return mols[-1], modified_smiles
|
|
|
|
|
| def output_to_smiles(output,idx_to_labels,bond_labels,result):
|
|
|
| x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
|
| y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
|
|
|
| center_coords = torch.stack((x_center, y_center), dim=1)
|
|
|
| output = {'bbox': output["boxes"].to("cpu").numpy(),
|
| 'bbox_centers': center_coords.to("cpu").numpy(),
|
| 'scores': output["scores"].to("cpu").numpy(),
|
| 'pred_classes': output["labels"].to("cpu").numpy()}
|
|
|
|
|
| atoms_list, bonds_list,charge = bbox_to_graph_with_charge(output,
|
| idx_to_labels=idx_to_labels,
|
| bond_labels=bond_labels,
|
| result=result)
|
| smiles, mol= mol_from_graph_with_chiral(atoms_list, bonds_list,charge)
|
| abc=[atoms_list, bonds_list,charge ]
|
|
|
| if isinstance(smiles, type(None)):
|
| print(f"get atoms_list problems")
|
|
|
| elif isinstance(atoms_list,type(None)):
|
| print(f"get atoms_list problems")
|
|
|
|
|
|
|
| return abc,smiles,mol,output
|
|
|
|
|
| def output_to_smiles2(output,idx_to_labels,bond_labels,result):
|
|
|
| x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
|
| y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
|
|
|
| center_coords = torch.stack((x_center, y_center), dim=1)
|
|
|
| output = {'bbox': output["boxes"].to("cpu").numpy(),
|
| 'bbox_centers': center_coords.to("cpu").numpy(),
|
| 'scores': output["scores"].to("cpu").numpy(),
|
| 'pred_classes': output["labels"].to("cpu").numpy()}
|
|
|
|
|
| atoms_list, bonds_list,charge = bbox_to_graph_with_charge(output,
|
| idx_to_labels=idx_to_labels,
|
| bond_labels=bond_labels,
|
| result=result)
|
| smiles, mol= mol_from_graph_with_chiral(atoms_list, bonds_list,charge)
|
| abc=[atoms_list, bonds_list,charge ]
|
| if isinstance(smiles, type(None)):
|
| print(f"get atoms_list problems")
|
|
|
| elif isinstance(atoms_list,type(None)):
|
| print(f"get atoms_list problems")
|
|
|
|
|
|
|
| return abc,smiles,mol,output
|
|
|
|
|
|
|
| def bbox_to_graph(output, idx_to_labels, bond_labels,result):
|
|
|
|
|
| atoms_mask = np.array([True if ins not in bond_labels else False for ins in output['pred_classes']])
|
|
|
|
|
| atoms_list = [idx_to_labels[a] for a in output['pred_classes'][atoms_mask]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| atoms_list = pd.DataFrame({'atom': atoms_list,
|
| 'x': output['bbox_centers'][atoms_mask, 0],
|
| 'y': output['bbox_centers'][atoms_mask, 1]})
|
|
|
|
|
| for idx, row in atoms_list.iterrows():
|
| if row.atom[-1] != '0':
|
| if row.atom[-2] != '-':
|
| overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-1])]
|
| else:
|
| overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-2])]
|
|
|
| kdt = cKDTree(overlapping[['x', 'y']])
|
| dists, neighbours = kdt.query([row.x, row.y], k=2)
|
| if dists[1] < 7:
|
| atoms_list.drop(overlapping.index[neighbours[1]], axis=0, inplace=True)
|
|
|
| bonds_list = []
|
|
|
|
|
| for bbox, bond_type, score in zip(output['bbox'][np.logical_not(atoms_mask)],
|
| output['pred_classes'][np.logical_not(atoms_mask)],
|
| output['scores'][np.logical_not(atoms_mask)]):
|
|
|
|
|
| if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| _margin = 5
|
| else:
|
| _margin = 8
|
|
|
|
|
| anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
|
| oposite_anchor_positions = anchor_positions.copy()
|
| oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
|
|
|
|
|
|
|
| anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
|
|
|
|
|
| atoms_pos = atoms_list[['x', 'y']].values
|
| kdt = cKDTree(atoms_pos)
|
| dists, neighbours = kdt.query(anchor_positions, k=1)
|
|
|
|
|
| if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0:
|
|
|
| begin_idx, end_idx = neighbours[:2]
|
| else:
|
|
|
| begin_idx, end_idx = neighbours[2:]
|
|
|
|
|
| if begin_idx != end_idx:
|
| bonds_list.append((begin_idx, end_idx, idx_to_labels[bond_type], idx_to_labels[bond_type], score))
|
| else:
|
| continue
|
|
|
| return atoms_list, bonds_list
|
|
|
|
|
| def calculate_distance(coord1, coord2):
|
|
|
| return math.sqrt((coord1[0] - coord2[0])**2 + (coord1[1] - coord2[1])**2)
|
|
|
| def assemble_atoms_with_charges(atom_list, charge_list):
|
| used_charge_indices=set()
|
| atom_list = atom_list.reset_index(drop=True)
|
|
|
| kdt = cKDTree(atom_list[['x','y']])
|
| for i, charge in charge_list.iterrows():
|
| if i in used_charge_indices:
|
| continue
|
| charge_=charge['charge']
|
|
|
| dist, idx_atom=kdt.query([charge_list.x[i],charge_list.y[i]], k=1)
|
|
|
| if idx_atom not in atom_list.index:
|
| print(f"Warning: idx_atom {idx_atom} is out of range for atom_list.")
|
| continue
|
| atom_str = atom_list.iloc[idx_atom]['atom']
|
| if atom_str=='*':
|
| atom_=atom_str + charge_
|
| else:
|
| try:
|
| atom_ = re.findall(r'[A-Za-z*]+', atom_str)[0] + charge_
|
| except Exception as e:
|
| print(atom_str,charge_,charge_list)
|
| print(f"@assemble_atoms_with_charges\n {e}\n{atom_list}")
|
| atom_=atom_str + charge_
|
| atom_list.loc[idx_atom,'atom']=atom_
|
|
|
| return atom_list
|
|
|
|
|
|
|
| def assemble_atoms_with_charges2(atom_list, charge_list, max_distance=10):
|
| used_charge_indices = set()
|
|
|
| for idx, atom in atom_list.iterrows():
|
| atom_coord = atom['x'],atom['y']
|
| atom_label = atom['atom']
|
| closest_charge = None
|
| min_distance = float('inf')
|
|
|
| for i, charge in charge_list.iterrows():
|
| if i in used_charge_indices:
|
| continue
|
|
|
| charge_coord = charge['x'],charge['y']
|
| charge_label = charge['charge']
|
|
|
| distance = calculate_distance(atom_coord, charge_coord)
|
|
|
| if distance <= max_distance and distance < min_distance:
|
| closest_charge = charge
|
| min_distance = distance
|
|
|
|
|
| if closest_charge is not None:
|
| if closest_charge['charge'] == '1':
|
| charge_ = '+'
|
| else:
|
| charge_ = closest_charge['charge']
|
| atom_ = atom['atom'] + charge_
|
|
|
|
|
| atom_list.loc[idx,'atom'] = atom_
|
| used_charge_indices.add(tuple(charge))
|
|
|
| else:
|
|
|
| atom_list.loc[idx,'atom'] = atom['atom'] + '0'
|
|
|
| return atom_list
|
|
|
|
|
|
|
| def bbox_to_graph_with_charge(output, idx_to_labels, bond_labels,result):
|
|
|
| bond_labels_pre=bond_labels
|
|
|
| atoms_mask = np.array([True if ins not in bond_labels and ins not in charge_labels else False for ins in output['pred_classes']])
|
|
|
| try:
|
|
|
|
|
| atoms_list = [idx_to_labels[a] for a in output['pred_classes'][atoms_mask]]
|
| if isinstance(atoms_list, pd.Series) and atoms_list.empty:
|
| return None, None, None
|
| else:
|
| atoms_list = pd.DataFrame({'atom': atoms_list,
|
| 'x': output['bbox_centers'][atoms_mask, 0],
|
| 'y': output['bbox_centers'][atoms_mask, 1],
|
| 'bbox': output['bbox'][atoms_mask].tolist() ,
|
| 'scores': output['scores'][atoms_mask].tolist(),
|
| })
|
| except Exception as e:
|
| print(output['pred_classes'][atoms_mask].dtype,output['pred_classes'][atoms_mask])
|
| print(e)
|
| print(idx_to_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| charge_mask = np.array([True if ins in charge_labels else False for ins in output['pred_classes']])
|
| charge_list = [idx_to_labels[a] for a in output['pred_classes'][charge_mask]]
|
| charge_list = pd.DataFrame({'charge': charge_list,
|
| 'x': output['bbox_centers'][charge_mask, 0],
|
| 'y': output['bbox_centers'][charge_mask, 1],
|
| 'scores': output['scores'][charge_mask],
|
|
|
| })
|
|
|
|
|
| try:
|
| atoms_list['atom'] = atoms_list['atom']+'0'
|
| except Exception as e:
|
| print(e)
|
| print(atoms_list['atom'],'atoms_list["atom"] @@ adding 0 ')
|
|
|
|
|
| if len(charge_list) > 0:
|
| atoms_list = assemble_atoms_with_charges(atoms_list,charge_list)
|
|
|
|
|
|
|
|
|
|
|
| for idx, row in atoms_list.iterrows():
|
| if row.atom[-1] != '0':
|
| try:
|
| if row.atom[-2] != '-':
|
| overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-1])]
|
| except Exception as e:
|
| print(row.atom,"@rin case atoms with sign gets detected two times")
|
| print(e)
|
| else:
|
| overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-2])]
|
|
|
| kdt = cKDTree(overlapping[['x', 'y']])
|
| dists, neighbours = kdt.query([row.x, row.y], k=2)
|
| if dists[1] < 7:
|
| atoms_list.drop(overlapping.index[neighbours[1]], axis=0, inplace=True)
|
|
|
| bonds_list = []
|
|
|
|
|
| bond_mask=np.logical_not(atoms_mask) & np.logical_not(charge_mask)
|
| for bbox, bond_type, score in zip(output['bbox'][bond_mask],
|
| output['pred_classes'][bond_mask],
|
| output['scores'][bond_mask]):
|
|
|
|
|
| if len(idx_to_labels)==23:
|
| if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| _margin = 5
|
| else:
|
| _margin = 8
|
| elif len(idx_to_labels)==30:
|
| _margin=0
|
| elif len(idx_to_labels)==24:
|
| _margin=0
|
|
|
| anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
|
| oposite_anchor_positions = anchor_positions.copy()
|
| oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
|
|
|
|
|
|
|
| anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
|
|
|
|
|
| atoms_pos = atoms_list[['x', 'y']].values
|
| kdt = cKDTree(atoms_pos)
|
| dists, neighbours = kdt.query(anchor_positions, k=1)
|
|
|
|
|
| if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0:
|
|
|
| begin_idx, end_idx = neighbours[:2]
|
| else:
|
|
|
| begin_idx, end_idx = neighbours[2:]
|
|
|
|
|
| if begin_idx != end_idx:
|
| if bond_type in bond_labels:
|
| bonds_list.append((begin_idx, end_idx, idx_to_labels[bond_type], idx_to_labels[bond_type], score))
|
| else:
|
| print(f'this box may be charges box not bonds {[bbox, bond_type, score ]}')
|
| else:
|
| continue
|
|
|
|
|
| return atoms_list, bonds_list,charge_list
|
|
|
| def parse_atom(node):
|
| s10 = [str(x) for x in range(10)]
|
|
|
| if 'other' in node:
|
| a = '*'
|
| if '-' in node or '+' in node:
|
| fc = -1 if node[-1] == '-' else 1
|
| else:
|
| fc = int(node[-2:]) if node[-2:] in s10 else 0
|
| elif node[-1] in s10:
|
| if '-' in node or '+' in node:
|
| fc = -1 if node[-1] == '-' else 1
|
| a = node[:-1]
|
| else:
|
| a = node[:-1]
|
| fc = int(node[-1])
|
| elif node[-1] == '+':
|
| a = node[:-1]
|
| fc = 1
|
| elif node[-1] == '-':
|
| a = node[:-1]
|
| fc = -1
|
| else:
|
| a = node
|
| fc = 0
|
| return a, fc
|
|
|
|
|
|
|
| def iou_(box1, box2):
|
| """
|
| 计算两个框的 IoU(Intersection over Union)。
|
| 参数:
|
| box1, box2: [x1, y1, x2, y2] 格式的框坐标
|
|
|
| 返回:
|
| float: IoU 值
|
| """
|
| x1 = max(box1[0], box2[0])
|
| y1 = max(box1[1], box2[1])
|
| x2 = min(box1[2], box2[2])
|
| y2 = min(box1[3], box2[3])
|
|
|
| intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
| area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| union = area1 + area2 - intersection
|
| return intersection / union if union > 0 else 0
|
|
|
|
|
| def calculate_iou(bbox1, bbox2):
|
|
|
| x_min1, y_min1, x_max1, y_max1 = bbox1
|
| x_min2, y_min2, x_max2, y_max2 = bbox2
|
|
|
|
|
| x_min_inter = max(x_min1, x_min2)
|
| y_min_inter = max(y_min1, y_min2)
|
| x_max_inter = min(x_max1, x_max2)
|
| y_max_inter = min(y_max1, y_max2)
|
|
|
|
|
| inter_width = max(0, x_max_inter - x_min_inter)
|
| inter_height = max(0, y_max_inter - y_min_inter)
|
| inter_area = inter_width * inter_height
|
|
|
|
|
| area1 = (x_max1 - x_min1) * (y_max1 - y_min1)
|
| area2 = (x_max2 - x_min2) * (y_max2 - y_min2)
|
|
|
|
|
| union_area = area1 + area2 - inter_area
|
|
|
|
|
| iou = inter_area / union_area if union_area > 0 else 0
|
|
|
|
|
| result = []
|
| if iou == 0:
|
| result.append("无重叠")
|
| elif iou > 0:
|
| result.append("有重叠")
|
| if iou == 1:
|
| result.append("完全重合")
|
| elif inter_area == area2:
|
| result.append("bbox1 包含 bbox2")
|
| elif inter_area == area1:
|
| result.append("bbox2 包含 bbox1")
|
|
|
| return iou, result, inter_area, union_area
|
|
|
| def adjust_bbox1(large_bbox, small_bbox, bond_bbox):
|
|
|
|
|
| x_min_l, y_min_l, x_max_l, y_max_l = large_bbox
|
| x_min_s, y_min_s, x_max_s, y_max_s = small_bbox
|
| x_min_b, y_min_b, x_max_b, y_max_b = bond_bbox
|
| scaled_box= max([x_min_l,x_min_s,x_min_b]),max([y_min_l,y_min_s,y_min_b]),x_max_l, y_max_l
|
| return large_bbox
|
|
|
|
|
|
|
|
|
|
|
|
|
| def nms_per_class(labels, boxes, scores, iou_thresh=0.5):
|
| """
|
| 对每个类别应用 NMS,保留得分最高的框。
|
| 参数:
|
| labels: numpy array,类别标签
|
| boxes: numpy array,框坐标 [x1, y1, x2, y2]
|
| scores: numpy array,得分
|
| iou_thresh: float,IoU 阈值
|
| 返回:
|
| dict: 筛选后的输出
|
| """
|
|
|
| unique_labels = np.unique(labels)
|
| kept_indices = []
|
| for label in unique_labels:
|
|
|
| class_mask = labels == label
|
| class_indices = np.where(class_mask)[0]
|
| class_boxes = boxes[class_mask]
|
| class_scores = scores[class_mask]
|
|
|
|
|
| order = np.argsort(class_scores)[::-1]
|
| class_boxes = class_boxes[order]
|
| class_scores = class_scores[order]
|
| class_indices = class_indices[order]
|
|
|
|
|
| keep = []
|
| while len(class_scores) > 0:
|
|
|
| keep.append(class_indices[0])
|
| if len(class_scores) == 1:
|
| break
|
|
|
|
|
| ious = np.array([calculate_iou(class_boxes[0], box) for box in class_boxes[1:]])
|
|
|
| keep_mask = ious < iou_thresh
|
| class_boxes = class_boxes[1:][keep_mask]
|
| class_scores = class_scores[1:][keep_mask]
|
| class_indices = class_indices[1:][keep_mask]
|
|
|
| kept_indices.extend(keep)
|
|
|
|
|
| kept_indices = np.array(kept_indices)
|
| return {
|
| 'labels': labels[kept_indices],
|
| 'boxes': boxes[kept_indices],
|
| 'scores': scores[kept_indices]
|
| }
|
|
|
|
|