--- title: "clustra: clustering trajectories" author: "George Ostrouchov, Hanna Gerlovin, and David Gagnon" date: "`r Sys.Date()`" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{clustra: clustering trajectories} %\VignetteEngine{knitr::rmarkdown} \VignetteEncoding[UTF-8]{inputenc} --- ```{r, setup, echo = FALSE} knitr::opts_knit$set( collapse = TRUE, comment = "#>" ) start_knit = proc.time() data.table::setDTthreads(1) # manage data.table threads ``` The **clustra** package was built to **clus**ter longitudinal **tra**jectories (time series) on a common time axis. For example, a number of individuals are started on a specific drug regimen and their blood pressure data is collected for a varying amount of time before and after the start of the medication. Observations can be unequally spaced, unequal length, and only partially overlapping. Clustering proceeds by an EM algorithm that iterates switching between fitting a thin plate spline (TPS) to combined responses within each cluster (M-step) and reassigning cluster membership based on nearest fitted bspline (E-step). The fitting is done with the **mgcv** package function `bam`, which scales well to very large data sets. For this vignette, we begin by generating a data set with the `gen_traj_data()` function. Given its parameters, the function generates groups of `id`s (their size given by the vector `n_id`) and for each `id`, a random number of observations based on the `Poisson`($\lambda =$ `m_obs`) distribution plus 3. The 3 additional observations are to guarantee one before intervention at time `start`, one at the intervention time 0, and one after the intervention at time `end`. The `start` time is `Uniform`(`s_range`) and the `end` time is `Uniform`(`e_range`). The remaining times are at times Uniform(`start`, `end`). The time units are arbitrary and depend on your application. Up to 3 groups are implemented so far, with Sin, Sigmoid, and constant forms. Code below generates the data and looks at a few observations of the generated data. The `mc` variable sets core use and will be assigned to `mccores` parameter through the rest of the vignette. By default, 1 core is assigned. Parallel sections are implemented with `parallel::mclappy()`, so on unix and Mac platforms it is recommended to use the full number of cores available for faster performance. Default initialization of the clusters is set to `"random"` (see `clustra` help file for the other option `distant`). We also set *seed* for reproducibility. ```{r gendata} library(clustra) mc = 2 # If running on a unix or a Mac platform, increase up to # cores if (.Platform$OS.type == "windows") mc = 1 init = "random" set.seed(12345) data = gen_traj_data(n_id = c(500, 1000, 1500, 2000), types = c(2, 1, 3, 2), intercepts = c(70, 130, 120, 130), m_obs = 25, s_range = c(-365, -14), e_range = c(0.5*365, 2*365), noise = c(0, 15)) head(data) ``` The histogram shows the distribution of generated lengths. The short ones will be the most difficult to cluster correctly. Select a few random `id`s and show their scatterplots. ```{r plotdata, fig.width = 7, fig.height = 9} plot_sample(data[id %in% sample(unique(data[, id]), 9)], group = "true_group") ``` Next, cluster the trajectories. Set `k=4` (we will consider selection of `k` later), spline max degrees of freedom to 30, and set `conv` maximum iterations to 10 and convergence when 0 changes occur. `mccores` sets the number of cores to use in various components of the code. Note that this does not work on Windows operating systems, where it should be set to 1 (the default). In the code that follows, we use verbose output to get information from each iteration. ```{r clustra4} set.seed(12345) cl4 = clustra(data, k = 4, maxdf = 10, conv = c(10, 0), mccores = mc, verbose = TRUE) ``` Each iteration displays components of the M-step and the E-step followed by its duration in seconds, the number of classification changes in the E-step, the current counts in each cluster, and the deviance. Next, plot the raw data (sample if more than 10,000 points) with resulting spline fit, colored by the cluster value. ```{r smooths, fig.width = 7, fig.height = 7} plot_smooths(data, group = NULL) plot_smooths(data, cl4$tps) ``` The Rand index for comparing with true_groups is ```{r rand4} MixSim::RandIndex(cl4$data_group, data[, true_group]) ``` The AR stands for Adjusted Rand index, which adjusts for random agreement. A .827 value comparing with true groups used to generate the data is quite good, considering that the short series are easily misclassified and that k-means often find a local minimum. Let's double the error standard deviation in data generation and repeat... ```{r gen2_clustra, fig.width = 7, fig.height = 9} set.seed(12345) data2 = gen_traj_data(n_id = c(500, 1000, 1500, 2000), types = c(2, 1, 3, 2), intercepts = c(70, 130, 120, 130), m_obs = 25, s_range = c(-365, -14), e_range = c(60, 2*365), noise = c(0, 30)) plot_sample(data2[id %in% sample(unique(data2[, id]), 9)], group = "true_group") cl4a = clustra(data2, k = 4, maxdf = 10, conv = c(10, 0), mccores = mc, verbose = TRUE) MixSim::RandIndex(cl4a$data_group, data2[, true_group]) ``` This time the AR is 0.815 result is less but still respectable. It recovers the trajectory means quite well as we see the following plots. The first without cluster colors (obtained by setting `group = NULL`), showing the mass of points and the second with cluster means and cluster colors. ```{r smooths2, fig.width = 7, fig.height = 7} plot_smooths(data2, group = NULL) plot_smooths(data2, cl4a$tps) ``` Average silhouette value is a way to select the number of clusters and a silhouette plot provides a way for a deeper evaluation (Rouseeuw 1986). As silhouette requires distances between individual subjects, this is not possible due to unequal subject sampling without fitting a separate trajectory model for each subject id. As a proxy, we use subject distances to cluster mean trajectories in the `clustra_sil()` function. The structure returned from the `clustra()` function contains the matrix `loss`, which has all the information needed to construct these proxy silhouette plots. The function `clustra_sil()` performs clustering for a number of `k` values and outputs information for the silhouette plot that is displayed next. We relax the convergence criterion in `conv` to 1 % of changes (instead of 0 used earlier) for faster processing. We use the first data set with `noise = c(0, 15)`. ```{r sil, fig.width = 7} set.seed(12345) sil = clustra_sil(data, kv = c(2, 3, 4, 5), mccores = mc, maxdf = 10, conv = c(7, 1), verbose = TRUE) lapply(sil, plot_silhouette) ``` The plots for 3 or 4 clusters give the best Average Width. Usually we take the larger one, 4, which is supported here also by the minimum AIC and BIC scores. We also note that the final deviance drops substantially a 4 clusters and barely moves when 5 clusters are fit, further corroborating that `k = 4`. If we don't want to recluster the data again, we can directly reuse a previous clustra run and produce a silhouette plot for it, as we now do for the double variance error data clustra run above results in `cl4`. ```{r sil2, fig.width = 7} sil = clustra_sil(cl4) lapply(sil, plot_silhouette) ``` Another way to select the number of clusters is the Rand Index comparing different random starts and different numbers of clusters. When we replicate clustering with different random seeds, the "replicability" is an indicator of how stable the results are for a given k, the number of clusters. For this demonstration, we look at `k = c(2, 3, 4)`, and 10 replicates for each `k`. To run this long-running chunk, set `eval = TRUE`. ```{r rand_plot, fig.width = 7, fig.height=7, eval = FALSE} set.seed(12345) ran = clustra_rand(data, k = c(2, 3, 4, 5), mccores = mc, replicates = 10, maxdf = 10, conv = c(7, 1), verbose = TRUE) rand_plot(ran) ``` The plot shows AR similarity level between all pairs of 40 clusterings (10 random starts for each of 2, 3, 4, and 5 clusters). It is difficult to distinguish between the 3, 4, and 5 results but the 4 result has the largest block of complete agreement. Here, we cat try running clustra with the "distant" starts option. The sequential selection of above-medial length series that are most distant from previous selections introduces less initial variability. To run this long-running chunk, set `eval = TRUE`. ```{r rand_plot_d, fig.width = 7, fig.height=7, eval = FALSE} set.seed(12345) ran = clustra_rand(data, k = c(2, 3, 4, 5), starts = "distant", mccores = mc, replicates = 10, maxdf = 10, conv = c(7, 1), verbose = TRUE) rand_plot(ran) ``` In this case, `k = 4` comes with complete agreement between the 10 starts. Another possible evaluation of the number of clusters is to first ask clustra for a large number of clusters, evaluate the cluster centers on a common set of time points, and feed the resulting matrix to a hierarchical clustering function. Below, we ask for 40 clusters on the `data2` data set but actually get back only 17 because several become empty or too small for `maxdf`. Below, the `hclust()` function clusters the 17 resulting cluster means, each evaluated on 100 time points. ```{r hclust, fig.width = 7, fig.height = 7} set.seed(12345) cl30 = clustra(data, k = 40, maxdf = 10, conv = c(10, 0), mccores = mc, verbose = TRUE) gpred = function(tps, newdata) as.numeric(mgcv::predict.bam(tps, newdata, type = "response", newdata.guaranteed = TRUE)) resp = do.call(rbind, lapply(cl30$tps, gpred, newdata = data.frame( time = seq(min(data2$time), max(data2$time), length.out = 100)))) plot(hclust(dist(resp))) ``` The dendrogram clearly indicates 4 clusters. When we use `starts = "distant"`, the selected distant starts are more likely to persist into a nearby local minimum, retaining the full 40 specified clusters. ```{r hclust_d, fig.width = 7, fig.height = 7} set.seed(12345) cl30 = clustra(data, k = 40, starts = "distant", maxdf = 10, conv = c(10, 0), mccores = mc, verbose = TRUE) gpred = function(tps, newdata) as.numeric(mgcv::predict.bam(tps, newdata, type = "response", newdata.guaranteed = TRUE)) resp = do.call(rbind, lapply(cl30$tps, gpred, newdata = data.frame( time = seq(min(data2$time), max(data2$time), length.out = 100)))) plot(hclust(dist(resp))) ``` Here again (if we consider 24 as an outlier) we get 4 clusters. ```{r finish} cat("clustra vignette run time:\n") print(proc.time() - start_knit) ```