Machine Learning with .NET: Model Serialization

This is part five in a series on Machine Learning with .NET.

So far in this series, we’ve looked at training a model at runtime. Even in the prior post, where we looked at predictions, we first had to train a model. With a Naive Bayes model against a small data set, that’s not an onerous task—it’s maybe a second or two the first time we instantiate our objects and we can leave the prediction engine around for the lifetime of our application, so we are able to amortize the cost fairly effectively.

Suppose, however, that we have a mighty oak of a neural network which we have trained over the past two months. It finally completed training and we now have a set of weights. Obviously, we don’t want to retrain this if we can avoid it, so we need to find a way to serialize our model and save it to disk somewhere so that we can re-use it later.

Saving a Model to Disk

Saving a model to disk is quite easy. Here’s the method I created to do just that:

public void SaveModel(MLContext mlContext, ITransformer model, string modelPath)
{
	using (var stream = File.Create(modelPath))
	{
		mlContext.Model.Save(model, null, stream);
	}
}

We create a file stream and use the built-in Save() method to write our data to that stream. Here is an example of me calling that method:

string modelPath = "C:\\Temp\\BillsModel.mdl";
bmt.SaveModel(mlContext, model, modelPath);

This generates a binary .mdl file which contains enough information to reconstitute our model and generate predictions off of it.

Loading a Model from Disk

Here’s where things get annoying again. If you have any kind of custom mapping, you need to do a bit of extra work, as I mentioned in my rant.

Next, when it comes time to deserialize the model, we need to register assemblies. I have two custom mappings but they’re both in the same assembly. Therefore, I only need to register the assembly using one of them. Should I desire to migrate these mappings into separate assemblies later, I would need to update the code to include each assembly at least once. I’m not sure yet which is a better practice: include all custom assemblies regardless of whether you need them, or go back and modify calling code later. What I do know is that it’s a bit annoying when all I really wanted was a simple string or integer translation. Anyhow, here’s the code:

mlContext.ComponentCatalog.RegisterAssembly(typeof(QBCustomMappings).Assembly);
mlContext.ComponentCatalog.RegisterAssembly(typeof(PointsCustomMappings).Assembly);

Once I’ve registered those assemblies, I can reconstitute the model using a method call:

var newModel = bmt.LoadModel(mlContext, modelPath);

That method call wraps the following code:

public ITransformer LoadModel(MLContext mlContext, string modelPath)
{
	ITransformer loadedModel;
	using (var stream = File.OpenRead(modelPath))
	{
		DataViewSchema dvs;
		loadedModel = mlContext.Model.Load(stream, out dvs);
	}

	return loadedModel;
}

What we’re doing is building a model (of interface type ITransformer) as well as a DataViewSchema which I don’t need here. I get back my completed model and can use it like I just finished training. In fact, I have a test case in my GitHub repo which compares a freshly-trained model versus a saved and reloaded model to ensure function calls work and the outputs are the same:

private string GenerateOutcome(PredictionEngineBase<RawInput, Prediction&gt; pe)
{
	return pe.Predict(new RawInput
	{
		Game = 0,
		Quarterback = "Josh Allen",
		Location = "Home",
		NumberOfPointsScored = 17,
		TopReceiver = "Robert Foster",
		TopRunner = "Josh Allen",
		NumberOfSacks = 0,
		NumberOfDefensiveTurnovers = 0,
		MinutesPossession = 0,
		Outcome = "WHO KNOWS?"
	}).Outcome;
}

[Test()]
public void SaveAndLoadModel()
{
	string modelPath = "C:\\Temp\\BillsModel.mdl";
	bmt.SaveModel(mlContext, model, modelPath);

	// Register the assembly that contains 'QBCustomMappings' with the ComponentCatalog
	// so it can be found when loading the model.
	mlContext.ComponentCatalog.RegisterAssembly(typeof(QBCustomMappings).Assembly);
	mlContext.ComponentCatalog.RegisterAssembly(typeof(PointsCustomMappings).Assembly);

	var newModel = bmt.LoadModel(mlContext, modelPath);

	var newPredictor = mlContext.Model.CreatePredictionEngine<RawInput, Prediction&gt;(newModel);
	var po = GenerateOutcome(predictor);
	var npo = GenerateOutcome(newPredictor);

	Assert.IsNotNull(newModel);
	Assert.AreEqual(po, npo);
}

Results, naturally, align.

Conclusion

In today’s post, we looked at the ability to serialize and deserialize models. We looked at saving the model to disk, although we could save to some other location, as the Save() method on MLContext requires merely a Stream and not a FileStream. Saving and loading models is straightforward as long as you’ve done all of the pre-requisite work around custom mappings (or avoided them altogether).

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s