Fast operations on scikit-learn decision trees with numba

The title is a bit wordy. But that's what this post is about.

To start with, you might be wondering why someone would want to operate on a decision tree from inside numba in the first place. After all, the scikit-learn implementation of trees uses Cython, which should be providing something close to native C speeds for tree traversals.

In this particular instance, we had been using numba already in a particular bottleneck, and the next set of hot spots in this same bottleneck revolved around interacting with decision trees. Specifically, we end up calling the tree object's apply method

    cdef inline np.ndarray _apply_dense(self, object X):
        """Finds the terminal region (=leaf node) for each sample in X."""

        # Check input
        if not isinstance(X, np.ndarray):
            raise ValueError("X should be in np.ndarray format, got %s"
                             % type(X))

        if X.dtype != DTYPE:
            raise ValueError("X.dtype should be np.float32, got %s" % X.dtype)

        # Extract input
        cdef const DTYPE_t[:, :] X_ndarray = X
        cdef SIZE_t n_samples = X.shape[0]

        # Initialize output
        cdef np.ndarray[SIZE_t] out = np.zeros((n_samples,), dtype=np.intp)
        cdef SIZE_t* out_ptr = <SIZE_t*> out.data

        # Initialize auxiliary data-structure
        cdef Node* node = NULL
        cdef SIZE_t i = 0

        with nogil:
            for i in range(n_samples):
                node = self.nodes
                # While node not a leaf
                while node.left_child != _TREE_LEAF:
                    # ... and node.right_child != _TREE_LEAF:
                    if X_ndarray[i, node.feature] <= node.threshold:
                        node = &self.nodes[node.left_child]
                    else:
                        node = &self.nodes[node.right_child]

                out_ptr[i] = <SIZE_t>(node - self.nodes)  # node offset

        return out

somewhere on the order of 10 million or 100 million times in a row, and then do some postprocessing and groupby aggregations on the result.

Initially, the naïve solution we had ran these calls in serial, and all was well. As we scaled up in the amount of data we were handling, the runtime became a bottleneck and we were waiting days for the computation to finish. A second earlier attempt chunked the input data and ran an outer loop over chunks in parallel, but this caused too much memory pressure (mostly due to lots of transient data copies 😬) and nodes starting getting killed with 137 OOM.

We wanted to do something a little more clever and parallelize the apply calls and groupby aggregations in the inner loops of the algorithm, but since the interface exposed by scikit-learn is in Python, we would need to re-lock the GIL and make a standard Python function call for each of these applys.

Numba does have support for calling functions that are written in C (or in Cython). And scikit-learn provides pxd files for their cythonized classes all across the repository:

cdef class Tree:
    # The Tree object is a binary tree structure constructed by the
    # TreeBuilder. The tree structure is used for predictions and
    # feature importances.

    # Input/Output layout
    cdef public SIZE_t n_features        # Number of features in X
    cdef SIZE_t* n_classes               # Number of classes in y[:, k]
    cdef public SIZE_t n_outputs         # Number of outputs in y
    cdef public SIZE_t max_n_classes     # max(n_classes)

    # Inner structures: values are stored separately from node structure,
    # since size is determined at runtime.
    cdef public SIZE_t max_depth         # Max depth of the tree
    cdef public SIZE_t node_count        # Counter for node IDs
    cdef public SIZE_t capacity          # Capacity of tree, in terms of nodes
    cdef Node* nodes                     # Array of nodes
    cdef double* value                   # (capacity, n_outputs, max_n_classes) array of values
    cdef SIZE_t value_stride             # = n_outputs * max_n_classes

    # Methods
    cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
                          SIZE_t feature, double threshold, double impurity,
                          SIZE_t n_node_samples,
                          double weighted_n_node_samples) nogil except -1
    cdef int _resize(self, SIZE_t capacity) nogil except -1
    cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1

    cdef np.ndarray _get_value_ndarray(self)
    cdef np.ndarray _get_node_ndarray(self)

    cpdef np.ndarray predict(self, object X)

    cpdef np.ndarray apply(self, object X)
    cdef np.ndarray _apply_dense(self, object X)
    cdef np.ndarray _apply_sparse_csr(self, object X)

    cpdef object decision_path(self, object X)
    cdef object _decision_path_dense(self, object X)
    cdef object _decision_path_sparse_csr(self, object X)

    cpdef compute_feature_importances(self, normalize=*)

so in theory this should have been do-able, but we weren't able to get it working properly. In addition, we had been toying around with the idea of rewriting some parts of the tree traversal logic, and importing it directly would have prohibited this.

Luckily for us, scikit-learn trees expose their internal structure as a collection of numpy arrays that represent the set of lefthand child indexes, the set of righthand child indexes, the index of the feature used for splitting any node, and the threshold for that feature.1 So, for example, you may have the following:

lefthand = [1, -1, -1]
righthand = [2, -1, -1]
feature = [3, -1, -1]
threshold = [1.5, -1, -1]

which lets you know that node 0 (the root of the tree, but this logic works for any node) has its lefthand child at index 1, and its righthand child at index 2. It used whatever the fourth feature in your features table was to split, and every row with a value less than or equal to 1.5 went into the left child. The -1s indicate a leaf or terminal node.

Now we never have just one tree, so one problem is we have to do this for lots of trees at the same time. Numba supports typed lists as an experimental feature, but trying to pass in a typed list of numpy arrays resulted in compiler errors when we tried it. So, instead, we took the structure of each tree and packed them into forest-wide numpy arrays like so:

children_left = np.full((n_trees, n_leaves), -1, dtype=np.int64)
for i in range(n_trees):
    children_left[i] = forest.estimators_[i].tree_.children_left
    # repeat this for everything else
    ...

which makes it easy to pass them into numba functions. And, now that we have the tree structure separate from the cython class, we are free to reimplement the apply function as we see fit. Here's how we did it, with a bit of added parallelism thanks to numba:

@numba.njit([
    # like twenty lines of type signatures here
    ...
], nogil=True, parallel=True)
def _apply(X, children_left, children_right, feature, threshold):
    index_size = X.shape[0]
    num_trees = children_left.shape[0]
    result = np.empty((index_size, num_trees), dtype=np.uint16)
    for j in numba.prange(num_trees):
        for i in range(index_size):
            node = 0
            while True:
                if X[i, feature[j, node]] <= threshold[j, node]:
                    candidate_node = children_left(j, node)
                else:
                    candidate_node = children_right[j, node]
                if candidate_node < 0:
                    break
                else:
                    node = candidate_node
            result[i, j] = node
    return result

The end result of this effort is that a calculation that used to take a small number of days now completes in a small number of hours (frequently less than a single hour). This is mostly not because we reimplemented scikit-learn's apply method in numba -- the big performance gains came from optimizations on other parts of the bottleneck -- but a lot of people use scikit-learn, and so maybe this will be helpful to someone else. For more examples of using numba to make stuff go fast (like pandas), please see this earlier post on rewriting a groupby call in numba

Thanks to the numba developers for building awesome things, and thanks to the scikit-learn community for their stunningly high quality documentation.