Marcel Blijleven

Published on

Discriminating Unions and Type Narrowing

A discriminating union is a union of types that can be identified by one or more of their properties, or discriminator(s). These discriminators are often literal values or type guards.

In this post I’ll provide some examples using type narrowing in Python before moving on to an example using discriminating unions in Pydantic to parse .env files.

Basic example

Take a look at the following code where we have two datastructures. Both structures have the property security_protocol, which is the discriminator. The other properties are different and not shared between the datastructures.

from pathlib import Path
from typing import Literal, TypeAlias


class SSLAuthentication:
    security_protocol: Literal["SSL"]
    certificate_location: Path


class SASLSSLAuthentication:
    security_protocol: Literal["SASL_SSL"]
    sasl_mechanism: Literal["PLAIN", "GSSAPI"]


Authentication: TypeAlias = SSLAuthentication | SASLSSLAuthentication

It’s easy to narrow down the type using the builtin isinstance function, this will also provide the correct code completions for the corresponding type.

def do_something(auth: Authentication) -> str:
    if isinstance(auth, SSLAuthentication):
        return auth.certificate_location.name
    elif isinstance(auth, SASLSSLAuthentication):
        return auth.sasl_mechanism

But sometimes you want to check the type based on something else, for example a property value. You won’t get much type hinting by simply using Equality Narrowing on Authentication.security_protocol, unlike with TypeScript for example. While writing the following function your IDE will not give you any narrowed down type hints. When you dot into the properties of auth on line 2, you’ll still see sasl_mechanism which is not a property of SSLAuthentication.

def do_something(auth: Authentication) -> str:
    if auth.security_protocol == "SSL":
        return auth.certificate_location.name
    elif auth.security_protocol == "SASL_SSL":
        return auth.sasl_mechanism

Type guards

This is where a user defined type guard is useful. TypeGuard was introduced in Python 3.10 via PEP 647. When a function has return type TypeGuard[int] and the return value is True, it is expected that the type is actually an int. By adding a user defined TypeGuard to the function you will get proper code completion and type hinting when narrowing down a type by property values.

from typing import TypeGuard


def is_ssl(auth: Authentication) -> TypeGuard[SSLAuthentication]:
    return auth.security_protocol == "SSL"


def is_sasl(auth: Authentication) -> TypeGuard[SASLSSLAuthentication]:
    return auth.security_protocol == "SASL_SSL"


def do_something_better(auth: Authentication) -> str:
    if is_ssl(auth):
        return auth.certificate_location.name
    elif is_sasl(auth):
        return auth.sasl_mechanism

This becomes even more useful when we split the SASLSSLAuthentication class by sasl_mechanism into two new classes.

class PlainAuthentication:
    security_protocol: Literal["SASL_SSL"]
    sasl_mechanism: Literal["PLAIN"]


class GSSAPIAuthentication:
    security_protocol: Literal["SASL_SSL"]
    sasl_mechanism: Literal["GSSAPI"]


Authentication: TypeAlias = (
    SSLAuthentication | PlainAuthentication | GSSAPIAuthentication
)

def is_plain(auth: Authentication) -> TypeGuard[SASLSSLAuthentication]:
    return auth.security_protocol == "SASL_SSL" and auth.sasl_mechanism == "PLAIN"


def is_gssapi(auth: Authentication) -> TypeGuard[GSSAPIAuthentication]:
    return auth.security_protocol == "SASL_SSL" and auth.sasl_mechanism == "GSSAPI"

Discriminating unions

Pydantic introduces a simple way to do this by specifying a discriminator on a field. This reduces the amount of code you have to write yourself, which improves maintainability should you ever add more authentication methods.

There are two ways of using a discriminator with Pydantic. The most simple way is by providing the name of the property to the discriminator kwarg in a Field constructor.

Authentication: TypeAlias = SSLAuthentication | PlainAuthentication | GSSAPIAuthentication

class Client(BaseModel):
    auth: Authentication = Field(..., discriminator="security_protocol")

This will only discriminate by one property though, so you will not know if it is PlainAuthentication or GSSAPIAuthentication. The second way of using a discriminating union, by using Annotated and Discriminator, makes that possible.

from typing import Annotated, TypeAlias
from pydantic import BaseModel, ConfigDict, Discriminator

SASLAuthentications: TypeAlias = PlainAuthentication | GSSAPIAuthentication


class Client(BaseModel):
    auth: Annotated[
        SSLAuthentication | Annotated[
            SASLAuthentications,
            Discriminator("sasl_mechanism")
        ],
        Discriminator("security_protocol")
    ]

    model_config = ConfigDict(
        extra="ignore"
    )

It’s especially useful when using Pydantic to validate user data, for example from and .env file using Pydantic Settings. You won’t have to write logic to instantiate certain authentication methods by checking values yourself, this will all be done automatically.

Env file example

Take a look at the following script. It defines a Settings model which has a property auth, its type is a discriminating union. When instantiated the Settings model will read the .env file and populate its auth property depending on the values in that file. Below the script you’ll find three sets of input/output to demonstrate that this simple code can handle multiple user input settings without adding any logic yourself.

from pathlib import Path
from typing import TypeAlias, Literal, Annotated

from pydantic import BaseModel, Discriminator
from pydantic_settings import BaseSettings, SettingsConfigDict


class SSLAuthentication(BaseModel):
    security_protocol: Literal["SSL"]
    certificate_location: Path
    whatami: str = "I am SSLAuthentication"


class PlainAuthentication(BaseModel):
    security_protocol: Literal["SASL_SSL"]
    sasl_mechanism: Literal["PLAIN"]
    whatami: str = "I am PlainAuthentication"


class GSSAPIAuthentication(BaseModel):
    security_protocol: Literal["SASL_SSL"]
    sasl_mechanism: Literal["GSSAPI"]
    whatami: str = "I am GSSAPIAuthentication"


SASLSSLAuthentication: TypeAlias = PlainAuthentication | GSSAPIAuthentication


class Settings(BaseSettings):
    title: str
    auth: Annotated[
        SSLAuthentication | Annotated[
            SASLSSLAuthentication,
            Discriminator("sasl_mechanism")
        ],
        Discriminator("security_protocol")
    ]

    model_config = SettingsConfigDict(
        case_sensitive=False,
        env_file=".env",
        env_file_encoding="utf-8",
        env_nested_delimiter="__",
        env_prefix="blog_",
        extra="ignore",
    )


if __name__ == "__main__":
    settings = Settings()
    print(settings.model_dump_json(indent=4))

Env SSL

BLOG_TITLE="Discriminating Unions"
BLOG_AUTH__SECURITY_PROTOCOL="SSL"
BLOG_AUTH__CERTIFICATE_LOCATION="/home/ca.crt"

Output SSL

{
    "title": "Discriminating Unions",
    "auth": {
        "security_protocol": "SSL",
        "certificate_location": "/home/ca.crt",
        "whatami": "I am SSLAuthentication"
    }
}

Env SASL_SSL, Plain

BLOG_TITLE="Discriminating Unions"
BLOG_AUTH__SECURITY_PROTOCOL="SASL_SSL"
BLOG_AUTH__SASL_MECHANISM="PLAIN"

Output SASL_SSL, Plain

{
    "title": "Discriminating Unions",
    "auth": {
        "security_protocol": "SASL_SSL",
        "sasl_mechanism": "PLAIN",
        "whatami": "I am PlainAuthentication"
    }
}

Env SASL_SSL, Plain

BLOG_TITLE="Discriminating Unions"
BLOG_AUTH__SECURITY_PROTOCOL="SASL_SSL"
BLOG_AUTH__SASL_MECHANISM="PLAIN"

Output SASL_SSL, GSSAPI

{
    "title": "Discriminating Unions",
    "auth": {
        "security_protocol": "SASL_SSL",
        "sasl_mechanism": "GSSAPI",
        "whatami": "I am GSSAPIAuthentication"
    }
}

That’s it for now!