Deploying ML models in Clojure

Kira Howe ‘s 2024 article about the current state of ML in Clojure prominently features the Tribuo library by Oracle Labs and the Clojure wrapper for Tribuo. Tribuo integrates XGBoost, ONNX runtime, and Tensorflow-Java. However the Tensorflow bindings for Java look a bit verbose (see e.g. MNIST example).

Another approach is to train the model in Python, export it to the ONNX format and then use the ONNX runtime directly to perform inference in Clojure. There is a recent tutorial on using ONNX models from Clojure. However it only deals with tabular data.

Training

The following example uses PyTorch to train a traditional CNN classifier on the well-known MNIST dataset (the dataset can be obtained here). The implementation performs the following steps:

  • A class for reading MNIST images and labels is implemented.
  • A CNN model using two convolutional layers and two fully connected layers is implemented and dropout regularization is applied.
  • The training and test data is loaded as batches.
  • The cross entropy loss function and an Adam optimizer are instantiated. Note that learning rate and dropout are hyperparameters which need to be tuned.
  • The training loop performs prediction, loss computation, backpropagation, and optimization step.
  • The test loop accumulates and displays the prediction accuracy on the test set.
  • After 25 epochs, the models is exported to the ONNX format.
import numpy as np
import torch
from torch import nn
from torch import onnx
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset


class MNISTData(Dataset):

    def __init__(self, images_file_name, labels_file_name):
        """Read MNIST images and labels from specified files"""
        super(MNISTData, self).__init__()
        # Read images (skip magic, length, height, and width integers)
        self.images = np.fromfile(images_file_name, dtype=np.uint8)[16:].reshape(-1, 28, 28)
        # Read labels (skip magic and length integer)
        self.labels = np.fromfile(labels_file_name, dtype=np.uint8)[8:]

    def __len__(self):
        """Return the number of images (or labels) in the dataset"""
        return len(self.labels)

    def __getitem__(self, idx):
        """Return the image and label at the specified index"""
        image = torch.from_numpy(self.images[idx]).to(torch.float) / 255.0
        label = torch.zeros(10)
        label[self.labels[idx]] = 1
        return image, label


class MNISTNet(nn.Module):

    def __init__(self):
        """Construct network with 2 convolutional layers and 2 fully connected layers"""
        super(MNISTNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d(p=0.2)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        """Perform forward pass of network"""
        x = x.view(-1, 1, 28, 28)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.fc2(x)
        return F.softmax(x, dim=1)


def main():
    train_data = MNISTData('data/train-images-idx3-ubyte', 'data/train-labels-idx1-ubyte')
    test_data = MNISTData('data/t10k-images-idx3-ubyte', 'data/t10k-labels-idx1-ubyte')

    train_loader = DataLoader(train_data, batch_size=64)
    test_loader = DataLoader(test_data, batch_size=64)

    model = MNISTNet()
    loss = nn.CrossEntropyLoss()
    # Adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(25):
        for x, y in train_loader:
            pred = model(x)
            l = loss(pred, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()

        correct = 0
        total = 0
        for x, y in test_loader:
            pred = model(x).argmax(dim=1)
            correct += (pred == y.argmax(dim=1)).sum().item()
            total += len(y)
        print('Accuracy: {}'.format(correct / total))

    # Save model as ONNX
    torch.onnx.export(model,
                      (torch.randn((1, 1, 28, 28), dtype=torch.float),),
                      'mnist.onnx',
                      input_names=['input'],
                      output_names=['output'])

Inference

The model file mnist.onnx can now be used for inference in Clojure. The deps.edn file specifies the ONNX runtime and the cljfx library:

{:deps {com.microsoft.onnxruntime/onnxruntime {:mvn/version "1.20.0"}
        cljfx/cljfx {:mvn/version "1.9.3"}}
 :paths ["."]
 :aliases {:infer {:main-opts ["-m" "infer.core"]}}}

The infer.clj file contains the code to run the inference on the model. The code contains the following functions for inference:

  • read-digit - Read a 28*28 gray-scale byte block from the MNIST dataset
  • feature-scaling - Scale byte features to [0, 1] floating-point range. Note that Clojure byte arrays contain signed values which need to be converted to unsigned values!
  • argmax - Return the index of the maximum value of a one-dimensional probability vector.
  • infer - Convert a byte array to a ONNX tensor with batch size and number of channels being 1, run inference, and return the argmax of the probability vector.

Furthermore the digit->image function uses the idea shown in James Thompson’s Gist to convert a byte array to a JavaFX image in order to display it. The remaining code displays a small JavaFX GUI showing random images from the MNIST test data and the inference result.

(ns infer.core
    (:require [clojure.java.io :as io]
              [cljfx.api :as fx])
    (:import [java.io ByteArrayOutputStream ByteArrayInputStream]
             [java.nio FloatBuffer]
             [javafx.application Platform]
             [ai.onnxruntime OrtEnvironment OrtSession OnnxTensor]))

(def environment (OrtEnvironment/getEnvironment))

(def mnist (-> environment (.createSession "mnist.onnx")))

(defn read-digit [n]
  "Read a 28*28 gray-scale byte block from the MNIST dataset."
  (with-open [in (io/input-stream "data/t10k-images-idx3-ubyte")]
    (.skip in (+ 16 (* n 28 28)))
    (.readNBytes in (* 28 28))))

(defn byte->ubyte [b]
  "Convert byte to unsigned byte"
  (if (>= b 0) b (+ b 256)))

(defn feature-scaling [digit]
  "Scale features to [0, 1] range"
  (float-array (map #(/ (byte->ubyte %) 255.0) digit)))

(defn argmax [arr]
  "Return the index of the maximum value in the array"
  (first
    (reduce (fn [[result maximum] [index value]] (if (> value maximum) [index value] [result maximum]))
            [0 (first arr)]
            (map vector (range) arr))))

(defn inference [digit]
  "Run inference on a digit image"
  (let [scaled        (feature-scaling digit)
        input-buffer  (FloatBuffer/wrap scaled)
        inputs        {"input" (OnnxTensor/createTensor environment input-buffer (long-array [1 1 28 28]))}
        outputs       (.run mnist inputs)
        output-tensor (.get (.get outputs "output"))
        output-buffer (.getFloatBuffer output-tensor)
        result        (float-array 10)]
    (.get output-buffer result)
    (argmax result)))

(defn digit->image [data]
  "Convert a 28*28 byte array to JavaFX image"
  (let [image  (java.awt.image.BufferedImage. 28 28 java.awt.image.BufferedImage/TYPE_BYTE_GRAY)
        raster (.getRaster image)
        out    (ByteArrayOutputStream.)]
    (.setDataElements raster 0 0 28 28 data)
    (javax.imageio.ImageIO/write image "png" out)
    (.flush out)
    (javafx.scene.image.Image. (ByteArrayInputStream. (.toByteArray out)))))

(def app-state (atom {:index (rand-int 10000)}))

(defn event-handler [& args]
  "Update application state with random index"
  (swap! app-state update :index (fn [_] (rand-int 10000))))

(defn display-image [{:keys [image]}]
  "Image display for cljfx GUI"
  {:fx/type :image-view
   :fit-width 256
   :fit-height 256
   :image image})

(defn next-button [_]
  "Next button for cljfx GUI"
  {:fx/type :button
   :text "Next"
   :on-action event-handler})

(defn root [{:keys [index]}]
  "Main window for cljfx GUI"
  (let [digit  (read-digit index)
        result (inference digit)]
    {:fx/type :stage
     :showing true
     :title "MNIST"
     :scene {:fx/type :scene
             :root {:fx/type :v-box
                    :padding 3
                    :spacing 5
                    :children [{:fx/type display-image :image (digit->image digit)}
                               {:fx/type :h-box
                                :padding 3
                                :spacing 5
                                :children [{:fx/type next-button}
                                           {:fx/type :label :text (str "result = " result)}]}]}}}))

(def renderer
  "Renderer for cljfx GUI"
  (fx/create-renderer
   :middleware (fx/wrap-map-desc assoc :fx/type root)))

(defn -main [& args]
  (Platform/setImplicitExit true)
  (fx/mount-renderer app-state renderer))

Here is a screenshot of the inference GUI:

inference GUI screenshot

GPU usage

For the MNIST example a CPU is sufficient for training and inference. For larger models one needs to use a GPU.

In PyTorch one can use the .to method to move models and tensors to the GPU. For inference in Clojure, one needs to install onnxruntime_gpu instead of onnxruntime. Furthermore one needs to select a GPU device when creating a session:

; ...
(def device-id 0)
(def options (OrtSession$SessionOptions.))
(.addCUDA options device-id)
(def environment (OrtEnvironment/getEnvironment))

(def mnist (-> environment (.createSession "mnist.onnx" options)))
; ...

Conclusion

The ONNX runtime allows you to train models using PyTorch and deploy them in Clojure applications. Furthermore there are Tensorflow-Java bindings however they are more verbose. Hopefully the Clojure Tribuo bindings eventually will provide a more concise API for implementing ML models and training them.

When using byte arrays in Clojure to represent images, one needs to convert them to unsigned byte in order to obtain correct results. In the example we also used feature scaling for faster convergence during training.

Also see github.com/wedesoft/clojure-onnx for source code.

Enjoy!

The SOLID principles illustrated using Clojure code examples

Here is a short introduction to the SOLID software design principles explained using Clojure code examples.

Click above image to watch the 20 minutes presentation.

You can get the SOLID Clojure slides here: solid-clojure.pdf

If Python is your language of choice, you can get the slides here: solid-python.pdf

See github.com/wedesoft/solid for source code of slides.

Any suggestions and comments are welcome.

Performance comparison of Clojure, Ruby, and Python

Introduction

Speedometer

A fair performance comparison of implementations of different programming languages is difficult. Ideally one measures the performance for a wide range of algorithms and programming tasks. However each language will have different application areas where it performs best.

In the following we compare the performance of Clojure, Ruby, and Python on the factorial function. While the results might not allow to make general statements about the performance of each language interpreter, it gives us an idea of what the individual strengths and weaknesses are.

The measurements were performed using an AMD Ryzen 7 4700U with 16GB RAM and 8 cores.

  • For the Clojure measurements I used the Criterium benchmarking library.
  • For the Ruby measurements I used the benchmark module of the standard library.
  • For the Python measurements I used the Pyperf module. See this Gist for the benchmarking source code.

Implementations

There are different ways to implement the factorial function.

Recursive

In functional programming the recursive implementation of the factorial function is the most common.

The Clojure implementation is as follows.

(defn factorial [n]
  (if (zero? n)
    1
    (*' n (factorial (dec n)))))

Note that we use *' instead of * which is required if we want to be able to handle large numbers.

The Ruby implementation is as follows.

def factorial n
  return 1 if n <= 1
  n * factorial(n - 1)
end

Finally the Python implementation is as follows.

def factorial(n):
    if n == 0:
        return 1
    else:
        return n * factorial(n - 1)

Loop

If we introduce an additional variable, we can implement the factorial using a loop.

In Clojure we can use tail recursion as follows.

(defn factorial [n]
  (loop [result 1N n n]
    (if (zero? n)
      result
      (recur (*' result n) (dec n)))))

Note that we need to initialise the result with 1N (big number type) instead of 1, because the rebinding of the result variable does not allow a dynamic type change.

In Ruby we can implement factorial using a while loop.

def factorial n
  result = 1
  while n.positive?
    result *= n
    n -= 1
  end
  result
end

In Python we can implement factorial using a while loop as well.

def factorial(n):
    result = 1
    while n > 0:
        result *= n
        n -= 1
    return result

Reduce

Using the higher-order function reduce we can implement the factorial as follows.

In Clojure:

(defn factorial [n] (reduce *' (range 1 (inc n))))

In Ruby:

def factorial n
  1.upto(n).reduce :*
end

And in Python:

def factorial(n):
    return reduce(operator.mul, range(1, n + 1))

Unchecked integer math

One can use unchecked integers and type annotations in Clojure if it is known, that the result is not going to exceed the integer range.

(set! *unchecked-math* true)
(defn factorial [^long n] (if (zero? n) 1 (* n (factorial (dec n)))))
(set! *unchecked-math* false)

A similar approach in Python for improving performance is to use the Cython compiler. Here the method is implemented in a dialect of Python which uses static typing.

def factorial(int n):
    cdef int i, ret
    ret = 1
    for i in range(n):
        ret *= n
    return ret

Ruby only has the RubyInline library which requires to reimplement the method in C.

class Factorial
  inline do |builder|
    builder.c "
        long factorial(int max) {
          int i=max, result=1;
          while (i >= 2) { result *= i--; }
          return result;
        }"
  end
end

Other implementations

In Python one can instead use the factorial from the math module.

In Clojure one can apply the multiplication function to a range of numbers, since multiplication in Clojure can take an arbitrary number of arguments.

(defn factorial [n] (apply *' (range 1 (inc n))))

Only Clojure supports parallelism. For computing factorials we can use the fold function. Here we split the task into two chunks. Unfortunately range does not support random access, so we need to convert it to a vector.

(defn factorial4 [n] (clojure.core.reducers/fold (quot n 2) *' *' (vec (range 1 (inc n)))))

Finally if the input argument is known at compile time, one can use a macro in Clojure. This obviously is going to have much better performance than all the other implementations.

(defmacro factorial-macro [n]
  `(fn [] (*' ~@(range 1 (inc n)))))

Factorial of 20

First we compared the performance of computing the factorial of 20.

implementation Clojure 1.12.0 Ruby 3.4.1 Python 3.13.1
recursive 104 ns 538 ns 957 ns
loop 164 ns 665 ns 717 ns
reduce 116 ns 1512 ns 718 ns
unchecked integer 44.4 ns 77.6 ns 41.6 ns
fold 6211 ns n/a n/a
math library n/a n/a 45.4 ns
apply 178 ns n/a n/a
macro 0.523 ns n/a n/a

The Clojure implementation makes use of the JVM and the resulting performance for recursive, loop, and reduce implementation of factorial is the best. Forcing fold, which is a parallel version of reduce, to use 2 threads, does not yield better performance.

Note that the recursive implementation of Ruby is faster than the Python implementation. This is maybe due to the YJIT optimizing JIT compiler build into the Ruby interpreter.

Also note that the loop implementation in Python is faster than the recursive implementation. Surprisingly the reduce implementation in Python has comparable performance to the loop implementation.

The factorial implementation of the Python math library is very fast. The Python math library implementation uses a lookup table for arguments up to 20.

Since the result of factorial of 20 fits into a 64 bit integer, one can use unchecked integers in Clojure to get a fast implementation. The resulting implementation is slightly faster than the Python math library implementation. Note that the factorial implementation of the Python math library does not perform unchecked 64 bit integer math. In Python one can use the Cython dialect to use the C compiler and unchecked math with 64 bit integers. Finally one can use RubyInline to embed a C implementation in a Ruby program. The Cython implementation is the fastest for computing factorial of 20.

Finally using a macro, which of the three languages only Clojure supports, is faster than all other implementations but obviously limited to cases where the function argument is known at compile time.

Factorial of 100

When computing the factorial of arguments greater than 20, big integers are required. In the following table the performance of computing factorials of 100 is shown.

implementation Clojure 1.12.0 Ruby 3.4.1 Python 3.13.1
recursive 3148 ns 27191 ns 5520 ns
loop 3434 ns 32426 ns 4430 ns
reduce 2592 ns 28851 ns 3790 ns
fold 16012 ns n/a n/a
math library n/a n/a 599 ns
apply 2624 ns n/a n/a
macro 3.70 ns n/a n/a

Again the recursive, loop, and reduce implementations are fastest in Clojure. Also again using fold with two chunks fares much worse.

The Ruby implementations fare much worse for big numbers. Maybe the implementation of big numbers in Ruby has much worse performance than the Python and Clojure one.

Overall the implementation in the Python math library is the fastest candidate (unless we can use a Clojure macro of course).

Conclusion

As stated before, one cannot generalize this limited performance comparison. However maybe one can maybe make the following observations:

  • The combination of Clojure and the JVM allows for better performance of dynamically typed programs.
  • The performance of numerical algorithms in Clojure can be further improved when using unchecked math. The performance in this case is even getting close to Cython.
  • Being able to call an AOT compiled C-implementation of a numerical algorithm still gives the best performance as can be seen by the Python math library implementation of factorial. The fact that this method exists is maybe a reflection of the fact that Python has the strongest support when it comes to numerical libraries.
  • Only Clojure supports macros, which allow for results to be computed at compile time instead of at runtime, if the arguments are known early.
  • Currently only Clojure supports parallel algorithms such as fold and pmap. However they only offer performance benefits for larger tasks than the one tested here.

Any suggestions and comments are welcome.

Updates:

  • Replaced Numba implementation with Cython.
  • Add type hints to unchecked math Clojure implementation.
  • Add RubyInline implementation.

Because Python and Ruby bind methods late, there is a significant overhead when calling methods. In the following the identity function is tested in Clojure, Ruby, and Python. It looks like the JVM has even inlined the Clojure method, because the method invocation time is close to zero.

implementation Clojure 1.12.0 Ruby 3.4.1 Python 3.13.1
identity function 0.0023 ns 42.5 ns 44.2 ns

Also see: https://clojure-goes-fast.com/

Minimal OpenGL example in C using GLEW and GLFW

OpenGL is a reasonably abstract API for doing 3D graphics. In the past I did an example of OpenGL using GLUT. However GLUT is a bit outdated now and a more modern alternative is GLFW. The example still uses GLEW to setup the OpenGL extensions.

This example is minimal and only uses a vertex shader and a fragment shader to get started with OpenGL. For an example using tesselation and geometry shaders as well, see my short introduction to OpenGL.

Note that it is important to add code for retrieving error messages (as I have done below) in order to be able to do development of the shaders.

As in my old example, the code draws a coloured triangle on the screen.

// Minimal OpenGL example using GLFW and GLEW
#include <math.h>
#include <stdio.h>
#include <GL/glew.h>
#include <GLFW/glfw3.h>

// Vertex shader source code:
// This shader takes in vertex positions and texture coordinates,
// passing them to the fragment shader.
const char *vertexSource = "#version 130\n\
in mediump vec3 point;\n\
in mediump vec2 texcoord;\n\
out mediump vec2 UV;\n\
void main()\n\
{\n\
  gl_Position = vec4(point, 1);\n\
  UV = texcoord;\n\
}";

// Fragment shader source code:
// This shader samples the color from a texture based on UV coordinates.
const char *fragmentSource = "#version 130\n\
in mediump vec2 UV;\n\
out mediump vec3 fragColor;\n\
uniform sampler2D tex;\n\
void main()\n\
{\n\
  fragColor = texture(tex, UV).rgb;\n\
}";

GLuint vao; // Vertex Array Object
GLuint vbo; // Vertex Buffer Object
GLuint idx; // Index Buffer Object
GLuint tex; // Texture
GLuint program; // Shader program
int width = 320; // Width of window in pixels
int height = 240; // Height of window in pixels

// Function to handle shader compile errors
void handleCompileError(const char *step, GLuint shader)
{
  GLint result = GL_FALSE;
  glGetShaderiv(shader, GL_COMPILE_STATUS, &result);
  if (result == GL_FALSE) {
    char buffer[1024];
    glGetShaderInfoLog(shader, 1024, NULL, buffer);
    if (buffer[0])
      fprintf(stderr, "%s: %s\n", step, buffer);
  };
}

// Function to handle shader program link errors
void handleLinkError(const char *step, GLuint program)
{
  GLint result = GL_FALSE;
  glGetProgramiv(program, GL_LINK_STATUS, &result);
  if (result == GL_FALSE) {
    char buffer[1024];
    glGetProgramInfoLog(program, 1024, NULL, buffer);
    if (buffer[0])
      fprintf(stderr, "%s: %s\n", step, buffer);
  };
}

// Vertex data:
// Each vertex has a position (x, y, z) and a texture coordinate (u, v)
GLfloat vertices[] = {
   0.5f,  0.5f,  0.0f, 1.0f, 1.0f,  // Top right
  -0.5f,  0.5f,  0.0f, 0.0f, 1.0f,  // Top left
  -0.5f, -0.5f,  0.0f, 0.0f, 0.0f   // Bottom left
};

// Indices for drawing the triangle
unsigned int indices[] = { 0, 1, 2 };

// Texture BGR data for a 2x2 texture
float pixels[] = {
  0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f,
  1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f
};

int main(int argc, char** argv)
{
  // Initialize GLFW library.
  glfwInit();
  // Create a window.
  GLFWwindow *window = glfwCreateWindow(width, height, "minimal OpenGL example", NULL, NULL);
  // Set current OpenGL context to window.
  glfwMakeContextCurrent(window);
  // Initialize GLEW library.
  glewInit();

  glViewport(0, 0, width, height);

  // Compile and check vertex shader.
  GLuint vertexShader = glCreateShader(GL_VERTEX_SHADER);
  glShaderSource(vertexShader, 1, &vertexSource, NULL);
  glCompileShader(vertexShader);
  handleCompileError("Vertex shader", vertexShader);

  // Compile and check fragment shader.
  GLuint fragmentShader = glCreateShader(GL_FRAGMENT_SHADER);
  glShaderSource(fragmentShader, 1, &fragmentSource, NULL);
  glCompileShader(fragmentShader);
  handleCompileError("Fragment shader", fragmentShader);

  // Link and check shader program.
  program = glCreateProgram();
  glAttachShader(program, vertexShader);
  glAttachShader(program, fragmentShader);
  glLinkProgram(program);
  handleLinkError("Shader program", program);

  // Create a vertex array object which serves as context for the
  // vertex buffer object and the index buffer object.
  glGenVertexArrays(1, &vao);
  glBindVertexArray(vao);

  // Initialize vertex buffer object with the vertex data.
  glGenBuffers(1, &vbo);
  glBindBuffer(GL_ARRAY_BUFFER, vbo);
  glBufferData(GL_ARRAY_BUFFER, sizeof(vertices), vertices, GL_STATIC_DRAW);

  // Initialize the index buffer object with the index data.
  glGenBuffers(1, &idx);
  glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, idx);
  glBufferData(GL_ELEMENT_ARRAY_BUFFER, sizeof(indices), indices, GL_STATIC_DRAW);

  // Set up layout of vertex buffer object.
  glVertexAttribPointer(glGetAttribLocation(program, "point"), 3, GL_FLOAT,
                        GL_FALSE, 5 * sizeof(float), (void *)0);
  glVertexAttribPointer(glGetAttribLocation(program, "texcoord"), 2, GL_FLOAT,
                        GL_FALSE, 5 * sizeof(float), (void *)(3 * sizeof(float)));

  // Enable depth testing using depth buffer.
  glEnable(GL_DEPTH_TEST);

  // Switch to the shader program.
  glUseProgram(program);

  // Enable the two variables of the vertex buffer layout.
  glEnableVertexAttribArray(glGetAttribLocation(program, "point"));
  glEnableVertexAttribArray(glGetAttribLocation(program, "texcoord"));

  // Initialize texture.
  glGenTextures(1, &tex);
  // Bind texture to first slot.
  glActiveTexture(GL_TEXTURE0);
  glBindTexture(GL_TEXTURE_2D, tex);
  // Set uniform texture in shader object to first texture.
  glUniform1i(glGetUniformLocation(program, "tex"), 0);
  // Load pixel data into texture.
  glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, 2, 2, 0, GL_BGR, GL_FLOAT, pixels);
  // Set texture wrapping mode and interpolation modes.
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
  // Initialize multiresolution layers.
  glGenerateMipmap(GL_TEXTURE_2D);

  // Loop until the user closes the window.
  while (!glfwWindowShouldClose(window)) {
    // Clear color buffer and depth buffer.
    glClearColor(0.0f, 0.0f, 0.0f, 0.0f);
    glClear(GL_COLOR_BUFFER_BIT|GL_DEPTH_BUFFER_BIT);
    // Switch to the shader program.
    glUseProgram(program);
    // Draw triangle(s).
    glDrawElements(GL_TRIANGLES, 3, GL_UNSIGNED_INT, (void *)0);
    // Swap front and back buffers.
    glfwSwapBuffers(window);
    // Poll for and process events.
    glfwPollEvents();
  };

  // Disable the two shader variables.
  glDisableVertexAttribArray(glGetAttribLocation(program, "point"));
  glDisableVertexAttribArray(glGetAttribLocation(program, "texcoord"));

  // Unbind and delete the texture.
  glBindTexture(GL_TEXTURE_2D, 0);
  glDeleteTextures(1, &tex);

  // Unbind and delete the index buffer object.
  glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, 0);
  glDeleteBuffers(1, &idx);

  // Unbind and delete the vertex buffer object.
  glBindBuffer(GL_ARRAY_BUFFER, 0);
  glDeleteBuffers(1, &vbo);

  // Unbind and delete the vertex array object.
  glBindVertexArray(0);
  glDeleteVertexArrays(1, &vao);

  // Unlink and delete the shader program.
  glDetachShader(program, vertexShader);
  glDetachShader(program, fragmentShader);
  glDeleteProgram(program);
  glDeleteShader(vertexShader);
  glDeleteShader(fragmentShader);

  // Set OpenGL context to NULL.
  glfwMakeContextCurrent(NULL);
  // Destroy window.
  glfwDestroyWindow(window);
  // Terminate GLFW.
  glfwTerminate();
  return 0;
}

The example uses the widely supported OpenGL version 3.1 (which has the version tag 130). You can download, compile, and run the example as follows:

wget https://www.wedesoft.de/downloads/raw-opengl-glfw.c
gcc -o raw-opengl-glfw raw-opengl-glfw.c $(pkg-config --libs glfw3 glew)
./raw-opengl-glfw

image

Any feedback, comments, and suggestions are welcome.

Enjoy!

Getting started with the Jolt Physics Engine

Motivation

In the past I have experimented with sequential impulses to implement constraints (see part 1, part 2, part 3, part 4, part 5, part 6 of my rigid body physics series). I tried to integrate Runge-Kutta integration with sequential impulses. However it was difficult to prevent interpenetration of objects. Also implementing a vehicle with wheels and suspension, where the weight ratio between the vehicle and the wheels was high, required a high number of iterations to stabilise. Finally stacking of boxes turned out to be unstable.

In a GDC 2014 talk, Erin Catto showed sequential impulses and stable box stacking in the Box2D engine. Stacking of 2D boxes was made stable by solving for multiple impulses at the same time.

In 2022 Jorrit Rouwé released JoltPhysics which is a physics engine for 3D rigid objects also using sequential impulses. His GDC 2022 talk Architecting Jolt Physics for Horizon Forbidden West refers to Erin Catto’s talk and discusses various performance optimisations developed in Jolt Physics.

In the following I have provided a few Jolt physics example programs to demonstrate some capabilities of the physics engine.

Installing Jolt

Jolt Physics is a C++ library built using CMake. To compile with double precision, I invoked JoltPhysics/Build/cmake_linux_clang_gcc.sh as follows:

cd Build
./cmake_linux_clang_gcc.sh Release g++ -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DDOUBLE_PRECISION=ON \
    -DDEBUG_RENDERER_IN_DEBUG_AND_RELEASE=OFF -DPROFILER_IN_DEBUG_AND_RELEASE=OFF

A release build with g++ and installation is done as follows:

cd Linux_Release
make -j `nproc`
sudo make install
cd ../..

Next you can have a look at JoltPhysics/HelloWorld/HelloWorld.cpp which is a simple example of a sphere bouncing on a floor. The example shows how to implement the required layers and collision filters (e.g. stationary objects cannot collide with each other). Make sure to define the Trace variable so you get useful warnings if something goes wrong.

Tumbling object in space

In this section we test the tumbling motion of a cuboid in space.

To compile a C++ program using Jolt, you need to use the same preprocessor definitions which were used to compile Jolt. If you have set up the Trace function, you will get a warning if the preprocessor definitions do not match.

Here is an example Makefile to compile and link a program with the release build of the Jolt library, GLFW, and GLEW.

CCFLAGS = -g -O3 -fPIC -Wall -Werror -DNDEBUG -DJPH_OBJECT_STREAM -DJPH_DOUBLE_PRECISION $(shell pkg-config --cflags glfw3 glew)
LDFLAGS = -flto=auto $(shell pkg-config --libs glfw3 glew) -lJolt

all: tumble

tumble: tumble.o
	g++ -o $@ $^ $(LDFLAGS)

clean:
	rm -f tumble *.o

.cc.o:
	g++ -c $(CCFLAGS) -o $@ $<

See Makefile for complete build code.

The core of the example creates a shape of dimension a×b×c and sets the density to 1000.0. Furthermore the convex radius used for approximating collision shapes needs to be much smaller than the object dimensions. The limit for the linear velocity is lifted and most importantly the solution for gyroscopic forces is enabled. Furthermore linear and angular damping are set to zero. Finally the body is created, added to the physics system, and the angular velocity is set to an interesting value. The code snippet is shown below:

float a = 1.0;
float b = 0.1;
float c = 0.5;
// ...
BoxShapeSettings body_shape_settings(Vec3(a, b, c));
body_shape_settings.mConvexRadius = 0.01;
body_shape_settings.SetDensity(1000.0);
body_shape_settings.SetEmbedded();
ShapeSettings::ShapeResult body_shape_result = body_shape_settings.Create();
ShapeRefC body_shape = body_shape_result.Get();
BodyCreationSettings body_settings(body_shape, RVec3(0.0, 0.0, 0.0), Quat::sIdentity(), EMotionType::Dynamic, Layers::MOVING);
body_settings.mMaxLinearVelocity = 10000.0;
body_settings.mApplyGyroscopicForce = true;
body_settings.mLinearDamping = 0.0;
body_settings.mAngularDamping = 0.0;
Body *body = body_interface.CreateBody(body_settings);
body_interface.AddBody(body->GetID(), EActivation::Activate);
body_interface.SetLinearVelocity(body->GetID(), Vec3(0.0, 0.0, 0.0));
body_interface.SetAngularVelocity(body->GetID(), Vec3(0.3, 0.0, 5.0));

Here is a video showing the result of the simulation. As one can see, Jolt is able to simulate a tumbling motion without deterioation.

See tumble.cc for full source code.

Stack of cuboids

In this section we test the falling motion of a stack of cuboids. Three cuboids are created and the initial positions are staggered in the x direction to get a more interesting result. Using i = 0, 1, 2 the cuboids are created in the following way:

BoxShapeSettings body_shape_settings(Vec3(0.5 * a, 0.5 * b, 0.5 * c));
body_shape_settings.mConvexRadius = 0.01;
body_shape_settings.SetDensity(1000.0);
body_shape_settings.SetEmbedded();
ShapeSettings::ShapeResult body_shape_result = body_shape_settings.Create();
ShapeRefC body_shape = body_shape_result.Get();
BodyCreationSettings body_settings(body_shape, RVec3(i * 0.4, 0.2 + i * 0.2, -i * 0.3), Quat::sIdentity(), EMotionType::Dynamic, Layers::MOVING);
body_settings.mMaxLinearVelocity = 10000.0;
body_settings.mApplyGyroscopicForce = true;
body_settings.mLinearDamping = 0.0;
body_settings.mAngularDamping = 0.0;
Body *body = body_interface.CreateBody(body_settings);
body->SetFriction(0.5);
body->SetRestitution(0.3f);
body_interface.AddBody(body->GetID(), EActivation::Activate);

Furthermore a ground shape is created. Note that for simplicity I only created one layer. If the ground was composed of multiple convex objects, a static layer should be created and used.

BoxShapeSettings ground_shape_settings(Vec3(3.0, 0.1, 3.0));
ground_shape_settings.mConvexRadius = 0.01;
ground_shape_settings.SetEmbedded();
ShapeSettings::ShapeResult ground_shape_result = ground_shape_settings.Create();
ShapeRefC ground_shape = ground_shape_result.Get();
BodyCreationSettings ground_settings(ground_shape, RVec3(0.0, -0.5, 0.0), Quat::sIdentity(), EMotionType::Static, Layers::MOVING);
Body *ground = body_interface.CreateBody(ground_settings);
ground->SetFriction(0.5);
body_interface.AddBody(ground->GetID(), EActivation::DontActivate);

Note that the bodies need to be activated for the simulation to take place.

body_interface.ActivateBody(body->GetID());

The simulation is run by repeatedly calling the Update method on the physics system.

const int cCollisionSteps = 1;
physics_system.Update(dt, cCollisionSteps, &temp_allocator, &job_system);

The following video shows the result of the simulation.

See stack.cc for full source code.

For a more challenging demo of this type, see the Stable Box Stacking demo video by Jorrit Rouwé.

Double pendulum

The double pendulum is created using the HingeConstraintSettings class. There are two hinges. One between the base and the upper arm of the pendulum and one between the upper arm and the lower arm. The physics library also requires initialisation of a vector normal to the hinge axis.

HingeConstraintSettings hinge1;
hinge1.mPoint1 = hinge1.mPoint2 = RVec3(0.0, 0.5, 0);
hinge1.mHingeAxis1 = hinge1.mHingeAxis2 = Vec3::sAxisZ();
hinge1.mNormalAxis1 = hinge1.mNormalAxis2 = Vec3::sAxisY();
physics_system.AddConstraint(hinge1.Create(*base, *upper));

HingeConstraintSettings hinge2;
hinge2.mPoint1 = hinge2.mPoint2 = RVec3(a, 0.5, 0);
hinge2.mHingeAxis1 = hinge2.mHingeAxis2 = Vec3::sAxisZ();
hinge2.mNormalAxis1 = hinge2.mNormalAxis2 = Vec3::sAxisY();
physics_system.AddConstraint(hinge2.Create(*upper, *lower));

The following video shows the result.

See pendulum.cc for full source code.

Suspension

Another test case is a prismatic joint with a suspension constraint. The prismatic joint is created using the SliderConstraintSettings class. The suspension is created using a soft distance constraint. The code snippet is shown below:

SliderConstraintSettings slider_settings;
slider_settings.mAutoDetectPoint = true;
slider_settings.SetSliderAxis(Vec3::sAxisY());
physics_system.AddConstraint(slider_settings.Create(*boxes[0], *boxes[1]));

DistanceConstraintSettings distance_settings;
distance_settings.mPoint1 = RVec3(0.0, 0.0, 0.0);
distance_settings.mPoint2 = RVec3(0.0, 0.4, 0.0);
distance_settings.mLimitsSpringSettings.mDamping = 0.1f;
distance_settings.mLimitsSpringSettings.mStiffness = 1.0f;
physics_system.AddConstraint(distance_settings.Create(*boxes[0], *boxes[1]));

The video shows the result of running this sumulation.

See suspension.cc for full source code.

Wheeled vehicle

Jolt comes with a specialised implementation for simulating wheeled vehicles (there is also even one for tracked vehicles). The vehicle API allows placing the wheels and adjusting the suspension minimum and maximum length. One can set the angular damping of the wheels to zero. Furthermore there are longitudinal and lateral friction curves of the wheels which I haven’t modified. Finally there is a vehicle controller object for setting motor, steering angle, brakes, and hand brake.

RefConst<Shape> car_shape = new BoxShape(Vec3(half_vehicle_width, half_vehicle_height, half_vehicle_length));
BodyCreationSettings car_body_settings(car_shape, RVec3::sZero(), Quat::sIdentity(), EMotionType::Dynamic, Layers::MOVING);
car_body_settings.mOverrideMassProperties = EOverrideMassProperties::CalculateInertia;
car_body_settings.mMassPropertiesOverride.mMass = 1500.0f;
car_body_settings.mLinearDamping = 0.0;
car_body_settings.mAngularDamping = 0.0;

VehicleConstraintSettings vehicle;

WheelSettingsWV *w1 = new WheelSettingsWV;
w1->mPosition = Vec3(0.0f, -0.9f * half_vehicle_height, half_vehicle_length - 1.0f * wheel_radius);
w1->mSuspensionMinLength = wheel_radius;
w1->mSuspensionMaxLength = 2 * wheel_radius;
w1->mAngularDamping = 0.0f;
w1->mMaxSteerAngle = 0.0f; // max_steering_angle;
w1->mMaxHandBrakeTorque = 0.0f;
w1->mRadius = wheel_radius;
w1->mWidth = wheel_width;

WheelSettingsWV *w2 = new WheelSettingsWV;
w2->mPosition = Vec3(half_vehicle_width, -0.9f * half_vehicle_height, -half_vehicle_length + 1.0f * wheel_radius);
// ...

WheelSettingsWV *w3 = new WheelSettingsWV;
w3->mPosition = Vec3(-half_vehicle_width, -0.9f * half_vehicle_height, -half_vehicle_length + 1.0f * wheel_radius);
// ...

vehicle.mWheels = {w1, w2, w3};

WheeledVehicleControllerSettings *controller = new WheeledVehicleControllerSettings;
vehicle.mController = controller;

Body *car_body = body_interface.CreateBody(car_body_settings);
body_interface.AddBody(car_body->GetID(), EActivation::Activate);
VehicleConstraint *constraint = new VehicleConstraint(*car_body, vehicle);
VehicleCollisionTester *tester = new VehicleCollisionTesterRay(Layers::MOVING);
constraint->SetVehicleCollisionTester(tester);
physics_system.AddConstraint(constraint);
physics_system.AddStepListener(constraint);

vehicle_controller->SetDriverInput(0.0f, 0.0f, 0.0f, 0.0f);

A vehicle dropping on the ground with horizontal speed is shown in the following video.

Note that the inertia of the wheels was high in this video. One can correct this by reducing the inertia of the wheels as follows.

w1->mInertia = 0.01;
w2->mInertia = 0.01;
w3->mInertia = 0.01;

See vehicle.cc for full source code.

Enjoy!

Update:

To prevent tunnelling when fast objects are colliding, you can switch the motion quality to linear cast instead of discrete:

body_settings.mMotionQuality = EMotionQuality::LinearCast;

Update:

Note that there is a mMinVelocityForRestitution setting. I.e. if two bodies collide at a velocity below that (default is 1.0 m/s), an inelastic collision will occur.