Tensorflow Extended (TFX) in action: build a production ready deep learning pipeline

Sergios Karagiannakoson2021-04-22·5 mins

In this tutorial, we will explore TensorFlow Extended (TFX). TFX was developed by Google as an end-to-end platform for deploying production ML pipelines. Here we will see how we can build one from scratch. We will explore the different built-in components that we can use, which cover the entire lifecycle of machine learning. From research and development to training and deployment.

But first, let’s start with some basic concepts and terminology to make sure that we are all on the same page.

I highly recommend the ML Pipelines on Google Cloud course by the Google cloud team or the Advanced Deployment Scenarios with TensorFlow by DeepLearning.ai to improve your skills with a holistic course.

TFX glossary

Components are the building blocks of a pipeline and are the ones that perform all the work. Components can be used intact or can be overridden with our own code.

Metadata store is the single source of truth for all components. It contains 3 things basically:

  • Artifacts and their properties: these can be trained models, data, metrics

  • Execution records of components and pipelines

  • Metadata about the workflow (order of components, inputs, outputs, etc)

TFX pipeline is a portable implementation of an ML workflow that is composed of component instances and input parameters

Orchestrators are systems that execute TFX pipelines. They are basically platforms to author, schedule, and monitor workflows. They usually represent a pipeline as a Directed Acyclic Graph and make sure that each job (or a worker) is executed at the correct time with the correct input.

Examples of popular orchestrators that work with TFX are Apache Airflow, Apache Beam, Kubeflow pipelines

Based on the different stages of the machine learning lifecycle, TFX provides a set of different components with standard functionality. These components can be overridden. For example, we may want to extend their functionality. They can also be replaced by entirely new ones. In most cases, though, the built-in components will take most of us a long way down the road.

Let’s do a quick walkthrough on all of them starting with data loading and ending to deployment. Note that we will not dive deep into the code because there are a lot of new libraries and packages that most are unfamiliar with.

The whole point is to give you an overview of TFX and its modules and help you understand why we need such end-to-end solutions

Data Ingestion

The first phase of the ML development process is data loading. The ExampleGen component ingests data into a TFX pipeline by converting different types of data to tf.Record or tf.Example ( both supported by TFX). Sample code can be found below:

from tfx.proto import example_gen_pb2
from tfx.components import ImportExampleGen
input_config = example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='train', pattern='train/*'),
example_gen_pb2.Input.Split(name='eval', pattern='test/*')
example_gen = ImportExampleGen(
input_base=data_root, input_config=input_config)

ImportExampleGen is a special type of ExampleGen that receives a data path and a configuration on how to handle our data. In this case, we split them into training and test datasets.

Data Validation

The next step is to explore our data, visualize it and validate it for possible inaccuracies and anomalies.

The StatisticsGen component generates a set of useful statistics describing our data distribution. As you can see it receives the output of ExampleGen

from tfx.components import StatisticsGen
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

Tensorflow Data Validation is a built-in TFX library that, among other things, can help us visualize the statistics produced by StatisticsGen. It is used internally by StatisticsGen but can also be used as a standalone tool.

import tensorflow_data_validation as tfdv


The same library is being used by SchemaGen, which generated a primitive schema for our data. This can be of course adjusted based on our domain knowledge but it is a decent starting point.

from tfx.components import SchemaGen
schema_gen = SchemaGen( statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

The schema and the statics produced can now be utilized in order to perform some form of data validation that will catch outliers, anomalies, and errors in our dataset.

from tfx.components import ExampleValidator
example_validator = ExampleValidator(

Feature Engineering

One of the most important steps in any ML pipeline is feature engineering. Basically, we preprocess our data so that it can be passed to our model. TFX provides the Transform component and the tensorflow_transform library to help us with the task. The transform step can be performed like this:

from tfx.components import Transform
transform = Transform(

But that’s not the entire story.

We need to define our preprocessing functionality somehow. This is where the argument module_file comes in. The most usual way to do that is to have a different file with all of our transformations. Essentially, we need to implement a preprocessing_fn function which is the point of entrance for TFX.

Here is a sample I borrowed from the official TFX examples:

def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing inputs."""
outputs = {}
image_features = tf.map_fn(
lambda x: tf.io.decode_png(x[0], channels=3),
image_features = tf.cast(image_features, tf.float32)
image_features = tf.image.resize(image_features, [224, 224])
image_features = tf.keras.applications.mobilenet.preprocess_input(
outputs[_transformed_name(_IMAGE_KEY)] = image_features
outputs[_transformed_name(_LABEL_KEY)] = inputs[_LABEL_KEY]
return outputs

Normal Tensorflow and Keras code as you can see.

Model training

Training the model is a vital part of the process and in contrast to what many people believe, is not a one-time operation.

Models need to be retrained constantly to stay relevant and ensure the best possible accuracy in their results.

from tfx.dsl.components.base import executor_spec
from tfx.proto import trainer_pb2
from tfx.components.trainer.executor import GenericExecutor
from tfx.components import Trainer
trainer = Trainer(
custom_config={'labels_path': labels_path})

As before, the training logic is in a separate module file. This time we have to implement the run_fn function, which normally defines the model and training loop. Again borrowed from the official examples and stripped from some unnecessary stuff, here is an example:

import tensorflow_transform as tft
def run_fn(fn_args: FnArgs):
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = _input_fn(
eval_dataset = _input_fn(
model, base_model = _build_keras_model()

Note that the _build_keras_model returns a vanilla tf.keras.Sequential model, while input_fn returns a batched dataset of training examples and labels.

Check the official git repo for the full code. Also be sure that with the correct callbacks, we can take advantage of Tensorfboard to visualize the training progress.

Model validation

Next in line is model validation. Once we train a model, we have to evaluate it and analyze its performance before we push it into production. TensorFlow Model Analysis (TFMA) is a library for that particular thing. Notice here that this actual model evaluation has already happened during training.

This step intends to record evaluation metrics for future runs and compare them with previous models.

That way we can be sure that our current model is the best we have at the moment.

I will not go into the details of TFMA but here is some code for future reference:

import tensorflow_model_analysis as tfma
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(label_key='label_xf', model_type='tf_lite')],
lower_bound={'value': 0.55}),
# Change threshold will be ignored if there is no
# baseline model resolved from MLMD (first run).
absolute={'value': -1e-3})))

The important part is where we define out Evaluator component as a part of our pipeline

from tfx.components import Evaluator
evaluator = Evaluator(

Push the model

Once the model validation succeeds, it’s time to push the model into production. This is the job of the Pusher component, which handles all the deploying stuff depending on our environment.

from tfx.components import Pusher
pusher = Pusher(

Build a TFX pipeline

Ok, we defined a number of components so far that contain everything we need. But how do we tie them together? TFX pipelines are defined using the pipeline class, which receives a list of components among other things.

from tfx.orchestration import metadata
from tfx.orchestration import pipeline
components = [
example_gen, statistics_gen, schema_gen, example_validator, transform,
trainer, model_resolver, evaluator, pusher
pipeline = pipeline.Pipeline(

Component instances produce artifacts as outputs and typically depend on artifacts produced by upstream component instances as inputs. The order of execution of the components is determined by a Direct Acyclic Graph (DAG) based on each artifact's dependencies. Here is a typical TFX pipeline:

kubeflow-pipelines Source: Google cloud platform Docs

Run a TFX pipeline

At last, we reach the part where we will run the pipeline. As we already said, pipelines are executed by an orchestrator, which will handle all the job scheduling and the networking. Here I chose Apache Beam using BeamDagRunner but the same principles are true for Kubeflow or Airflow.

from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
if __name__ == '__main__':
BeamDagRunner().run( pipeline)

Also, I should mention that similar commands can be executed from the command line using the TFX CLI.

I’m sure it goes without saying that orchestrators like Apache beam in 99% of use cases will run on cloud resources.

This means Beam will spin up cloud instances/workers and stream data through them. That will be depending on the environment and the pipeline.

Typical runners below Apache Beam include Spark, Flink, Google Dataflow. On the other hand, frameworks like Kubeflow rely on Kubernetes. So one important job of MLOps engineers is to find the best environment for their needs.


End-to-end machine learning systems have gained a lot of attention over the past years. MLOps is starting to become more and more relevant as many different startups and frameworks have been born. One perfect example is TFX. I have to admit that building such pipelines is not an easy task and it requires a deep dive into the intricacies of TFX. But I think it’s one of the best tools we have in our arsenal at the moment. So the next time you want to deploy a machine learning model, maybe it’s worth giving it a shot.

As a side note, I have to prompt you again to ML Pipelines on Google Cloud course by the Google cloud team or the Advanced Deployment Scenarios with TensorFlow by DeepLearning.ai.

Join the AI Summer community

Get access to free resources and educational content by subscribing to our newsletter

* We're committed to your privacy. AI Summer uses the information you provide to send you our newsletter and contact you about our products. You may unsubscribe from these communications at any time. For more information, check out our Privacy Policy.

* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.