Stacking Lookup Tables in Logos

Logos logo

Lookup what?

Lookup tables are one of the tricks Logos employs to speed up hot loops where all you do is match how many bytes in a sequence match a single pattern. It's a fairly straight forward optimization, pretty easy to grok, but to really appreciate the problem tables can introduce, let's start at the beginning.

Suppose you have a lexer that's trying to match an identifier in some programming language matching regular expression pattern [a-zA-Z0-9_$]+. Rust makes this fairly easy with a match expression:

fn count_matching(source: &[u8]) -> usize {
    source
        .iter()
        .copied()
        .take_while(|&byte| {
            match byte {
                b'a'..=b'z' |
                b'A'..=b'Z' |
                b'0'..=b'9' |
                b'_' | b'$' => true,
                _ => false,
            }
        })
        .count()
}

Easy, right? This is likely good enough for any code you might want to write by hand... but what about code you might not want to write by hand, the kind of code Logos generates for you?

First lets look at the assembly code our loop produces (you can tell this blog post is about to get real):

count_matching:
        xor     eax, eax
        test    rsi, rsi                   # check length
        jne     .LBB0_1                    # jump if length isn't 0
        jmp     .LBB0_6                    # length was 0, jump to end
.LBB0_5:
        add     rax, 1                     # increment index
        cmp     rsi, rax                   # check bounds
        je      .LBB0_6                    # break
.LBB0_1:
        movzx   ecx, byte ptr [rdi + rax]  # load byte from slice
        lea     edx, [rcx - 48]
        cmp     dl, 10
        jb      .LBB0_5                    # loop
        mov     edx, ecx
        and     dl, -33
        add     dl, -65
        cmp     dl, 26
        jb      .LBB0_5                    # loop
        cmp     cl, 95
        je      .LBB0_5                    # loop
        cmp     cl, 36
        je      .LBB0_5                    # loop
.LBB0_6:
        ret

Compiler decided to put the index increment on top since this code has 4 places from which it can loop. Curious, but makes sense. If you have never read assembly before, here is the dumbest, 51% correct heuristic you can use to judge what you are seeing:

4 branches to handle 3 different ranges plus two other values is not bad at all, at some level optimizing compilers are indistinguishable from magic, or so the saying goes. But can we do better than the compiler?

The magic of lookup tables

#[no_mangle]
fn count_matching_lut(source: &[u8]) -> usize {
    static LUT: [bool; 256] = {
        let mut table = [false; 256];
        // We don't have loops in const expressions yet,
        // so I'm just putting the edges of the ranges in here,
        // you get the idea.
        table[b'a' as usize] = true;
        table[b'z' as usize] = true;
        table[b'A' as usize] = true;
        table[b'Z' as usize] = true;
        table[b'0' as usize] = true;
        table[b'9' as usize] = true;
        table[b'_' as usize] = true;
        table[b'$' as usize] = true;
        table
    };

    source
        .iter()
        .copied()
        .take_while(|&byte| {
            LUT[byte as usize]
        })
        .count()
}

So that's it, that's the whole trick. Instead of using a match expression, we do a single indexing operation into a static array [bool; 256]. Since the array has 256 bools in it, and since we are indexing using a byte, the compiler will omit bounds checking here so we don't even need any unsafe. Let's look at assembly to see if this is really an improvement we expect it to be:

count_matching_lut:
        xor     eax, eax
        test    rsi, rsi                   # check length
        je      .LBB0_4                    # jump to end if length is 0
        lea     rcx, [rip + example::count_matching_lut::LUT]
.LBB0_2:
        movzx   edx, byte ptr [rdi + rax]  # load byte from slice
        cmp     byte ptr [rdx + rcx], 0    # table lookup
        je      .LBB0_4                    # break if false
        add     rax, 1
        cmp     rsi, rax
        jne     .LBB0_2                    # loop
.LBB0_4:
        ret

example::count_matching_lut::LUT:
        .asciz  "\000\000\000..." # our 256 bytes of a table

Fewer instructions, fewer branches, win-win! Right? There is that extra memory access to load the table, but since this is a hot loop our odds of the table being in CPU cache are pretty good. But wait, if this is obviously faster way of doing this (it is, but don't just trust me, benchmark it), why hasn't the compiler optimized the match expression into a lookup table?

As with almost all things in computers, lookup tables are a trade-off, they aren't some magical performance for free, they come with a price tag. While we do have fewer instructions to execute, we do have to put extra 256 bytes into the binary, those bytes then have to be loaded into L1 CPU cache to be effective, and L1 cache is tiny! We are talking kilobytes here. Sprinkling 256 byte-long tables across every branch is hardly ideal. Use enough of them, and soon enough our tables will start competing with actual data we are trying to read, we are going to see cache misses, at which point it doesn't matter how many instructions we shaved off.

So here is the problem: Logos loves lookup tables, but can we do something better than naive tables of 256 bools to conserve our precious memory?

Packed bitsets for the rescue?

You my dear clever reader, might have figured out by now that we are using only a single bit of information in every byte of our table. What if instead of 256 bytes we use a table that is only using 256 bits? Let's try!

#[no_mangle]
fn count_matching_bitset(source: &[u8]) -> usize {
    static LUT: [u64; 4] = {
        let mut table = [0; 4];
        // Same as above, we don't bother defining full ranges.
        // You could use some bitset crate, underneath they all
        // do something very similar to this
        table[(b'a' / 64) as usize] |= 1 << (b'a' % 64);
        table[(b'z' / 64) as usize] |= 1 << (b'z' % 64);
        table[(b'A' / 64) as usize] |= 1 << (b'A' % 64);
        table[(b'Z' / 64) as usize] |= 1 << (b'Z' % 64);
        table[(b'0' / 64) as usize] |= 1 << (b'0' % 64);
        table[(b'9' / 64) as usize] |= 1 << (b'9' % 64);
        table[(b'_' / 64) as usize] |= 1 << (b'_' % 64);
        table[(b'$' / 64) as usize] |= 1 << (b'$' % 64);
        table
    };

    source
        .iter()
        .copied()
        .take_while(|&byte| {
            LUT[(byte / 64) as usize] & 1 << (byte % 64) > 0
        })
        .count()
}

Doesn't look too bad. You could do this on any of [u64; 4], [u32; 8], [u16; 16], or [u8; 32]. There isn't much of a difference between all of them, I'm on x86_64 so I'm happy to use all the bits in my registers. Let's check out the output assembly for this:

count_matching_bitset:
        xor     eax, eax
        test    rsi, rsi                   # check length
        je      .LBB0_4                    # jump to end if length is 0
        lea     r8, [rip + example::count_matching_bitset::LUT]
.LBB0_2:
        movzx   ecx, byte ptr [rdi + rax]  # load byte from slice
        mov     rdx, rcx                   # copy the byte to compute index
        shr     rdx, 3                     # compiler optimized division to >> 3
        and     edx, 24                    # mask upper bits since post >>
        mov     rdx, qword ptr [rdx + r8]  # load a 64bit chunk from table
        bt      rdx, rcx                   # bit test into the chunk
        jae     .LBB0_4                    # break if bit is 0
        add     rax, 1
        cmp     rsi, rax
        jne     .LBB0_2                    # loop
.LBB0_4:
        ret

example::count_matching_bitset::LUT:
        .asciz  "\000\000\000..." # 32 bytes, better

Okay okay, this doesn't look too bad. We've added some instructions, we don't have any extra branches, so this should be comparable in performance, right? Unfortunately, this version isn't any faster than the naive one using the match expression, while still having to deal with all the drawbacks of a lookup table (albeit less of a strain on cache). Verdict: ๐Ÿ‘Ž

Creative problem solving at 1am

John Cleese once gave an amazing talk about creativity, and how being creative requires one to enter the "open mode" (seriously, listen to the talk). For me, letting myself contemplate the problem at night as I fall asleep is how I've solved many issues in my code. It's also a good excuse if someone asks me why do I sleep 10 hours a day - I'm not sleeping, I'm being creative!

This was one of those things. I already knew that rustc & LLVM can optimize byte & mask > 0 into a single instruction, so there had to be a way to (ab)use this to make the table use less memory.

And then it hit me.

I was trying to solve the wrong problem! I was trying to figure out how to make a lookup table use less memory, but that was not the problem I had, my problem is about tables, not a table. I was thinking too narrow. Instead of trying to reduce the number of bits and bytes I use per table, what if I just accept having to use 256 bytes, but be smarter about using them?

Since my procedural macro has full control over the code it's producing, I can stall producing code for the tables, and pack 8 separate tables into a single [u8; 256] array. The way this works is kind of like treating each byte like a pixel on the screen, and each bit as a subpixel with a different color. I can have one byte array that contains a red table, and a green table, and a blue table, and a... right, not a perfect metaphor. Anyway, you get the idea.

So I got up at 1am, came to the computer and quickly benchmarked code that looked something like this:

#[no_mangle]
fn count_matching_stacked(source: &[u8]) -> usize {
    static LUT: [u8; 256] = {
        let mut table = [0; 256];

        // Mocking different bits here, all those values
        // will produce non-zero when masked with 4.
        table[b'a' as usize] = 4;
        table[b'z' as usize] = 5;
        table[b'A' as usize] = 6;
        table[b'Z' as usize] = 7;
        table[b'0' as usize] = 4;
        table[b'9' as usize] = 5;
        table[b'_' as usize] = 6;
        table[b'$' as usize] = 7;
        table
    };

    source
        .iter()
        .copied()
        .take_while(|&byte| {
            LUT[byte as usize] & 4 > 0
        })
        .count()
}

Using 4 for the mask is pretty arbitrary here, as long as it's a power of two it's all good. All I wanted to know is whether this is at least not slower than the naive table using bools. First, assembly:

count_matching_stacked:
        xor     eax, eax
        test    rsi, rsi                   # check length
        je      .LBB0_4                    # jump to end if length is 0
        lea     rcx, [rip + example::count_matching_stacked::LUT]
.LBB0_2:
        movzx   edx, byte ptr [rdi + rax]  # load byte from slice
        test    byte ptr [rdx + rcx], 4    # test bit in lookup table
        je      .LBB0_4                    # break if 0
        add     rax, 1
        cmp     rsi, rax
        jne     .LBB0_2                    # loop
.LBB0_4:
        ret

example::count_matching_stacked::LUT:
        .asciz  "\000\000\000..." # 256 bytes again

Ok, this looks very similar. Let's clear the clutter and look at the actual hot part of the loop:

Indexing into [bool; 256]:

        movzx   edx, byte ptr [rdi + rax]  # load byte from slice
        cmp     byte ptr [rdx + rcx], 0    # table lookup
        je      .LBB0_4                    # break if false

Indexing with bit mask on a [u8; 256]:

        movzx   edx, byte ptr [rdi + rax]  # load byte from slice
        test    byte ptr [rdx + rcx], 4    # test bit in the lookup table
        je      .LBB0_4                    # break if 0

BINGO! We turned cmp into test! Both instructions use a table pointer and a constant, so they should be equivalent in speed. I've run the benchmarks, and indeed they were!

Stack them tables

Okay, I haven't made anything worse, but I haven't yet made anything better either. Now we get back to mr. Cleese and enter the "closed mode". Stacking multiple tables into a single byte array ended up being pretty easy, the actual implementation of this in Logos ended up being mere 81 LOC. Here is the actual generated code:

static COMPACT_TABLE_0: [u8; 256] = [
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127,
    127, 127, 127, 127, 127, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255,
    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
    255, 0, 0, 0, 0, 255, 0, 223, 255, 255, 255, 239, 255, 255, 255, 253, 255, 255, 255,
    255, 255, 255, 255, 255, 255, 191, 251, 255, 247, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];

// ...

#[inline]
fn pattern0(byte: u8) -> bool {
    COMPACT_TABLE_0[byte as usize] & 1 > 0
}

// ...

#[inline]
fn pattern4(byte: u8) -> bool {
    COMPACT_TABLE_0[byte as usize] & 2 > 0
}
#[inline]
fn pattern5(byte: u8) -> bool {
    COMPACT_TABLE_0[byte as usize] & 4 > 0
}
#[inline]
fn pattern6(byte: u8) -> bool {
    COMPACT_TABLE_0[byte as usize] & 8 > 0
}
#[inline]
fn pattern7(byte: u8) -> bool {
    COMPACT_TABLE_0[byte as usize] & 16 > 0
}

// ...

#[inline]
fn pattern8(byte: u8) -> bool {
    COMPACT_TABLE_0[byte as usize] & 32 > 0
}
#[inline]
fn pattern9(byte: u8) -> bool {
    COMPACT_TABLE_0[byte as usize] & 64 > 0
}

// ...

#[inline]
fn pattern10(byte: u8) -> bool {
    COMPACT_TABLE_0[byte as usize] & 128 > 0
}

Bam! 8 separate pattern definitions, 1 byte array ๐ŸŽ‰. Performance? Throughput went up by a few %, which makes sense - we've increased cache locality and chances of hitting L1 cache.

I bet someone else did this exact optimization before me, probably sometime in 1973, that's okay. It still feels good to solve a difficult problem.

Now to stack those jump tables...

Grab Logos from crates.io, or look at API documentation on docs.rs. All code syntax highlighting in this blog is done with Logos.