Skip to content

StructureBatch

A batch of protein structures.

This class provides an interface to initialize from and represent a batch of protein structures with various types of representations:

StructureBatch object can be initialized with
  • A single PDB file or a list of PDB files StructureBatch.from_pdb
  • A pdb identifier or a list of PDB identifiers StructureBatch.from_pdb_id
  • Backbone or full atom 3D coordinates StructureBatch.from_xyz
  • Backbone orientation and translations StructureBatch.from_backbone_orientations_translations
  • Dihedral angles StructureBatch.from_dihedrals (TODO)

from_xyz(xyz, atom_mask=None, chain_idx=None, chain_ids=None, seq=None, **kwargs) classmethod

Initialize a StructureBatch from a 3D atom coordinate array.

Examples:

Initialize a StructureBatch object from a numpy array of 3D atom coordinates.

>>> batch_size, n_max_res, n_max_atoms = 2, 10, 25
>>> xyz = np.random.randn(batch_size, n_max_res, n_max_atoms, 3)
>>> sb = StructureBatch.from_xyz(xyz)

Parameters:

Name Type Description Default
xyz Union[np.ndarray, torch.Tensor]

Shape: (batch_size, num_residues, num_atoms, 3)

required
atom_mask Union[np.ndarray, torch.Tensor]

Shape: (batch_size, num_residues, num_atoms)

None
chain_idx Union[np.ndarray, torch.Tensor]

Chain indices for each residue. Should be starting from zero. Defaults to None. Shape: (batch_size, num_residues)

None
chain_ids List[List[str]]

A list of unique chain IDs for each protein.

None
seq List[Dict[str, str]]

A list of dictionaries containing sequence information for each chain.

None

Returns:

Name Type Description
StructureBatch StructureBatch

A StructureBatch object.

from_pdb(pdb_path, **kwargs) classmethod

Initialize a StructureBatch from a PDB file or a list of PDB files.

Examples:

Initialize a StructureBatch object from a single PDB file,

>>> pdb_path = '1a0a.pdb'
>>> sb = StructureBatch.from_pdb(pdb_path)

or with a list of PDB files.

>>> pdb_paths = ['1a0a.pdb', '1a0b.pdb']
>>> sb = StructureBatch.from_pdb(pdb_paths)

Parameters:

Name Type Description Default
pdb_path Union[str, List[str]]

Path to a PDB file or a list of paths to PDB files.

required

Returns:

Name Type Description
StructureBatch StructureBatch

A StructureBatch object.

from_pdb_id(pdb_id, **kwargs) classmethod

Initialize a StructureBatch from a PDB ID or a list of PDB IDs.

Examples:

>>> pdb_id = "2ZIL"  # Human lysozyme
>>> sb = StructureBatch.from_pdb_id(pdb_id)
>>> xyz = sb.get_xyz()
>>> xyz.shape
torch.Size([1, 130, 15, 3])
>>> dihedrals, dihedral_mask = sb.backbone_dihedrals()
>>> dihedrals.shape
torch.Size([1, 130, 3])
>>> dihedral_mask.shape
torch.Size([1, 130, 3])
>>> dihedral_mask.sum()
tensor(3)

Parameters:

Name Type Description Default
pdb_id Union[str, List[str]]

A PDB identifier or a list of PDB identifiers.

required

Returns:

Name Type Description
StructureBatch StructureBatch

A StructureBatch object.

from_backbone_orientations_translations(orientations, translations, chain_idx=None, chain_ids=None, seq=None, include_cb=False, **kwargs) classmethod

Initialize a StructureBatch from an array of backbone orientations and translations.

Parameters:

Name Type Description Default
orientations Union[np.ndarray, torch.Tensor]

Shape: (batch_size, num_residues, 3, 3)

required
translations Union[np.ndarray, torch.Tensor]

Shape: (batch_size, num_residues, 3)

required
chain_idx Union[np.ndarray, torch.Tensor]

Chain identifiers for each residue. Should be starting from zero. Defaults to None. Shape: (batch_size, num_residues)

None
chain_ids List[List[str]]

A list of unique chain IDs for each protein.

None
seq List[Dict[str, str]]

A list of dictionaries containing sequence information for each chain.

None
include_cb bool

Whether to include CB atoms when initializing. Defaults to False.

False

Returns:

Name Type Description
StructureBatch StructureBatch

A StructureBatch object.

from_dihedrals(dihedrals, chain_idx=None, chain_ids=None, **kwargs) classmethod

Initialize a StructureBatch from a dihedral angle array.

Parameters:

Name Type Description Default
dihedrals Union[np.ndarray, torch.Tensor]

Shape: (batch_size, num_residues, num_dihedrals)

required
chain_idx Union[np.ndarray, torch.Tensor]

Chain identifiers for each residue. Should be starting from zero. Defaults to None. Shape: (batch_size, num_residues)

None
chain_ids List[List[str]]

A list of unique chain IDs for each protein.

None

get_local_xyz()

Return the coordinates of each atom in the local frame of each residue.

Returns:

Name Type Description
local_xyz torch.Tensor

Shape: (batch_size, num_residues, num_atoms_per_residue, 3)

get_atom_mask()

Return a boolean mask for valid atoms.

Returns:

Name Type Description
atom_mask torch.BoolTensor

Shape (batch_size, num_residues, num_atoms)

get_residue_mask()

Return a boolean mask for valid residues.

Returns:

Name Type Description
residue_mask torch.BoolTensor

Shape (batch_size, num_residues)

get_seq()

Return the amino acid sequence of proteins.

Returns:

Name Type Description
seq_dict List[Dict[str, str]]

A list of dictionaries containing sequence information for each chain.

get_seq_idx()

Return a tensor containing the integer representation of amino acid sequence of proteins.

Returns:

Name Type Description
seq_idx torch.LongTensor

A tensor containing the integer representation of amino acid sequence of proteins.

get_total_lengths()

Return the total sum of chain lengths for each protein.

Note

This counts the number of missing residues in the middle of a chain, but does not count the missing residues at the beginning and end of a chain.

Returns:

Name Type Description
total_lengths torch.LongTensor

A tensor containing the total length of each protein. Shape: (batch_size,)

get_max_n_residues()

Return the number of residues in the longest protein in the batch.

Returns:

Name Type Description
max_n_residues int

The number of residues in the longest protein in the batch.

get_n_terminal_mask()

Return a boolean mask for the N-terminal residues.

Returns:

Type Description
torch.BoolTensor

A boolean tensor denoting N-terminal residues. True if N-terminal. Shape: (batch_size, num_residues)

get_c_terminal_mask()

Return a boolean mask for the C-terminal residues.

Returns:

Type Description
torch.BoolTensor

A boolean tensor denoting C-terminal residues. True if C-terminal. Shape: (batch_size, num_residues)

pairwise_distance_matrix()

Return the all-atom pairwise pairwise distance matrix between residues.

Info

Distances are measured in Angstroms.

Examples:

dist[:, :, :, 1, 1] will give pairwise alpha-carbon distance matrix between residues, as the index 1 corresponds to the alpha-carbon atom.

>>> structure_batch = StructureBatch.from_pdb("1a8o.pdb")
>>> dist = structure_batch.pairwise_distance_matrix()
>>> ca_dist = dist[:, :, :, 1, 1]  # 1 = CA_IDX

Returns:

Name Type Description
dist torch.FloatTensor

A tensor containing an all-atom pairwise distance matrix for each pair of residues. A distance between atom a of residue i and atom b of residue j of protein at index batch_idx is given by dist[batch_idx, i, j, a, b]. Shape: (batch_size, num_residues, num_residues, max_n_atoms_per_residue, max_n_atoms_per_residue)

dist_mask torch.BoolTensor

A boolean tensor denoting which distances are valid. Shape: (batch_size, num_residues, num_residues, max_n_atoms_per_residue, max_n_atoms_per_residue)

backbone_dihedrals()

Return the backbone dihedral angles phi, psi and omega for each residue.

Info

Dihedral angles are measured in radians and are in the range [-pi, pi].

For a quick reminder of the definition of the dihedral angles, refer to the following image: Dihedral

Source: Fabian Fuchs

Note

phi angles are not defined for the first residue (it needs a predecessor) and psi and omega angles are not defined for the last residue (they need successors). Those invalid angles can be filtered using the dihedral_mask tensor returned from the method.

Warning

Dihedral angles involving the residues at the chain breaks are not handled correctly for now.

Returns:

Name Type Description
dihedrals torch.FloatTensor

A tensor containing phi, psi and omega dihedral angles for each residue. Shape: (batch_size, num_residues, 3)

dihedral_mask torch.FloatTensor

A tensor containing a boolean mask for the dihedral angles. True if the corresponding dihedral angle is defined, False otherwise. Shape: (batch_size, num_residues, 3)

backbone_orientations(a1='N', a2='CA', a3='C')

Return the orientation of the backbone for each residue.

Parameters:

Name Type Description Default
a1 str

First atom used to determine backbone orientation. Defaults to 'N'.

'N'
a2 str

Second atom used to determine backbone orientation. Defaults to 'CA'.

'CA'
a3 str

Third atom used to determine backbone orientation. Defaults to 'C'.

'C'
Note

The backbone orientations are determined by using Gram-Schmidt orthogonalization on the vectors a3 - a2 and a1 - a2. Note that a3 - a2 forms the first basis, and a1 - a2 - proj_{a3 - a2}(a1 - a2) forms the second basis. The third basis is formed by taking the cross product of the first and second basis vectors.

Returns:

Name Type Description
bb_orientations torch.FloatTensor

A tensor containing the local reference backbone orientation for each residue.

backbone_translations(atom='CA')

Return the coordinate (translation) of a given backbone atom for each residue.

Note

Reference atom is set to the alpha-carbon (CA) by default.

Parameters:

Name Type Description Default
atom str

Type of atom used to determine backbone translation. Defaults to 'CA'.

'CA'

Returns:

Name Type Description
bb_translations torch.FloatTensor

xyz coordinates (translations) of a specified backbone atoms. Shape: (batch_size, num_residues, 3)

pairwise_dihedrals(atoms_i, atoms_j)

Return a matrix representing a pairwise dihedral angle between residues defined by two sets of atoms, one for each side of the residue.

Parameters:

Name Type Description Default
atoms_i List[str]

List of atoms to be used for the first residue.

required
atoms_j List[str]

List of atoms to be used for the second residue.

required

Returns:

Name Type Description
pairwise_dihedrals torch.FloatTensor

A tensor containing pairwise dihedral angles between residues. Shape: (batch_size, num_residues, num_residues)

pairwise_planar_angles(atoms_i, atoms_j)

Return a matrix representing a pairwise planar angles between residues defined by two sets of atoms, one for each side of the residue.

Parameters:

Name Type Description Default
atoms_i List[str]

List of atoms to be used for the first residue.

required
atoms_j List[str]

List of atoms to be used for the second residue.

required

Returns:

Name Type Description
pairwise_planar_angles torch.FloatTensor

A tensor containing pairwise planar angles between residues. Shape: (batch_size, num_residues, num_residues)

translate(translation, atomwise=False)

Translate the structures by a given tensor of shape (batch_size, num_residues, 3) or (batch_size, 1, 3). Translation is performed residue-wise by default, but atomwise translation can be performed when atomwise=True. In that case, the translation tensor should have a shape of (batch_size, num_residues, num_atom, 3).

Parameters:

Name Type Description Default
translation torch.Tensor

Translation vector. Shape: (batch_size, num_residues, 3) if atomwise=False, (batch_size, num_residues, num_atom, 3) otherwise.

required

rotate(rotation)

Rotate the structures by a given rotation matrix of shape (batch_size, 3, 3).

Parameters:

Name Type Description Default
rotation torch.Tensor

Rotation matrix. Shape: (batch_size, 3, 3) if rotations is applied structure-by-structure, (3, 3) if the same rotation is to be applied to all structures.

required

standardize(atom_mask=None, residue_mask=None)

Standardize the coordinates of the structures to have zero mean and unit standard deviation.

Parameters:

Name Type Description Default
atom_mask bool

Mask for atoms used for standardization. If None, all atoms are used. atom_mask and residue_mask cannot be specified at the same time. Shape: (batch_size, num_residues, num_atoms)

None
residue_mask bool

Mask for residues used for standardization. If None, all residues are used. atom_mask and residue_mask cannot be specified at the same time. Shape: (batch_size, num_residues)

None

unstandardize()

Recover the coordinates at original scale from the standardized coordinates.

center_of_mass()

Compute the center of mass of the structures.

Warning

Only Ca atoms are considered when computing the coordinates of center of mass.

Returns:

Name Type Description
center_of_mass torch.Tensor

A tensor containing the center of mass of the structures. Shape: (batch_size, 3)

center_at(center=None)

Translate the whole structure so that the center of Ca atom coordinates is at the given 3D coordinates. If center is not specified, the structures (considering only Ca coordinates) are centered at the origin.

Parameters:

Name Type Description Default
center torch.Tensor

Coordinates of the center. Shape: (batch_size, 3) or (3,)

None

inter_residue_geometry()

Return a dictionary of inter-residue geometry, which is used for representing protein structure for trRoseTTA.

Returns:

Name Type Description
inter_residue_geometry Dict[str, torch.Tensor]

A dictionary containing inter-residue geometry tensors.