Building Prediction APIs in Python (Part 1): Series Introduction & Basic Example
Ok, so you’ve trained a model, but now what? All that work is meaningless if no one can use it. In some applications, you can simply create a batch job that will collect the records that need to be scored, prepare them, and the run the model. However, other applications require or at least highly benefit from a process for scoring models in real-time. Ideally, we’d even like to share the ability to use the model with others through a REST API. While deploying a simple model is fairly easy, there are challenges in extending and iterating on this starting point.
In this series, we’ll take an incremental approach to building a prediction API. For this post, we’re going to build the simplest possible API that will allow us to score our model and return a prediction. Each subsequent post will then focus on how we can improve on what we’ve already implemented. There won’t be a set end to this series. As new topics come up, I’ll try to expand on what’s already been done — depending on how much time I have and the interest level of readers.
Intended Audience
Throughout this series, we’ll primarily focus on the software engineering aspects of building a model scoring platform and discuss the gray area between engineering and data science. I’m going to assume for the most part that readers have a model that they want to deploy or that they are capable of producing such a model. As such, we generally won’t cover the process of building predictive models. There are many great tutorials and books that cover this topic.
There are two main audiences for this series. The first is full-stack data scientists who want to deploy their model through an API, but who aren’t sure how to do it effectively. The second is software engineers or technical product managers who haven’t worked in the data science space and may benefit from understanding how model deployment differs from general software deployment.
Tools
We’re going to use Python 3 as the primary language. In particular, I will be mostly using Python 3.6. Flask will be used for the initial APIs. All models will be built using Scikit-Learn. If you’re not sure how to setup a development environment with these packages, I recommend starting with Anaconda as it provides everything you need to start already set up.
While this tool set is fairly narrow in scope, my hope is that the topics are broadly applicable to other programming languages and API frameworks.
Basic Example
Now that we’ve set the groundwork for this series, let’s build a basic model. We’ll start with a random forest classifier built on the iris dataset.
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib
# Grab the dataset from scikit-learn
X, y = datasets.load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,
random_state=42)
# Build and train the model
model = RandomForestClassifier(random_state=101)
model.fit(X_train, y_train)
print("Score on the training set is: {:2}"
.format(model.score(X_train, y_train)))
print("Score on the test set is: {:.2}"
.format(model.score(X_test, y_test)))
# Save the model
model_filename = 'iris-rf-v1.0.pkl'
print("Saving model to {}...".format(model_filename))
joblib.dump(model, model_filename)
Again, I’m going to skip discussing the model build process, so we’re not going to walk through this example in depth. The main thing to note is that once we’ve built the model, we are using joblib to save (pickle) it to a file. While this works, there are some concerns and limitations with this approach. Scikit has a great article on this [http://scikit-learn.org/stable/modules/model_persistence.html], but the main issues are:
- Security: when you load a pickled object, you're essentially executing code stored in the file. Only load persisted objects that you trust.
- Portability: any shifts in versions of scikit-learn (or potentially dependencies) between the build environment and the production scoring environment will cause warnings at a minimum and potentially unexpected behavior, such as failures or prediction errors.
We’ll set these concerns aside for now, but these may be topics that we explore in subsequent posts.
Building Our First API
Now, we’re going to build the simplest possible API using Flask. We’ll include a single endpoint /predict that will allow us to pass the feature values through query parameters.
from flask import Flask, request, jsonify
from sklearn.externals import joblib
app = Flask(__name__)
# Load the model
MODEL = joblib.load('iris-rf-v1.0.pkl')
MODEL_LABELS = ['setosa', 'versicolor', 'virginica']
@app.route('/predict')
def predict():
# Retrieve query parameters related to this request.
sepal_length = request.args.get('sepal_length')
sepal_width = request.args.get('sepal_width')
petal_length = request.args.get('petal_length')
petal_width = request.args.get('petal_width')
# Our model expects a list of records
features = [[sepal_length, sepal_width, petal_length, petal_width]]
# Use the model to predict the class
label_index = MODEL.predict(features)
# Retrieve the iris name that is associated with the predicted class
label = MODEL_LABELS[label_index[0]]
# Create and send a response to the API caller
return jsonify(status='complete', label=label)
if __name__ == '__main__':
app.run(debug=True)
We’re loading our model globally as MODEL
. We also have a list of labels (MODEL_LABELS)¹
that correspond to the integer values that MODEL.predict()
will output. We created our API endpoint using the @app.route('/predict')
decorator and defined the predict()
function to handle requests sent to this endpoint. Calls to this endpoint are expected to contain 4 parameters which correspond to our features²: sepal_length
, sepal_width
, petal_length
, and petal_width
. For example, a call would look something like this:
http://127.0.0.1:5000/predict?sepal_length=5&sepal_width=3.1&petal_length=2.5&petal_width=1.2
Flask handles parsing the URL query string and adds the parameters to request.args
which has an API similar to a python dict. Using get will retrieve a value if the key exists, and will return a default value otherwise. Since we didn't specify a default, None
will be returned³. Also, by default, the values in args
will be strings, but during scoring the model will automatically convert these to floats.
The feature values are then packed into a nested list (list of lists). This is necessary because our model expects a list of records where each record is the length of the feature set. We’re only scoring a single record at a time which is why the outer list contains only one inner list (record).
We use MODEL.predict()
to obtain find the predicted class which will be 0, 1, or 2. Finally, we can get the label by using the class as an index into MODEL_LABELS
.
Running Our Prediction Service
Now, we can start our Flask server from the command line.
$ python model_scoring_basic.py
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
* Restarting with stat
* Debugger is active!
* Debugger PIN: 428-676-378
From our browser, we can make a test request using the example above.
We could also use requests to do the same thing.
>>> import requests
>>> response = requests.get('http://127.0.0.1:5000/predict?sepal_length=5&sepal_width=3.1&petal_length=2.5&petal_width=1.2')
>>> print(response.text)
{
"label": "versicolor",
"status": "complete"
}
What’s Ahead
In 20–30 lines of code, we’ve built a simple model and created an API that will accept requests and return a prediction. However, there are several areas where we can and should improve on what we’ve done here. Exploration of these topics will be handled in later posts, but here’s a quick preview:
- Error Handling: Currently, we’re not doing any error handling when we score the model. What happens when we receive bad data or data is missing (no sepal_length is provided)? What kind of feedback will the caller receive if there's a failure?
- Automated Tests: Do we have any bugs in our code? Rather than manually running a few requests, is it possible to automate this process?
- Extensibility: This API is only handling a single, global model. Should we create another one-off API for our next model? What if we create a new version of our current model? Should there be a separate API for each version?
- Data Collection: No data are collected from this API. How are we going to track how well our model is performing? Do we know when errors are being made? How long does it take to handle a request? If we start collecting data, where should we store it?
- Deployment: This is running locally, but we need to put it on a server. How do we scale? Should we go serverless? How do we handle load testing?
- Feature Engineering: This API expects the requester to have already prepared any features, but often we need to apply some transformations to the data before they’re ready to be consumed by the model. How and where should we do this?
REFERENCES