Sunday 5 July 2020

SMI to 2D MOL: Parallel processing with the CDK from Jython

At some point in a computational chemistry workflow, if you started with SMILES you need to transition to the world of 3D. A slight niggle in the works is that some tools may not correctly read SMILES, or indeed not read them at all. For this reason, it is useful to convert to 2D MOL files. In this respect, the CDK has established itself as one of the go-to tools thanks to efforts by John Mayfield (see for example, CDK depict and his RDKit talk on the topic).

The best way to use the CDK is from Java. However, I find the energy barrier to writing a Java problem to be high, and so here I'll use Jython. Once installed, you just add the cdk-2.3.jar to the environment variable CLASSPATH and you are good to go.

Serial approach
I'll be writing about timings for reading 100K SMILES strings from a CSV file and converting them to an SDF. The baseline is the serial implementation, which takes about 60s.
# Python
import csv
import time

# Java
import java
import org.openscience.cdk as cdk

sp = cdk.smiles.SmilesParser(cdk.silent.SilentChemObjectBuilder.getInstance())
TITLE = cdk.CDKConstants.TITLE

def calculate(smi, title):
    # Read SMILES
    mol = sp.parseSmiles(smi)

    # Set the title
    mol.setProperty(TITLE, title)

    # Do the SDG
    sdg = cdk.layout.StructureDiagramGenerator()
    sdg.generateCoordinates(mol)

    # Write SDF file
    writer = java.io.StringWriter()
    molwriter = cdk.io.SDFWriter(writer)
    molwriter.write(mol)
    molwriter.close() # flushes

    return writer.toString()

if __name__ == "__main__":
    INPUTFILE = "100000.csv"
    OUTPUTFILE = "out.sdf"

    t = time.time()
    with open(OUTPUTFILE, "w") as out:
        with open(INPUTFILE) as inp:
            reader = csv.reader(inp)
            for smi, _, title in reader:
                out.write(calculate(smi, title))
    print(time.time() - t)
If we have millions of SMILES strings, a parallel approach can help. Unfortunately, Jython does not provide an implementation of the multiprocessing library so we need to do this the Java way...

Approach 1 - Using streams
The script below reads in SMILES strings from a CSV file as a stream and passes them one-at-a-time to multiple threads running in parallel to be converted to an SDF entry. The API doesn't allow any access (as far as I can tell) to control the number of threads. The SDF entries are written to the output file in the original order if ".forEachOrdered" is used versus ".forEach". There was a 4.5X speed-up, from 60s to 13s. This was on a machine with 12 physical cores (24 logical, due to the hyperthreading). Timings for forEach() instead of forEachOrdered() were about the same (surprisingly).
# Python
import csv

# Java
import java
import org.openscience.cdk as cdk
from java.nio.file import Files, Paths
from java.util.function import Function, Consumer

sp = cdk.smiles.SmilesParser(cdk.silent.SilentChemObjectBuilder.getInstance())
TITLE = cdk.CDKConstants.TITLE

def smi2sdf(line):
    smi, _, title = next(csv.reader([line]))

    # Read SMILES
    mol = sp.parseSmiles(smi)

    # Set the title
    mol.setProperty(TITLE, title)

    # Do the SDG
    sdg = cdk.layout.StructureDiagramGenerator()
    sdg.generateCoordinates(mol)

    # Write SDF file
    writer = java.io.StringWriter()
    molwriter = cdk.io.SDFWriter(writer)
    molwriter.write(mol)
    molwriter.close() # flushes

    return writer.toString()

class Calculate(Function):
    def apply(self, text):
        return smi2sdf(text)

class Write(Consumer):
    def __init__(self, filename):
        self.mfile = open(filename, "w")
    def accept(self, text):
        self.mfile.write(text)
    def __del__(self):
        self.mfile.close()

if __name__ == "__main__":
    INPUTFILE = "100000.csv"
    OUTPUTFILE = "out.sdf"

    calculate = Calculate()
    write = Write(OUTPUTFILE)

    Files.lines(Paths.get(INPUTFILE)).parallel().map(calculate).forEach(write)
Approach 2 - Using a ThreadPool
Daniel Lowe suggested using a ThreadPool and provided example Java code showing that it ran faster that the streams approach. This was also the case in Jython, where a timing of 9.6s was obtained for 12 threads, a 6X speedup over the serial implementation. The upside of using a ThreadPool is that the number of threads can be controlled explicitly, and it's worth noting that using 24 actually slowed things down to 10.2s - a reminder that "hyperthreading" is marketing BS. A potential downside is that there's no possibility (with this implementation at least) to order the output.
# Python
import csv

# Java
import java
import org.openscience.cdk as cdk
import java.util.concurrent as conc

sp = cdk.smiles.SmilesParser(cdk.silent.SilentChemObjectBuilder.getInstance())
TITLE = cdk.CDKConstants.TITLE

def calculate(smi, title):

    # Read SMILES
    mol = sp.parseSmiles(smi)

    # Set the title
    mol.setProperty(TITLE, title)

    # Do the SDG
    sdg = cdk.layout.StructureDiagramGenerator()
    sdg.generateCoordinates(mol)

    # Write SDF file
    writer = java.io.StringWriter()
    molwriter = cdk.io.SDFWriter(writer)
    molwriter.write(mol)
    molwriter.close() # flushes

    return writer.toString()

class SmiToMol(java.lang.Runnable):

    def __init__(self, smi, title, writer):
        self.smi = smi
        self. title = title
        self.writer = writer

    def run(self):
        self.writer.write(calculate(self.smi, self.title))

class LimitedQueue(conc.LinkedBlockingQueue):
    serialVersionIUD = 1

    def __init__(self, maxSize):
        conc.LinkedBlockingQueue.__init__(self, maxSize)

    def offer(self, e):
        # convert offer to 'put' to make it blocking
        try:
            self.put(e)
            return True
        except InterruptedException as ie:
            Thread.currentThread().interrupt()
        return False

if __name__ == "__main__":
    INPUTFILE = "100000.csv"
    OUTPUTFILE = "out.threadpool.sdf"

    THREADS = 12
    executor = conc.ThreadPoolExecutor(THREADS, THREADS, 0, conc.TimeUnit.MILLISECONDS, LimitedQueue(THREADS * 2))
    with open(OUTPUTFILE, "w") as out:
        with open(INPUTFILE) as inp:
            reader = csv.reader(inp)
            for smi, _, title in reader:
                executor.execute(SmiToMol(smi, title, out))
        executor.shutdown()
        executor.awaitTermination(10000, conc.TimeUnit.SECONDS)
Credits
Thanks to Daniel Lowe and John Mayfield for an interesting discussion about various approaches and what's going on under-the-hood.

No comments: