Spark ML hyperparameter tuning
Recently, while going through the Udemy course on Spark to refresh my knowledge about this awesome tool, I have noticed that the examples given in the course are not always producing sensible results. Although this fact was acknowledged by the instructor, there was no digging into to the root cause of the problem, which of course, I couldn't leave un-investigated :).
Chris Bloom @ Flickr.com
Building a movie recommender system
The example I'm going to talk about in this blog post is quite simple and it is about building a movie recommender system with the given data set of 100k movies. The data set consists of 2 files: one being an association of movie id with movie name and the other one with a list of user id, movie id, and a rating, spark schemas looks like the following:
val moviesNamesSchema = new StructType() .add("movieID", IntegerType, nullable = true) .add("movieTitle", StringType, nullable = true) // Create schema when reading u.data val moviesSchema = new StructType() .add("userID", IntegerType, nullable = true) .add("movieID", IntegerType, nullable = true) .add("rating", IntegerType, nullable = true) .add("timestamp", LongType, nullable = true)
The data is read into a data frame and then passed onto one of the Spark built-in recommendation engines called ALS — Alternating Least Squares (ALS) matrix factorization. The
names data frame is simply converted to an array and used later on to display the results while evaluating the model.
The ALS recommendation engine was constructed initially with the following parameters:
val als = new ALS() .setMaxIter(5) .setRegParam(0.01) .setUserCol("userID") .setItemCol("movieID") .setRatingCol("rating") val model = als.fit(ratings)
The training data was modified by adding a new user with some movies and ratings for them so we could later on get movie recommendations for this user and check if they look good.
I have added 3 entries for the user with id:
0 and movies:
50, Star Wars (1977); 172, The Empire Strikes Back (1980); 133, Gone With the Wind (1939)
The first 2 movies have a rating of
4 and the last one is rated
Output with provided hyperparameters:
(Angel Baby (1995),4.753925) (Microcosmos: Le peuple de l'herbe (1996),4.681522) (Boys, Les (1997),4.5429416) (Faust (1994),4.189102) (Friday (1995),4.162115) (Unzipped (1995),4.1620374) (Last Dance (1996),4.03131) (Star Wars (1977),4.009129) (Fresh (1994),3.9773624) (Empire Strikes Back, The (1980),3.969892)
As we can easily notice, the output has little in common with what we would expect. We ranked higher the movies with Sci-Fi genre and ended up with recommendations for completely different movies like Angel Baby.
Assuming the data makes sense and the people rating Star Wars rate other similar movies higher, we should be able to do better while recommending movies to others.
Let's use Spark ML’s built-in mechanism
Looks like we need a better model! Instead of tuning the hyperparameters by hand and building the model every time we need to check the output, we can use Spark ML’s built-in mechanism to do that for us automatically. We will use very popular
CrossValidation as model selection tool. To do that, we need to define 3 basic ingredients first:
- Estimator — our algorithm we want to tune (in our case ALS)
- ParamMap(s) — the list of parameters we want to evaluate to find the best model
- Evaluator — the algorithm which will run our evaluator with different combinations from defined param maps and will give us the best model.
To get the best model, we need to remove the hyperparameters we had defined and let our evaluator take care of finding the best combination for us.
Having the before-mentioned 3 components defined, using the CrossValidator is pretty simple:
val als = new ALS() .setNonnegative(true) .setImplicitPrefs(false) .setColdStartStrategy("drop") .setUserCol("userID") .setItemCol("movieID") .setRatingCol("rating") val paramGrid = new ParamGridBuilder() .addGrid(als.regParam, Array(.01, .05, .1, .15)) .addGrid(als.rank, Array(10, 50, 100, 150)).build() val evaluator = new RegressionEvaluator() .setMetricName("rmse") .setLabelCol("rating") .setPredictionCol("prediction") val cv = new CrossValidator() .setEstimator(als) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid) .setNumFolds(5) val model = cv.fit(ratings)
As you can see, instead of using fit straight on our defined ALS algorithm, we use
CrossValidator instance to get the model.
We can simply convert returned object best model available back into ALS model type with:
val bestModel = model.bestModel.asInstanceOf[ALSModel]
After acquiring the proper model, we proceed as usual when getting the recommendations for our user.
After our simple tuning, we get:
Top 10 recommendations for user ID 0: (Star Wars (1977),3.9488063) (Empire Strikes Back, The (1980),3.8318543) (Return of the Jedi (1983),3.6589277) (Raiders of the Lost Ark (1981),3.6525774) (Wrong Trousers, The (1993),3.426427) (Princess Bride, The (1987),3.4248486) (Casablanca (1942),3.3963377) (Close Shave, A (1995),3.371017) (Shawshank Redemption, The (1994),3.3682468) (Usual Suspects, The (1995),3.3641634)
which looks much better than the initial recommendations.
On par with
CrossValidator, Spark ML toolbox offers you also
TrainValidationSplit that works similarly but it only evaluates each combination of parameters once and according to Spark documentation itself can be sometimes better when working with small datasets.
val trainValidationSplit = new TrainValidationSplit() .setEstimator(als) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid) // 80% of the data will be used for training and the remaining 20% for validation. .setTrainRatio(0.8) // Evaluate up to 2 parameter settings in parallel .setParallelism(2) val model = trainValidationSplit.fit(ratings)
After introducing this change and running the program again, we end up with exactly the same results as with
CrossValidator but much faster.
Apache Spark and Scala courses I recommend
If you want to learn more about Spark in general, you can try the Udemy course yourself (available at: Apache Spark with Scala — Hands On with Big Data!) and my favourite place for any kind of courses: Rock The Jvm with a lot of great content.