I found this excellent book about implementing a ray tracer in C: raytracing.github.io/books/RayTracingInOneWeekend.html. As you may have discerned from the title, I’ve attempted to implement this ray tracer in Python. Here’s the source code: github.com/a-vinod/ray-tracer2

The astute of you may realize that this is a fool’s errand. The program runtime scales at O(height*width*aa_samples*recursion_depth) so of course this will be slow with Python! But it was still very interesting to see how quickly I hit the limits of a reasonable wait between iterations of generating an image. Also, it was worth learning about ray tracing without the pleasure of debugging seg faults.

As for packages, I’m limiting myself to numpy for math and cv2 for image import/export. I also used some convenience packages that I’ll go into later. Yes this isn’t 🍦 Python there’s some chocolate sauce too.

High-Level Program Flow

ray_tracer2/camera.py

def render(self, world: World) -> np.ndarray:
    ...
        for y in range(self.image_height):
            for x in range(self.image_width):
                ...
                for aa in range(self.aa_samples):
                    ...
                    color += self.ray_color(ray, world, 50)
                colors[y][x] = 255*color/self.aa_samples

This render() function is the top-level nested loop that populates the 2D numpy colors array with the RGB values for each pixel.

The ray_color() function traces the ray originating from the camera to each sampled point in each pixel in the viewport of the world. When the ray hits something in the world, the World::hit(Ray) function returns a HitRecord that has some information about the hit including the material to determine if/how the subsequent ray scatters.

The World::hit(Ray) iterates through the different objects in the world (currently only spheres) and invokes Sphere::hit(Ray) to compute the intersection of the ray and the sphere.

ray_tracer2/hittable.py

    def hit(self, ray: Ray, tmin: float, tmax: float) -> (HitRecord, bool):
        """
        x^2 + y^2 + z^2 = r^2
        (Cx-x)^2 + (Cy-y)^2 + (Cz-z)^2 = r^2
        Linear algebrify this
        [Cx-x,Cy-y,Cz-z] . [Cx-x,Cy-y,Cz-z] = r^2
            C = [Cx,Cy,Cz]
            P = [x,y,z]
        (C-P).(C-P) = r^2
            P = origin + t*direction (our ray)
              = Q + t*d
        (C-(Q + t*d)).(C-Q + t*d) = r^2
        (-t*d + (C-Q).(-t*d + (C-Q) = r^2
        (d.d)*t^2 - 2*(-t*d)*(C-Q) + (C-Q).(C-Q) = r^2
        (d.d)*t^2 + (-2*d).(C-Q)*t + (C-Q).(C-Q)-r^2 = 0
            a = d.d
            b = -2*(d.(C-Q))
            c = (C-Q).(C-Q)-r^2
        Then apply the quadratic formula to get t, which is the quantity
        scaling the direction of the ray. This effectively tells us the
        distances with respect to the origin that the ray intersects with
        this sphere!
        """
        a = np.dot(ray.direction, ray.direction)
        b = -2.0*np.dot(ray.direction, self.center - ray.origin)
        c = np.dot(self.center - ray.origin, self.center -
                   ray.origin)-(self.radius*self.radius)

        discriminant = b*b - 4*a*c
        ...

The world is built in the main function. This is actually one of my favorite parts of the system. Once objects (like Spheres) are defined, you can compose a world to your liking. Something I thought would be interesting would be to take this further and make a simple domain specific scripting language to compose the world programatically. The compiler’s backend format could just be a Python list that looks like the one in the main.

ray_tracer2/main.py

def main(output_image: str, image_width: int, anti_aliasing: int):
    world = World(
        hittable_list=[
            Sphere(center=[0.0, -100.5, -1.0], radius=100.0,
                   material=Lambertian(albedo=[0.8, 0.8, 0.0])),
            Sphere(center=[0.0, 0.0, -1.2], radius=0.5,
                   material=Lambertian(albedo=[0.1, 0.2, 0.5])),
            Sphere(center=[-1.0, 0.0, -1.0], radius=0.5,
                   material=Metal(albedo=[0.8, 0.8, 0.8], fuzz=0.3)),
            Sphere(center=[1.0, 0.0, -1.0], radius=0.5,
                   material=Metal(albedo=[0.8, 0.6, 0.2], fuzz=1.0)),
        ],
        tmin=0.001,
        tmax=1000000
    )

This was the first time I used the abstract class package in Python, abc. Taking an object-oriented approach certainly made the implementation more enjoyable, for example as I added new types of surface materials. Later in the book, there will be new types of hittables like rectangular prisms.

@dataclass
class Hittable(ABC):
    material: 'Material'

    @abstractmethod
    def hit(self, ray: Ray, tmin: float, tmax: float) -> (HitRecord, bool):
        return


@dataclass
class Sphere(Hittable):
    center: np.ndarray
    radius: float

    def hit(self, ray: Ray, tmin: float, tmax: float) -> (HitRecord, bool):
        ...

You may have also noticed I’m a fan of dataclasses. By decorating a class with @dataclass, it automatically generates implicit constructors with parameters to initialize member variables. It also reminds you to specify the type of your member variables.

Performance

The images started taking noticeably longer to generate, especially after adding anti-aliasing. So let’s profile the code.

Here are my CPU specs from lscpu:

  Model name:             11th Gen Intel(R) Core(TM) i5-1145G7 @ 2.60GHz
    CPU family:           6
    Model:                140
    Thread(s) per core:   2
    Core(s) per socket:   4
    Socket(s):            1

Let’s try generating a 400x225 image with 10 AA samples. Using Python’s cProfile package and running the program with only one process:

$ python3 -m cProfile -o cprofile.log ray-tracer2.py

And this python script to view the output:

import pstats
p = pstats.Stats('cprofile.log')
p.sort_stats('tottime').print_stats(50)
         132493209 function calls (131083200 primitive calls) in 181.877 seconds

   Ordered by: internal time
   List reduced from 1551 to 50 due to restriction <50>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  9225760   68.206    0.000   76.258    0.000 ray_tracer2/hittable.py:33(hit)
  1455329   17.573    0.000   37.015    0.000 ray_tracer2/utils.py:3(random_unit_vec)
    90000    9.651    0.000  181.106    0.002 ray_tracer2/camera.py:49(render_px)
  2872704    7.461    0.000   12.666    0.000 .venv/lib/python3.12/site-packages/numpy/linalg/_linalg.py:2566(norm)
  3673241    6.314    0.000    6.314    0.000 {method 'reduce' of 'numpy.ufunc' objects}
   900000    5.894    0.000   24.630    0.000 .venv/lib/python3.12/site-packages/numpy/lib/_arraypad_impl.py:545(pad)
   566264    5.216    0.000   22.605    0.000 ray_tracer2/material.py:21(scatter)
  3330462    4.644    0.000    4.644    0.000 ray_tracer2/ray.py:10(trace)
   851111    4.603    0.000    9.125    0.000 ray_tracer2/ray.py:13(colorize_miss)
  3673241    4.569    0.000   11.514    0.000 .venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py:69(_wrapreduction)
  2306440    4.390    0.000   80.648    0.000 ray_tracer2/hittable.py:86(hit)
2306509/900000    4.246    0.000  145.980    0.000 ray_tracer2/camera.py:32(ray_color)
   900000    4.102    0.000    5.799    0.000 .venv/lib/python3.12/site-packages/numpy/lib/_arraypad_impl.py:86(_pad_simple)
  1800000    3.730    0.000    7.646    0.000 .venv/lib/python3.12/site-packages/numpy/lib/_arraypad_impl.py:470(_as_pairs)
  3673241    3.542    0.000   15.732    0.000 .venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py:2255(sum)
   889065    3.251    0.000   29.325    0.000 ray_tracer2/material.py:38(scatter)
   900000    2.738    0.000    3.503    0.000 .venv/lib/python3.12/site-packages/numpy/lib/_arraypad_impl.py:129(_set_pad_area)
 30475039    2.532    0.000    2.532    0.000 .venv/lib/python3.12/site-packages/numpy/_core/multiarray.py:750(dot)
  2872705    2.343    0.000    2.343    0.000 {method 'dot' of 'numpy.ndarray' objects}
  4402304    2.102    0.000    2.102    0.000 {built-in method numpy.array}
  4672704    1.522    0.000    1.522    0.000 {method 'ravel' of 'numpy.ndarray' objects}
  4672704    1.107    0.000    1.107    0.000 {built-in method numpy.asarray}
   900000    0.772    0.000    0.772    0.000 .venv/lib/python3.12/site-packages/numpy/lib/_arraypad_impl.py:58(_view_roi)
  1800000    0.765    0.000    0.765    0.000 .venv/lib/python3.12/site-packages/numpy/lib/_arraypad_impl.py:33(_slice_at_axis)
  2872704    0.735    0.000    1.021    0.000 .venv/lib/python3.12/site-packages/numpy/linalg/_linalg.py:128(isComplexType)
  5745684    0.706    0.000    0.706    0.000 {built-in method builtins.issubclass}
  3808230    0.697    0.000    0.697    0.000 {built-in method math.sqrt}
  3689370    0.678    0.000    0.678    0.000 {built-in method builtins.isinstance}
   900000    0.649    0.000    1.457    0.000 .venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py:51(_wrapfunc)
  3673390    0.631    0.000    0.631    0.000 {method 'items' of 'dict' objects}
  1800000    0.627    0.000    0.627    0.000 .venv/lib/python3.12/site-packages/numpy/lib/_arraypad_impl.py:109(<genexpr>)
   900000    0.606    0.000    0.606    0.000 {method 'round' of 'numpy.ndarray' objects}
   900000    0.606    0.000    2.063    0.000 .venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py:3360(round)
   900001    0.538    0.000    0.538    0.000 {built-in method numpy.empty}
  1800000    0.531    0.000    0.531    0.000 .venv/lib/python3.12/site-packages/numpy/lib/_arraypad_impl.py:120(<genexpr>)
  3673241    0.480    0.000    0.480    0.000 .venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py:2250(_sum_dispatcher)
   900000    0.425    0.000    0.425    0.000 {method 'astype' of 'numpy.ndarray' objects}
  2355329    0.368    0.000    0.368    0.000 <string>:2(__init__)
  2872704    0.354    0.000    0.354    0.000 .venv/lib/python3.12/site-packages/numpy/linalg/_linalg.py:2562(_norm_dispatcher)
      2/1    0.289    0.145  181.985  181.985 ray_tracer2/camera.py:68(render)
...

That’s over 3 minutes for this tiny thing!

I sorted by tottime, the total time spent in the given function excluding time made in calls to sub-functions. This helps us isolate time spent in the business logic within a function’s scope. As expected, the compute-heavy hit() function is at the top of the list.

The random_unit_vec function is a (distant) second. This may be due to its iterative nature:

def random_unit_vec() -> np.ndarray:
    """
    Generate a random unit vector with (0,0) origin and |vec|=1.
    """
    sample = np.random.rand(3)*2 - 1
    ss = np.sum(np.square(sample))
    while (ss > 1) or (ss < 1e-160):
        sample = np.random.rand(3)*2 - 1
        ss = np.sum(np.square(sample))

    return sample/np.linalg.norm(sample)

We see np.linalg.norm near the top, which is used by hit().

  2872704    7.461    0.000   12.666    0.000 .venv/lib/python3.12/site-packages/numpy/linalg/_linalg.py:2566(norm)

Adding multiprocessing to Python programs is a low-hanging fruit that I love to pick. Paired with tqdm, it provides a great way to speed up embarassingly parallel applications with a nice progress bar.

See commit b55d037 to see a simple example of how to do so.

Of course, this is not for free. The overhead of using multiprocessing even with only 1 process can be non-negligible. That’s why I added a bypass for single-process runs in commit 4660793.

By using 8 workers, the runtime reduces to less than 1 minute! This is still only a 3x improvement, which goes to show the overhead.

$ python3 -m cProfile -o cprofile.log ray-tracer2.py -p 8
...
         741496 function calls (735497 primitive calls) in 50.173 seconds