Creating a custom choice field for enums in Django REST Framework
The problem with ChoiceField
In Django REST Framework, there is a field called ChoiceField
which can be used when you have certain choices for a field in your REST API. For example, let's imagine you were creating an API to convert colors to their respective RGB values. This is how your serializers.py
file would look like:
from rest_framework.fields import ChoiceField
from rest_framework.serializers import Serializer
class ColorSerializer(Serializer):
color = ChoiceField(choices=["RED", "GREEN", "BLUE"])
And the views.py
from typing import Optional, Tuple
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from core.serializers import ColorSerializer
class Color2RGBAPIView(APIView):
def post(self, request: Request) -> Response:
serializer = ColorSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
color = serializer.validated_data["color"]
return Response({"rgb": self.get_rgb(color)})
def get_rgb(self, color: str) -> Optional[Tuple[int, int, int]]:
rgb = None
match color:
case "RED":
rgb = (255, 0, 0)
case "GREEN":
rgb = (0, 255, 0)
case "BLUE":
rgb = (0, 0, 255)
return rgb
While we're at it, let's write a test to make sure everything works as expected. Your test.py
file should look as below:
from django.test import TestCase
from django.urls import reverse
from rest_framework.status import HTTP_200_OK
class Color2RGBAPIViewTest(TestCase):
def test_color_2_rgb(self):
url = reverse("color2rgb")
payload = {"color": "BLUE"}
response = self.client.post(url, data=payload)
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertDictEqual(response.json(), {"rgb": [0, 0, 255]})
Let's run the tests:
(venv) django-enum-choice-field $ python manage.py test
Found 1 test(s).
Creating test database for alias 'default'...
System check identified no issues (0 silenced).
.
----------------------------------------------------------------------
Ran 1 test in 0.003s
OK
Destroying test database for alias 'default'...
So tests pass. Awesome!
Code at this point can be found in
feature/add-color2rgb-api
branch here.
And this would work okay... until, you decide that you want to replace BLUE
color with YELLOW
. For that, you'll have to find the string literal "BLUE"
everywhere in the following files:
serializers.py
views.py
tests.py
and replace it with "YELLOW"
including its respective RGB value. Refactoring like this is very risky and error-prone. If you misspell any occurrences, your code will fall apart. So what do we do?
Enums to the rescue
Create an enums.py
and put the following code in there:
from enum import Enum, auto
class Color(Enum):
RED = auto()
GREEN = auto()
BLUE = auto()
Then you need to import it to serializers.py
and refactor as follows:
from rest_framework.fields import ChoiceField
from rest_framework.serializers import Serializer
from core.enums import Color # new line
class ColorSerializer(Serializer):
color = ChoiceField(choices=Color._member_names_) # changed line
And views.py
from typing import Optional, Tuple
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from core.enums import Color # new line
from core.serializers import ColorSerializer
class Color2RGBAPIView(APIView):
def post(self, request: Request) -> Response:
serializer = ColorSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
color = Color[serializer.validated_data["color"]] # changed line
return Response({"rgb": self.get_rgb(color)})
def get_rgb(
self,
color: Color, # changed line
) -> Optional[Tuple[int, int, int]]:
rgb = None
match color:
case Color.RED: # changed line
rgb = (255, 0, 0)
case Color.GREEN: # changed line
rgb = (0, 255, 0)
case Color.BLUE: # changed line
rgb = (0, 0, 255)
return rgb
Lastly, you can update the tests.py
to utilize the newly created enum:
from django.test import TestCase
from django.urls import reverse
from rest_framework.status import HTTP_200_OK
from core.enums import Color # new line
class Color2RGBAPIViewTest(TestCase):
def test_color_2_rgb(self):
url = reverse("color2rgb")
payload = {"color": Color.BLUE.name} # changed line
response = self.client.post(url, data=payload)
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertDictEqual(response.json(), {"rgb": [0, 0, 255]})
Now, if you misspell a color, you'll get squiggly lines in your IDE by the linter:
And your tests will also fail.
If you fix all the attribute names and run the tests again:
(venv) django-enum-choice-field $ python manage.py test
Found 1 test(s).
Creating test database for alias 'default'...
System check identified no issues (0 silenced).
.
----------------------------------------------------------------------
Ran 1 test in 0.003s
OK
Destroying test database for alias 'default'...
Sure enough, everything still works as before.
Code at this point can be found in
improvement/add-color-enum
branch here.
Can we do even better?
Everything looks okay at this point but there are two minor problems.
First, we are using the private attribute _member_names_
of the Color
enum in serializers.py
file. If we had many ChoiceField
type fields then it'll look ugly.
Second, notice this line in views.py
color = Color[serializer.validated_data["color"]]
The serializer.validated_data["color"]
part returns a str
and then I am using Color[]
syntax to convert it to a Color
enum type. This is probably okay if you have it in a couple of places but if you plan to use enums a lot then this syntax will become annoying.
What if your serializer could work some magic and provide you with the enum directly instead of returning a string and having you convert it to an enum explicitly?
Let's try that, shall we?
Creating custom EnumChoiceField
So the idea is to inherit from Django REST Framework's ChoiceField
and create a custom field called EnumChoiceField
which can do this magic for us. Create a new file called fields.py
from enum import Enum
from rest_framework.fields import ChoiceField
class EnumChoiceField(ChoiceField):
error_messages = {
"invalid_enum_class": "enum must be a subclass of builtin Enum class",
}
def __init__(self, enum: Enum, **kwargs):
if not issubclass(enum, Enum):
self.fail("invalid_enum_class")
self.enum = enum
super().__init__(enum._member_names_, **kwargs)
def to_internal_value(self, data):
return self.enum[super().to_internal_value(data)]
def to_representation(self, value):
return value.name
And then have your serializers.py
look something like this:
from rest_framework.serializers import Serializer
from core.enums import Color
from core.fields import EnumChoiceField # new line
class ColorSerializer(Serializer):
color = EnumChoiceField(enum=Color) # updated line
Now we're not using private attributes anymore in our serializer.
And finally this line in views.py
color = Color[serializer.validated_data["color"]]
can be refactored as:
color = serializer.validated_data["color"]
The color
variable will be of Color
type in the first place so you don't have to convert the string explicitly. Neat isn't it?
Let's run the tests again to make sure nothing is broken.
(venv) django-enum-choice-field $ python manage.py test
Found 1 test(s).
Creating test database for alias 'default'...
System check identified no issues (0 silenced).
.
----------------------------------------------------------------------
Ran 1 test in 0.003s
OK
Destroying test database for alias 'default'...
And all tests pass!
Code at this point can be found in
improvement/add-enum-choice-field
branch here.
To utilize EnumChoiceField
in your projects, you can simply copy what's in the fields.py
and use it as I have in serializers.py
.
All the code is available on GitHub.