Example: Cluster Analysis Using Spark.ML to predict cluster membership with the iris dataset

Slightly adapted from source: https://spark.rstudio.com/

Load Packages

library(tidyverse)

Installation

#install.packages("sparklyr")

# Upgrade to latest version
#devtools::install_github("rstudio/sparklyr")

Connecting to Spark

library(sparklyr)
## 
## Attaching package: 'sparklyr'
## The following object is masked from 'package:purrr':
## 
##     invoke
sc <- spark_connect(master = "local")

Copy data from R into Spark Cluster

iris_tbl <- copy_to(sc, iris, "iris", overwrite = TRUE)
iris_tbl
## # Source: spark<iris> [?? x 5]
##    Sepal_Length Sepal_Width Petal_Length Petal_Width Species
##           <dbl>       <dbl>        <dbl>       <dbl> <chr>  
##  1          5.1         3.5          1.4         0.2 setosa 
##  2          4.9         3            1.4         0.2 setosa 
##  3          4.7         3.2          1.3         0.2 setosa 
##  4          4.6         3.1          1.5         0.2 setosa 
##  5          5           3.6          1.4         0.2 setosa 
##  6          5.4         3.9          1.7         0.4 setosa 
##  7          4.6         3.4          1.4         0.3 setosa 
##  8          5           3.4          1.5         0.2 setosa 
##  9          4.4         2.9          1.4         0.2 setosa 
## 10          4.9         3.1          1.5         0.1 setosa 
## # ... with more rows

Inspect Data Set

glimpse(iris_tbl)
## Observations: ??
## Variables: 5
## Database: spark_connection
## $ Sepal_Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9,...
## $ Sepal_Width  <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1,...
## $ Petal_Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5,...
## $ Petal_Width  <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1,...
## $ Species      <chr> "setosa", "setosa", "setosa", "setosa", "setosa",...

Simple Filtering Example Using Dplyr

iris_tbl %>% filter(Sepal_Length < 5.0)
## # Source: spark<?> [?? x 5]
##    Sepal_Length Sepal_Width Petal_Length Petal_Width Species
##           <dbl>       <dbl>        <dbl>       <dbl> <chr>  
##  1          4.9         3            1.4         0.2 setosa 
##  2          4.7         3.2          1.3         0.2 setosa 
##  3          4.6         3.1          1.5         0.2 setosa 
##  4          4.6         3.4          1.4         0.3 setosa 
##  5          4.4         2.9          1.4         0.2 setosa 
##  6          4.9         3.1          1.5         0.1 setosa 
##  7          4.8         3.4          1.6         0.2 setosa 
##  8          4.8         3            1.4         0.1 setosa 
##  9          4.3         3            1.1         0.1 setosa 
## 10          4.6         3.6          1           0.2 setosa 
## # ... with more rows

Use Spark K-Means Clustering

kmeans_model <- iris_tbl %>%
  select(Petal_Width, Petal_Length) %>%
  ml_kmeans(formula= ~ Petal_Width + Petal_Length, k = 3)

# print our model fit
kmeans_model
## K-means clustering with 3 clusters
## 
## Cluster centers:
##   Petal_Width Petal_Length
## 1    1.359259     4.292593
## 2    0.246000     1.462000
## 3    2.047826     5.626087
## 
## Within Set Sum of Squared Errors =  31.41289
# predict the associated class
predicted <- ml_predict(kmeans_model, iris_tbl) %>%
  collect

table(predicted$Species, predicted$prediction)
##             
##               0  1  2
##   setosa      0 50  0
##   versicolor 48  0  2
##   virginica   6  0 44
# plot cluster membership
ml_predict(kmeans_model) %>%
  collect() %>%
  ggplot(aes(Petal_Length, Petal_Width)) +
  geom_point(aes(Petal_Width, Petal_Length, col = factor(prediction + 1)),
             size = 2, alpha = 0.5) + 
  geom_point(data = kmeans_model$centers, aes(Petal_Width, Petal_Length),
             col = scales::muted(c("red", "green", "blue")),
             pch = 'x', size = 12) +
  scale_color_discrete(name = "Predicted Cluster",
                       labels = paste("Cluster", 1:3)) +
  labs(
    x = "Petal Length",
    y = "Petal Width",
    title = "K-Means Clustering",
    subtitle = "Use Spark.ML to predict cluster membership with the iris dataset."
  )