Programming a deep learning model is not easy (I’m not going to lie) but testing one is even harder. That’s why most of the TensorFlow and PyTorch code out there does not include unit testing. But when your code is going to live in a production environment, making sure that it actually does what is intended should be a priority. After all, machine learning is not different from any other software.
Note that this post is the third part of the Deep Learning in Production course where we discover how to convert a notebook into production-ready code that can be served to millions of users.
In this article, we are going to focus on how to properly test machine learning code, analyze some best practices when writing unit tests and present a number of example cases where testing is kind of a necessity. We will start on why we need them in our code, then we will do a quick catch up on the basics of testing in python, and then we will go over a number of practical real-life scenarios.
Why we need unit testing
When developing a neural network, most of us don’t care about catching all possible exceptions, finding all corner cases, or debugging every single function. We just need to see our model fit. And then we just need to increase its accuracy until it reaches an acceptable point. That’s all good but what happens when the model will be deployed into a server and used in an actual public faced application? Most likely it will crash because some users may be sending wrong data or because of some silent bug that messes up our data preprocessing pipeline. We might even discover that our model was in fact corrupted all this time.
This is where unit tests come into play. To prevent all these things before they even occur. Unit tests are tremendously useful because they:
Find software bugs early
Debug our code
Ensure that the code does what is supposed to do
Simplify the refactoring process
Speed up the integration process
Act as documentation
Don’t tell me that you don’t want at least some of the above. Sure testing can take up a lot of your precious time but it’s 100% worth it. You will see why in a bit.
But what exactly is a unit test?
Basics of Unit testing
In simple terms, unit testing is just a function calling another function (or a unit) and checking if the values returned match the expected output. Let’s see an example using our UNet model to make it more clear.
If you haven’t followed the series you can find the code in our GitHub repo.
In a few words, we took an official Tensorflow google colab that performs image segmentation and we try to convert it into a highly optimized production-ready code. Check the first two parts here and here).
So we have this simple function that normalized an image by dividing all the pixels by 255.
def _normalize(self, input_image, input_mask):""" Normalise input imageArgs:input_image (tf.image): The input imageinput_mask (int): The image maskReturns:input_image (tf.image): The normalized input imageinput_mask (int): The new image mask"""input_image = tf.cast(input_image, tf.float32) / 255.0input_mask -= 1return input_image, input_mask
To make sure that it does exactly what it is supposed to do, we can write another function that uses the “_normalize” and check its result. It will look something like this.
def test_normalize(self):input_image = np.array([[1., 1.], [1., 1.]])input_mask = 1expected_image = np.array([[0.00392157, 0.00392157], [0.00392157, 0.00392157]])result = self.unet._normalize(input_image, input_mask)self.assertEquals(expected_image, result)
The “test_normalize” function creates a fake input image, calls the function with that image as an argument, and then makes sure that the result is equal to the expected image. The “assertEquals” is a special function, coming from the unittest package in python (more on that in a sec) and does exactly what its name suggests. It asserts that the two values are equal. Note that you can also use something like this bellow but using built-in functions has its advantages
assert expected_image == result
That’s it. That’s unit testing. Tests can be used on both very small functions and bigger complex functionalities across different modules.
Unit tests in Python
Before we see some more examples, I’d say to do a quick catch up on how Python supports unit testing.
The main test framework/runner that comes into Python’s standard library is unittest. Unittest is pretty straightforward to use and it has only two requirements: to put your tests into a class and use its special assert functions. A simple example can be found below:
import unittestclass UnetTest(unittest.TestCase):def test_normalize(self):. . .if __name__ == '__main__':unittest.main()
Some things to notice here:
We have our test class which includes a “testnormalize” function as a method. In general, test functions are named with “test” as a prefix followed by the name of the function they test. (This is a convention but it also enables unittest’s autodiscovery functionality, which is the ability of the library to automatically detect all unit tests within a project or a module so you don’t have to run them one by one)
To run unit tests, we call the “unittest.main()” function which discovers all tests within the module, runs them and prints their output.
Our UnetTest class inherits the “unittest.TestCase” class. This class helps us set unique test cases with different inputs because it comes with “setUp()” and “tearDown()” methods. In setUp() we can define our inputs that can be accessed by all tests and in tearDown() we can dissolve them( see snippet in the next chapter). This is helpful because all tests should run independently and generally they can’t share information. Well, now they can.
Another two powerful frameworks are pytest and nose, which are pretty much governed by the same principles. I suggest playing with them a little before you decide what suits you best. I personally use pytest most of the times because it feels a bit simpler and it supports a few nice to have things like fixtures and test parameterization( which I’m not gonna go into details here, you can check the official docs for more). But honestly it doesn’t have that big of a difference so you should be fine with either of them.
Tests in Tensorflow: tf.test
But here I’m going to discuss another one, a less known one. Since we use Tensorflow to program our model we can take advantage of “tf.test”, which is an extension of unittest but it contains assertions tailored to Tensorflow code (yup I was shocked when I found that out too). In that case, our code morphed into this:
import tensorflow as tfclass UnetTest(tf.test.TestCase):def setUp(self):super(UnetTest, self).setUp(). . .def tearDown(self):passdef test_normalize(self):. . .if __name__ == '__main__':tf.test.main()
It has exactly the same baselines with the caveat that we need to call the “super()” function inside the “setUp”, which enables “tf.test” to do its magic. Pretty cool hah?
Another super important topic you should be aware of is Mocking and mock objects. Mocking classes and functions are super common when writing java for example but in Python is very underutilized. Mocking makes it very easy to replace complex logic or heavy dependencies when testing code using dummy objects. By dummy objects, we refer to simple, easy to code objects that have the same structure with our real objects but contain fake or useless data. In our case a dummy object might be a 2d tensor with all ones which mimics an actual image (just like the “input_image in the first code snippet).
Mocking also helps us control the code’s behavior and simulate expensive calls. Let’s look at an example using once again our UNet.
Let’s assume that we want to make sure that the data preprocessing step is correct and that our code splits the data and creates the training and tests dataset as it should (a very common test case). Here is our code we want to test:
def load_data(self):""" Loads and Preprocess data """self.dataset, self.info = DataLoader().load_data(self.config.data)self._preprocess_data()def _preprocess_data(self):""" Splits into training and test and set training parameters"""train = self.dataset['train'].map(self._load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)test = self.dataset['test'].map(self._load_image_test)self.train_dataset = train.cache().shuffle(self.buffer_size).batch(self.batch_size).repeat()self.train_dataset = self.train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)self.test_dataset = test.batch(self.batch_size)def _load_image_train(self, datapoint):""" Loads and preprocess a single training image """input_image = tf.image.resize(datapoint['image'], (self.image_size, self.image_size))input_mask = tf.image.resize(datapoint['segmentation_mask'], (self.image_size, self.image_size))if tf.random.uniform(()) > 0.5:input_image = tf.image.flip_left_right(input_image)input_mask = tf.image.flip_left_right(input_mask)input_image, input_mask = self._normalize(input_image, input_mask)return input_image, input_maskdef _load_image_test(self, datapoint):""" Loads and preprocess a single test image"""input_image = tf.image.resize(datapoint['image'], (self.image_size, self.image_size))input_mask = tf.image.resize(datapoint['segmentation_mask'], (self.image_size, self.image_size))input_image, input_mask = self._normalize(input_image, input_mask)return input_image, input_mask
No need to dive very deep into the code but what it actually does is splitting the data, some shuffling, some resizing, and batching. So we want to test this code. Everything is nice and well except that freaking loading function.
self.dataset, self.info = DataLoader().load_data(self.config.data)
Are we supposed to load the entire data every time we run a single unit test? Absolutely not. Therefore, we could mock that function to return a dummy dataset instead of calling the real one. Mocking to the rescue.
We can do that with unittests’s mock object package . It provides a mock class “Mock()” to create a mock object directly and a “patch()” decorator that replaces an imported module within the module we test with a mock object. As it’s not trivial to understand the difference, I’ll leave a link from an amazing article at the end for those who want extra details.
For those who aren’t familiar, the decorator is simply a function that wraps another function to extend its functionality. Once we declare the wrapper function we can annotate other functions to enhance them. See the @patch below? That’s a decorator that wraps the “test_load_data” with the “patch” function. For more information follow the link at the end of the post.
By using the “patch()” decorator we get this:
@patch('model.unet.DataLoader.load_data')def test_load_data(self, mock_data_loader):mock_data_loader.side_effect = dummy_load_datashape = tf.TensorShape([None, self.unet.image_size, self.unet.image_size, 3])self.unet.load_data()mock_data_loader.assert_called()self.assertItemsEqual(self.unet.train_dataset.element_spec.shape, shape)self.assertItemsEqual(self.unet.test_dataset.element_spec.shape, shape)
I can tell that you are amazed by this. Don’t try to hide it.
Before we see some specific testing use cases on machine learning, I would like to mention another important aspect. Coverage. By coverage, we mean how much of our code is actually tested by unit tests.
Coverage is an invaluable metric that can help us write better unit tests, discover which areas our tests don’t exercise, find new test cases, and ensure the quality of our tests. You can simply check your coverage like this:
- Install the coverage package
$ conda install coverage
- Run the package in your test file
$ coverage run -m unittest /home/aisummer/PycharmProjects/Deep-Learning-Production-Course/model/tests/unet_test.py
- Print the results
$ coverage report -m /home/aisummer/PycharmProjects/Deep-Learning-Production-Course/model/tests/unet_test.py
Name Stmts Miss Cover Missing-------------------------------------------------------------model/tests/unet_test.py 35 1 97% 52
This says that we cover 97% of our code. There are 35 statements in total and we missed just 1 of them. The missing info tells us which lines of code still need coverage (how handy!).
Test example cases
I think it’s time to explore some of the different deep learning scenarios and parts of the codebase when unit testing can be incredibly useful. Well, I’m not gonna write the code for every single one of them, but I think it would be very important to outline a few use cases.
We already discussed one of them. Ensuring that our data has the right format is critical. A few others I can think of are:
Ensure that our data has the right format (yes I put it again here for completion)
Ensure that the training labels are correct
Test our complex processing steps such as image manipulation
Assert data completion, quality, and errors
Test the distribution of the features
Run a training step and compare the weight before and after to ensure that they are updated
Check that our loss function can be actually used on our data
Having tests to ensure that your metrics ( e.g accuracy, precision, and recall ) are above a threshold when iterating over different architectures
You can run speed/benchmark tests on training to catch possible overfitting
Of course, cross-validation can be in the form of a unit test
The model’s layers are actually stacking
The model’s output has the correct shape
Actually let’s program the last one to prove to you how simple it is:
def test_ouput_size(self):shape = (1, self.unet.image_size, self.unet.image_size, 3)image = tf.ones(shape)self.unet.build()self.assertEqual(self.unet.model.predict(image).shape, shape)
That’s it. Define the expected shape, construct a dummy input, build the model, and run a prediction is all it takes. Not so bad for such a useful test, right? You see unit tests don’t have to be complex. Sometimes a few lines of code can save us from a lot of trouble. Trust me. But also we shouldn’t go on the other side and test every single thing imaginable. This is a huge time sink. We need to find a balance.
I am confident that you can come up with many many more test scenarios when developing your own models. This was just to give a rough idea of the different areas you can focus on.
Something that I deliberately avoided mentioning is integration and acceptance tests. These kinds of tests are very powerful tools and aim to test how well our system integrates with other systems. If you have an application with many services or client/server interaction, acceptance tests are the go-to functionality to make sure that everything works as expected at a higher level.
Later throughout the course, when we deploy our model in a server, we will absolutely need to write some acceptance tests as we want to be certain that the model returns what the user/client expects in the form that he expects it. As we iterate over our application while it is live and is served to users, we can’t have a failure due to some silly bug (remember the reliability principle from the first article?) These kinds of things acceptance tests help us avoid.
So let’s leave them for now and deal with them when the time comes. To make sure that you will be notified when the next part of this course is out, you can subscribe to our newsletter.
Unit tests are indeed an invaluable tool in our arsenal, especially when building complex deep learning models. I mean I can think of a million things that can go wrong on machine learning apps. Although it can be hard to write good tests as well as time-consuming, it is something that you shouldn’t neglect. My laziness has come and bitten me more times than I can count, so I decided to always write them from now on. But again, we always need to find a balance.
However unit testing is only one of the ways to make our code production-ready. To make sure that our original notebook can be used reliably in a deployment environment, we have to do a couple of more things. So far we talked about system design for deep learning and best practices to write python deep learning code. Next in our list is to add logging to our codebase and learn how to debug our TensorFlow code. Can’t wait.
See you then...
programiz.com, Python Decorators
wikipedia.org, Code coverage
toptal.com, An Introduction to Mocking in Python
realpython.com, Getting Started With Testing in Python
* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.