How medaka works

The following is a relatively short document describing how Oxford Nanopore Technologies' program for consensus calling of sequencing data, medaka, functions internally. We will demonstrate the core functionality required to process alignment data, how it is presented to a recurrent neural network, and how a consensus sequence is formed.

Getting started¶

Before anything else we will create and set a working directory:

In [1]:
from epi2melabs import ping
tutorial_name = "medaka_walkthrough"
pinger = ping.Pingu()

# create a work directory and move into it
working_dir = '/epi2melabs/{}/'.format(tutorial_name)
!mkdir -p "$working_dir"
%cd "$working_dir"
/epi2melabs/medaka_walkthrough

Medaka's input¶

As input the core medaka algorithm accepts sequencing reads aligned to an assembly sequence. If you have run the medaka_consensus pipeline you will have given as input an assembly sequence and your sequencing data. The pipeline simply runs minimap2 to calculate alignments of the reads to the assembly.

For the purposes of this demonstration we will download pre-aligned data from an R9.4.1 MinION sequencing run:

In [ ]:
!wget https://ont-research.s3-eu-west-1.amazonaws.com/datasets/r941_zymo/references.fasta \
    && wget https://ont-research.s3-eu-west-1.amazonaws.com/labs_resources/misc/saureus.bam \
    && wget https://ont-research.s3-eu-west-1.amazonaws.com/labs_resources/misc/saureus.bam.bai \
    && wget https://ont-research.s3-eu-west-1.amazonaws.com/labs_resources/misc/saureus_canu.fasta

The downloaded saureus.bam file contains alignments of sequencing reads to the downloaded saureus_canu.fasta. The depth of sequencing has been reduced to around 150-fold coverage of the genome.

Diving in: counting bases¶

The first step of medaka's calculation is to parse the alignment data into a base counts table ready for input to the neural network. In this section we explore the functions responsible for doing this, how exactly counting is performed and what the results may represent.

Pileup interface¶

At the heart of medaka resides a straight-forward base-counting procedure. From the alignment data comparing sequencing reads to the reference sequence a pileup is created, much like the display and alignment viewer such as IGV would display.

The pileup is summarise by counting the different base types contained within its columns. The function responsible for this counting excercise is called pileup_counts in the features module:

In [3]:
from medaka.features import pileup_counts
help(pileup_counts)
Help on function pileup_counts in module medaka.features:

pileup_counts(region, bam, dtype_prefixes=None, region_split=100000, workers=8, tag_name=None, tag_value=None, keep_missing=False, num_qstrat=1, weibull_summation=False, read_group=None)
    Create pileup counts feature array for region.
    
    :param region: `medaka.common.Region` object
    :param bam: .bam file with alignments.
    :param dtype_prefixes: prefixes for query names which to separate counts.
        If `None` (or of length 1), counts are not split.
    :param region_split: largest region to process in single thread.
        Regions are processed in parallel and stitched before being returned.
    :param workers: worker threads for calculating pileup.
    :param tag_name: two letter tag name by which to filter reads.
    :param tag_value: integer value of tag for reads to keep.
    :param keep_missing: whether to keep reads when tag is missing.
    :param num_qstrat: number of layers for qscore stratification.
    :param weibull_summation: use a Weibull partial-counts approach,
        requires 'WL' and 'WK' float-array tags.
    
    :returns: iterator of tuples
        (pileup counts array, reference positions, insertion positions)
        Multiple chunks are returned if there are discontinuities in
        positions caused e.g. by gaps in coverage.

The pileup_counts function above has various arguments, most of which are advanced options and not used within the default operation of medaka. To create a counts matrix we call the function with a Samtools-style region string (medaka uses 0-based end exclusive co-ordinates) and a filepath to our alignment file:

In [4]:
from timeit import default_timer as now
from medaka.common import Region

t0 = now()
region = Region.from_string('tig00000061:0-1499707')
bam_file = 'saureus.bam'
pileup_data = pileup_counts(region, bam_file)
pileup_data = pileup_data[0]  # implementation detail that need not trouble us
counts, positions = pileup_data
t1 = now()
print("{:.2f}s to form pileup counts.".format(t1 - t0))
1.66s to form pileup counts.

The counts matrix¶

The pileup_counts function returned two structures. The latter of these is a positions table, this records which pileup columns are reference positions and which are caused by inserted bases in one or more reads:

In [5]:
display(positions)
array([(      0, 0), (      1, 0), (      2, 0), ..., (1499704, 0),
       (1499705, 0), (1499706, 0)],
      dtype=[('major', '<i8'), ('minor', '<i8')])

The field minor in the above array indicates reference and insertion columns: it takes a value 0 for a reference position and counts upwards for all following insertion events. The major field keeps track of the reference base co-ordinate.

The base counts themselves from the alignment pileup are stored separately:

In [6]:
display(counts.shape)
display(counts)
(3508694, 10)
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 4, 0, 0],
       ...,
       [0, 0, 0, ..., 2, 0, 0],
       [0, 0, 0, ..., 2, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint64)

The matrix is of shape (# pileup columns, 10), each row of the matrix corresponds to the counts of bases and gaps in the pileup columns (yes, the rows and columns get confusing). There are 10 entries one each for the fours base types and gap, multiplied by two as reads on the forward and reverse strand are counted separately. The ordering of the entries is given by:

In [7]:
from medaka.features import libmedaka
ffi, lib = libmedaka.ffi, libmedaka.lib
plp_bases = lib.plp_bases
codes = ffi.string(plp_bases).decode()
display(','.join(codes))
'a,c,g,t,A,C,G,T,d,D'

in which lower-case letters denote reverse strand counts (upper case, forward) and 'd' and 'D' count deletions. A point of note is that this counting strategy it itself makes a distinction between bases which are deleted in reads with respect to reference sequence and bases which are deleted in reads with respect to other reads (the bases in the other reads being insertions with respect to the reference). Previous versions of medaka have performed a symmetrization here: by adding in deletion counts for all read that span a pileup column, whether that pileup column is a reference position (minor=0) or an insertion column (minor>0).

Normalization¶

After obtained the base-counts matrix produced in the section above medaka performs a normalization of the counts. Across the pileup columns, all count vectors with equal corresponding major position index are normalized by the total count for the column with minor=0 (the reference position). This choice of normalization accounts for the lack of symmetry described above, and that whilst consensus insertions are typically rare, isolated insertions may still occur within any one read spanning two reference positions. There are on average up to three pileup columns for every input reference position.

Ordinarily this normalization is performed in a post-processing method of the CountsFeatureEncoder class, for the purposes of exposition the operation under normal behaviour is:

In [8]:
import numpy as np
minor_inds = np.where(positions['minor'] > 0)
major_pos_at_minor_inds = positions['major'][minor_inds]
major_ind_at_minor_inds = np.searchsorted(
    positions['major'], major_pos_at_minor_inds, side='left')

depth = np.sum(counts, axis=1)
depth[minor_inds] = depth[major_ind_at_minor_inds]

feature_array = counts / np.maximum(1, depth).reshape((-1, 1))
display(feature_array)
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

The normalization is across all bases, it is not split by strand; splitting the normalization by strand would potentially lose important information with respect to strand bias and relative errors. The plot below visualizes the final input to the neural network used in medaka.

In [18]:
# Plot feature array (click play)
import aplanat
from bokeh.plotting import figure
from bokeh.models import Range1d


# select just a region to plot
reg = slice(3925,3960)
pdata = feature_array[reg].transpose()
ppos = positions[reg]
# create a figure
p = figure(
    title="Base counts",
    plot_height=300, plot_width=600)

img = np.zeros(pdata.shape, dtype=np.uint32)
view = img.view(dtype=np.uint8).reshape(pdata.shape + (4,))
# set row colours: A blue, C red, G, green, T yello
cols = {'A': {2}, 'C':{0}, 'G':{1}, 'T':{0,1}}
for k, v in cols.items():
    where = [x.upper() == k for x in codes]
    for i in v:
        view[where,:,i] = 255
# use data as transparency mask
view[:,:,3] = np.minimum(255, 8*255 * pdata)

p.x_range.range_padding = p.y_range.range_padding = 0
p.image_rgba(image=[img], x=0, y=0, dw=pdata.shape[1], dh=pdata.shape[0])
ylabels = np.arange(0.5,10.5)
p.yaxis.ticker = ylabels
p.yaxis.major_label_overrides = dict(zip(ylabels, codes))
p.y_range = Range1d(
    start=0, end=10,
    bounds=(0, 10))
xlabels = np.arange(0.5, pdata.shape[1])
p.xaxis.ticker = xlabels
p.xaxis.major_label_overrides = dict(zip(
    xlabels, ('{}.{}'.format(x['major'], x['minor']) for x in ppos)
))
p.xaxis.major_label_orientation = 3.14/2
aplanat.show(p, background="#F4F4F4")

The neural network¶

Having counted bases in an alignment pileup medaka proceeds to analyse these counts using a Recurrent Neural Network, (RNN). A full discussion of such algorithms is beyond the scope of this discussion, this section demonstrates their use in calculating a consensus sequence from the base counts array. When medaka is used as a variant caller different methods are used.

The model¶

In order to construct a consensus sequence medaka uses a multi-layer bidirection RNN. This is defined using the keras API in tensorflow. The following code is adapted from the models module of medaka, it has been simplified to show only the essential parts:

In [13]:
from pkg_resources import resource_filename

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU, Bidirectional

from medaka.labels import BaseLabelScheme

# parameters of the model
gru_size = 128
time_steps, feature_len = (1000, counts.shape[1])
symbols = BaseLabelScheme.symbols  # [-, A, C, G, T]
num_classes = len(symbols)

# build the model
model = Sequential(name='medaka')
input_shape = (time_steps, feature_len)
for i in range(2):
    gru = GRU(gru_size, return_sequences=True, name="gru_{}".format(i))
    model.add(Bidirectional(gru, input_shape=input_shape))
model.add(Dense(
    num_classes, activation='softmax', name='classify',
    input_shape=(time_steps, 2 * gru_size)))

model.summary()
Model: "medaka"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
bidirectional_2 (Bidirection (None, 1000, 256)         107520    
_________________________________________________________________
bidirectional_3 (Bidirection (None, 1000, 256)         296448    
_________________________________________________________________
classify (Dense)             (None, 1000, 5)           1285      
=================================================================
Total params: 405,253
Trainable params: 405,253
Non-trainable params: 0
_________________________________________________________________

The model takes the count matrix as input and outputs for each corresponding column of the counts matrix a set of five scores. The five scores express the possibility that the consensus sequence should contain one of the four bases A, C, G, or T, or a gap '-' character at the pileup column under consideration.

Making predictions¶

In order to make predictions using the RNN, medaka splits the normalized counts array into overlapping chunks before processing by the model. Chunking the array allows for more efficient parallel computation while overlapping is a mitigation against edge-effects at the boundaries of chunks.

As mentioned above in the aside on performance, medaka has a somewhat elaborate system for managing data chunks. For the purposes of exposition the code below implements a simple chunking and batching of the data.

In [14]:
from functools import partial
from medaka.common import sliding_window, grouper

# load an actual medaka model
from medaka import models
model = models.open_model("r941_min_high_g360").load_model()

# create a function to perform windowing on an array
overlap = 200
window = partial(
    sliding_window,
    window=time_steps, step=time_steps - overlap, axis=0)

# run the network on input data
def get_predictions(data, batch_size=40):
    for batch in grouper(data, batch_size=batch_size):
        batch = np.stack(batch)
        results = model.predict_on_batch(batch)
        yield from results

t0 = now()
predictions = get_predictions(window(feature_array))
seq_chunks = list()
for pred in predictions:
    # remove half the overlapping region of chunks
    pred = pred[overlap // 2:-overlap // 2]
    # find the most likely base at each position and form the sequence
    mp = np.argmax(pred, -1)
    seq = ''.join((symbols[x] for x in mp))
    seq = seq.replace('*', '')
    seq_chunks.append(seq)
sequence = ''.join(seq_chunks)
t1 = now()
print("{:.2f} to run predictions".format(t1 - t0))
print("Total sequence length: {}.".format(len(sequence)))
113.92 to run predictions
Total sequence length: 1500539.

The code performs a simple undoing of the overlapping before stitching the consensus sequence pieces back together. This is sufficient to obtain results here; the full medaka implementation also keeps track of the positions array to ensure the sequence stitching is performed correctly with respect to the original input reference sequence.

Checking our results¶

We can write out the full consensus sequence derived above and compare it to the truth sequence by using the assess_assembly program from the pomoxis package. By also examining the original draft sequence, we can see the improvement in quality from medaka:

In [15]:
output = "output.fasta"
with open(output, 'w') as fh:
    fh.write(">seq\n{}\n".format(sequence))
for fname in ("saureus_canu.fasta", output):
    print("Analysing: {}.".format(fname))
    !assess_assembly -r references.fasta -i "$fname" 2>/dev/null
    print("\n")
Analysing: saureus_canu.fasta.
Writing list of indels 100 bases and longer to assm_indel_ge100.txt.
#  Percentage Errors
  name     mean     q10      q50      q90   
err_ont   0.133%   0.056%   0.068%   0.290% 
err_bal   0.133%   0.056%   0.068%   0.290% 
   iden   0.004%   0.000%   0.001%   0.024% 
    del   0.117%   0.049%   0.059%   0.166% 
    ins   0.012%   0.003%   0.006%   0.046% 

#  Q Scores
  name     mean      q10      q50      q90  
err_ont   28.76    32.54    31.65    25.38  
err_bal   28.76    32.54    31.65    25.38  
   iden   44.39      inf    50.00    36.20  
    del   29.30    33.10    32.29    27.79  
    ins   39.20    45.23    42.22    33.41  

All done, output written to assm_stats.txt, assm_summ.txt and assm_indel_ge100.txt


Analysing: output.fasta.
Writing list of indels 100 bases and longer to assm_indel_ge100.txt.
#  Percentage Errors
  name     mean     q10      q50      q90   
err_ont   0.014%   0.008%   0.015%   0.023% 
err_bal   0.014%   0.008%   0.015%   0.023% 
   iden   0.001%   0.000%   0.000%   0.003% 
    del   0.010%   0.005%   0.010%   0.016% 
    ins   0.004%   0.001%   0.003%   0.007% 

#  Q Scores
  name     mean      q10      q50      q90  
err_ont   38.40    40.97    38.24    36.47  
err_bal   38.40    40.97    38.24    36.47  
   iden   50.30      inf      inf    46.02  
    del   40.21    42.60    40.00    37.95  
    ins   43.98    48.24    44.56    41.55  

All done, output written to assm_stats.txt, assm_summ.txt and assm_indel_ge100.txt


Remarks¶

In this short walkthrough we have examined some of the internals of Oxford Nanopore Technologies' medaka program performs GPU accelerated consensus calculations from aligned sequencing data. The public medaka codebase implements various alternative forms of the algorithms presented here including run length compression and support for multiple datatypes. Hopefully this guide will prove useful to anyone wishing to implement algorithms similar to that implemented in medaka.