The clustra package was built to cluster medical trajectories (time series) related by an intervention time. This vignette is intended to reproduce the results of an associated paper “clustra: A multi-platform k-means clustering algorithm for analysis of longitudinal trajectories in large electronic health records data.”
To run the paper version, set chunk = figs
and
full_data = TRUE
in the first code chunk below. It takes
about 70 minutes with mc = 5
on a Quad-Core I7, and
requires Internet access to get the full data from GitHub. The default
for CRAN submission is limited to a sample of about 1/8 of the full data
set to satisfy CRAN limits on package data size and runs a smaller
subset of the code chunks (takes about 40 seconds).
Additional plots to explain the methodology and examine the data can
be produced by setting chunk = all
in the chunk below. This
will run over 2 hours with the same system configuration.
Below is the code that controls which code chunks are evaluated and how many cores to use.
knitr::opts_knit$set(
collapse = TRUE,
comment = "#>"
)
cran = list(LibsData = TRUE, plot_true = TRUE, plot_raw = TRUE, run5 = TRUE, run5_v2 = FALSE, iter_plots = TRUE, plot5_side_by_side = FALSE, fig1 = FALSE, comp_true = FALSE, run10 = FALSE, plot10 = FALSE, run2 = FALSE, rand2 = FALSE, silplot = FALSE, silplot2 = FALSE, rand_plot = FALSE, hclust40 = FALSE, SAS_R_data = FALSE, Rand_SAS_R = FALSE, true_group = FALSE)
figs = list(LibsData = TRUE, plot_true = FALSE, plot_raw = FALSE, run5 = TRUE, run5_v2 = TRUE, iter_plots = FALSE, plot5_side_by_side = FALSE, fig1 = TRUE, comp_true = TRUE, run10 = TRUE, plot10 = TRUE, run2 = TRUE, rand2 = TRUE, silplot = TRUE, silplot2 = FALSE, rand_plot = TRUE, hclust40 = FALSE, SAS_R_data = TRUE, Rand_SAS_R = TRUE, true_group = TRUE)
all = list(LibsData = TRUE, plot_true = TRUE, plot_raw = TRUE, run5 = TRUE, run5_v2 = TRUE, iter_plots = TRUE, plot5_side_by_side = TRUE, fig1 = TRUE, comp_true = TRUE, run10 = TRUE, plot10 = TRUE, run2 = TRUE, rand2 = TRUE, silplot = TRUE, silplot2 = TRUE, rand_plot = TRUE, hclust40 = TRUE, SAS_R_data = TRUE, Rand_SAS_R = TRUE, true_group = TRUE)
chunk = cran # chose which version of the vignette to run
full_data = FALSE # choose if full 80,000 id data or sample of 10,000 ids
mc = 2
# If running on a Unix or a Mac platform, set to number of available cores.
# This vignette will experience notable speedup up to about 10 cores.
# On Intel chips with hyperthreading, up to 2x available cores can be used.
if (.Platform$OS.type == "windows") mc = 1
data.table::setDTthreads(1) # manage data.table threads (increase with full_data)
We start by loading the clustra
package and getting the
data. We also set the output path for plots to tempdir()
,
which should be changed to a user directory if they are needed beyond
the Knit output.
library(clustra)
start_time = deltime()
print(paste("Number of cores being used: mc parameter =", mc))
## [1] "Number of cores being used: mc parameter = 2"
repo = "https://github.com/MVP-CHAMPION/"
repo_sas = paste0(repo, "clustra-SAS/raw/main/")
if(full_data) {
url = paste0(repo_sas, "bp_data/simulated_data_27June2023.csv.gz")
data = data.table::fread(url)
} else {
data = bp10k
}
data$true_group = data$group
head(data)
## id group time response true_group
## <int> <int> <int> <num> <int>
## 1: 18 3 -21 176.9072 3
## 2: 18 3 22 154.3234 3
## 3: 18 3 324 107.2318 3
## 4: 18 3 353 130.8007 3
## 5: 18 3 373 137.3623 3
## 6: 18 3 391 130.8484 3
plot_path = paste0(tempdir(), "/") # output dir for plots
cat("Resulting plots are saved in", plot_path, "directory\n")
## Resulting plots are saved in /tmp/RtmpLpw3Gx/ directory
Select a few random id
s and print their scatterplots
colored by true_group
. This only displays 9 individuals.
The data was generated based on five clusters. As the data used to
simulate this synthetic dataset is extremely noisy, it would be
difficult to ascertain any patterns by looking at a few subjects.
Now plot a 20,000 sample of the raw data. Use the true or simulated group assignment for coloring. As we can see, there is a lot of noise and it would be difficult to identify the clusters of trajectories by looking at the individual points, alone. A greater density of points is seen at time 0, where all trajectories have a data point.
Next, cluster the trajectories. Set k=5
, spline max
degrees of freedom to 30, and maximum iterations to 20. Initial cluster
assignments are the default “random”. We print out full information for
the iterations, including the elapsed time.
set.seed(12345)
cl5 = clustra(data, k = 5, maxdf = 30, conv = c(20, 0.5), mccores = mc, verbose = TRUE)
## Start counts 2001 2004 2016 2010 1969 Starts time: 0
## 1 (M-123)-1.05 (E-12345)-0.8 Changes: 7797 Counts: 1257 820 3529 1099 3295 Deviance: 46903674
## 2 (M-123)-0.46 (E-12345)-0.56 Changes: 2413 Counts: 1373 1774 2712 1450 2691 Deviance: 34739540
## 3 (M-123)-0.47 (E-12345)-0.51 Changes: 1405 Counts: 1384 2495 2156 1549 2416 Deviance: 33236379
## 4 (M-123)-0.32 (E-12345)-0.75 Changes: 938 Counts: 1407 2925 1814 1621 2233 Deviance: 32517385
## 5 (M-123)-0.47 (E-12345)-0.52 Changes: 645 Counts: 1466 3080 1633 1762 2059 Deviance: 32187814
## 6 (M-123)-0.33 (E-12345)-0.57 Changes: 741 Counts: 1536 3031 1549 2011 1873 Deviance: 32013822
## 7 (M-123)-0.32 (E-12345)-0.84 Changes: 929 Counts: 1624 2932 1486 2299 1659 Deviance: 31769678
## 8 (M-123)-0.46 (E-12345)-0.56 Changes: 898 Counts: 1728 2851 1439 2497 1485 Deviance: 31317772
## 9 (M-123)-0.33 (E-12345)-0.81 Changes: 517 Counts: 1857 2846 1401 2490 1406 Deviance: 30949472
## 10 (M-123)-0.47 (E-12345)-0.55 Changes: 220 Counts: 1923 2859 1386 2459 1373 Deviance: 30831685
## 11 (M-123)-0.32 (E-12345)-0.59 Changes: 82 Counts: 1952 2869 1382 2434 1363 Deviance: 30811168
## 12 (M-123)-0.33 (E-12345)-0.83 Changes: 37 Counts: 1971 2876 1380 2415 1358 Deviance: 30807849
## AIC:424.11 BIC:1627.1 edf:119.97 Total time: -13.36 converged
# What happens if we cut the iterations at 5 maximum?
set.seed(12345)
cl5.v2 = clustra(data, k = 5, maxdf = 30, conv = c(5, 0.5), mccores = mc)
Setting the option verbose = 2
produces plots of the
cluster spline centers at every iteration. We also pass the
ylim
parameter because the yaxis extent of the data is much
wider than the models.
set.seed(12345)
clustra(data, k = 5, maxdf = 30, conv = c(10, 0.5), mccores = mc, verbose = 2, ylim = c(110, 170))
## Start counts 2001 2004 2016 2010 1969 Starts time: -0.01
## 1 (M-123)-0.31 (E-12345)-0.74 Changes: 7797 Counts: 1257 820 3529 1099 3295 Deviance: 46903674
##
## 2 (M-123)-0.39 (E-12345)-1 Changes: 2413 Counts: 1373 1774 2712 1450 2691 Deviance: 34739540
##
## 3 (M-123)-0.55 (E-12345)-0.55 Changes: 1405 Counts: 1384 2495 2156 1549 2416 Deviance: 33236379
##
## 4 (M-123)-0.45 (E-12345)-0.75 Changes: 938 Counts: 1407 2925 1814 1621 2233 Deviance: 32517385
##
## 5 (M-123)-0.59 (E-12345)-0.73 Changes: 645 Counts: 1466 3080 1633 1762 2059 Deviance: 32187814
##
## 6 (M-123)-0.45 (E-12345)-1.12 Changes: 741 Counts: 1536 3031 1549 2011 1873 Deviance: 32013822
##
## 7 (M-123)-0.44 (E-12345)-0.53 Changes: 929 Counts: 1624 2932 1486 2299 1659 Deviance: 31769678
##
## 8 (M-123)-0.46 (E-12345)-1 Changes: 898 Counts: 1728 2851 1439 2497 1485 Deviance: 31317772
##
## 9 (M-123)-0.44 (E-12345)-0.52 Changes: 517 Counts: 1857 2846 1401 2490 1406 Deviance: 30949472
##
## 10 (M-123)-0.45 (E-12345)-0.84 Changes: 220 Counts: 1923 2859 1386 2459 1373 Deviance: 30831685
##
## AIC:425.11 BIC:1632.81 edf:120.44 Total time: -12.71 max-iter
Test producing two side-by-side plots.
plot_smooths(data, fits = cl5$tps, max.data = 30000)
m = matrix(c(0.1, .95, 0.1, .95), nrow = 1)
sc = split.screen(m)
screen(1)
sc = split.screen(c(1, 2))
screen(2)
plot_smooths(data, fits = cl5$tps, max.data = 0, ylim = c(110, 180))
screen(3)
plot_smooths(data, fits = cl5.v2$tps, max.data = 0, ylim = c(110, 180))
close.screen(all.screens = TRUE)
This is Figure 1 in the paper.
png(paste0(plot_path, "R_cl5_plot.png"), width = 480, height = 360)
plot_smooths(data, fits = cl5$tps, max.data = 0, ylim = c(110, 180))
dev.off()
knitr::include_graphics(paste0(plot_path, "R_cl5_plot.png"))
Compute the Rand index for comparing with true_groups and with each other.
MixSim::RandIndex(cl5$data_group, data[, true_group])$AR
MixSim::RandIndex(cl5.v2$data_group, data[, true_group])$AR
MixSim::RandIndex(cl5$data_group, cl5.v2$data_group)$AR
Next, we try requesting too many clusters in the noisy data.
set.seed(12345)
t = deltime()
cl10 = clustra(data, k = 10, maxdf = 30, conv = c(50, 0.5), mccores = mc)
deltimeT(t, paste("run time on", mc, "cores "))
Next, plot the resulting fit. This is Figure 2 in the paper.
png(paste0(plot_path, "R_cl10_plot.png"), width = 480, height = 360)
plot_smooths(data, cl10$tps, max.data = 0, ylim = c(110, 180))
dev.off()
knitr::include_graphics(paste0(plot_path, "R_cl10_plot.png"))
MixSim::RandIndex(cl10$data_group, data[, true_group])$AR
Next, we look at what the clusters would look like with 2 groups and plot the result. This is figure 3 in the paper.
set.seed(12345)
t = deltime()
cl2 = clustra(data, k = 2, maxdf = 30, conv = c(20,0.5), mccores=mc)
deltimeT(t, paste(" run time on", mc, "cores "))
cat(paste("Number of iterations completed:",cl2$iterations))
cat(paste("Number of individuals changing groups in last iteration:",cl2$changes))
png(paste0(plot_path, "R_cl2_plot.png"), width = 480, height = 360)
plot_smooths(data, cl2$tps, max.data = 0, ylim = c(110, 180))
dev.off()
knitr::include_graphics(paste0(plot_path, "R_cl2_plot.png"))
MixSim::RandIndex(cl2$data_group, data[, true_group])$AR
Compute the Adjusted Rand Index between the various results.
MixSim::RandIndex(cl5$data_group,cl2$data_group)$AR
MixSim::RandIndex(cl5$data_group,cl10$data_group)$AR
MixSim::RandIndex(cl10$data_group,cl2$data_group)$AR
Average silhouette value is a way to select the number of clusters
and a silhouette plot provides a way for a deeper evaluation (Rouseeuw
1986). A true silhouette requires distances between individual
trajectories, which is not possible here due to unequal trajectory
sampling without fitting a separate model for each id. As a proxy for
distance between points, we use trajectory distances to cluster mean
spline trajectories in the clustra_sil()
function. The
structure returned from the clustra()
function contains the
matrix loss
, which has all the information needed to
construct these proxy silhouette plots. The function
clustra_sil()
performs clustering for a number of
k
values and outputs information for the silhouette plot
that is displayed next.
These plots are Figure 4, 5, and 6 in the paper.
t = deltime()
sil = clustra_sil(cl5, mccores = mc, conv=c(20,0.5))
png(paste0(plot_path, "R_cl5_silplot.png"), width = 480, height = 360)
lapply(sil, plot_silhouette)
dev.off()
set.seed(12345)
sil2 = clustra_sil(cl2, mccores = mc, conv=c(20,0.5))
png(paste0(plot_path, "R_cl2_silplot.png"), width = 480, height = 360)
lapply(sil2, plot_silhouette)
dev.off()
set.seed(12345)
sil10 = clustra_sil(cl10, mccores = mc, conv=c(20,0.5))
png(paste0(plot_path, "R_cl10_silplot.png"), width = 480, height = 360)
lapply(sil10, plot_silhouette)
dev.off()
knitr::include_graphics(paste0(plot_path, "R_cl5_silplot.png"))
knitr::include_graphics(paste0(plot_path, "R_cl2_silplot.png"))
knitr::include_graphics(paste0(plot_path, "R_cl10_silplot.png"))
deltimeT(t, paste(" silhouettes on", mc, "cores "))
Here we produce silhouettes directly from the data, which runs clustra internally.
# Example running from the data step
t = deltime()
set.seed(12345)
sil = clustra_sil(data, k = c(2, 5, 10), mccores = mc, conv = c(50, 0.5))
lapply(sil, plot_silhouette)
deltimeT(t, paste("Silhouettes and clustra on", mc, "cores "))
Another way to select the number of clusters is the Rand Index
comparing different random starts and different numbers of clusters.
When we replicate clustering with different random seeds, the
“replicability” is an indicator of how stable the results are for a
given k, the number of clusters. For this demonstration, we look at
k = c(2, 3, 4, 5, 6, 7, 8, 9, 10)
, and 10 replicates for
each k
. To run this chunk with Knit, set
rand = TRUE
at the top of this vignette (line 8).
t = deltime()
set.seed(12345)
ran = clustra_rand(data, k = seq(2, 10, 1), starts = "random", mccores = mc,
replicates = 10, conv=c(40,0.5), verbose = TRUE)
deltimeT(t, paste(" clustra_rand on", mc, "cores "))
rand_plot(ran, name = paste0(plot_path, "R_randplot.pdf")) # save pdf version
rand_plot(ran) # render png version
The plot shows Adjusted Rand Index similarity level between all pairs
of 90 clusterings (10 random starts for each of 2 to 10 clusters). The
ten random starts agree the most for k = 2
and
k = 5
.
From the deviance results shown during iterations, we also see that
all of the k = 3
clusters are near the best deviance
attainable even with k = 4
. Among the k = 4
results, several converged to only three clusters that agree with
k = 3
results.
Another possible evaluation of the number of clusters is to first ask
clustra for a large number of clusters, evaluate the cluster centers on
a common set of time points, and feed the resulting matrix to a
hierarchical clustering function. Below, we ask for 40 clusters and the
hclust()
function clusters the 40 resulting cluster means,
each evaluated on 100 time points.
gpred = function(tps, newdata) {
as.numeric(mgcv::predict.bam(tps, newdata, type = "response",
newdata.guaranteed = TRUE))
}
set.seed(12345)
t = deltime()
cl40 = clustra(data, k = 40, maxdf = 30, conv = c(100,0.5), mccores = mc, verbose = TRUE)
timep = data.frame(time = seq(min(data$time), max(data$time), length.out = 100))
resp = do.call(rbind, lapply(cl40$tps, gpred, newdata = timep))
png(paste0(plot_path, "R_hclust_plot.png"), width = 720, height = 480)
plot(hclust(dist(resp)))
dev.off()
knitr::include_graphics(paste0(plot_path, "R_hclust_plot.png"))
deltimeT(t, paste("clustra k = 40, conv = c(100, 0.5) on", mc, "cores + hclust + dist "))
The cluster dendrogram indicates there are five clusters. Making the cut at a height of roughly 140 groups the 40 clusters into only five.
Look at the data output from the SAS clustering analysis. Since SAS
and R produce different cluster labels, we need to relabel and replot
the SAS data output, which is on the https://github.com/MVP-CHAMPION/clustra-SAS repository.
The mapping algorithm in sas_r_cplot()
below matches SAS
and R pairs by distance, starting with shortest distance pair, and then
plots the pairs with matching colors.
sas.file = function(file, rp = repo_sas) paste0(rp, "sas_results/", file)
sas_r_cplot = function(file, clustra_out) {
sas = cbind(haven::read_sas(sas.file(file)), source = 1)
r = data.frame(time = min(sas$time):max(sas$time))
n = nrow(r)
k = nrow(sas)/n
## below assumes same time in each group and groups are contiguous.
rpred = do.call(c, parallel::mclapply(clustra_out$tps, gpred, newdata = r,
mc.cores = mc))
## compute mapping from SAS to R using manhattan distance over time
rmap = vector("integer", k)
m = as.matrix(
dist(rbind(t(matrix(sas$pred, n)),
t(matrix(rpred, nrow = n))), "manhattan"))[1:k, (k + 1):(2*k)]
for(i in 1:k) { # assign starting with closest pair
ind = arrayInd(which.min(m), dim(m))
rmap[ind[1]] = ind[2]
m[ind[1], ] = Inf
m[, ind[2]] = Inf
}
r = cbind(r, group = sas$group, pred = rpred, source = 2)
plot(pred ~ time, data = sas, type = "n")
for(i in 1:k) {
sr = (1 + (i - 1)*n):(i*n)
rr = (1 + (rmap[i] - 1)*n):(rmap[i]*n)
lines(sas[sr, ]$time, sas[sr, ]$pred, lty = sas[sr, ]$source, col = i)
lines(r[rr, ]$time, r[rr, ]$pred, lty = r[rr, ]$source, col = i)
}
rmap # output the mapping from SAS clusters to R: R = rmap[SAS]
}
png(paste0(plot_path, "R_SAS_cl2_compare.png"), width = 480, height = 360)
map2 = sas_r_cplot("graphout_cl2.sas7bdat", cl2)
dev.off()
png(paste0(plot_path, "R_SAS_cl5_compare.png"), width = 480, height = 360)
map5 = sas_r_cplot("graphout_cl5.sas7bdat", cl5)
dev.off()
png(paste0(plot_path, "R_SAS_cl10_compare.png"), width = 480, height = 360)
map10 = sas_r_cplot("graphout_cl10.sas7bdat", cl10)
dev.off()
knitr::include_graphics(paste0(plot_path, "R_SAS_cl2_compare.png"))
knitr::include_graphics(paste0(plot_path, "R_SAS_cl5_compare.png"))
knitr::include_graphics(paste0(plot_path, "R_SAS_cl10_compare.png"))
The k=2 and k=5 cluster results are almost identical between SAS (dashed lines) and R (solid lines). The 10 cluster result appears to keep the bottom cluster (yellow) from the 5-cluster result and is nearly identical between SAS and R. The green clusters are also somewhat similar. The remaining 8 clusters seem to partition the data in a very unstable fashion that produces different results for different seeds.
Next, check the Rand index comparing R and SAS results:
MixSim::RandIndex(haven::read_sas(sas.file("rms_cl2.sas7bdat"))$newgroup, cl2$group)
MixSim::RandIndex(haven::read_sas(sas.file("rms_cl5.sas7bdat"))$newgroup, cl5$group)
MixSim::RandIndex(haven::read_sas(sas.file("rms_cl10.sas7bdat"))$newgroup, cl10$group)
Nearly perfect! The high adjusted rand indices for 2 and 5 clusters indicate that there is great agreement in the classification of clustering between the SAS and R results. This is also evident by the lines that are nearly identical (colors and labels have been remapped to align, while the dashed line indicates results from SAS and solid lines from R). The pattern does not, however, hold for the results from 10 clusters. We see that some of the patterns are consistent for the bounded clusters, however, the middle clusters that fluctuate the most are inconsistent between SAS and R. This is to be expected, as the R algorithm nears convergence to the minimum percent change, but there continues to be some instability in the clustering assignment. Also, we would expect that the rand index would not be quite as strong, since the true underlying cluster number was 5. This is also consistent with the full Rand Index plot results, above.
Now compare with true_group used to generate the data:
true_group = data[, mean(true_group),by=id]
MixSim::RandIndex(haven::read_sas(sas.file("rms_cl5.sas7bdat"))$newgroup, true_group$V1)
MixSim::RandIndex(true_group$V1, cl5$group)
A 0.58 result is possibly due to the wiggle observed in the trajectories.
## clustra vignette run time -27.75 seconds