Thursday 28 December 2023

Eyes on tokenize

I was writing a tokenizer for SMILES and came across a recent paper by the IBM Research team on reaction standardisation which contained a description of their tokenization method. It uses a regex:

(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])

Looking at the source code, here's an adapted version of how this is applied:

SMILES_TOKENIZER_PATTERN = r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
SMILES_REGEX = re.compile(SMILES_TOKENIZER_PATTERN)
def split_on_regexp(smi):
    """
    >>> split_on_regexp("Cl11%11%(111)C[C@@H](Br)I")
    ['Cl', '1', '1', '%11', '%(111)', 'C', '[C@@H]', '(', 'Br', ')', 'I']
    """
    return SMILES_REGEX.findall(smi)

My approach is to tokenize directly on a single pass through the string, and so I was curious how it compared...

Tokenize v1

The above code takes 7.7s to tokenize 1M SMILES from ChEMBL (Python 3.10 on Linux in a VM). "Surely a hand-rolled tokenizer will do better?" says I:

def tokenize_v1(smi):
    """
    >>> tokenize_v1("Cl11%11%(111)C[C@@H](Br)I")
    ['Cl', '1', '1', '%11', '%(111)', 'C', '[C@@H]', '(', 'Br', ')', 'I']
    """
    tokens = []
    i = 0
    N = len(smi)
    while i < N:
        x = smi[i]
        if x == 'C' and i+1<N and smi[i+1]=='l':
            tokens.append("Cl")
            i += 1
        elif x == 'B' and i+1<N and smi[i+1]=='r':
            tokens.append("Br")
            i += 1
        elif x=='[':
            j = i+1
            while smi[j] != ']':
                j += 1
            tokens.append(smi[i:j+1])
            i += j-i
        elif x == '%':
            if smi[i+1] == '(':
                j = i
                while smi[j] != ')':
                    j += 1
                tokens.append(smi[i:j+1])
                i += j-i
            else:
                tokens.append(smi[i:i+3])
                i += 2
        else:
            tokens.append(x)
        i += 1
    return tokens

10.4s, I'm afraid. Well, we can't give up without a fight. Time to optimise.

Tokenize v2

Consider that 'C' is the most common character but will be among the slowest to handle due to the check for 'Cl' (and then the subsequent 'if' statements). How about we check for the 'l' in 'Cl' instead of the 'C'. The previous token will be incorrect ('C'), but we can just correct it:

def tokenize_v2(smi):
    """
    >>> tokenize_v2("Cl11%11%(111)C[C@@H](Br)I")
    ['Cl', '1', '1', '%11', '%(111)', 'C', '[C@@H]', '(', 'Br', ')', 'I']
    """
    tokens = []
    i = 0
    N = len(smi)
    while i < N:
        x = smi[i]
        if x == 'l':
            tokens[-1] = 'Cl'
        elif x == 'r':
            tokens[-1] = 'Br'
        elif x=='[':
            j = i+1
            while smi[j] != ']':
                j += 1
            tokens.append(smi[i:j+1])
            i += j-i
        elif x == '%':
            if smi[i+1] == '(':
                j = i
                while smi[j] != ')':
                    j += 1
                tokens.append(smi[i:j+1])
                i += j-i
            else:
                tokens.append(smi[i:i+3])
                i += 2
        else:
            tokens.append(x)
        i += 1
    return tokens

Down to 9.8s. Not a major step forward, but it enables the next optimisation...

Tokenize v3

Optimising parsing for SMILES simply boils down to making 'C' fast even if everything else is slowed down, as the overall average will still be faster. With this in mind, let's minimise the 'if' statements that 'C' needs to go through by bundling all of them into a single test up-front:

chars = set('[lr%')
def tokenize_v3(smi):
    """
    >>> tokenize_v3("Cl11%11%(111)C[C@@H](Br)I")
    ['Cl', '1', '1', '%11', '%(111)', 'C', '[C@@H]', '(', 'Br', ')', 'I']
    """
    tokens = []
    i = 0
    N = len(smi)
    while i < N:
        x = smi[i]
        if x not in chars:
            tokens.append(x)
            i += 1
        else:
            if x == 'l':
                tokens[-1] = 'Cl'
                i += 1
            elif x == 'r':
                tokens[-1] = 'Br'
                i += 1
            elif x=='[':
                j = i+1
                while smi[j] != ']':
                    j += 1
                tokens.append(smi[i:j+1])
                i += j-i + 1
            else: # %
                if smi[i+1] == '(':
                    j = i
                    while smi[j] != ')':
                        j += 1
                    tokens.append(smi[i:j+1])
                    i += j-i + 1
                else:
                    tokens.append(smi[i:i+3])
                    i += 3
    return tokens

Which comes in as 6.1s, a modest improvement on the regex. But the story does not end here...

Python 3.12 vs Python 3.10

The Python 3.10 I was using earlier is the system Python on Ubuntu 22.04, but Python 3.12 is available via Conda. More recent Pythons contain significant speed-ups but it appears that these speed-ups have a greater effect on pure Python code rather than regex handling, which is presumably already handled by optimised C code. The associated timings are 7.2s (regex), 5.6s (v1), 4.9s (v2) and 4.1s (v3).

PyPy

I always like to keep an eye on PyPy, which is a drop-in replacement for Python 3 except that it cannot be used with OB or RDKit (though see this and this). With PyPy, the difference is even greater: 6.6s (regex), 1.7s (v1), 0.8s (v2), 1.0s (v3). Note that v2 is coming out ahead of v3 surprisingly. I guess that PyPy realises that it can use a switch statement for v2 but doesn't realise for v3.

Conclusion

Even the slowest of the speeds above isn't going to affect most applications, but hopefully the discussion around optimisations and Python versions is of interest. Ultimately my own preference is to avoid regexs unless necessary, as I find them difficult to read and check for errors, though others may prefer the one line simplicity of a carefully-crafted regex.

The full code is available as a gist.

No comments: