Saturday, June 8, 2019

Learning middleware patterns in Golang

Middlewares are often applied using a wrapping a method wrapping pattern (chain of responsibility design pattern). To better understand how to do this in golang, I've started with the Handler interface and created a HandlerFunc that mimics the HandlerFunc in https://golang.org/src/net/http/server.go except I've replaced the args with (context string). This allowed me to trace the order that middlewares access the current context.

func (r *Request) Context() context.Context

Some useful links that helped guide me.
https://www.calhoun.io/why-cant-i-pass-this-function-as-an-http-handler/
https://www.alexedwards.net/blog/making-and-using-middleware
https://github.com/justinas/alice Ended with an implementation that resembles this. Thx!

Update July 2019: Middleware args changed to Handler instead of HandlerFunc

package main

import "fmt"

// HandlerFunc, ServeHTTP match the go standard libs except
// (w ResponseWriter, r *Request) has been replaced by (context string).
// We treat this as a buffer that we can read from add values to.
// Analagous to reading GET/POST args from Request and adding
// Information to Request.context()
// https://golang.org/src/net/http/server.go

// This implements Handler interface because it matches signature, meaning it has a
// ServerHTTP method with the same argument types

type Handler interface {
    ServeHTTP(context string)
}

type HandlerFunc func(context string)

func (f HandlerFunc) ServeHTTP(context string) {
    f(context)
}

func baseHandler(h Handler) Handler {
    fmt.Println("Before return baseHandler")
    return HandlerFunc(func(context string) {
        fmt.Println("Before baseHandler")
        context = context + " base"
        h.ServeHTTP(context) // call ServeHTTP on the original handler
        fmt.Println("After baseHandler")
    })
}

func first(h Handler) Handler {
    fmt.Println("Before return first")
    return HandlerFunc(func(context string) {
        fmt.Println("Before first")
        context = context + " first"
        h.ServeHTTP(context) // call ServeHTTP on the original handler
        fmt.Println("After first")
    })
}

func second(h Handler) Handler {
    fmt.Println("Before return second")
    return HandlerFunc(func(context string) {
        fmt.Println("Before second")
        context = context + " second"
        h.ServeHTTP(context) // call ServeHTTP on the original handler
        fmt.Println("After second")
    })
}

func IndexEndPoint(s string) {
    fmt.Println("Index EndPoint: ", s)
}

type Middleware func(Handler) Handler

type MiddlewareStack struct {
    middlewares []Middleware
}

func NewMiddlewareStack(middlewares ...Middleware) MiddlewareStack {
    return MiddlewareStack{middlewares: middlewares}
}

// The middleware wrap pattern eg. second(first(baseHandler(IndexEndPoint))
// means you need to find the deepest method and work backwards -
// baseHandler, then first, then second.
// This implementation stores the middlewares in an array and can mutate the
// values beginning with the lowest to highest index; which has some
// readability benefits.
func (ms *MiddlewareStack) EndPoint(endPoint HandlerFunc) Handler {
    var h Handler

    // first middlware in array can access the context first
    for i := len(ms.middlewares) - 1; i >= 0; i-- {
        mw := ms.middlewares[i]
        // for _, mw := range ms.middlewares {
        if h == nil {
            h = mw(endPoint)
        } else {
            h = mw(h)
        }
    }

    return h
}

func main() {

    // middleware function wrapping
    // Output: Index EndPoint: start second first base
    f := second(first(baseHandler(HandlerFunc(IndexEndPoint))))
    f.ServeHTTP("start")

    /*
        // array of middleware
        // Another version of above, but storing in an array
        middleWares := []MiddleWare{baseHandler, first, second}

        var hFunc HandlerFunc
        for _, mw := range middleWares {
            if hFunc == nil {
                hFunc = mw(IndexEndPoint)
            } else {
                hFunc = mw(hFunc)
            }
        }

        hFunc.ServeHTTP("start")
    */

    // middleware struct
    // Index EndPoint: start base first second
    middlewareStack := NewMiddlewareStack(baseHandler, first, second)

    hFunc := middlewareStack.EndPoint(IndexEndPoint)

    hFunc.ServeHTTP("start")
}