The following is a relatively short document describing how Oxford Nanopore Technologies' program for consensus calling of sequencing data, medaka
, functions internally. We will demonstrate the core functionality required to process alignment data, how it is presented to a recurrent neural network, and how a consensus sequence is formed.
Before anything else we will create and set a working directory:
from epi2melabs import ping
tutorial_name = "medaka_walkthrough"
pinger = ping.Pingu()
# create a work directory and move into it
working_dir = '/epi2melabs/{}/'.format(tutorial_name)
!mkdir -p "$working_dir"
%cd "$working_dir"
/epi2melabs/medaka_walkthrough
As input the core medaka
algorithm accepts sequencing reads aligned to an assembly sequence. If you have run the medaka_consensus
pipeline you will have given as input an assembly sequence and your sequencing data. The pipeline simply runs minimap2
to calculate alignments of the reads to the assembly.
For the purposes of this demonstration we will download pre-aligned data from an R9.4.1 MinION sequencing run:
!wget https://ont-research.s3-eu-west-1.amazonaws.com/datasets/r941_zymo/references.fasta \
&& wget https://ont-research.s3-eu-west-1.amazonaws.com/labs_resources/misc/saureus.bam \
&& wget https://ont-research.s3-eu-west-1.amazonaws.com/labs_resources/misc/saureus.bam.bai \
&& wget https://ont-research.s3-eu-west-1.amazonaws.com/labs_resources/misc/saureus_canu.fasta
The downloaded saureus.bam
file contains alignments of sequencing reads to the downloaded saureus_canu.fasta
. The depth of sequencing has been reduced to around 150-fold coverage of the genome.
The first step of medaka
's calculation is to parse the alignment data into a base counts table ready for input to the neural network. In this section we explore the functions responsible for doing this, how exactly counting is performed and what the results may represent.
At the heart of medaka
resides a straight-forward base-counting procedure. From the alignment data comparing sequencing reads to the reference sequence a pileup is created, much like the display and alignment viewer such as IGV would display.
The pileup is summarise by counting the different base types contained within its columns. The function responsible for this counting excercise is called pileup_counts
in the features
module:
from medaka.features import pileup_counts
help(pileup_counts)
Help on function pileup_counts in module medaka.features: pileup_counts(region, bam, dtype_prefixes=None, region_split=100000, workers=8, tag_name=None, tag_value=None, keep_missing=False, num_qstrat=1, weibull_summation=False, read_group=None) Create pileup counts feature array for region. :param region: `medaka.common.Region` object :param bam: .bam file with alignments. :param dtype_prefixes: prefixes for query names which to separate counts. If `None` (or of length 1), counts are not split. :param region_split: largest region to process in single thread. Regions are processed in parallel and stitched before being returned. :param workers: worker threads for calculating pileup. :param tag_name: two letter tag name by which to filter reads. :param tag_value: integer value of tag for reads to keep. :param keep_missing: whether to keep reads when tag is missing. :param num_qstrat: number of layers for qscore stratification. :param weibull_summation: use a Weibull partial-counts approach, requires 'WL' and 'WK' float-array tags. :returns: iterator of tuples (pileup counts array, reference positions, insertion positions) Multiple chunks are returned if there are discontinuities in positions caused e.g. by gaps in coverage.
The pileup_counts
function above has various arguments, most of which are advanced options and not used within the default operation of medaka. To create a counts matrix we call the function with a Samtools-style region string (medaka
uses 0-based end exclusive co-ordinates) and a filepath to our alignment file:
from timeit import default_timer as now
from medaka.common import Region
t0 = now()
region = Region.from_string('tig00000061:0-1499707')
bam_file = 'saureus.bam'
pileup_data = pileup_counts(region, bam_file)
pileup_data = pileup_data[0] # implementation detail that need not trouble us
counts, positions = pileup_data
t1 = now()
print("{:.2f}s to form pileup counts.".format(t1 - t0))
1.66s to form pileup counts.
The pileup_counts
function returned two structures. The latter of these is a positions table, this records which pileup columns are reference positions and which are caused by inserted bases in one or more reads:
display(positions)
array([( 0, 0), ( 1, 0), ( 2, 0), ..., (1499704, 0), (1499705, 0), (1499706, 0)], dtype=[('major', '<i8'), ('minor', '<i8')])
The field minor
in the above array indicates reference and insertion columns: it takes a value 0
for a reference position and counts upwards for all following insertion events. The major
field keeps track of the reference base co-ordinate.
The base counts themselves from the alignment pileup are stored separately:
display(counts.shape)
display(counts)
(3508694, 10)
array([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 4, 0, 0], ..., [0, 0, 0, ..., 2, 0, 0], [0, 0, 0, ..., 2, 0, 0], [0, 0, 0, ..., 0, 0, 0]], dtype=uint64)
The matrix is of shape (# pileup columns, 10), each row of the matrix corresponds to the counts of bases and gaps in the pileup columns (yes, the rows and columns get confusing). There are 10 entries one each for the fours base types and gap, multiplied by two as reads on the forward and reverse strand are counted separately. The ordering of the entries is given by:
from medaka.features import libmedaka
ffi, lib = libmedaka.ffi, libmedaka.lib
plp_bases = lib.plp_bases
codes = ffi.string(plp_bases).decode()
display(','.join(codes))
'a,c,g,t,A,C,G,T,d,D'
in which lower-case letters denote reverse strand counts (upper case, forward) and 'd' and 'D' count deletions. A point of note is that this counting strategy it itself makes a distinction between bases which are deleted in reads with respect to reference sequence and bases which are deleted in reads with respect to other reads (the bases in the other reads being insertions with respect to the reference). Previous versions of medaka
have performed a symmetrization here: by adding in deletion counts for all read that span a pileup column, whether that pileup column is a reference position (minor=0
) or an insertion column (minor>0
).
After obtained the base-counts matrix produced in the section above medaka
performs a normalization of the counts. Across the pileup columns, all count vectors with equal corresponding major
position index are normalized by the total count for the column with minor=0
(the reference position). This choice of normalization accounts for the lack of symmetry described above, and that whilst consensus insertions are typically rare, isolated insertions may still occur within any one read spanning two reference positions. There are on average up to three pileup columns for every input reference position.
Ordinarily this normalization is performed in a post-processing method of the CountsFeatureEncoder
class, for the purposes of exposition the operation under normal behaviour is:
import numpy as np
minor_inds = np.where(positions['minor'] > 0)
major_pos_at_minor_inds = positions['major'][minor_inds]
major_ind_at_minor_inds = np.searchsorted(
positions['major'], major_pos_at_minor_inds, side='left')
depth = np.sum(counts, axis=1)
depth[minor_inds] = depth[major_ind_at_minor_inds]
feature_array = counts / np.maximum(1, depth).reshape((-1, 1))
display(feature_array)
array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 1., 0., 0.], ..., [0., 0., 0., ..., 1., 0., 0.], [0., 0., 0., ..., 1., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]])
The normalization is across all bases, it is not split by strand; splitting the normalization by strand would potentially lose important information with respect to strand bias and relative errors. The plot below visualizes the final input to the neural network used in medaka
.
# Plot feature array (click play)
import aplanat
from bokeh.plotting import figure
from bokeh.models import Range1d
# select just a region to plot
reg = slice(3925,3960)
pdata = feature_array[reg].transpose()
ppos = positions[reg]
# create a figure
p = figure(
title="Base counts",
plot_height=300, plot_width=600)
img = np.zeros(pdata.shape, dtype=np.uint32)
view = img.view(dtype=np.uint8).reshape(pdata.shape + (4,))
# set row colours: A blue, C red, G, green, T yello
cols = {'A': {2}, 'C':{0}, 'G':{1}, 'T':{0,1}}
for k, v in cols.items():
where = [x.upper() == k for x in codes]
for i in v:
view[where,:,i] = 255
# use data as transparency mask
view[:,:,3] = np.minimum(255, 8*255 * pdata)
p.x_range.range_padding = p.y_range.range_padding = 0
p.image_rgba(image=[img], x=0, y=0, dw=pdata.shape[1], dh=pdata.shape[0])
ylabels = np.arange(0.5,10.5)
p.yaxis.ticker = ylabels
p.yaxis.major_label_overrides = dict(zip(ylabels, codes))
p.y_range = Range1d(
start=0, end=10,
bounds=(0, 10))
xlabels = np.arange(0.5, pdata.shape[1])
p.xaxis.ticker = xlabels
p.xaxis.major_label_overrides = dict(zip(
xlabels, ('{}.{}'.format(x['major'], x['minor']) for x in ppos)
))
p.xaxis.major_label_orientation = 3.14/2
aplanat.show(p, background="#F4F4F4")
Having counted bases in an alignment pileup medaka
proceeds to analyse these counts using a Recurrent Neural Network, (RNN). A full discussion of such algorithms is beyond the scope of this discussion, this section demonstrates their use in calculating a consensus sequence from the base counts array. When medaka
is used as a variant caller different methods are used.
from pkg_resources import resource_filename
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU, Bidirectional
from medaka.labels import BaseLabelScheme
# parameters of the model
gru_size = 128
time_steps, feature_len = (1000, counts.shape[1])
symbols = BaseLabelScheme.symbols # [-, A, C, G, T]
num_classes = len(symbols)
# build the model
model = Sequential(name='medaka')
input_shape = (time_steps, feature_len)
for i in range(2):
gru = GRU(gru_size, return_sequences=True, name="gru_{}".format(i))
model.add(Bidirectional(gru, input_shape=input_shape))
model.add(Dense(
num_classes, activation='softmax', name='classify',
input_shape=(time_steps, 2 * gru_size)))
model.summary()
Model: "medaka" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= bidirectional_2 (Bidirection (None, 1000, 256) 107520 _________________________________________________________________ bidirectional_3 (Bidirection (None, 1000, 256) 296448 _________________________________________________________________ classify (Dense) (None, 1000, 5) 1285 ================================================================= Total params: 405,253 Trainable params: 405,253 Non-trainable params: 0 _________________________________________________________________
The model takes the count matrix as input and outputs for each corresponding column of the counts matrix a set of five scores. The five scores express the possibility that the consensus sequence should contain one of the four bases A, C, G, or T, or a gap '-' character at the pileup column under consideration.
In order to make predictions using the RNN, medaka
splits the normalized counts array into overlapping chunks before processing by the model. Chunking the array allows for more efficient parallel computation while overlapping is a mitigation against edge-effects at the boundaries of chunks.
As mentioned above in the aside on performance, medaka
has a somewhat elaborate system for managing data chunks. For the purposes of exposition the code below implements a simple chunking and batching of the data.
from functools import partial
from medaka.common import sliding_window, grouper
# load an actual medaka model
from medaka import models
model = models.open_model("r941_min_high_g360").load_model()
# create a function to perform windowing on an array
overlap = 200
window = partial(
sliding_window,
window=time_steps, step=time_steps - overlap, axis=0)
# run the network on input data
def get_predictions(data, batch_size=40):
for batch in grouper(data, batch_size=batch_size):
batch = np.stack(batch)
results = model.predict_on_batch(batch)
yield from results
t0 = now()
predictions = get_predictions(window(feature_array))
seq_chunks = list()
for pred in predictions:
# remove half the overlapping region of chunks
pred = pred[overlap // 2:-overlap // 2]
# find the most likely base at each position and form the sequence
mp = np.argmax(pred, -1)
seq = ''.join((symbols[x] for x in mp))
seq = seq.replace('*', '')
seq_chunks.append(seq)
sequence = ''.join(seq_chunks)
t1 = now()
print("{:.2f} to run predictions".format(t1 - t0))
print("Total sequence length: {}.".format(len(sequence)))
113.92 to run predictions Total sequence length: 1500539.
The code performs a simple undoing of the overlapping before stitching the consensus sequence pieces back together. This is sufficient to obtain results here; the full medaka
implementation also keeps track of the positions
array to ensure the sequence stitching is performed correctly with respect to the original input reference sequence.
output = "output.fasta"
with open(output, 'w') as fh:
fh.write(">seq\n{}\n".format(sequence))
for fname in ("saureus_canu.fasta", output):
print("Analysing: {}.".format(fname))
!assess_assembly -r references.fasta -i "$fname" 2>/dev/null
print("\n")
Analysing: saureus_canu.fasta. Writing list of indels 100 bases and longer to assm_indel_ge100.txt. # Percentage Errors name mean q10 q50 q90 err_ont 0.133% 0.056% 0.068% 0.290% err_bal 0.133% 0.056% 0.068% 0.290% iden 0.004% 0.000% 0.001% 0.024% del 0.117% 0.049% 0.059% 0.166% ins 0.012% 0.003% 0.006% 0.046% # Q Scores name mean q10 q50 q90 err_ont 28.76 32.54 31.65 25.38 err_bal 28.76 32.54 31.65 25.38 iden 44.39 inf 50.00 36.20 del 29.30 33.10 32.29 27.79 ins 39.20 45.23 42.22 33.41 All done, output written to assm_stats.txt, assm_summ.txt and assm_indel_ge100.txt Analysing: output.fasta. Writing list of indels 100 bases and longer to assm_indel_ge100.txt. # Percentage Errors name mean q10 q50 q90 err_ont 0.014% 0.008% 0.015% 0.023% err_bal 0.014% 0.008% 0.015% 0.023% iden 0.001% 0.000% 0.000% 0.003% del 0.010% 0.005% 0.010% 0.016% ins 0.004% 0.001% 0.003% 0.007% # Q Scores name mean q10 q50 q90 err_ont 38.40 40.97 38.24 36.47 err_bal 38.40 40.97 38.24 36.47 iden 50.30 inf inf 46.02 del 40.21 42.60 40.00 37.95 ins 43.98 48.24 44.56 41.55 All done, output written to assm_stats.txt, assm_summ.txt and assm_indel_ge100.txt
In this short walkthrough we have examined some of the internals of Oxford Nanopore Technologies' medaka
program performs GPU accelerated consensus calculations from aligned sequencing data. The public medaka
codebase implements various alternative forms of the algorithms presented here including run length compression and support for multiple datatypes. Hopefully this guide will prove useful to anyone wishing to implement algorithms similar to that implemented in medaka
.