Bootstrap your own Handler: How and why to create custom handlers for PyTorch’s TorchServe

TorchServe is a great tool to deploy trained PyTorch models, there is no denying that. But, as with any relatively new project, it is still creating a community around it to help with the more niche aspects of its implementation. As part of this community, we can contribute to this.

So today, we will be discussing how to develop advanced custom handlers with PyTorch’s TorchServe. We will also be reviewing the process of saving your PyTorch model with torch-model-archiver and how to include all the new artifacts created while we are at it.

We embarked on this journey specifically because, as great as the included inference handlers can be, you will need to tailor them at some point to fit your needs.

In our case, the needs were twofold. We wanted to have a super minimal API in terms of data processing and also needed it to be very lean in terms of external dependencies. This all meant moving logic into TorchServe to receive more ready-to-go results.

TorchServe also handles GPU support already. So, if your deployment is GPU-enabled, bootstrapping it could really speed up your pipeline steps, apart from just inference.

But enough with the why’s, let’s go over an example to show you how to get the ball rolling.

Code example

This blogpost is structured in a top-down fashion. We found it is the easiest way to understand most systems while getting to know all the minutiae at the same time.

In our example, we have a U2net model we use in a background removal task. Say that, at the moment, our model is being served with the default ImageSegmenter handler. But we want a lean API so, rather than having it return the predicted mask, we need the handler to do the actual background removing. Moreover, suppose we tweaked the pre and post-processing steps. The only way to deploy all this is by having a custom handler.

Here is our custom handler’s structure:

u2net
├── handler.py # Contains our custom handler, extending the `BaseHandler` and overriding most of its functions.
├── model_requirements.txt # The list of the external dependencies we want. Formatted in the classic requirements.txt style.
└── preprocessing.py # Houses the transformations used in the training process, which will be applied before inference.

And inside handler.py you will find:

import base64
import io
import os
import time

import numpy as np
import torch
from PIL import Image
from torchvision.transforms import Compose
from ts.torch_handler.base_handler import BaseHandler

from preprocessing import Normalize, Rescale

class U2Net(BaseHandler):

    image_processing = Compose(
        [
            Rescale(320),
            Normalize(),
        ]
    )

    def _norm_pred(self, d):
        ma = torch.max(d)
        mi = torch.min(d)
        dn = (d - mi) / (ma - mi)
        return dn

    def basic_cutout(self, img, mask):
        u2net_mask = Image.fromarray(mask).resize(img.size, Image.LANCZOS)
        mask = np.array(u2net_mask.convert("L")) / 255.0

        result = img.copy().convert("RGBA")
        return result.putalpha(mask)

    def postprocess(self, image, output):
        pred = output[0][:, 0, :, :]
        predict = self._norm_pred(pred)
        predict = predict.squeeze()
        predict_np = predict.cpu().detach().numpy()
        mask = (predict_np * 255).astype(np.uint8)

        return [self.basic_cutout(image, mask).tobytes()]

    def load_images(self, data):
        images = []

        for row in data:
            # Compat layer: normally the envelope should just return the data
            # directly, but older versions of Torchserve didn't have envelope.
            image = row.get("data") or row.get("body")
            if isinstance(image, str):
                # if the image is a string of bytesarray.
                image = base64.b64decode(image)

            # the image is sent as bytesarray
            image = Image.open(io.BytesIO(image))
            images.append(image)

        return images

    def handle(self, data, context):
        """Entry point for handler. Usually takes the data from the input request and
           returns the predicted outcome for the input.
           We change that by adding a new step to the postprocess function to already
           return the cutout.

        Args:
            data (list): The input data that needs to be made a prediction request on.
            context (Context): It is a JSON Object containing information pertaining to
                               the model artefacts parameters.

        Returns:
            list : Returns the data input with the cutout applied.
        """
        start_time = time.time()

        self.context = context
        metrics = self.context.metrics

        images = self.load_images(data)
        data_preprocess = self.preprocess(images)

        if not self._is_explain():
            output = self.inference(data_preprocess)
            output = self.postprocess(images, output)
        else:
            output = self.explain_handle(data_preprocess, data)

        stop_time = time.time()
        metrics.add_time(
            "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms"
        )
        return output

We will get into the specifics of handler.py's structure in the next section. Let's focus on the U2Net class for now. There are a few nifty tricks in here.

This class is the one we tell serve to use when handling a U2Net request. It needs to control mainly two things: It needs to control model initialization and model inference.

Model initialization is already taken care of by the BaseHandler. We are covered in that front. In case you need to cover other special requirements on model startup, you can overwrite the initialization process overriding the initialize function. Take a look into BaseHandler’s implementation for more details.

When TorchServe receives a request for model inference, it will call the handle function. And here is where the actual refactoring begins.

The handle function will receive the request. It will contain the data and context to work with and decide what to do with it.

In our case, we need the handler to:

  1. Load the images into PIL format.
  2. Then preprocess them with our custom transformations
  3. Get the image mask prediction from the model (forward pass also taken care of by BaseHandler)
  4. Finally, place the mask over the original image to return it.

Sounds straightforward enough. But there are a few caveats on how data gets in and out of there we will need to address. And also an optimization that may be useful when working with large outputs.

Firstly, the received images (the data argument) come in base64 format. And inside a list, and then inside a dictionary under the “data” key. So be sure to access that structure and turn your b64 image into something useful before moving forward.

Then, after you turned over your results to TorchServe to send back the response, it has a type-check to see if you are returning your results inside a list and checks if you are returning the same amount of results as images you received in your batch. Make sure you are considering those two checks. We had to figure this one out by digging through serve’s model service module since we didn’t want to do batch inference for this model and, because of the following optimization, we weren’t returning a list either.

Lastly, the optimization. All default post-processes do a .tolist() before sending the result back (that is an easy way to always return a list). But we had an image, and if you know Python, you know it is not very good at shuffling things between similar structures. The .tolist() was taking a while for us, so instead we converted the result into bytes (.tobytes()) and wrapped that in a list.

This simple optimization almost doubled this endpoint’s throughput. This happens because NumPy’s .tolist() has to create a new Python list and do the structure translation while .tobytes() just returns the raw contents of data memory it is already using.

In our example, we are sending an image and receiving another one of the same sizes, with its background removed. So all we need to do to load the final image, that came as a set of bytes representing a NumPy array, is to save the original image’s size and do:

Image.frombytes("RGBA", self.image_size, res.content)

As a side note, in case you are using custom modules like we are with preprocessing, you can import them as if they were in the same root directory. That is where torch-model-archiver will place them. More of that in a second. Like so:

from preprocessing import Normalize, Rescale

And that is it! Create a class, inherit from BaseHandler, or VisionHandler for that matter, and create your own handle function to fit your use-case. Looking out for the specifics of how TorchServe works around it.

In the following section, we will work on getting all this new and improved code into a .mar file that can get registered through the Management API.

Choosing your entry point

Okay, you made your custom handler class, now what? As explained in the official docs, to apply the changes explained before you will need to specify where your custom entry point is. The entry point will be the one taking care of model initialization (on startup and scale-up) and inference. In our example, it is the U2Net class.

But your custom handler can have either a module-level entry point or a class-level one.

As I said, we will be using a class-level entry point to keep everything nice and structured as our complexity increases. But if you want to use a module-level one, check this section of the docs for better pointers.

When creating your model archive through torch-model-archiver make sure to add the --handler option pointing to your python file. That would be handler.py for us. This file should only have one class. In our example it is called U2Net, but you can call it whatever you want. If it is not the only class, you will get an Expected only one class in custom service code or a function entry point error. If you need to use more classes, make sure to move them to a new file and import them as explained in the previous section.

Remember, TorchServe expects your handler class to have a def handle(self, data, context) in there to take care of the complete inference process, such as preprocessing, inference, post-processing, and capturing any metrics or logging you may need along the way.

If you are getting any weird errors apart from those we talked about here, check out the official model loader module, if it is an error on load, or the Service class if it is a prediction error.

Here’s an example torch-model-archiver command applying everything we discussed:

torch-model-archiver --model-name u2net --version 1.0.0 --serialized-file u2net.pth --handler handler.py --extra-files preprocessing.py --requirements-file model_requirements.txt

Setting up for success

So, we went over the structure needed for the custom inference handler. All that is left now is to handle any extra files we are using to keep our handle class alone in its file. And tell TorchServe to install the external dependencies we need.

What we need to add is the --extra-files option when creating your model archive. This is a comma-separated list of file paths that the handler class will use.

In case your model needs some specific dependencies installed, TorchServe handles that natively. You will have to define your models requirements.txt file and add its path with the --requirements-file option. Then, you’ll have to go to your TorchServe’s config.properties file to add this line, enabling model dependencies installation:

install_py_dep_per_model=true

This will let the Java frontend know there is a requirement’s list to be installed when setting up or updating the model. In the process, it will let you know if it fails by telling which package was it with a "Custom pip package installation failed for {}" You will find the code where that happens here.

Summary

To wrap up, let’s review what we covered: We looked into how to set up your custom handler class, saw how TorchServe will work with it, prepared the build .mar file with all it needs and got the TorchServe environment ready to receive these new models.

So, if your models could benefit from a custom pipeline, you need a lighter API, you need to improve tracking inside your serve deployment, or anything in between, give it a go!

We love TorchServe and hope this walkthrough helps anyone who wants to make the most out of it. Be it as an introduction or to set up their models in an ordered fashion and in a fraction of the time.

If you liked this article, check out the other entries on our blog.

See you around!

Do you like our content?

Don't miss the rest of our content