Custom Pre-Processing of Graphs
The chem_mat_data
package provides dataset in an already pre-processed graph format. This pre-processed graph
format makes some opinionated choices about which kinds of node and edge features are inlcuded to represent
information about each individual atom or bond in the molecular graph.
If this pre-defined format for some reason isn't sufficient for a given dataset, there exists the possibility to
define a custom Processing
class to construct a pre-processing pipeline that converts the SMILES representations
into graph structures.
Creating a MoleculeProcessing
Subclass
To define a custom pre-processing structure, one can define a custom subclass of the MoleculeProcessing
subclass
and define the desired node and edge features by modifying the node_attributes
and edge_attributes
class
properties. Both properties have to be dictionary objects that define the node and edge features by providing
a callback function or class that derives the desired property from the corresponding rdkit.Atom
or rdkit.Bond
objects.
In the following example we can define a customized processing class which only encodes a subset of atom types and only includes the mass of the atom as an additional feature. For the edge attributes we only encode the difference between single and double bonds.
from chem_mat_data.processing import MoleculeProcessing
from chem_mat_data.processing import OneHotEncoder, chem_prop, list_identity
# Has to inherit from MoleculeProcessing!
class CustomProcessing(MoleculeProcessing):
node_attribute_map = {
'mass': {
# "chem_prop" is a wrapper function which will call the given
# property method on the rdkit.Atom object - in this case the
# GetMass() method - and pass the output to the transformation
# function given as the second argument. "list_identity" means
# that the value is simply converted to a list as it is.
# Therefore, this configuration will result in outputs such as
# [12.08], [9.88] etc. as parts of the overall feature vector.
'callback': chem_prop('GetMass', list_identity),
# Provide a human-readable description of what this section of
# the node feature vector represents.
'description': 'The mass of the atom'
},
'symbol': {
# "OneHotEncoder" is a special callable class that can be used
# to automatically define one-hot encodings. The object will
# accept the output of the given chem prop - in this case the
# GetSymbol action on the rdkit.Atom - and create an integer
# one-hot vector according to the provided list. In this case,
# the encoding will encode a carbon as [1, 0, 0, 0],
# a oxygen as [0, 1, 0, 0] etc.
'callback': chem_prop('GetSymbol', OneHotEncoder(
['C', 'O', 'N', 'S'],
add_unknown=False,
dtype=str,
)),
'description': 'One hot encoding of the atom type',
'is_type': True,
'encodes_symbol': True,
},
}
edge_attributes = {
'type': {
'callback': chem_prop()
}
}
Applying the Custom Processing
After the custom processing class has been defined it can be used in the same manner as the orginal MoleculeProcessing
class to convert the SMILES string representations of the dataset into the graph representation by using the process
method.
from rich.pretty import pprint
processing = CustomProcessing()
graph: dict = processing.process('CCCC')
pprint(graph)
Defining Custom Transformation Callbacks
As introduced in the previous example, the chem_prop
wrapper can be used to cast the output of an rdkit.Atom
or
rdkit.Bond
atom to some transformation callback which is then supposed to return a list that will become part
of the final node/edge feature vector.
The most simple usage is the list_identity
transformation which will simply wrap the output value in a list as it is.
An alternative is to use the existing OneHotEncoder class to convert the output of a property getter method into an
integer one-hot encoded vector.
Alternatively, it is also possible to define a completely custom callback to derive properties from the atom / bond
objects directly. The callback functions simply have to accept a single positional argument entity: Atom | Bond
.
callback function has to return a list of numeric values which will be appended to the overall feature vector.
import rdkit.Chem as Chem
from rich.pretty import pprint
from typing import List
from chem_mat_data.processing import MoleculeProcessing
def custom_callback(atom: Chem.Atom) -> List[float]:
# Mass multiplied with the charge
return [atom.GetMass() * atom.GetCharge()]
class CustomCallbackProcessing(MoleculeProcessing):
node_attributes = {
'mass_times_charge': {
'callback': custom_callback,
'description': 'atom mass multiplied with the charge',
}
}
processing = CustomCallbackProcessing()
graph = processing.process('CCCC')
pprint(graph)