schnetpack

Tips

version > 1.0.0

Checkpoint (>= 2.0.0)

There are two places where checkpoint files are saved during the training.

  1. pl.Trainer will automatically save a checkpoint file for resuming a training. It has model information (representation, output_modules, ...) plus lighting training information (learning rate, epochs, ...).
  2. spk.train.ModelCheckpoint will save the best model (entire model) so far via spk.task.AtomisticTask().save_model(). It only has the model information.

Both files can be loaded by checkpoint = torch.load(FILE_PATH, map_location). It is possible to access the model from those two types of checkpoint files.

# Saving and loading weights
model = spk.model.NeuralNetworkPotential(parameters)
torch.save(model.state_dict(), 'model_weights.pth')

new_model = spk.model.NeuralNetworkPotential(parameters)
new_model.load_state_dict(torch.load('model_weights.pth'))


# Saving and loading the entire model (used in schnetpack)
model = spk.model.NeuralNetworkPotential(parameters)
torch.save(model, 'model.pth')

new_model = torch.load('model.pth')

# Getting the model from the trainer checkpoint
checkpoint = torch.load(ckpt_path)

# keys of checkpoint : ['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters']

# checkpoint['hyper_parameters'] has all parameters of a spk.task.AtomisticTask class
# keys: ['model', 'outputs', 'optimizer_cls', 'optimizer_args', 'scheduler_cls', 'scheduler_args', 'scheduler_monitor', 'warmup_steps']

new_model = checkpoint['hyper_parameters']['model']
# checkpoint['hyper_parameters']['model'] is a spk.model.NeuralNetworkPotential class

By default, all parameters in the loaded model new_model are trainable.

loading parameters of a model

model = torch.load(model_path, map_location='cpu')

representation = model.representation

for param in representation.parameters():
    print(param.size(), param.requires_grad)

for name, param in representation.named_parameters():
    print(name, param.size(), param.requires_grad)

Tensorboard

version <= 1.0.0

Training

It is necessary to precondition with the mean and standard deviation per atom of the energy when training energies and forces. Check how to get atomref.

# loading training data
training_data = AtomsData('training.db')

# generate training, validation and test data
train, val, test = spk.train_test_split(
        data=training_data,
        num_train=950,
        num_val=50
    )

train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True)

# calculate means and stddevs
properties = ["energy", "forces"]
#atomrefs = training_data.get_atomref(properties)

# PBE0 def2-SVP for O,Mg,Si
atomrefs = np.zeros((100,1))
atomrefs[8] = -2038.811331364917
atomrefs[12] = -5439.27566785787
atomrefs[14] = -7868.226280776057
atomrefs = {'energy':atomrefs}

means, stddevs = train_loader.get_statistics('energy', divide_by_atoms=True, single_atom_ref=atomrefs)

energy_model = spk.atomistic.Atomwise(
    property='energy',
    mean=means["energy"],
    stddev=stddevs["energy"],
    atomref=atomrefs["energy"],
    derivative='forces',
    negative_dr=True
)

You can also set a atomref metadata during the preparation of training.db.

# working for Field_Schnet and radical cations
def prepare_training_db(trajs, saving_path='./'):
    #trajs = read(trajs_path, index=':')
    # parse properties as list of dictionaries
    property_list_e = []
    property_list_f = []
    property_list_d = []
    properties_list = []
    for atoms in trajs:
        # All properties need to be stored as numpy arrays.
        # Note: The shape for scalars should be (1,), not ()
        # Note: GPUs work best with float32 data
        energy = np.array([atoms.get_potential_energy()], dtype=np.float32) / Ha
        forces = atoms.get_forces().astype(np.float32) / (Ha/Bohr)
        dipole_moment = atoms.get_dipole_moment().astype(np.float32) / Bohr

        property_list_e.append({'energy': energy})
        property_list_f.append({'forces': forces})
        property_list_d.append({'dipole_moment': dipole_moment})
        properties_list.append(dict(energy=energy, forces=forces, dipole_moment=dipole_moment, charges=1, electric_field=np.zeros(3)))

        pos = atoms.get_positions() / Bohr
        atoms.set_positions(pos)

    db_path = os.path.join(saving_path, 'training.db')
    if os.path.exists(db_path):
        os.remove(db_path)
    new_dataset = AtomsData(db_path, available_properties=['energy', 'forces', 'dipole_moment', 'charges', 'electric_field'])
    new_dataset.add_systems(trajs, properties_list)

    # B3LYP def2-TZVP eV
    atomrefs = np.zeros((100,1))
    atomrefs[0] = -13.572067490656195
    atomrefs[5] = -1029.6286015373034
    atomrefs /= eV

    metadata = {
        'atref_labels' : ['energy'],
        'atomrefs' : atomrefs.tolist()
    }
    new_dataset.set_metadata(metadata)

MD

Currently only a primitive version of a neighbor list is implemented, which cannot deal with periodic boundary conditions and does not possess optimal scaling for large systems.

Converting hdf5 to ase

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import argparse
import numpy as np
from ase import Atoms, units
from ase.io import write
from ase.calculators.singlepoint import SinglePointCalculator
from schnetpack.md.utils import HDF5Loader, MDUnits

def get_ase_atoms_from_spk_hdf5(hdf5_data, indices):
    if isinstance(indices, int):
        indices = [indices]
    energy_unit = 1
    force_unit = 1
    dipole_unit = 1

    # Internal units (MD internal -> ASE internal)
    energy_unit = units.kJ / units.mol
    length_unit = units.nm
    mass_unit = 1.0  # 1 Dalton in ASE reference frame
    # Derived units (MD internal -> ASE internal)
    time_unit = length_unit * np.sqrt(mass_unit / energy_unit)
    velocity_unit = length_unit/time_unit

    atomic_numbers = hdf5_data.get_property('_atomic_numbers')
    positions = hdf5_data.get_property('_positions') * MDUnits.internal2unit("a")
    energies = hdf5_data.get_property('energy') * energy_unit
    forces = hdf5_data.get_property('forces') * force_unit
    dipoles = hdf5_data.get_property('dipole_moment') * dipole_unit
    temperatues = hdf5_data.get_temperature()
    velocities = hdf5_data.get_velocities() * velocity_unit

    trajs = []
    for idx in indices:
        atoms = Atoms(numbers=atomic_numbers, positions=positions[idx])
        atoms.set_velocities(velocities[idx]) 
        calc = SinglePointCalculator(atoms, energy=energies[idx][0], forces=forces[idx], dipole=dipoles[idx])
        atoms.calc = calc
        #print(temperatues[idx], atoms.get_temperature())
        trajs.append(atoms)

    if len(indices) == 1:
        return trajs[0]
    else:
        return trajs


parser = argparse.ArgumentParser(description='Extracting ase trajectories from hdf5')
parser.add_argument('fname', type=str, default='simulation.hdf5', help='fname of hdf5 database')
parser.add_argument('--n_skip', type=int, default=40000, help='skip ')
parser.add_argument('--n_selected', type=int, default=2000, help='number of selected trajectories')
parser.add_argument('-idx', '--indices_file', type=str, default='spk-hdf5-to-ase-indices.txt', help='indices of selected trajectories')
parser.add_argument('--overwrite', type=bool, default=False, help='overwrite?')

args = parser.parse_args()


data = HDF5Loader(args.fname, args.n_skip)
n_entries = data.entries

if os.path.exists(args.indices_file) and not args.overwrite:
    print(f'reading indices from {args.indices_file}')
    indices = np.loadtxt(args.indices_file, dtype=int, ndmin=1)
    if len(indices) < args.n_selected:
        rest_indices = [index for index in range(n_entries) if index not in indices]
        added_indices = np.random.choice(rest_indices, args.n_selected-len(indices), False)
        indices = np.append(indices, added_indices)
        np.savetxt(args.indices_file, indices, '%.d', header=f'n_skip={args.n_skip}')
else:
    if args.n_selected == 1:
        indices = [n_entries-1]
    else:
        indices = np.random.choice(range(n_entries), args.n_selected, False)
    np.savetxt(args.indices_file, indices, '%.d', header=f'n_skip={args.n_skip}')


print(n_entries)
trajs = get_ase_atoms_from_spk_hdf5(data, indices)
write(f'spk-ase-trajs-from-hdf5-{len(indices)}.traj', trajs)