📖 Check out our Introduction to Deep Learning & Neural Networks course 📖

Learn more

Deploy a Deep Learning model as a web application using Flask and Tensorflow

Sergios Karagiannakoson2020-11-05·9 mins
MLOpsSoftwareTensorflow

Developing a state-of-the-art deep learning model has no real value if it can’t be applied in a real-world application. Don't get me wrong, research is awesome! But most of the time the ultimate goal is to use the research to solve a real-life problem. In the case of deep learning models, a vast majority of them are actually deployed as a web or mobile application. In the next couple of articles, this is exactly what we're gonna do:

We will take our image segmentation model, expose it via an API (using Flask) and deploy it in a production environment.

If you are new to this article series here is a quick reminder: we took a simple Unet model from a Colab notebook that performs segmentation on an image, and we converted it to a full-size highly-optimized project. Now we will serve it to real users at scale. For more details check out the previous article or our GitHub repository.

Our end goal is to have a fully functional service that can be called by clients/users and perform segmentation in real-time. My assumption here is that most of you are not very familiar with building client-server applications. That’s why I will try to explain some fundamental terms to give you a clear understanding of how to achieve that.

Here is our glossary:

  • Web service: any self-contained piece of software that makes itself available over the internet and uses a standard communication protocol such as HTTP.

  • Server: a computer program or device that provides a service to another computer program and its user, also known as the client.

  • Client-server is a programming model in which one program (the client) requests a service or resource from another program (the server)

  • API(Application Programming Interface): a set of definitions and functions that allows applications to access data and interact with external software components, operating systems, or microservices.

So let's pause for a moment and think about what we need to do. We first need to have some sort of “inferrer” class. In simple terms, an inferer interacts with our Tensorflow model and computes the segmentation map. Then we need to build a web application to expose that functionality (API) and finally, we need to create a web service that allows clients to communicate with it and send their own images for prediction.

Shall we?

Deep Learning in Production Book 📖

Learn how to build, train, deploy, scale and maintain deep learning models. Understand ML infrastructure and MLOps using hands-on examples.

Learn more

Inferring a segmentation mask of a custom image

We had trained the model using a custom training loop and then we saved the training variables using the Tensorflow built-in saving functionality.

save_path = os.path.join(self.model_save_path, "unet")
tf.saved_model.save(self.model, save_path)

Our next steps in a nutshell: a) load the saved model, b) feed it with the user's image, and c) infer the segmentation mask.

I'd say that a good way to do that is to build an inferrer class! The latter will load the model on creation (to avoid loading it multiple times) and has an inference method that returns the result of the model.

Remark: Don't forget that the user’s image might not be in our desired format so we will need to do some preprocessing on it before passing it to the model.

class UnetInferrer:
def __init__(self):
self.saved_path = 'model_path_location'
self.model = tf.saved_model.load(self.saved_path)
self.predict = self.model.signatures["serving_default"]

Notably, Tensorflow uses a built-in saved model format that is optimized for serving the model in a web service. That’s why we can’t simply load and do a “keras.fit()”. The object that we use to represent a saved model contains a set of specific fields. More specifically it has different graph metadata (called tag con stants) and a set of fields that define the different input, output, and method names ( signature constants). Moreover, most of the models when saved have a ‘serving_default’ key, which is a function that takes an input and predicts an output using the computational graph of the model. For that reason, we need to get the value from the signature and define a predict function that can be used for inference. To perform a prediction on an image, we can do something like this:

prediction= self.predict(input image)['output_layer_name']

In my example, the output layer name is “conv2d_transpose_4”. A good practice is to define custom names for the layers and the variables but I'm guilty of not doing that at this project.

To continue let's define a preprocess method. We only need to resize the image and normalize it.

def preprocess(self, image):
image = tf.image.resize(image, (self.image_size, self.image_size))
return tf.cast(image, tf.float32) / 255.0

Finally, we need an “infer” method that takes an image as an argument and returns the segmentation output. Note that the image won't be in a tf.Tensor format! So, we first need to convert it, then apply the preprocessing and then do the actual prediction.

def infer(self, image=None):
tensor_image = tf.convert_to_tensor(image, dtype=tf.float32)
tensor_image = self.preprocess(tensor_image)
shape = tensor_image.shape
tensor_image = tf.reshape(tensor_image,[1, shape[0],shape[1], shape[2]])
return self.predict(tensor_image)['conv2d_transpose_4']

Did you notice an extra step?

In general, the images are 3-dimensional tensors (RGB) and the model expects them as a 4-dimensional tensor. To ensure that, we reshape which luckily is a piece of cake with TensorFlow.

Remember unit testing? Feel free to revisit it!

To be absolutely sure that our method is correct, we may want to implement a very simple but incredibly useful unit test.

from PIL import Image
import numpy as np
from executor.unet_inferrer import UnetInferrer
class MyTestCase(unittest.TestCase):
def test_infer(self):
image = np.asarray(Image.open('resources/yorkshire_terrier.jpg')).astype(np.float32)
inferrer = UnetInferrer()
inferrer.infer(image)

Nothing special here. We load a sample dog image with a cute Yorkshire terrier, convert it to a NumPy array, and use our infer function to make sure that everything works as expected.

We are ready to continue creating our web server with Flask. As we will see, this is not very accurate since Flask is not a web server.

Creating a web application using Flask

A vital question first. What is Flask?

Flask is a web application framework that enables us to build simple applications with minimal boilerplate code and a few out of the box functionalities. Flask is built on top of the WSGI (Web Server Gateway Interface) protocol, a protocol written in Python that describes how a web server communicates with web applications and a part of Python’s standard (more on WSGI in the next articles).

However, Flask is not a fully functional web server and should not be used for production use. A better approach might be something like the uWSGI web server that we will explore in the next article. It's a perfect solution though to develop a quick app and do some prototyping on how our web server will look like.

Flask like all the other web frameworks provide some basic features:

  • it helps us define different routes based on a URL so we can expose different functionalities

  • it exposes different endpoints

  • it comes with some nice to have extras like integrated support for unit testing, built-in server, debugger and HTML templating.

Ok, let’s pause for a moment. Actually now I'm thinking about it, let's get back to the absolute basics and remember how modern web applications work.

Basics of modern web applications

In the client-server paradigm, the client sends a request to the server, the server processes that request by running a piece of software and returns a response back to the client. This simple sentence raises many questions.

How does the client know in what format the server expects the request (data)?

This is taken care of by the HTTP protocol. HTTP defines the communication between a server and a client, how messages are formatted and transmitted, and what action the server and the client need to take in response to various commands and request types.

What does a request look like?

An HTTP request has 4 basic components: a destination URL, a method, some metadata called headers, and optionally a request body.

  • The URL is a remote path under which the server’s functionality lives. From the server's perspective, this is called a route and it includes a URL plus a specific port.

  • A method is the type of the request. In HTTP we have four basic types: GET, POST, UPDATE, DELETE. Depending on the method the server will expose a different functionality under the hood.

  • Headers are different metadata such as date, status, content type, and other stuff that are necessary both for the client and the server to take action.

  • The Body contains the full data that we send over the web.

Note that an HTTP response has the exact same format.

How does the client know where to send the request?

The URL alongside the method defines an endpoint, which is a point of entry on a server.

Route: localhost:8080/semantic-segmentation
Method: POST
Endpoint: POST localhost:8080/semantic-segmentation

Different methods with the same route define different endpoints and therefore have different functionalities. If all that makes sense, you should have understood by now that to communicate with the server all we need to do is send a POST request to the “localhost:8080/semantic-segmentation” URL.

I believe that you might have a clear idea by now about the basics so I will proceed with building our flask application. If you still are unsure about something, I will highly recommend digging into how modern web applications work so I will provide a few links at the end.

Exposing the Deep Learning model using Flask

The first thing we need to do to create an app is to import Flask and create a new instance of it.

from flask import Flask, request
app = Flask(__name__)

To start the application, we can use the “ run” method on a form like:

if __name__ == '__main__':
app.run(host=HOST, port=PORT_NUMBER)

The HOST is our URL (in our case is localhost ) and the PORT is the standard 8080.

if __name__ == '__main__':
app.run(host=HOST, port=PORT_NUMBER)

Now we want to build our endpoint on a specific route. For example, we can have “0.0.0.0:8080/infer” as our route and use a POST method.

If we don't want to hardcode the URL and make it flexible for different environments we can get the APP_ROOT environmental variable from os and append our”/infer” path.

APP_ROOT = os.getenv('APP_ROOT', '/infer')

We also want to create an instance of our infer class outside of the endpoint so we don't have to load the model on each request.

u_net = UnetInferrer()

And our endpoint will look like thid:

@app.route(APP_ROOT, methods=["POST"])
def infer():
data = request.json
image = data['image']
return u_net.infer(image)

Let's examine this a bit more carefully. The “app.route” annotation accepts the URL and the method and lets Flask know that in this endpoint we want to expose this particular function. For any other endpoint, you might want to create, you can follow the exact same pattern.

The request object is built-in inside flask, it contains a full HTTP request with all the things mentioned before. The request body is in json format so we can easily get the image and feed it into our Inferrer class that triggers the Tensorflow prediction. It's that simple. Every time a user wants to predict a segmentation mask of an image, all he has to do is send a request to that specific endpoint and he will get back a response.

Another cool feature that Flask has, is a very intuitive way to handle all the exceptions that might occur during the execution of our server.

from flask import jsonify
@app.errorhandler(Exception)
def handle_exception(e):
return jsonify(stackTrace=traceback.format_exc())

This will be triggered every time an error happens and it will return a traceback of the failed python code. Of course, Flask is more powerful than this but we can't possibly outline all of its features here, so I will urge you to check out their documentation for more details.

Cool, our server is up and running. But are we sure that everything works perfectly? To be 100% certain, we might need to build a simple client to send a request and examine the response. An ideal solution would be to create a simple UI on the browser, upload an image from there and plot the segmented image, but that goes out of the scope of this article.

Creating a client

Our client will be nothing more than a Python script which sends a request to the endpoint and displays the response. Python has an amazing library called “requests” that makes sending and receiving HTTP requests quite straightforward. As we did in our unit test, we will load an image from a local folder and then we will create a request object and send it to the server.

import requests
from PIL import Image
import numpy as np
ENDPOINT_URL = http://0.0.0.0:8080/infer
def infer():
image = np.asarray(Image.open('resources/yorkshire_terrier.jpg')).astype(np.float32)
data = { 'image': image.tolist() }
response = requests.post(ENDPOINT_URL, json = data)
response.raise_for_status()
print(response)
if __name__ =="__main__":
infer()

After loading the image and converting it to a NumPy array, we create a json object called data, do a post request to our endpoint URL and print the response.

Since HTTP doesn't recognize NumPy arrays or TensorFlow tensors, we have to convert the image into a python list (which is a json compatible object). This also means that our response will contain a list.

“response.raise_for_status()” is a little trick that will raise an exception If the server returns an error so we can be sure that the rest of our program won’t continue with a falsified response.

Since printing a 3-dimensional array is an impractical idea, let's instead plot the predicted image.

import matplotlib.pyplot as plt
import tensorflow as tf
def display(display_list):
plt.figure(figsize=(15, 15))
title = ['Input Image', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i + 1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
plt.axis('off')
plt.show()

Matplotlib to the rescue. And there we have it. Don’t be alarmed by the quality of the predicted image. Nothing went wrong. I was just too lazy to wait for the model to be fully trained. If you have followed along and you are willing to wait for the training to be completed, you should be able to produce a perfect segmentation mask.

unet-segmentation-result

But the important thing is that everything works fine and both our server and client do what they are supposed to do. I don't know if you realized it, but we just created our web application MVP. A full deep learning-powered app. How cool is that?

Conclusion

In this article we built a model inferrer, we exposed it into a web server through Flask and we constructed a client that sends a request to the server to predict the mask of a custom image. Unfortunately, we're not done yet. At the moment our web server runs only locally, it is using flask which is not optimized for production environments, and it can't handle many users at the same time.

In the next article, we're gonna see how to utilize uWSGI to create a high performant production-ready server and how to use a load balancer like Nginx to distribute the traffic equally to multiple processes so we can serve lots of users at the same time. If that sounds interesting, I hope to see you when.

Auf Wiedersehen...

References

Deep Learning in Production Book 📖

Learn how to build, train, deploy, scale and maintain deep learning models. Understand ML infrastructure and MLOps using hands-on examples.

Learn more

* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.