Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Tensoring the "technical" interview

alt text

Ah yes, Leetcode Easy, where even the brightest of minds have struggled. But you should be able to overcome this, sleep deprived or not, if only with the help of a certain energy drink (which you frequently wish would sponsor you).

You start (as any good piece of code should) with a dash of boilerplate:

import jax.numpy as jnp
from jax.nn import one_hot
from jax import jit, lax

The problem, inspired as it is, compels you to bind the laws of time to the medium of mathematics:

A = jnp.zeros((7, 7, 10), dtype="int32")
A = A.at[:, 1, 0].set(one_hot(2, 7))
A = A.at[:, 1, 1].set(one_hot(2, 7))
A = A.at[:, 1, 2].set(one_hot(3, 7))
for i in range(10):
    A = A.at[:, 2, i].set(one_hot(4, 7))
    A = A.at[:, 5, i].set(one_hot(6, 7))
for i in range(6):
    A = A.at[:, 4, i].set(one_hot(5, 7))
for i in range(4):
    A = A.at[:, 3, i].set(one_hot(4, 7))

The language of time, cryptic to some, but regular (finite too, though it matters little right now).

Conjure an entry point, capturing the essense of the input string, but leave the main code still unwritten:

def solve(inp):
    chrs = jnp.array([ord(inp[i]) for i in (0, 1, 3, 4)], dtype="int32")
    return impl(chrs)

And finish it, only now having to use a single nontrivial control flow construct.

@jit
def impl(chrs):
    vs = lax.map(
        lambda sym: lax.select(
            sym != ord("?"),
            one_hot(sym - ord("0"), 10, dtype="int32"),
            jnp.ones(10, dtype="int32"),
        ),
        chrs,
    )
    return lax.fori_loop(
        0, vs.shape[0], lambda i, w: w @ A @ vs[i], one_hot(1, 7, dtype="int32")
    )[6]

Think of a million parallel executions of the kernel, sharded between thousands of GPUs, all doing precisely what they excel at -- linear algebra.

Surely this is the solution they had in mind?