StructureBatch
A batch of protein structures.
This class provides an interface to initialize from and represent a batch of protein structures with various types of representations:
StructureBatch object can be initialized with
- A single PDB file or a list of PDB files
StructureBatch.from_pdb
- A pdb identifier or a list of PDB identifiers
StructureBatch.from_pdb_id
- Backbone or full atom 3D coordinates
StructureBatch.from_xyz
- Backbone orientation and translations
StructureBatch.from_backbone_orientations_translations
- Dihedral angles
StructureBatch.from_dihedrals
(TODO)
from_xyz(xyz, atom_mask=None, chain_idx=None, chain_ids=None, seq=None, **kwargs)
classmethod
Initialize a StructureBatch
from a 3D atom coordinate array.
Examples:
Initialize a StructureBatch
object from a numpy array of 3D atom coordinates.
>>> batch_size, n_max_res, n_max_atoms = 2, 10, 25
>>> xyz = np.random.randn(batch_size, n_max_res, n_max_atoms, 3)
>>> sb = StructureBatch.from_xyz(xyz)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xyz |
Union[np.ndarray, torch.Tensor]
|
Shape: (batch_size, num_residues, num_atoms, 3) |
required |
atom_mask |
Union[np.ndarray, torch.Tensor]
|
Shape: (batch_size, num_residues, num_atoms) |
None
|
chain_idx |
Union[np.ndarray, torch.Tensor]
|
Chain indices for each residue. Should be starting from zero. Defaults to None. Shape: (batch_size, num_residues) |
None
|
chain_ids |
List[List[str]]
|
A list of unique chain IDs for each protein. |
None
|
seq |
List[Dict[str, str]]
|
A list of dictionaries containing sequence information for each chain. |
None
|
Returns:
Name | Type | Description |
---|---|---|
StructureBatch |
StructureBatch
|
A StructureBatch object. |
from_pdb(pdb_path, **kwargs)
classmethod
Initialize a StructureBatch
from a PDB file or a list of PDB files.
Examples:
Initialize a StructureBatch
object from a single PDB file,
>>> pdb_path = '1a0a.pdb'
>>> sb = StructureBatch.from_pdb(pdb_path)
or with a list of PDB files.
>>> pdb_paths = ['1a0a.pdb', '1a0b.pdb']
>>> sb = StructureBatch.from_pdb(pdb_paths)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pdb_path |
Union[str, List[str]]
|
Path to a PDB file or a list of paths to PDB files. |
required |
Returns:
Name | Type | Description |
---|---|---|
StructureBatch |
StructureBatch
|
A StructureBatch object. |
from_pdb_id(pdb_id, **kwargs)
classmethod
Initialize a StructureBatch
from a PDB ID or a list of PDB IDs.
Examples:
>>> pdb_id = "2ZIL" # Human lysozyme
>>> sb = StructureBatch.from_pdb_id(pdb_id)
>>> xyz = sb.get_xyz()
>>> xyz.shape
torch.Size([1, 130, 15, 3])
>>> dihedrals, dihedral_mask = sb.backbone_dihedrals()
>>> dihedrals.shape
torch.Size([1, 130, 3])
>>> dihedral_mask.shape
torch.Size([1, 130, 3])
>>> dihedral_mask.sum()
tensor(3)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pdb_id |
Union[str, List[str]]
|
A PDB identifier or a list of PDB identifiers. |
required |
Returns:
Name | Type | Description |
---|---|---|
StructureBatch |
StructureBatch
|
A StructureBatch object. |
from_backbone_orientations_translations(orientations, translations, chain_idx=None, chain_ids=None, seq=None, include_cb=False, **kwargs)
classmethod
Initialize a StructureBatch from an array of backbone orientations and translations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
orientations |
Union[np.ndarray, torch.Tensor]
|
Shape: (batch_size, num_residues, 3, 3) |
required |
translations |
Union[np.ndarray, torch.Tensor]
|
Shape: (batch_size, num_residues, 3) |
required |
chain_idx |
Union[np.ndarray, torch.Tensor]
|
Chain identifiers for each residue. Should be starting from zero. Defaults to None. Shape: (batch_size, num_residues) |
None
|
chain_ids |
List[List[str]]
|
A list of unique chain IDs for each protein. |
None
|
seq |
List[Dict[str, str]]
|
A list of dictionaries containing sequence information for each chain. |
None
|
include_cb |
bool
|
Whether to include CB atoms when initializing. Defaults to False. |
False
|
Returns:
Name | Type | Description |
---|---|---|
StructureBatch |
StructureBatch
|
A StructureBatch object. |
from_dihedrals(dihedrals, chain_idx=None, chain_ids=None, **kwargs)
classmethod
Initialize a StructureBatch from a dihedral angle array.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dihedrals |
Union[np.ndarray, torch.Tensor]
|
Shape: (batch_size, num_residues, num_dihedrals) |
required |
chain_idx |
Union[np.ndarray, torch.Tensor]
|
Chain identifiers for each residue. Should be starting from zero. Defaults to None. Shape: (batch_size, num_residues) |
None
|
chain_ids |
List[List[str]]
|
A list of unique chain IDs for each protein. |
None
|
get_local_xyz()
Return the coordinates of each atom in the local frame of each residue.
Returns:
Name | Type | Description |
---|---|---|
local_xyz |
torch.Tensor
|
Shape: (batch_size, num_residues, num_atoms_per_residue, 3) |
get_atom_mask()
Return a boolean mask for valid atoms.
Returns:
Name | Type | Description |
---|---|---|
atom_mask |
torch.BoolTensor
|
Shape (batch_size, num_residues, num_atoms) |
get_residue_mask()
Return a boolean mask for valid residues.
Returns:
Name | Type | Description |
---|---|---|
residue_mask |
torch.BoolTensor
|
Shape (batch_size, num_residues) |
get_seq()
Return the amino acid sequence of proteins.
Returns:
Name | Type | Description |
---|---|---|
seq_dict |
List[Dict[str, str]]
|
A list of dictionaries containing sequence information for each chain. |
get_seq_idx()
Return a tensor containing the integer representation of amino acid sequence of proteins.
Returns:
Name | Type | Description |
---|---|---|
seq_idx |
torch.LongTensor
|
A tensor containing the integer representation of amino acid sequence of proteins. |
get_total_lengths()
Return the total sum of chain lengths for each protein.
Note
This counts the number of missing residues in the middle of a chain, but does not count the missing residues at the beginning and end of a chain.
Returns:
Name | Type | Description |
---|---|---|
total_lengths |
torch.LongTensor
|
A tensor containing the total length of each protein. Shape: (batch_size,) |
get_max_n_residues()
Return the number of residues in the longest protein in the batch.
Returns:
Name | Type | Description |
---|---|---|
max_n_residues |
int
|
The number of residues in the longest protein in the batch. |
get_n_terminal_mask()
Return a boolean mask for the N-terminal residues.
Returns:
Type | Description |
---|---|
torch.BoolTensor
|
A boolean tensor denoting N-terminal residues. |
get_c_terminal_mask()
Return a boolean mask for the C-terminal residues.
Returns:
Type | Description |
---|---|
torch.BoolTensor
|
A boolean tensor denoting C-terminal residues. |
pairwise_distance_matrix()
Return the all-atom pairwise pairwise distance matrix between residues.
Info
Distances are measured in Angstroms.
Examples:
dist[:, :, :, 1, 1]
will give pairwise alpha-carbon distance matrix between residues,
as the index 1
corresponds to the alpha-carbon atom.
>>> structure_batch = StructureBatch.from_pdb("1a8o.pdb")
>>> dist = structure_batch.pairwise_distance_matrix()
>>> ca_dist = dist[:, :, :, 1, 1] # 1 = CA_IDX
Returns:
Name | Type | Description |
---|---|---|
dist |
torch.FloatTensor
|
A tensor containing an all-atom pairwise distance matrix for each pair of residues.
A distance between atom |
dist_mask |
torch.BoolTensor
|
A boolean tensor denoting which distances are valid. Shape: (batch_size, num_residues, num_residues, max_n_atoms_per_residue, max_n_atoms_per_residue) |
backbone_dihedrals()
Return the backbone dihedral angles phi, psi and omega for each residue.
Info
Dihedral angles are measured in radians and are in the range [-pi, pi]
.
For a quick reminder of the definition of the dihedral angles, refer to the following image:
Source: Fabian Fuchs
Note
phi
angles are not defined for the first residue (it needs a predecessor)
and psi
and omega
angles are not defined for the last residue (they need successors).
Those invalid angles can be filtered using the dihedral_mask
tensor returned from the method.
Warning
Dihedral angles involving the residues at the chain breaks are not handled correctly for now.
Returns:
Name | Type | Description |
---|---|---|
dihedrals |
torch.FloatTensor
|
A tensor containing |
dihedral_mask |
torch.FloatTensor
|
A tensor containing a boolean mask for the dihedral angles.
|
backbone_orientations(a1='N', a2='CA', a3='C')
Return the orientation of the backbone for each residue.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
a1 |
str
|
First atom used to determine backbone orientation. Defaults to 'N'. |
'N'
|
a2 |
str
|
Second atom used to determine backbone orientation. Defaults to 'CA'. |
'CA'
|
a3 |
str
|
Third atom used to determine backbone orientation. Defaults to 'C'. |
'C'
|
Note
The backbone orientations are determined by using Gram-Schmidt
orthogonalization on the vectors a3 - a2
and a1 - a2
.
Note that a3 - a2
forms the first basis, and a1 - a2
- proj_{a3 - a2}(a1 - a2)
forms the second basis. The third basis is formed by taking the cross product of the
first and second basis vectors.
Returns:
Name | Type | Description |
---|---|---|
bb_orientations |
torch.FloatTensor
|
A tensor containing the local reference backbone orientation for each residue. |
backbone_translations(atom='CA')
Return the coordinate (translation) of a given backbone atom for each residue.
Note
Reference atom is set to the alpha-carbon (CA) by default.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
atom |
str
|
Type of atom used to determine backbone translation. Defaults to 'CA'. |
'CA'
|
Returns:
Name | Type | Description |
---|---|---|
bb_translations |
torch.FloatTensor
|
xyz coordinates (translations) of a specified backbone atoms. Shape: (batch_size, num_residues, 3) |
pairwise_dihedrals(atoms_i, atoms_j)
Return a matrix representing a pairwise dihedral angle between residues defined by two sets of atoms, one for each side of the residue.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
atoms_i |
List[str]
|
List of atoms to be used for the first residue. |
required |
atoms_j |
List[str]
|
List of atoms to be used for the second residue. |
required |
Returns:
Name | Type | Description |
---|---|---|
pairwise_dihedrals |
torch.FloatTensor
|
A tensor containing pairwise dihedral angles between residues. Shape: (batch_size, num_residues, num_residues) |
pairwise_planar_angles(atoms_i, atoms_j)
Return a matrix representing a pairwise planar angles between residues defined by two sets of atoms, one for each side of the residue.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
atoms_i |
List[str]
|
List of atoms to be used for the first residue. |
required |
atoms_j |
List[str]
|
List of atoms to be used for the second residue. |
required |
Returns:
Name | Type | Description |
---|---|---|
pairwise_planar_angles |
torch.FloatTensor
|
A tensor containing pairwise planar angles between residues. Shape: (batch_size, num_residues, num_residues) |
translate(translation, atomwise=False)
Translate the structures by a given tensor of shape (batch_size, num_residues, 3)
or (batch_size, 1, 3). Translation is performed residue-wise by default,
but atomwise translation can be performed when atomwise=True
.
In that case, the translation tensor should have a
shape of (batch_size, num_residues, num_atom, 3).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
translation |
torch.Tensor
|
Translation vector.
Shape: (batch_size, num_residues, 3) if |
required |
rotate(rotation)
Rotate the structures by a given rotation matrix of shape (batch_size, 3, 3).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rotation |
torch.Tensor
|
Rotation matrix. Shape: (batch_size, 3, 3) if rotations is applied structure-by-structure, (3, 3) if the same rotation is to be applied to all structures. |
required |
standardize(atom_mask=None, residue_mask=None)
Standardize the coordinates of the structures to have zero mean and unit standard deviation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
atom_mask |
bool
|
Mask for atoms used for standardization. If None, all atoms are used.
|
None
|
residue_mask |
bool
|
Mask for residues used for standardization. If None, all residues are used.
|
None
|
unstandardize()
Recover the coordinates at original scale from the standardized coordinates.
center_of_mass()
Compute the center of mass of the structures.
Warning
Only Ca atoms are considered when computing the coordinates of center of mass.
Returns:
Name | Type | Description |
---|---|---|
center_of_mass |
torch.Tensor
|
A tensor containing the center of mass of the structures. Shape: (batch_size, 3) |
center_at(center=None)
Translate the whole structure so that the center of Ca atom coordinates is at the given
3D coordinates. If center
is not specified, the structures (considering only Ca coordinates)
are centered at the origin.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
center |
torch.Tensor
|
Coordinates of the center. Shape: (batch_size, 3) or (3,) |
None
|
inter_residue_geometry()
Return a dictionary of inter-residue geometry, which is used for representing protein structure for trRoseTTA.
Returns:
Name | Type | Description |
---|---|---|
inter_residue_geometry |
Dict[str, torch.Tensor]
|
A dictionary containing inter-residue geometry tensors. |