Diffusing XYZ coordinates with Gaussian noise
Many diffusion models for protein structure generation require a set of atom coordinates to be diffused with Gaussian noises having a predefined variance schedule, which finally results in a randomized set of coordinates distributed according to 3D Gaussian distribution.
This tutorial shows how to use the StructureBatch
object to generate a set of diffused coordinates.
In [1]:
Copied!
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch
import protstruc as ps
from protstruc.general import ATOM
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import torch
import protstruc as ps
from protstruc.general import ATOM
In [2]:
Copied!
import math
def cosine_variance_schedule(T, s=8e-3, beta_max=0.999):
# cosine variance schedule
# T: total timesteps
# s: small offset to prevent beta from being too small
# beta_max: to prevent singularities at the end of the diffusion process
t = torch.arange(T + 1) # 0, 1, ..., T
f_t = torch.cos((t / T + s) / (1 + s) * math.pi / 2.0).square()
alpha_bar = f_t / f_t[0]
beta = torch.cat(
[
torch.tensor([0.0]),
torch.clip(1 - alpha_bar[1:] / alpha_bar[:-1], min=1e-5, max=beta_max),
]
)
alpha = 1 - beta
sched = {
"alpha": alpha,
"alpha_bar": alpha_bar,
"alpha_bar_sqrt": alpha_bar.sqrt(),
"one_minus_alpha_bar_sqrt": (1 - alpha_bar).sqrt(),
"beta": beta,
}
return sched
import math
def cosine_variance_schedule(T, s=8e-3, beta_max=0.999):
# cosine variance schedule
# T: total timesteps
# s: small offset to prevent beta from being too small
# beta_max: to prevent singularities at the end of the diffusion process
t = torch.arange(T + 1) # 0, 1, ..., T
f_t = torch.cos((t / T + s) / (1 + s) * math.pi / 2.0).square()
alpha_bar = f_t / f_t[0]
beta = torch.cat(
[
torch.tensor([0.0]),
torch.clip(1 - alpha_bar[1:] / alpha_bar[:-1], min=1e-5, max=beta_max),
]
)
alpha = 1 - beta
sched = {
"alpha": alpha,
"alpha_bar": alpha_bar,
"alpha_bar_sqrt": alpha_bar.sqrt(),
"one_minus_alpha_bar_sqrt": (1 - alpha_bar).sqrt(),
"beta": beta,
}
return sched
In [3]:
Copied!
pdb_id = '4EOT'
sb = ps.StructureBatch.from_pdb_id(pdb_id)
prt_idx = 0
atom_idx = ATOM.CA
fig = plt.figure(figsize=(14, 5))
ax1 = fig.add_subplot(131, projection='3d')
ax2 = fig.add_subplot(132, projection='3d')
ax3 = fig.add_subplot(133)
sb.standardize()
T = 300
sched = cosine_variance_schedule(T=T, s=8e-3, beta_max=0.999)
xyz0 = sb.get_xyz()
bsz, n_res = xyz0.shape[:2]
# initial coordinates (t=0)
xyz = xyz0
ims = []
for t in range(T):
# sample a noised structure from N( sqrt(1-b_{t}) * x_{t-1}, b_{t} * I).
xyz = torch.sqrt(1 - sched['beta'][t]) * xyz
xyz += torch.sqrt(sched['beta'][t]) * torch.randn(bsz, n_res, 1, 3)
im1 = ax1.scatter(
xyz[prt_idx, :, atom_idx, 0].numpy(),
xyz[prt_idx, :, atom_idx, 1].numpy(),
xyz[prt_idx, :, atom_idx, 2].numpy(),
c='C1'
)
im2, = ax2.plot(
xyz[prt_idx, :, atom_idx, 0].numpy(),
xyz[prt_idx, :, atom_idx, 1].numpy(),
xyz[prt_idx, :, atom_idx, 2].numpy(),
c='C1'
)
# histogram of x coordinates
_, _, im3 = ax3.hist(xyz[prt_idx, :, atom_idx, 0], bins=33, fc='C1')
# axes title
t = ax3.text(0.5, 1.01, f't={t}', ha='center', va='bottom', transform=ax1.transAxes)
# histogram patches (im3) is already a list, so just concatenate it
ims.append([im1, im2, t] + list(im3))
ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True, repeat_delay=1000)
ani.save(f'animations/{pdb_id}_diffusion.gif')
plt.clf() # not showing the results after this cell
pdb_id = '4EOT'
sb = ps.StructureBatch.from_pdb_id(pdb_id)
prt_idx = 0
atom_idx = ATOM.CA
fig = plt.figure(figsize=(14, 5))
ax1 = fig.add_subplot(131, projection='3d')
ax2 = fig.add_subplot(132, projection='3d')
ax3 = fig.add_subplot(133)
sb.standardize()
T = 300
sched = cosine_variance_schedule(T=T, s=8e-3, beta_max=0.999)
xyz0 = sb.get_xyz()
bsz, n_res = xyz0.shape[:2]
# initial coordinates (t=0)
xyz = xyz0
ims = []
for t in range(T):
# sample a noised structure from N( sqrt(1-b_{t}) * x_{t-1}, b_{t} * I).
xyz = torch.sqrt(1 - sched['beta'][t]) * xyz
xyz += torch.sqrt(sched['beta'][t]) * torch.randn(bsz, n_res, 1, 3)
im1 = ax1.scatter(
xyz[prt_idx, :, atom_idx, 0].numpy(),
xyz[prt_idx, :, atom_idx, 1].numpy(),
xyz[prt_idx, :, atom_idx, 2].numpy(),
c='C1'
)
im2, = ax2.plot(
xyz[prt_idx, :, atom_idx, 0].numpy(),
xyz[prt_idx, :, atom_idx, 1].numpy(),
xyz[prt_idx, :, atom_idx, 2].numpy(),
c='C1'
)
# histogram of x coordinates
_, _, im3 = ax3.hist(xyz[prt_idx, :, atom_idx, 0], bins=33, fc='C1')
# axes title
t = ax3.text(0.5, 1.01, f't={t}', ha='center', va='bottom', transform=ax1.transAxes)
# histogram patches (im3) is already a list, so just concatenate it
ims.append([im1, im2, t] + list(im3))
ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True, repeat_delay=1000)
ani.save(f'animations/{pdb_id}_diffusion.gif')
plt.clf() # not showing the results after this cell
MovieWriter ffmpeg unavailable; using Pillow instead.
<Figure size 1400x500 with 0 Axes>
The animation below shows that the coordinates of Ca atoms gradually reaches to the Gaussian distribution.