Creating a custom choice field for enums in Django REST Framework

·

5 min read

The problem with ChoiceField

In Django REST Framework, there is a field called ChoiceField which can be used when you have certain choices for a field in your REST API. For example, let's imagine you were creating an API to convert colors to their respective RGB values. This is how your serializers.py file would look like:

from rest_framework.fields import ChoiceField
from rest_framework.serializers import Serializer


class ColorSerializer(Serializer):
    color = ChoiceField(choices=["RED", "GREEN", "BLUE"])

And the views.py

from typing import Optional, Tuple

from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView

from core.serializers import ColorSerializer


class Color2RGBAPIView(APIView):
    def post(self, request: Request) -> Response:
        serializer = ColorSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        color = serializer.validated_data["color"]
        return Response({"rgb": self.get_rgb(color)})

    def get_rgb(self, color: str) -> Optional[Tuple[int, int, int]]:
        rgb = None
        match color:
            case "RED":
                rgb = (255, 0, 0)
            case "GREEN":
                rgb = (0, 255, 0)
            case "BLUE":
                rgb = (0, 0, 255)
        return rgb

While we're at it, let's write a test to make sure everything works as expected. Your test.py file should look as below:

from django.test import TestCase
from django.urls import reverse
from rest_framework.status import HTTP_200_OK


class Color2RGBAPIViewTest(TestCase):
    def test_color_2_rgb(self):
        url = reverse("color2rgb")
        payload = {"color": "BLUE"}
        response = self.client.post(url, data=payload)
        self.assertEqual(response.status_code, HTTP_200_OK)
        self.assertDictEqual(response.json(), {"rgb": [0, 0, 255]})

Let's run the tests:

(venv) django-enum-choice-field $ python manage.py test
Found 1 test(s).
Creating test database for alias 'default'...
System check identified no issues (0 silenced).
.
----------------------------------------------------------------------
Ran 1 test in 0.003s

OK
Destroying test database for alias 'default'...

So tests pass. Awesome!

Code at this point can be found in feature/add-color2rgb-api branch here.

And this would work okay... until, you decide that you want to replace BLUE color with YELLOW. For that, you'll have to find the string literal "BLUE" everywhere in the following files:

  • serializers.py

  • views.py

  • tests.py

and replace it with "YELLOW" including its respective RGB value. Refactoring like this is very risky and error-prone. If you misspell any occurrences, your code will fall apart. So what do we do?

Enums to the rescue

Create an enums.py and put the following code in there:

from enum import Enum, auto


class Color(Enum):
    RED = auto()
    GREEN = auto()
    BLUE = auto()

Then you need to import it to serializers.py and refactor as follows:

from rest_framework.fields import ChoiceField
from rest_framework.serializers import Serializer

from core.enums import Color  # new line


class ColorSerializer(Serializer):
    color = ChoiceField(choices=Color._member_names_)  # changed line

And views.py

from typing import Optional, Tuple

from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView

from core.enums import Color  # new line
from core.serializers import ColorSerializer


class Color2RGBAPIView(APIView):
    def post(self, request: Request) -> Response:
        serializer = ColorSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        color = Color[serializer.validated_data["color"]]  # changed line
        return Response({"rgb": self.get_rgb(color)})

    def get_rgb(
        self,
        color: Color,  # changed line
    ) -> Optional[Tuple[int, int, int]]:
        rgb = None
        match color:
            case Color.RED:  # changed line
                rgb = (255, 0, 0)
            case Color.GREEN:  # changed line
                rgb = (0, 255, 0)
            case Color.BLUE:  # changed line
                rgb = (0, 0, 255)
        return rgb

Lastly, you can update the tests.py to utilize the newly created enum:

from django.test import TestCase
from django.urls import reverse
from rest_framework.status import HTTP_200_OK

from core.enums import Color  # new line


class Color2RGBAPIViewTest(TestCase):
    def test_color_2_rgb(self):
        url = reverse("color2rgb")
        payload = {"color": Color.BLUE.name}  # changed line
        response = self.client.post(url, data=payload)
        self.assertEqual(response.status_code, HTTP_200_OK)
        self.assertDictEqual(response.json(), {"rgb": [0, 0, 255]})

Now, if you misspell a color, you'll get squiggly lines in your IDE by the linter:

And your tests will also fail.

If you fix all the attribute names and run the tests again:

(venv) django-enum-choice-field $ python manage.py test
Found 1 test(s).
Creating test database for alias 'default'...
System check identified no issues (0 silenced).
.
----------------------------------------------------------------------
Ran 1 test in 0.003s

OK
Destroying test database for alias 'default'...

Sure enough, everything still works as before.

Code at this point can be found in improvement/add-color-enum branch here.

Can we do even better?

Everything looks okay at this point but there are two minor problems.

First, we are using the private attribute _member_names_ of the Color enum in serializers.py file. If we had many ChoiceField type fields then it'll look ugly.

Second, notice this line in views.py

color = Color[serializer.validated_data["color"]]

The serializer.validated_data["color"] part returns a str and then I am using Color[] syntax to convert it to a Color enum type. This is probably okay if you have it in a couple of places but if you plan to use enums a lot then this syntax will become annoying.

What if your serializer could work some magic and provide you with the enum directly instead of returning a string and having you convert it to an enum explicitly?

Let's try that, shall we?

Creating custom EnumChoiceField

So the idea is to inherit from Django REST Framework's ChoiceField and create a custom field called EnumChoiceField which can do this magic for us. Create a new file called fields.py

from enum import Enum

from rest_framework.fields import ChoiceField


class EnumChoiceField(ChoiceField):
    error_messages = {
        "invalid_enum_class": "enum must be a subclass of builtin Enum class",
    }

    def __init__(self, enum: Enum, **kwargs):
        if not issubclass(enum, Enum):
            self.fail("invalid_enum_class")
        self.enum = enum
        super().__init__(enum._member_names_, **kwargs)

    def to_internal_value(self, data):
        return self.enum[super().to_internal_value(data)]

    def to_representation(self, value):
        return value.name

And then have your serializers.py look something like this:

from rest_framework.serializers import Serializer

from core.enums import Color
from core.fields import EnumChoiceField  # new line


class ColorSerializer(Serializer):
    color = EnumChoiceField(enum=Color)  # updated line

Now we're not using private attributes anymore in our serializer.

And finally this line in views.py

color = Color[serializer.validated_data["color"]]

can be refactored as:

color = serializer.validated_data["color"]

The color variable will be of Color type in the first place so you don't have to convert the string explicitly. Neat isn't it?

Let's run the tests again to make sure nothing is broken.

(venv) django-enum-choice-field $ python manage.py test                         
Found 1 test(s).
Creating test database for alias 'default'...
System check identified no issues (0 silenced).
.
----------------------------------------------------------------------
Ran 1 test in 0.003s

OK
Destroying test database for alias 'default'...

And all tests pass!

Code at this point can be found in improvement/add-enum-choice-field branch here.

To utilize EnumChoiceField in your projects, you can simply copy what's in the fields.py and use it as I have in serializers.py.

All the code is available on GitHub.