Contents

WebSocket chat using structured concurrency, Ox & Tapir

WebSocket chat using structured concurrency, Ox & Tapir webp image

A chat application is often used to demonstrate various approaches to working with WebSockets. We'll follow suit; below is a quick tour of how such a service can be implemented in direct style using:

  • Java 21 for its built-in asynchronous runtime & Virtual Threads
  • the Ox library for structured concurrency and error handling
  • the Tapir library for defining and exposing endpoints.

Let's start!

Echo server

Before diving into the chat itself, let's look at a WebSocket echo server, which will help us familiarize ourselves with the basic building blocks. To run the code below, you'll need to use the tapir-netty-server-sync variant of Tapir, which uses Netty for low-level networking, and supports the so-called direct-style, where we can freely use blocking, synchronous calls (hence the name).

Here's the complete code, runnable using scala-cli:

//> using dep com.softwaremill.sttp.tapir::tapir-netty-server-sync:1.10.7

import sttp.tapir.*
import sttp.tapir.CodecFormat.*
import sttp.tapir.server.netty.sync.{Id, NettySyncServer, OxStreams}

// (1)
val wsEndpoint = endpoint.get
  .in("echo")
  .out(webSocketBody[String, TextPlain, String, TextPlain](OxStreams))

// (2)
val wsProcessor: OxStreams.Pipe[String, String] =
  requestStream => requestStream.map(msg => s"You said: $msg")

// (3)
val wsServerEndpoint =
  wsEndpoint.serverLogicSuccess[Id](_ => wsProcessor)

// (4)
@main def echoWsServer(): Unit =
  NettySyncServer().addEndpoint(wsServerEndpoint).startAndWait()

First, in (1), we create a description of the WebSocket endpoint that we want to expose, using regular code. In our case, it will support GET /echo calls, as the output will be a WebSocket body, expecting text/plain string messages as the input and producing text/plain string messages as the output.

Secondly, we need to implement the logic of the endpoint. WebSockets in Tapir are stream-based, that is, you get a stream of incoming messages, and you have to provide a stream of outgoing messages. In (2), we define such a pipe.

OxStreams.Pipe[String, String] is an alias for Ox ?=> Channel[String] => Channel[String]. We'll deal with the Ox ?=> part later. For now, what's important is that we get a channel, which is a queue-like data structure. We map over each incoming element, returning a channel of transformed values, each of which will be sent back over the WebSocket.

In (3), we combine the endpoint description with the logic into a single data structure. Finally, we can start the server in (4), exposing the combined endpoint + logic.

Unwinding the map

Instead of using the high-level .map channel transformation in requestStream.map(msg => s"You said: $msg"), we could have implemented it "manually". Let's try that.

We need to perform three operations:

  1. create the outgoing channel
  2. start a background thread, which loops: receives a message from the incoming channel, transforms it, and sends it to the outgoing one
  3. return this outgoing channel as the result, while keeping the background thread running

Translated into code, this takes the following form:

import ox.channels.Channel
import ox.fork

val wsProcessor: OxStreams.Pipe[String, String] =
  requestStream =>
    val outgoing = Channel.bufferedDefault[String] // (1)
    fork { // (2)
      while true do
        val msg = requestStream.receive()
        outgoing.send(s"You said: $msg")
    }
    outgoing // (3)

Creating the outgoing channel, receiving & sending in a loop, and returning the channel is probably self-explanatory, but let's examine the background thread creation in more detail. We use the fork method from Ox, which allows starting a daemon thread in the current concurrency scope, and returns immediately.

However, we don't define any concurrency scopes here: they are externally managed by the Tapir<->Ox integration. Each WebSocket gets its own scope, tied with the lifetime of the WebSocket. A reference to that scope, allowing creating forks, is passed through the Ox ?=> context parameter. Forks can only be started within a concurrency scope, and the compiler verifies that.

Moreover, as the lifetime of the scope is tied to that of the WebSocket, if the socket closes (e.g., the client quits abruptly), the concurrency scope will also end, interrupting any threads that have been created. Conversely, if there's an unhandled error (an exception is thrown) from any fork created within the concurrency scope, it will also end, closing the WebSocket.

Hence, we can safely start background threads. They can be started only within the definition of the WebSocket processing pipe, with a clear lifetime and in a limited scope.

Chatting

Now, we can proceed to the chat application itself.

First, let's describe the endpoint. We're going to use a text-based protocol as before, though we'll wrap incoming data into a Message data class, just in case we'd like to migrate e.g., to a JSON-based protocol including both the nickname and the message in the near future:

case class Message(v: String)
given Codec[String, Message, TextPlain] = 
  Codec.string.map(Message(_))(_.v)

val chatEndpoint = endpoint.get
  .in("chat")
  .out(webSocketBody[Message, TextPlain, Message, TextPlain](OxStreams))

We need a codec to support our custom data type (again, this could easily be changed to JSON or any other representation).

Next, we'll need to implement the chat application's logic. Apart from the per-WebSocket OxStreams.Pipe[Message, Message], we'll also need a central component that keeps the current chat state.

That's an ideal use case for an actor. The actor will hold mutable state: the current list of the chat members. The actor's state will be protected and guaranteed to always be accessed serially by a single thread. The actor's logic should support a new member arriving and a member leaving, as well as handling a message which should be broadcasted.

Once again, we'll use channels to communicate with member WebSockets. Here's the data structure we'll use to represent chat members:

type ChatMemberId = UUID

case class ChatMember(id: ChatMemberId, channel: Channel[Message])
object ChatMember:
  def create: ChatMember = ChatMember(
    UUID.randomUUID(), Channel.bufferedDefault[Message])

And the logic of the actor, a plain class with mutable state:

class ChatRoom:
  private var members: Map[ChatMemberId, ChatMember] = Map()

  def connected(m: ChatMember): Unit =
    members = members + (m.id -> m)
    println(s"Connected: ${m.id}, number of members: ${members.size}")

  def disconnected(m: ChatMember): Unit =
    members = members—m.id
    println(s"Disconnected: ${m.id}, number of members: ${members.size}")

  def incoming(message: Message): Unit =
    println(s"Broadcasting: ${message.v}")
    members = members.flatMap { (id, member) => selectOrClosed(
      member.channel.sendClause(message), Default(())) match
        case member.channel.Sent() => Some((id, member))
        case _: ChannelClosed =>
          println(s"Channel of member $id closed, removing from members")
          None
        case DefaultResult(_) =>
          println(s"Buffer for member $id full, not sending message")
          Some((id, member))
    }

The only non-trivial part is broadcasting the message. For that, we perform a select, which tries to send the message to the member's channel. If this is not immediately possible (because the buffer is full), the Default clause is selected, and the message is dropped. That way, we keep memory usage in check, even for members whose WebSockets get clogged.

Of course, that's only the raw logic; we still need to create the chat actor. We can use the Actor object from Ox: val chatActor = Actor.create(new ChatRoom). This needs to be run within a concurrency scope, as a background thread is created, which executes the incoming actor operations one by one. The result, chatActor, is of type ActorRef[ChatRoom], using which we can schedule operations to be performed (yes, underneath is yet another Channel).

Having the chatActor: ActorRef[ChatRoom], we can implement the WebSocket logic. Once again, we're using the fact that each WebSocket gets its own concurrency scope, and we start a background thread to pass any incoming messages to the central actor:

def chatProcessor(
    a: ActorRef[ChatRoom]): OxStreams.Pipe[Message, Message] =
  incoming => {
    val member = ChatMember.create // create id, member channel

    a.tell(_.connected(member)) // register channel

    fork {
      incoming.foreach { msg =>
        a.tell(_.incoming(msg)) // schedule broadcast in the actor
      }
    }

    releaseAfterScope {
      member.channel.done()
      a.tell(_.disconnected(member)) // ensure cleanup
    }

    member.channel // return the channel with outgoing messages
  }

The new feature that we're using here is specifying a resource that should be released, when the scope ends, using releaseAfterScope. Whenever the WebSocket's concurrency scope ends (either because the WS gets closed by the client, or there's a server-side exception), this code is guaranteed to be run. In the cleanup, we deregister the member's channel from the global chat state.

Both the fork and releaseAfterScope methods manipulate the concurrency scope, and they return almost immediately, allowing the chatProcessor function to complete and return the outgoing channel. Then, the Tapir<->Ox integration will read messages from that channel and send them over the WebSocket.

The last step is starting the actor and the server. Remember that we need to start the actor in a concurrency scope, and that scope needs to live as long as the server!

@main def chatWsServer(): Unit =
  supervised {
    val chatActor = Actor.create(new ChatRoom)
    val chatServerEndpoint = chatEndpoint
      .serverLogicSuccess[Id](_ => chatProcessor(chatActor))
    NettySyncServer().addEndpoint(chatServerEndpoint).startAndWait()
  }

The entire scala-cli runnable example (75 lines of code) can be found here. Not entirely relevant for this example, but as a bonus, you could add AsyncAPI documentation generated using the endpoint description that we've created using just 3 additional lines of code.

To compare the approach above with a purely functional one, you might be interested in another article on a similar topic, "Functional pancakes".

And that's it—we've got a concise, structured, safe implementation of a WebSocket chat. You can find out more about Ox and Tapir in their documentation. And please tell us what you think in the comments!

Check: IO effect tracking using Ox

Blog Comments powered by Disqus.