Accurate predictions of prognostic outcomes are of substantial and pivotal significance in the context of quality care delivery. However, the application of deep learning models to enhance caregiving in healthcare has been limited by concerns related to the methodology reliability. Therefore, models that are robust and result in a throughput prediction of clinical outcomes (e.g., survival) while exhibiting high reliability and generalizability to larger populations remain in high demand. There has been an emerging, persistent interest in modeling survival data to leverage the promise that deep learning models offer. This is not surprising due to the significance of the healthcare sector, where we are often interested in understanding, for instance, the role of specific, differentially expressed gene in prognosis or, more generally, the likelihood of a given treatment regimen's impact on patient outcomes. In both cases, we must make decisions accordingly to improve these patient outcomes related to care in a timely and efficacious manner.
Analyzing time-to-event data is an inimitable problem because the outcome of interest might comprise whether or not an event has occurred (binary outcome) and/or the time when this event occurs (continuous outcome) (Feng and Zhao 2021). The problem is further complicated by missing data on the survival outcome of interest—censored data1. The very nature of censored survival data makes it impossible to apply classical analysis methods (e.g., logistic regression).
Additionally, models based on the Weibull model have restrictive assumptions, including a parametric form of the time-to-event distribution. Similarly, the semi-parametric Cox proportional hazards (PH) model (Burden and Winkler 2008) also has assumptions. One crucial assumption is the proportional hazards assumption: “the effect of a unit increase in a covariate is multiplicative with respect to the hazard rate.” Despite the outcome of interest not always being a hazard rate, it may be a probability: for instance, the PH assumption does not make much sense, especially when we have a substantial number of covariates, because we would need each of these covariates to satisfy this assumption. Also, the performance of these methods has been shown to be poor, particularly when the underlying model is incorrectly specified (Feng and Zhao 2021).
So, how can we tackle the problem of modeling survival data amicably? This post will review an extension of artificial neural networks (ANN) that is based on the Cox PH model and trained to model survival data with Bayesian learning. In particular, we utilize a 2-layer feed-forward artificial neural network (ANN) trained using Bayesian inference to model survival outcomes and compare this model to the more traditional Cox proportional hazards model. In contrast to previously studied models, we expect the ANN trained using Bayesian inference to perform better.
First, we introduce neural networks for general context before discussing how neural networks approach modeling survival data and how Bayesian inference has been introduced into these models to enhance their predictive capacity. Next, we present an application of the Bayesian learning artificial neural network (BLNN) using an R package of a similar name applied in modeling the effect of identified, differentially expressed genes on the survival of patients with primary bladder cancer. Lastly, we compare this model to the more traditional neural network for illustrative purposes and provide an extension code in Python.
With all the hype linked to this deep learning method in more recent times(Hastie et al. 2009), we provide a simplistic idea of what this method is. Defined: Neural networks are:
“Computer systems with interconnected nodes designed like neurons to mimic the human brain in terms of intelligence. These networks use algorithms to discover hidden data structures and patterns, correlations, clusters, and classify them, learn and improve over time.”
The idea is to take in simple functions as inputs and then allow these functions to build upon each other. The models are flexible enough to learn non-linear relationships rather than prescribe them as do kernels or transformations. A neural network takes in an input vector of p features \(X=(X_1, X_2, \cdots , X_p)\), and then creates a non-linear function to forecast an outcome variable, \(Y\). While varied statistical models such as Bayesian additive regression trees (BART) and random forests exist, neural networks have a structure that contrasts them from other methods. Figure 1 shows a feed-forward neural network with an input layer consisting of 4 input features, \(X=(X_1, \cdots, X_4)\), a single hidden layer with 5 nodes \(A_1, \cdots, A_5\), a non-linear activation function, \(f(X)\) (output layer), and the desired outcome, \(Y.\)
The arrows show that the input layer feeds into each of the nodes in the hidden layer, which in turn feed into our activation function all the way to the outcome in a "forward" manner, which is reflected in the name “feed forward.” A general neural network model has the form:
\[\begin{align} f(X) & = \beta_0 + \sum_{k=1}^{K} \beta_k h_k (X)\\ & = \beta_0 + \sum_{k=1}^{K} \beta_k g(w_{k0} + \sum_{j=1}^{p}w_{kj}X_j) \end{align}\]
In the first modeling step, the \(K\) activations in the hidden layer are computed as functions of the features in the input layer as follows:
\[\begin{align} A_k = h_k(X) &= g(w_{k0} + \sum_{j=1}^{p}w_{kj}X_j)\\ \end{align}\] where \(g(z)\) is the activation function specified. The \(K\) activation functions from the hidden layer then feed their outputs into the output layer so that we have:
\[\begin{align} f(X) & = \beta_0 + \sum_{k=1}^{K} \beta_k A_k \end{align}\]
where \(K\) in Figure 1 is 5. Parameters \(\beta_0, \cdots, \beta_K\), as well as, \(w_{10}, \cdots, w_{Kp}\) are estimated from the data. Quite a number of options exist for the activation function, \(g(z)\).2 The non-linearity of the activation function \(g(z)\) allows the model to capture both complex non-linear structures and interaction effects.
The BLNN implementation of Bayesian inference in artificial neural networks is based on the Cox proportional hazards (PH) neural model described by Sharaf and Tsokos (2015). Specifically, the idea is to build a predictive model for survival using a neural network with \(K\) outputs. Here, \(K\) defines the number of periods. Using this neural network architecture, Mani et al. estimated a hazard function where for each individual, we have a training vector of hazard probabilities \((h_{ik})\) defined as: \[ h_{ik}= \begin{cases} 0 & ~\textrm{if} ~ 1\leq k \leq K \\ 1 &~\textrm{if} ~ t \leq k \leq K ~ \textrm{and event = 1} \\ \frac{r_k}{n_k}~ \textrm{if}~ t \leq k \leq K ~ \textrm{and event = 0} \end{cases} \] where \(h_{ik}=0\) if the event of interest did not occur (the patient survived), \(h_{ik} =1\) if event occurred at some time, \(t\), and \(h_{ik}=\frac{r_k}{n_k}\) if the subject is censored/ lost in follow-up during the course of the study, \(t<K\). \(h_{ik}=\frac{r_k}{n_k}\) is the Kaplan-Meier (KM) hazard estimate for time interval \(k\) and \(r_k\) and \(n_k\) denote the number of events due to the risk factor of interest in time period \(k\) and the number at risk in time interval \(k.\) The neural network uses the logistic sigmoid activation function defined as: \[ \Phi (x) = \frac{1}{1+e^{-x}} \]
The weights for this network are obtained through a minimization of the cross-entropy loss function. Figure 2 shows the architecture of a feed-forward neural network based on the Cox PH model with an input layer consisting of \(p\) covariates and a bias term, a single hidden layer with \(H\) nodes, and a single bias term. Lastly, we have an output layer with \(K\) units, which learn to estimate the hazard probabilities associated with each individual at each time interval. The network’s input layer feeds the hidden layer, which in turn feeds the output layer. The “feed-forward” naming convention is derived from this aspect of the architecture. The hazard estimates based on this neural network model are then converted to estimates of survival based on the survival function:
\[ S(t_k)=\prod_{l=1}^k (1-h(t_l)) \]
where \(k\) denotes the disjoint intervals and \(l\) the number of time periods in which the event occurred.
This post focuses on inference using a two-layer feed-forward artificial neural network. Specifically, we describe the Bayesian learning neural networks implemented by Sharaf et al. (2020) on a neural network-based implementation of the Cox proportional hazard model described above. In training neural networks using conventional methodologies, the aim is to find a local minimum of the error function, an ideology that makes model selection rather difficult. Additionally, as described elsewhere by (Hastie et al. 2009), the training of neural networks presents such an issue as overfitting, a situation in which the model performs extremely well on the training data but fails to generalize well when resampling or on unobserved data. Overfitting has been linked to these models with such a great number of weights that they overfit at the global minimum of \(R\) (Lawrence, Giles, and Tsoi 1997; Hastie et al. 2009). According to Burden and Winkler (2008) :
“Bayesian regularized artificial neural networks (BRANNs) are more robust than standard backpropagation nets and can reduce or eliminate the need for lengthy cross-validation.”
In the Bayesian context, the idea is to use prior information about the distribution of the parameter of interest, update this information using the sample data, and obtain a posterior distribution for the parameter, \(\theta\). BLNN tries to present Hamiltonian energy, \(H(w, p)= U(w)+K(p)\) as a joint probability distribution of the neural network’s weights \(w\) and momentum \(\textbf{p}\). Given independence between \(w\) and \(\textbf{p}\), this joint probability is defined as:
\[ P(w, p) = (\frac{1}{z} exp^{-U(w)/z})(\frac{1}{T}exp^{-K(p)/T}) \] where \(U(w) =\) the negative log-likelihood of the posterior distribution defined as \(U(w)=-log[p(w)L(w|D)]\), \(L(w|D) =\) the likelihood function given the data, \(K(p)=\sum_{i=1}^{d}(P_i^2)/(2m_i)\) is the kinetic energy corresponding to the negative log-likelihood of the normal distribution with mean \(\mu\) and variance-covariance matrix with diagonal elements \(M=(m_1, \cdots, m_d)\), and \(Z\) and \(T\) are the normalizing constants.
The algorithm is summarized as below:
Details about the implementation of this method can be found here.
Sharaf et al. (2020) utilize a no-U-turn sampler (NUTS), an extension of Hamiltonian Monte-Carlo (HMC) that seeks to reduce the dependence on the number of step parameters used in HMC while retaining the efficiency in generating independent samples. The ANN is trained using both HMC and NUTS with dual averaging. The negative log-likelihood is replaced by network errors, and backpropagation is used to compute the gradients. Network errors and weights are assumed to be normally distributed with mean \(\mu\) but with a non-constant variance \(\sigma^2\). The variance of the prior is known by the precision parameter, \(\tau = \frac{1}{\sigma^2}\) aka the hyperparameters which are either assigned to fixed, fine-tuned values or re-estimated based on historical data. The list of hyperparameters allowed in the BLNN implementation is discussed elsewhere (Sharaf et al. 2020). The following section provides a sample application of BLNNs applied to real-world data.
In our analyses, we employ data consisting of 256 samples prepared using the Illumina Human-6 expression BeadChip Version 2.0 to identify DEGs and use bayesian neural networks to identify how these genes impact survival in patients with primary bladder cancer. The data relates to 165 primary bladder cancer samples and nine normal cells downloaded from the Gene Expression Omnibus (GEO) (Kim et al. 2010; Okutse and Nyongesa 2021). Before we begin this demonstration, we first load the Bioconductor packages to be used in the anallysis. Details on the installation of this packages can be found elsewhere.4
After loading the required packages, we download the data directly into our working environment and prepare it for analysis by performing a log transformation, creating a design matrix, and then fitting linear models for the identification of differentially expressed genes using empirical Bayes statistics for differential expression (eBayes) using the limma
package (Ritchie et al. 2015).
#load the data from the GEO
<- getGEO("GSE13507", GSEMatrix =TRUE, AnnotGPL=TRUE)
gset if (length(gset) > 1) idx <- grep("GPL6102", attr(gset, "names")) else idx <- 1
<- gset[[idx]]
gset
# make proper column names to match toptable
fvarLabels(gset) <- make.names(fvarLabels(gset))
# group names for all samples
<- paste0("0000000000XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
gsms "XXXXXXXXXXXXXXXXXX22222222222222222222222222222222",
"22222222222222222222222222222222222222222222222222",
"22222222222222222222222222222222222222222222222222",
"222222222222222222222222222222222XXXXXXXXXXXXXXXXX",
"XXXXXX")
<- c()
sml for (i in 1:nchar(gsms)) { sml[i] <- substr(gsms,i,i) }
# eliminate samples marked as "X"
<- which(sml != "X")
sel <- sml[sel]
sml <- gset[ ,sel]
gset
# log2 transform
exprs(gset) <- log2(exprs(gset))
# set up the data and proceed with analysis
<- paste("G", sml, sep="") # set group names
sml <- as.factor(sml)
fl $description <- fl
gset<- model.matrix(~ description + 0, gset)
design colnames(design) <- levels(fl)
<- lmFit(gset, design)
fit <- makeContrasts(G2-G0, levels=design)
cont.matrix <- contrasts.fit(fit, cont.matrix)
fit2 <- eBayes(fit2, 0.01)
fit2
<- topTable(fit2, adjust="fdr", sort.by="B", number=Inf)
tT #we can then save the list of 1000 DEGs
saveRDS(tT[1:1000,], "deg.RDS")
For our exploratory data analysis, we start by presenting a small sample of 10 up-and-down regulated genes in Table 5.1. Here we notice that CDC20 is one of the most significantly up-regulated genes between normal and primary bladder cancer samples (log-Fold Change = 0.472, Average Expression = 3.44, p<0.001).
#top regulated genes
<-tT[which(tT$logFC>0),][1:10,]
upregulated#upregulated
<-subset(upregulated, select=c("Gene.ID", "Gene.symbol","logFC","AveExpr","adj.P.Val","B"))
upregulated##getting top 10 downregulated genes
<-tT[which(tT$logFC<0),][1:10,]
downreg<-subset(downreg, select=c("Gene.ID","Gene.symbol","logFC","AveExpr","adj.P.Val","B"))
downreg<-rbind(upregulated, downreg)
degrownames(deg)<-NULL
%>% kable(format = "html",
degcaption = "Top 10 up and down regulated genes in primary bladder cancer. The first 10 rows represent upregulated genes",
col.names = c("Gene ID", "Gene Symbol", "logFC", "Average Expression", "Adjusted P-value", "B")) %>%
kable_styling(full_width = FALSE, latex_options = c("HOLD_position", "stripped", "scale_down"), position = "left")
Gene ID | Gene Symbol | logFC | Average Expression | Adjusted P-value | B |
---|---|---|---|---|---|
991 | CDC20 | 0.4715744 | 3.448492 | 0e+00 | 15.80790 |
5373 | PMM2 | 0.1710619 | 3.338349 | 0e+00 | 15.27138 |
128239 | IQGAP3 | 0.4125383 | 3.340340 | 0e+00 | 14.95455 |
0.2268436 | 3.233782 | 0e+00 | 13.79589 | ||
2932 | GSK3B | 0.1896154 | 3.377058 | 1e-07 | 12.82440 |
51537 | MTFP1 | 0.2168117 | 3.393753 | 2e-07 | 12.03233 |
51203 | NUSAP1 | 0.3719761 | 3.344562 | 3e-07 | 11.28112 |
7153 | TOP2A | 0.4546483 | 3.341922 | 4e-07 | 11.09297 |
0.0471829 | 2.866252 | 5e-07 | 10.82223 | ||
2810 | SFN | 0.3100817 | 3.493720 | 8e-07 | 10.18284 |
9890 | PLPPR4 | -0.3804426 | 2.891050 | 0e+00 | 50.66289 |
-0.0622618 | 2.808016 | 0e+00 | 42.95524 | ||
54360 | CYTL1 | -0.2690733 | 2.886378 | 0e+00 | 41.75689 |
11126 | CD160 | -0.1475023 | 2.853721 | 0e+00 | 38.93153 |
121601 | ANO4 | -0.1705039 | 2.827650 | 0e+00 | 36.38865 |
5126 | PCSK2 | -0.1268787 | 2.823028 | 0e+00 | 34.23361 |
9745 | ZNF536 | -0.2280635 | 2.898368 | 0e+00 | 34.04707 |
55022 | PID1 | -0.3029122 | 2.899585 | 0e+00 | 33.04130 |
-0.2550758 | 2.837756 | 0e+00 | 32.74756 | ||
10699 | CORIN | -0.1867473 | 2.868771 | 0e+00 | 32.25715 |
In Figure 5.1 we visualize the differentially expressed genes in this sample using a mean-difference plot highlighting the genes depending on whether they were significantly up or down regulated.
Additionally, we perform a basic functional enrichment analysis using Gene Ontology (GO) enrichment analysis and a small fraction of the DEGs (n = 300 genes) to identify the pathways where this statistically significantly differentially expressed genes are enriched. Here, we use a basic barplot to visualize the most common GO terms. We note in Figure 5.2 that these genes are involved in a number of biological functions including synapse organization as well as junction assembly and neural migration.
#barplots of the similar results showing only 20 enrichment categories
barplot(edo, showCategory=10, cex.names=3)
Because of the computationally intensive nature of training neural network models using Bayesian learning and simulation methodology such as the NUTS algorithm, we restrict analyses to only a small sample of the genetic information available in the dataset. We load a sample of n= 150 genetic information for analysis using BLNN.
#load the bayesian learning package
library(BLNN)
#load the saved train and test data files
<-readRDS("trainy.RDS")
trainy<-readRDS("trainxx.RDS") trainx
First, we build up an object of class BLNN
using the BLNN_Build()
declaring the number of input features in the input layer (ncol(trainx)
), the number of nodes in the output layer nout
, the number of nodes in the hidden layer, the activation function, the cost function, the output function, and the weights for the input features, bias terms, as well as hidden nodes. Details on all parameters accepted as inputs by the function can be found here. The hyper-parameters used here are arbitrarily selected and are re-estimated through an evidence procedure during the model training process.
The created BLNN
object is then trained either using evidence or not. Here, we first use the nnet
package as a baseline and compare this neural network model performance to the BLNN model with evidence and another neural network fitted to the data using the same architecture described, but without evidence. We use two chains in our learning process in order to develop an understanding of the parameter space while varying the starting values. The training function BLNN_Train()
function is supplied with random normal weights from the standard normal distribution. For each chain, however, we use a predefined method of generating weights.
#set the hyperparameters; change this to evaluate
#network weights
<-length(BLNN_GetWts(survObj))
n.par#number of desired chains
<-2
chains#initials weight values
<-lapply(1:chains, function(i) rnorm(n.par, 0, 1/sqrt(n.par)))
initials#using nnet as baseline
library(nnet)
<-nnet(trainx, trainy, size= 75, maxit = 1000, MaxNWts = 15000)
nnetBasesline<-predict(nnetBasesline)
nnetPredictions<-sum(abs(trainy-nnetPredictions)) #0.0736
NET.abserror<-nnetBasesline$value #7.97425e-05 NET.error
First, we train our feed-forward neural network model without using evidence by passing the function the BLNN
object, the dataframe with covariates, the response vector, \(y\), the number of iterations and use parallelization to leverage the power of multi-core computing. Additional arguments to this function can be found elsewhere.5
#variance for the moments
<-rep(1/2, n.par)
m1#training the model
<-scale(trainx)
trainx<-BLNN_Train(survObj, x=trainx, y=trainy,
survNUTSiter = 2000, thin = 10, warmup = 400,
init = initials, chains = chains,
parallel = TRUE, cores = 2,
algorithm = "NUTS", evidence = FALSE,
control = list(adapt_delta=0.99, momentum_mass=m1,
stepsize= 1, gamma=0.05, to=100, useDA=TRUE,
max_treedepth=20))
We check our values of Rhat to ensure that they are below one, and that we have larger values for effective sample size (minimum 50 each) before we can update each of our networks with the newly sampled parameters.
The next neural network model is built on this data but with the evidence option allowed to enable re-estimation of the model hyper-parameters using historical data or the matrix of covariates, \(x\) if the historical data is not available.
#bayesian learning with evidence used in re-estimating hyper-parameters
<-BLNN_Train(survObj, x=trainx, y=trainy,
survNUTS.eviter = 5000, thin = 10, warmup = 400,
init = initials, chains = chains,
parallel = TRUE, cores = 2,
algorithm = "NUTS", evidence = TRUE,
control = list(adapt_delta=0.99, momentum_mass=m1,
stepsize = 5, gamma = 0.05, to=100, useDA=TRUE,
max_treedepth=20))
After training, we can then update the trained neural network models with the values sampled from the learning process using the BLNN_Update()
function. If evidence was used in training the hyper-parameters will be updated as well.
#updating the parameters after the learning process
##update the no evidence neural network
<-BLNN_Update(survObj, survNUTS)
survNUTS##update the evidence neural network
<-BLNN_Update(survObj, survNUTS.ev) survNUTS.ev
After training the models, we can then update the trained neural network models with the values sampled from the learning process using the BLNN_Update()
function. If evidence was used in training the hyper-parameters will be updated as well.
#making predictions using these models
##predictions using no evidence
<-BLNN_Predict(survNUTS, trainx, trainy)
survpred##predictions using bayesian learning
<-BLNN_Predict(survNUTS.ev, trainx, trainy) survpred.ev
To evaluate our models, we create a tables of all evaluation metrics including all the models trained in this post, for illustrative purposes. Here we report the errors associated with the use of these models in survival modeling in Table 6.1.
#Model evaluations
##extract the errors in the classification
<-c(survpred$Errors$Total, survpred.ev$Errors$Total, nnetBasesline$value)
errors
#print out the model evaluations
<-data.frame(errors)
OutTabrownames(OutTab)<-c("NUTS (no evidence)", "NUTS (with evidence)", "NNET")
saveRDS(OutTab, "outTab.RDS")
write.csv(OutTab, "OutTab.csv")
%>% kable(format = "html", caption = "Artificial Neural Network Model Comparisons") %>% kable_styling(full_width = FALSE, latex_options = c("HOLD_position", "stripped", "scale_down"), position = "left") OutTab
Error | |
NUTS (no evidence) | 1.00E-04 |
NUTS (with evidence) | 7.97E-07 |
NNET | 7.97E-05 |
In this extension post, we explore Bayesian learning in the context of artificial neural networks for modeling survival data using micro array gene expression data. In particular, we have presented the main ideas behind why artificial neural networks using Bayesian learning present an alternative to modeling survival outcomes given the complexity presented by censored data in survival modeling. The post has presented a simplistic introduction to neural networks, focusing predominantly on the feed-forward neural network which is later implemented on the application problem. We later shift to the neural network approach to modeling survival data where we discuss the Cox PH-based ANNs and how this model has been improved to learn hyper-parameters from a posterior distribution derived from prior information (evidence) and the data.
For brevity, and purposes of demonstration, we present the errors associated with each models trained in this post and show that using evidence results in a reduction in the total error (see Table 6.1). However, a major issue that we encountered during the course of training these models was the high computational time it took to run the models, even with parallel processing. While the NUTs algorithm and the HMC algorithm employed by the BLNN package are promising, the sampling methods take an exorbitant amount of time to run, especially when the idea is to sample the values of the hyperparameters using historical data (or evidence). Sharaf et al. (2020) note that “Bayesian learning for ANN requires longer learning time when compared with conventional algorithms such as BFGS or gradient descent” something that holds true based on our implementation in this post.
In an extension to this post, we show an implementation of Bayesian learning for neural networks using a simulated dataset in Python is available. The method includes the estimation of the uncertatinty associated with survival estimates which is of great value in the context of predictive modeling. The post can be found here. After modeling and identifying the genes that impact survival in primary bladder cancer patients, the genes can then be annotated and encoded based on their enrichment level as either high, medium, or low and survival curves plotted to understand the effect these varied expression profiles have on survival.
All data used in this extension post is available online. The cleaned data analysis files used in this post, including the R Script generated from the .Rmd
file using knitr::purl()
can be found at our Github repository.6
Censoring refers to a concept in survival analysis where the actual time to event is unknown due to such reasons as the loss to follow up, withdrawals, or an exact unknown time of event. In right censoring, the event of interest occurs after the end of the experiment or study, whereas in left censoring, the event occurs before the onset of the study. Interval censoring is when the actual survival time is bounded between some interval↩︎
Common activation functions include the sigmoid activation function favored in classification problems, the rectified linear unit (ReLU) favored in linear regression problems, tanh, and leaky ReLU.↩︎
https://github.com/okutse/bayesian-networks↩︎, https://github.com/nlee100/bayesian-neuralnetworks↩︎