Tuesday, 14 January 2025

Release of partialsmiles v2.0 - A validating parser for partial SMILES

A while back, I wrote a library called partialsmiles to parse and validate prefixes of SMILES strings. I've just updated it with a minor enhancement (it now returns the full state of the parser, not just the molecule) and thought I would show an example of typical use in the context of generative models.

For a change, let's not use neural nets. Using molecules from ChEMBL converted to random SMILES, I looked at all runs of 8 tokens and counted up the frequency of occurrence of the following token. For example, following Cc1ccccc, '1' was observed 340389 times, '2' 582 times (etc.); following COc1ccc(, 'c' was observed 33288 times, then 'C', 'N', 'O', 'Cl', etc. To handle the start, I padded with 8 start tokens.

Once these frequencies are in hand, we can use this to power a generative model where we sample from the distribution of frequencies at each step. Just like with a neural net, we can apply a temperature factor; this determines to what extent we exaggerate or downplay the difference between frequencies. A value of 1.0 means use the frequencies/probabilities as provided, while an extreme value of 10 would mean treat all (non-zero probability) tokens as equally likely. The other extreme of 0.1 would mean only ever pick the most likely token.

The rationale for partialsmiles is that we can perform checks for validity while generating the tokens, rather than waiting until we complete the SMILES string (i.e. when we generate an END token). If, at some point, the partial SMILES string turns out to be invalid we could just discard the string; this would speed things up but not change the result. An alternative, which I show below, is to avoid sampling tokens that will lead to an invalid SMILES; this will increase the number of valid SMILES thus making the process more efficient. Here I do this by setting the frequencies of those tokens to 0.

If we use the code below to generate 1000 SMILES strings and use a temperature factor of 1.0, the number of valid strings increases from 230 to 327 when partialsmiles is used. A smaller increase is observed for a temperature factor of 1.2 (205 to 247) and larger for 0.8 (263 to 434).

I note in passing that it's possible to force the model to avoid adding an END token unless it results in a valid full SMILES string (see PREVENT_FULL_INVALID_STRING below). I'm afraid that doesn't work out very well. Some really long SMILES string are generated that just keep going until the monkeys typing it eventually close the ring or parenthesis or whatever is the problem. Some things you just gotta let go, or patch up afterwards.

Oh, you want to see some of the generated molecules? No, you don't; you really don't :-) There's a reason neural nets are popular. :-) You can find the associated code on GitHub.

import pickle
import tqdm
import numpy as np
import partialsmiles as ps

np.random.seed(1)

PREVENT_FULL_INVALID_STRING = False

SOS, EOS = range(2)

TOKENS = [
  '^', '$', # i.e. SOS, EOS
  'c', 'C', '(', ')', 'O', '1', '2', '=', 'N', 'n', '3', '[C@H]',
  '[C@@H]', '4', 'F', '-', 'S', '/', 'Cl', '[nH]', 's', 'o', '5', '[C@]', '#',
  '[C@@]', '\\', '[O-]', '[N+]', 'Br', '6', 'P', '7', '8', '9']

TOKEN_IDXS = list(range(len(TOKENS)))

def find_disallowed_tokens(seq, freqs):
    disallowed = set()
    smi = "".join(TOKENS[x] for x in seq)
    for i, x in enumerate(TOKENS[2:], 2):
        if freqs[i] > 0:
            try:
                ps.ParseSmiles(smi + x, partial=True)
            except ps.Error:
                disallowed.add(i)

    if PREVENT_FULL_INVALID_STRING:
        if freqs[EOS] > 0:
            try:
                ps.ParseSmiles(smi, partial=False)
            except:
                disallowed.add(EOS)
    return disallowed

def generate(all_freqs, prefix_length, temperature, avoid_invalid):
    seq = [SOS] * prefix_length
    i = 0
    while True:
        idx = (seq[i]<<6) + (seq[i+1]<<12) + (seq[i+2]<<18) + (seq[i+3]<<24) + (seq[i+4]<<30) + (seq[i+5]<<36) + (seq[i+6]<<42) + (seq[i+7]<<48)
        freqs = [all_freqs.get(idx+j, 0) for j in TOKEN_IDXS]
        if avoid_invalid:
            disallowed_tokens = find_disallowed_tokens(seq[prefix_length:], freqs)
            freqs = [all_freqs.get(idx+j, 0) if j not in disallowed_tokens else 0 for j in TOKEN_IDXS]
            if all(freq==0 for freq in freqs):
                return ""
        probs = freqs/np.sum(freqs)
        p = np.power(probs, 1.0/temperature)
        q = p / np.sum(p)
        chosen = np.random.choice(TOKEN_IDXS, 1, p=q)
        seq.append(chosen[0])
        if chosen == 1: # EOS
            break
        i += 1
    return seq[PREFIX_LENGTH:-1]

if __name__ == "__main__":
    PREFIX_LENGTH = 8
    TEMPERATURE = 0.8
    AVOID_INVALID = True

    with open(f"../nbu/probabilities.{PREFIX_LENGTH}.2.pkl", "rb") as inp:
        all_freqs = pickle.load(inp)
        print("Loaded...")

    ofile = f"../nbu/out.{PREFIX_LENGTH}_{TEMPERATURE}_{AVOID_INVALID}.smi"
    with open(ofile, "w") as out:
        for i in tqdm.tqdm(range(1000)):
            tokens = generate(all_freqs, PREFIX_LENGTH, TEMPERATURE, AVOID_INVALID)
            smi = "".join(TOKENS[x] for x in tokens)
            out.write(f"{smi}\n")