Background
Go supports functions as a first class concept and it's pretty common to pass a function to inject custom behavior. This can have a performance impact.
In the ideal case, like in the example below, the compiler can inline everything, resulting in optimal performance:
1 | package background
|
2 |
|
3 | import "fmt"
|
4 |
|
5 | func Example() {
|
6 | v := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
7 | apply(v, func(n int) int { return n % 2 })
|
8 | fmt.Println(v)
|
9 | // Output:
|
10 | // [1 0 1 0 1 0 1 0 1 0]
|
11 | }
|
12 |
|
13 | func apply(v []int, fn func(n int) int) {
|
14 | for i, n := range v {
|
15 | v[i] = fn(n)
|
16 | }
|
17 | }
|
The loop in apply
is inlined into Example
, as is the function we're passing in as a parameter.
The resulting code in the loop is free of function calls (no CALL
instruction):
...
MOVD ZR, R3
JMP loop_cond
loop:
MOVD (R0)(R3<<3), R4
ADD R4>>63, R4, R5
AND $-2, R5, R5
SUB R5, R4, R4
ADD $1, R3, R5
NOP
MOVD R4, (R0)(R3<<3)
MOVD R5, R3
loop_cond:
CMP $10, R3
BLT loop
...
However, in many cases, where the function that's called can't be inlined because it's too complex (or if we disallow the compiler to inline) Go needs to make an indirect function call from within the loop.
Let's use //go:noinline
to forbid the compiler from inlining apply
and compare the generated
assembly code:
1 | 1 | package background
|
|
2 | 2 |
|
|
3 | 3 | import "fmt"
|
|
4 | 4 |
|
|
5 | 5 | func Example() {
|
|
6 | 6 | v := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
|
7 | 7 | apply(v, func(n int) int { return n % 2 })
|
|
8 | 8 | fmt.Println(v)
|
|
9 | 9 | // Output:
|
|
10 | 10 | // [1 0 1 0 1 0 1 0 1 0]
|
|
11 | 11 | }
|
|
12 | 12 |
|
|
13 | + | //go:noinline
|
|
13 | 14 | func apply(v []int, fn func(n int) int) {
|
|
14 | 15 | for i, n := range v {
|
|
15 | 16 | v[i] = fn(n)
|
|
16 | 17 | }
|
|
17 | 18 | }
|
In this case, the loop is not inlined into Example
and we clearly see an indirect function call.
The call CALL
instruction is the function call, and it's indirect because it's using a register
(R1
) as an argument instead of a label.
...
MOVD R3, fn+24(FP)
MOVD R1, v+8(FP)
MOVD R0, v(FP)
PCDATA $3, $2
MOVD ZR, R2
JMP loop_cond:
loop:
MOVD R2, i-8(SP)
MOVD (R0)(R2<<3), R0
MOVD (R3), R1
MOVD R3, R26
PCDATA $1, $0
CALL (R1)
MOVD i-8(SP), R1
MOVD v(FP), R2
MOVD R0, (R2)(R1<<3)
ADD $1, R1, R1
MOVD R2, R0
MOVD fn+24(FP), R3
MOVD R1, R2
MOVD v+8(FP), R1
loop_cond:
CMP R2, R1
BGT loop
...
In many cases, the difference between the two is unnoticeable. It would definitely not be the right decision to always inline this function, and if the function is truly hot, Go may still inline it when using profile guided optimization.
The Problem
I ran into this while investigating optimization opportunities for znkr.io/diff. I had started out with a generic implementation of the diffing algorithm that works for any type, using a signature like:
func Diff[T any](x, y []T, eq func(a, b T) bool) Edits[T]
However, the implementation of diffing algorithms is rather complex and so the function was never
inlined. After a number of rounds of optimizations, I ended
up with a changed API
and an implementation that collapsed all comparable
diffs to using the algorithm on an int
type1.
func Diff[T comparable](x, y []T) Edits[T]
func DiffFunc[T any](x, y []T, eq func(a, b T) bool) Edits[T]
The problem was that the Diff
function still used an underlying algorithm for any
and supplied
an eq
function that was never inlined.
My hypothesis was that I could improve the performance by making sure that the eq
function was
inlined. However, I didn't know how to do that. I tried to validate the hypothesis with a simple
hack that duplicated the implementation and specialized it by hand. To my surprise, that hack
resulted in a runtime reduction by up to
-40%. So clearly,
that's a worthwhile optimization!
Demo
Unfortunately, the diff implementation is a bit too complicated to be a good example for this blog post. Instead, let's use a very simple (but inefficient) Quick Sort implementation (which incidentally has a similar recursive structure as the diff implementation I am using):
1 | package problem
|
2 |
|
3 | func Sort[T any](v []T, less func(a, b T) bool) {
|
4 | if len(v) <= 1 {
|
5 | return
|
6 | }
|
7 | pivot := v[len(v)-1]
|
8 | i := 0
|
9 | for j := range v {
|
10 | if less(v[j], pivot) {
|
11 | v[i], v[j] = v[j], v[i]
|
12 | i++
|
13 | }
|
14 | }
|
15 | v[i], v[len(v)-1] = v[len(v)-1], v[i]
|
16 | Sort(v[:i], less)
|
17 | Sort(v[i:], less)
|
18 | }
|
Manually specializing this function for int
is straightforward:
1 | 1 | package problem
|
|
2 | 2 |
|
|
3 | - | func Sort[T any](v []T, less func(a, b T) bool) {
|
|
3 | + | func SortInt(v []int) {
|
|
4 | 4 | if len(v) <= 1 {
|
|
5 | 5 | return
|
|
6 | 6 | }
|
|
7 | 7 | pivot := v[len(v)-1]
|
|
8 | 8 | i := 0
|
|
9 | 9 | for j := range v {
|
|
10 | - | if less(v[j], pivot) {
|
|
10 | + | if v[j] < pivot {
|
|
11 | 11 | v[i], v[j] = v[j], v[i]
|
|
12 | 12 | i++
|
|
13 | 13 | }
|
|
14 | 14 | }
|
|
15 | 15 | v[i], v[len(v)-1] = v[len(v)-1], v[i]
|
|
16 | - | Sort(v[:i], less)
|
|
17 | - | Sort(v[i:], less)
|
|
16 | + | SortInt(v[:i])
|
|
17 | + | SortInt(v[i:])
|
|
18 | 18 | }
|
Comparing the runtime of these two algorithms using a simple benchmark (sorting slices with 100 random numbers) on my M1 MacBook Pro also shows a runtime improvement of -40%:
BenchmarkSortInt-10 435818 2658 ns/op
BenchmarkSort-10 272185 4382 ns/op
The "Solution"
I tried a number of ideas but found no way to ensure the passed-in function was inlined. The
alternative of maintaining two copies of the same algorithm didn't sound very appealing. It felt
like I had to either reduce the API surface by removing the DiffFunc
option, take a performance
hit to have both, or maintain two versions of the same algorithm. I really wished for specialization
that would allow me to write the algorithm once and change aspects of it for comparable
types.
I liked none of these options, and I was despairing about which one to pick when it hit me: I could implement specialization myself!
Go has excellent support for refactoring Go code thanks to the go/ast package. We can use that to build a code generator that performs the manual specialization described above automatically:
1 | package main
|
2 |
|
3 | import (
|
4 | "bytes"
|
5 | "fmt"
|
6 | "go/ast"
|
7 | "go/format"
|
8 | "go/parser"
|
9 | "go/token"
|
10 | "os"
|
11 | "slices"
|
12 |
|
13 | "golang.org/x/tools/go/ast/astutil"
|
14 | )
|
15 |
|
16 | func main() {
|
17 | fset := token.NewFileSet()
|
18 | file, err := parser.ParseFile(fset, "sort.go", nil, 0)
|
19 | if err != nil {
|
20 | fmt.Fprintf(os.Stderr, "error parsing input: %v", err)
|
21 | os.Exit(1)
|
22 | }
|
23 |
|
24 | file = astutil.Apply(file, func(c *astutil.Cursor) bool {
|
25 | // Declaration of Sort[T any]
|
26 | fd, ok := c.Node().(*ast.FuncDecl)
|
27 | if !ok || fd.Name.Name != "Sort" {
|
28 | return true
|
29 | }
|
30 |
|
31 | // Rename to SortInt and remove type parameters.
|
32 | fd.Name.Name = "SortInt"
|
33 | fd.Type.TypeParams = nil
|
34 |
|
35 | // Remove `less` from function parameters (SortInt shouldn't have a less parametr)
|
36 | fd.Type.Params.List = slices.DeleteFunc(fd.Type.Params.List, func(f *ast.Field) bool {
|
37 | if len(f.Names) != 1 {
|
38 | return false
|
39 | }
|
40 | return f.Names[0].Name == "less"
|
41 | })
|
42 |
|
43 | // Specialize all type parameters in the parameter list.
|
44 | for _, param := range fd.Type.Params.List {
|
45 | param.Type = astutil.Apply(param.Type, func(c *astutil.Cursor) bool {
|
46 | if ident, ok := c.Node().(*ast.Ident); ok && ident.Name == "T" {
|
47 | ident.Name = "int"
|
48 | return false
|
49 | }
|
50 | return true
|
51 | }, nil).(ast.Expr)
|
52 | }
|
53 |
|
54 | // Specialize body, by replacing Sort invocation with SortInt and less invocations with `<`.
|
55 | fd.Body = astutil.Apply(fd.Body, func(c *astutil.Cursor) bool {
|
56 | call, ok := c.Node().(*ast.CallExpr)
|
57 | if !ok {
|
58 | return true
|
59 | }
|
60 | fun, ok := call.Fun.(*ast.Ident)
|
61 | if !ok {
|
62 | return true
|
63 | }
|
64 | switch fun.Name {
|
65 | case "Sort":
|
66 | fun.Name = "SortInt"
|
67 | call.Args = slices.DeleteFunc(call.Args, func(arg ast.Expr) bool {
|
68 | name, ok := arg.(*ast.Ident)
|
69 | return ok && name.Name == "less"
|
70 | })
|
71 | case "less":
|
72 | if len(call.Args) != 2 {
|
73 | return true
|
74 | }
|
75 | c.Replace(&ast.BinaryExpr{
|
76 | X: call.Args[0],
|
77 | Op: token.LSS,
|
78 | Y: call.Args[1],
|
79 | })
|
80 | }
|
81 | return true
|
82 | }, nil).(*ast.BlockStmt)
|
83 | return false
|
84 | }, nil).(*ast.File)
|
85 |
|
86 | var buf bytes.Buffer
|
87 | if err := format.Node(&buf, fset, file); err != nil {
|
88 | fmt.Fprintf(os.Stderr, "error formatting result: %v", err)
|
89 | os.Exit(1)
|
90 | }
|
91 | if err := os.WriteFile("gen_sort_int.go", buf.Bytes(), 0o644); err != nil {
|
92 | fmt.Fprintf(os.Stderr, "error writing result: %v", err)
|
93 | os.Exit(1)
|
94 | }
|
95 | }
|
This is overkill for a simple function like Sort
, but keep in mind that this is a simple example
by construction. The same principle can be applied to much more complicated functions or even, as
in the case that triggered me to do this, whole types with multiple functions.
It's also fairly trivial to hook the specialization into a unit test to validate that the specialized
version matches the output of the generator and use //go:generate
to regenerate the specialized
version.
The full version of what I ended up using for the diffing algorithm is here.
Conclusion
I found a way to specialize functionality in a way that doesn't require manual maintenance. Is it a good idea, though? I don't really know, but I can drop the specialized version for a performance hit with only a few lines of code. If this turns out to be a bad idea, or if I or someone else finds a better solution, it's quite easily removed or replaced.
-
The details don't matter here, but the idea is to assign a unique integer to every line in both inputs. This happens almost naturally when reducing the problem size by removing every line unique to either input. Both together significantly speed up a number of diffs. ↩︎